In [1]:
import torch
import os
os.chdir(os.pardir)
from models import egnn
from builder import build_model, build_dataloader, build_optimizer, build_scheduler

In [2]:
from utils import config_util
from checkpointer import Checkpointer
configs = config_util.load_configs('configs/egnn_denv2_finetune.yml')


model = build_model(
    configs["model"],
    configs["data"]
)

mode = 'test'
default_name = 'model'

# model.load_state_dict(configs['model']['ckpt_path'])
module = model


checkpointables = {}
optimizer = build_optimizer(
                configs["train"]["optimizer"], module
            )
scheduler = build_scheduler(
    configs["train"]["scheduler"], 
    optimizer
)

checkpointer = Checkpointer(
    mode, model, '/p/lustre1/ranganath2/fast.tmp/test/', default_name, **checkpointables
)

checkpointables = {
    "optimizer": optimizer,
    "scheduler": scheduler,
}

checkpoint = checkpointer.resume_or_load(path=configs["model"]["ckpt_path"], resume=True)



[37m[01m[2024-12-16 00:08:44,543] Loading from /p/lustre1/ranganath2/fast.tmp/PDBBIND2020/EGNN-20240819-174309-871564-clrte403/best_model.pth...[0m


In [3]:
denv2 = build_dataloader(
    configs=configs
)

[36m[01m[2024-12-16 00:08:48,938] Begin building dataset[0m


In [4]:
pdbbind_configs = config_util.load_configs("configs/egcnn_2020.yaml")
pdbbind = build_dataloader(
    configs=pdbbind_configs
)

[36m[01m[2024-12-12 15:48:22,134] Begin building dataset[0m


In [5]:
pdbbind_dataset = pdbbind.get('val')
denv2_dataset = denv2.get('val')

In [6]:
denv2_batch = next(iter(denv2_dataset))
pdb_batch = next(iter(pdbbind_dataset))
denv2_batch = [point['data'] for point in denv2_batch]
pdb_batch = [point['data'] for point in pdb_batch]
loss1 = model(denv2_batch)
loss2 = model(pdb_batch)

In [7]:
def _gather_flat_grad(model):
    views = []
    for p in model.parameters():
        p.retain_grad()
        if p.grad is None:
            view = p.new(p.numel()).zero_()
        elif p.grad.is_sparse:
            view = p.grad.to_dense().view(-1)
        else:
            view = p.grad.view(-1)
        views.append(view)
    return torch.cat(views, axis=0)

In [8]:
print(loss1)
model.zero_grad()
loss1['mse'].backward()
flat_grad = _gather_flat_grad(model)
print(flat_grad)

{'mse': tensor(28.6198, grad_fn=<MulBackward0>)}
tensor([ 7.9111e-01,  0.0000e+00,  5.8168e-01,  ...,  2.9554e+02,
         6.5761e+02, -8.1290e+02])


In [9]:
print(torch.norm(flat_grad))

tensor(5086.6436)


In [10]:
print(loss2)
model.zero_grad()
loss2['mse'].backward()
flat_grad_2 = _gather_flat_grad(model)
print(flat_grad_2)
print(torch.norm(flat_grad_2))

{'mse': tensor(56.0626, grad_fn=<MulBackward0>)}
tensor([ 13.0067,   0.0000,  11.2309,  ..., -40.0471, -89.1088, 110.1522])
tensor(1880.5219)


In [11]:
print(flat_grad.dot(flat_grad_2)/(flat_grad.norm()* flat_grad_2.norm()))

tensor(-0.3576)


In [None]:
# Dengue - denv2  denv3  denv4 
# pdbbind - x >> 0 proteins
# gradient similarity scoring - x \times denv2 and x \times denv3 x \times denv4 (ligands ?) 
# Tanimoto (SMILES LIGAND SIDE)+ esm3 sequence similarity (Sequence Protein side)


# data1 = Extract denv2 - pdbbind similarity with sequences (ESM3)
# From data1 data2 = Extract denv2 - pdbbind similar ligands (Tanimoto) 

# Repeat above with westnile, 
