In [1]:
import sys
sys.path.append('../')
sys.path.insert(0, '../terrace')

%load_ext autoreload
%autoreload 2

In [3]:
from tqdm import tqdm
from copy import deepcopy
import torch
import torch.nn as nn
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from common.cfg_utils import get_config
from common.losses import *
from datasets.make_dataset import *
from models.gnn_bind import *
from models.learnable_ff import LearnableFF
from routines.ai_routine import *
from terrace.batch import *
from terrace.type_data import *
from terrace.comp_node import *

In [2]:
cfg = get_config("../configs", "default")
cfg.data.sna_frac = 1
cfg.data.cache = False
loader = make_dataloader(cfg, "val")
in_node = Input(loader.get_type_data())

In [6]:
nn.ModuleDict({"train_metric": nn.Linear(10, 10)})

ModuleDict(
  (train_metric): Linear(in_features=10, out_features=10, bias=True)
)

In [11]:
dataset = make_dataset(cfg, "val")

In [12]:
sum(dataset.activities.active)/len(dataset.activities)

0.40773086223719857

In [12]:
torch.tensor(energies).var()

tensor(3721.9956)

In [16]:
from datasets.vina_score import VinaScoreDataset
dataset = VinaScoreDataset(cfg, "train")

In [65]:
(torch.randn((3,))*2).tolist()

[5.288211345672607, 2.007632255554199, 1.0654451847076416]

In [50]:
lig_file = dataset.lig_files[0]
lig = next(Chem.SDMolSupplier(lig_file, sanitize=True))
conf = lig.GetConformer()
lig_center = Chem.rdMolTransforms.ComputeCentroid(conf)
new_pos = lig_center + conf.GetAtomPosition(0)
print(conf.GetAtomPosition(0).x)
conf.SetAtomPosition(0, new_pos)
print(conf.GetAtomPosition(0).x)

-6.0803
-5.756255


In [14]:
lig = Chem.MolFromSmiles(smiles)

In [3]:
batch = next(iter(loader))

In [21]:
cfg = get_config("../configs", "outer_prod_gnn")
# cfg.model.type
model = make_model(cfg, in_node)
y_pred = model(batch)
y_pred

torch.Size([2, 4096])


tensor([0.5330, 0.2485], grad_fn=<SqueezeBackward1>)

In [8]:
cfg = get_config("../configs", "learnable_ff")
model = LearnableFF(cfg, in_node)
y_pred = model(batch)

In [10]:
get_losses(cfg, batch, y_pred)

(tensor(1031.6157, grad_fn=<AddBackward0>),
 {'coord_mse': tensor(1031.6157, grad_fn=<MeanBackward0>)})

In [79]:
torch.einsum('lf,rf->lr', lig_feat, rec_feat).shape

torch.Size([63, 54])

In [80]:
lig_feat.device

device(type='cpu')

In [77]:
rot, _ = torch.linalg.qr(torch.randn((3,3)))
trans = torch.randn(3,)
# torch.linalg.det(rand_rot)
# torch.einsum("ij,jk->ik")
coord = batch.lig.ndata.coord
trans_coord = torch.einsum('ij,bj->bi',rot,coord) + trans

In [78]:
batch_t = deepcopy(batch)
batch_t.lig.ndata.coord = trans_coord

In [43]:
(lig_coord - lig_coord.mean(0)).mean(0)

tensor([ 1.4381e-07, -4.5413e-08,  0.0000e+00])

In [39]:
lig_coord = batch.lig.ndata.coord
rec_coord = batch.rec.ndata.coord
dists = torch.zeros((lig_coord.shape[0], rec_coord.shape[0]))
for i, lc in enumerate(lig_coord):
    for j, rc in enumerate(rec_coord):
        dists[i,j] = torch.linalg.norm(lc - rc)
dists

tensor([[24.9056, 26.1221, 32.7037,  ..., 59.9716, 57.3830, 72.9151],
        [25.5268, 26.5024, 33.3077,  ..., 58.4674, 55.8833, 71.4241],
        [26.6782, 27.3604, 34.4925,  ..., 57.7209, 55.0941, 70.8719],
        ...,
        [34.9228, 35.8056, 41.8239,  ..., 46.7359, 44.6136, 58.4612],
        [35.0984, 35.6710, 42.1213,  ..., 46.3110, 44.0902, 58.4083],
        [30.1799, 31.6070, 37.5665,  ..., 54.0163, 51.6679, 66.3070]])

In [97]:
mask = torch.zeros((lig_coord.shape[0], rec_coord.shape[0]))
tot_rec = 0
tot_lig = 0
rec_graph = batch.rec.dgl_batch
lig_graph = batch.lig.dgl_batch
for r, l in zip(rec_graph.batch_num_nodes(), lig_graph.batch_num_nodes()):
    mask[tot_lig:tot_lig+l,tot_rec:tot_rec+r] = 1.0
    tot_rec += r
    tot_lig += l
mask

tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.]])

In [108]:
(1/(torch.cdist(lig_coord, rec_coord))**2)*mask

tensor([[0.0016, 0.0015, 0.0009,  ..., 0.0000, 0.0000, 0.0000],
        [0.0015, 0.0014, 0.0009,  ..., 0.0000, 0.0000, 0.0000],
        [0.0014, 0.0013, 0.0008,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0005, 0.0005, 0.0003],
        [0.0000, 0.0000, 0.0000,  ..., 0.0005, 0.0005, 0.0003],
        [0.0000, 0.0000, 0.0000,  ..., 0.0003, 0.0004, 0.0002]])

In [83]:
# from datasets.graphs.plot_graph import *
# plot_graph(batch_t.lig[0])

In [23]:
t1 = torch.randn((10, 100))
t2 = torch.randn((10, 50))
torch.einsum('bi,bj->bij', t1, t2).reshape((t1.size(0), -1)).shape

torch.Size([10, 5000])