In [None]:
from omtra.models.ligand_encoder.vq import LigandVQVAE
from omtra.load.quick import load_cfg, datamodule_from_config

import pytorch_lightning as pl
import dgl

from torch.utils.data import DataLoader, Dataset, Subset

import os
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt

In [4]:
class PharmitWrapperDataset(Dataset):
    def __init__(self, pharmit_dataset, graph_type):
        self.base_dataset = pharmit_dataset
        self.graph_type = graph_type
        self.length = len(pharmit_dataset)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.base_dataset[(self.graph_type, idx)]

def collate_fn(batch):
    return dgl.batch(batch)

In [6]:
pharmit_path = '/net/galaxy/home/koes/icd3/moldiff/OMTRA/data/pharmit_dev'

overrides = [
f"pharmit_path={pharmit_path}",
"task_group=no_protein"
]

cfg = load_cfg(overrides=overrides)
datamodule = datamodule_from_config(cfg)

train_dataset = datamodule.load_dataset("train")
pharmit_dataset = train_dataset.datasets['pharmit']

âš› Instantiating datamodule <omtra.dataset.data_module.MultiTaskDataModule>


In [7]:
wrapped_dataset = PharmitWrapperDataset(pharmit_dataset, 'denovo_ligand')
subset_dataset = Subset(wrapped_dataset, indices=list(range(1000)))

batch_size = 100
training_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
model = LigandVQVAE(
                    a_embed_dim=16,
                    c_embed_dim=8,
                    e_embed_dim=8,
                    scalar_size=128,  
                    vector_size=4,
                    num_gvp_layers= 2,
                    latent_dim=8,     
                    num_embeddings= 100, 
                    num_decod_hiddens=128, 
                    num_bond_decod_hiddens= 128, 
                    commitment_cost= 0.25)
    
trainer = pl.Trainer(max_epochs=5,
                        log_every_n_steps=1,
                        enable_progress_bar=True,
                        accelerator='auto',    
                        devices=1)

trainer.fit(model, train_dataloaders=training_loader)



MisconfigurationException: No supported gpu backend found!

In [None]:
log_base = "lightning_logs"
latest_run = sorted(os.listdir(log_base))[-1]
log_path = os.path.join(log_base, latest_run)

# Find the actual event file
event_file = [f for f in os.listdir(log_path) if f.startswith("events.out")][0]
event_path = os.path.join(log_path, event_file)

ea = event_accumulator.EventAccumulator(event_path)
ea.Reload()

# Check what metrics were logged
ea.Tags()['scalars']

In [None]:
def plot_metric(tag):
    events = ea.Scalars(tag)
    steps = [e.step for e in events]
    values = [e.value for e in events]

    plt.plot(steps, values, label=tag)
    plt.xlabel('Training Step')
    plt.ylabel('Value')
    plt.title(f'{tag} over Steps')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
plot_metric("train_total_loss")