In [1]:
from gninatorch.gnina import setup_gnina_model
import molgrid
import torch
import json
import glob

The `gninatorch` directory referenced below can be found here: [gnina-torch](https://github.com/RMeli/gnina-torch) .

In [2]:
ls gninatorch/weights

crossdock_default2018_1.pt  dense_2.pt                general_default2018_4.pt
crossdock_default2018_2.pt  dense_3.pt                general_default2018.pt
crossdock_default2018_3.pt  dense_4.pt                redock_default2018_1.pt
crossdock_default2018_4.pt  dense.pt                  redock_default2018_2.pt
crossdock_default2018.pt    general_default2018_1.pt  redock_default2018_3.pt
default2017.pt              general_default2018_2.pt  redock_default2018_4.pt
dense_1.pt                  general_default2018_3.pt  redock_default2018.pt


In [1]:
d = {
    'resolution': 0.5,
    'dimension' : 23.5,
    'recmap' : '''AliphaticCarbonXSHydrophobe 
AliphaticCarbonXSNonHydrophobe 
AromaticCarbonXSHydrophobe 
AromaticCarbonXSNonHydrophobe
Bromine Iodine Chlorine Fluorine
Nitrogen NitrogenXSAcceptor 
NitrogenXSDonor NitrogenXSDonorAcceptor
Oxygen OxygenXSAcceptor 
OxygenXSDonorAcceptor OxygenXSDonor
Sulfur SulfurAcceptor
Phosphorus 
Calcium
Zinc
GenericMetal Boron Manganese Magnesium Iron''',
    
'ligmap': '''AliphaticCarbonXSHydrophobe 
AliphaticCarbonXSNonHydrophobe 
AromaticCarbonXSHydrophobe 
AromaticCarbonXSNonHydrophobe
Bromine Iodine
Chlorine
Fluorine
Nitrogen NitrogenXSAcceptor 
NitrogenXSDonor NitrogenXSDonorAcceptor
Oxygen OxygenXSAcceptor 
OxygenXSDonorAcceptor OxygenXSDonor
Sulfur SulfurAcceptor
Phosphorus
GenericMetal Boron Manganese Magnesium Zinc Calcium Iron'''
}

old = {
        'resolution': 0.5,
    'dimension' : 23.5,
    'recmap' : '''AliphaticCarbonXSHydrophobe
AliphaticCarbonXSNonHydrophobe
AromaticCarbonXSHydrophobe
AromaticCarbonXSNonHydrophobe
Calcium
Iron
Magnesium
Nitrogen
NitrogenXSAcceptor
NitrogenXSDonor
NitrogenXSDonorAcceptor
OxygenXSAcceptor
OxygenXSDonorAcceptor
Phosphorus
Sulfur
Zinc''',
'ligmap':'''AliphaticCarbonXSHydrophobe
AliphaticCarbonXSNonHydrophobe
AromaticCarbonXSHydrophobe
AromaticCarbonXSNonHydrophobe
Bromine
Chlorine
Fluorine
Nitrogen
NitrogenXSAcceptor
NitrogenXSDonor
NitrogenXSDonorAcceptor
Oxygen
OxygenXSAcceptor
OxygenXSDonorAcceptor
Phosphorus
Sulfur
SulfurAcceptor
Iodine
Boron'''
}

In [4]:
for fname in glob.glob('gninatorch/weights/*.pt'):
    newf = fname.split('/')[-1]
    prefix = newf[:-3]
    
    model = setup_gnina_model(prefix)[0]
    model.eval()
    if 'default2017' in fname:
        extra = {'metadata':json.dumps(old)}
        z = torch.zeros((1,35,48,48,48))
    else:
        extra = {'metadata':json.dumps(d)}
        z = torch.zeros((1,28,48,48,48))
        
    script = torch.jit.trace(model, z)
    script.save(newf,_extra_files=extra)    

In [332]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Overlap(nn.Module):
    '''Compute overlap of single channel ligand and receptor'''
    def __init__(self):
        super().__init__()

    def forward(xelf, x):
        lig = x[:,1,:,:,:]
        rec = x[:,0,:,:,:]
        prot = rec * lig
        ave = F.avg_pool3d(prot,48).flatten(1)
        cave = torch.hstack([torch.zeros_like(ave),ave])
        cave = -torch.log(cave)
        cave = torch.nan_to_num(cave,posinf=0)
        return (cave,torch.zeros((1,1),device=x.device))

In [333]:
x = torch.rand((1,2,48,48,48),requires_grad=True)

In [334]:
x.requires_grad = True

In [335]:
omodel = Overlap()

In [336]:
single = {
    'resolution': 0.5,
    'dimension' : 23.5,
    'skip_loss' : True,
    'recmap' : '''AliphaticCarbonXSHydrophobe AliphaticCarbonXSNonHydrophobe AromaticCarbonXSHydrophobe AromaticCarbonXSNonHydrophobe Bromine Iodine Chlorine Fluorine Nitrogen NitrogenXSAcceptor NitrogenXSDonor NitrogenXSDonorAcceptor Oxygen OxygenXSAcceptor OxygenXSDonorAcceptor OxygenXSDonor Sulfur SulfurAcceptor Phosphorus GenericMetal Boron Manganese Magnesium Zinc Calcium Iron''',
    'ligmap': '''AliphaticCarbonXSHydrophobe AliphaticCarbonXSNonHydrophobe AromaticCarbonXSHydrophobe AromaticCarbonXSNonHydrophobe Bromine Iodine Chlorine Fluorine Nitrogen NitrogenXSAcceptor NitrogenXSDonor NitrogenXSDonorAcceptor Oxygen OxygenXSAcceptor OxygenXSDonorAcceptor OxygenXSDonor Sulfur SulfurAcceptor Phosphorus GenericMetal Boron Manganese Magnesium Zinc Calcium Iron'''
}

In [337]:
oscript = torch.jit.trace(omodel, x)

In [338]:
oscript.save('overlap.pt',_extra_files={'metadata':json.dumps(single)})