In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import rdkit
import numpy as np
import pandas as pd
import torch
from torch import nn
from pathlib import Path

## Get molecular transformer embeddings

In [None]:
## Get transformer embeddings
train_emb = torch.load("data/train_sm.dat")
val_emb = torch.load("data/val_sm.dat")
train_data = pd.read_csv('data/train.csv')
val_data = pd.read_csv('data/val.csv')

mol_to_emb = {}

sets = [(train_emb, train_data["Drug"]), (val_emb, val_data["Drug"])]
for embs, datas in sets:
    for emb, drug in zip(embs, datas):
        mol_to_emb[drug] = emb

## Preprocess Data

In [None]:
import numpy as np
from rdkit import Chem
from rdkit.Chem import DataStructs
from rdkit.Chem import rdMolDescriptors as rdmd


def smiles_to_mols(smiles):
    return [Chem.MolFromSmiles(smi) for smi in smiles]

def get_representations(mols, smiles):
    descriptors = []
    for mol, smile in zip(mols, smiles):
        molecular_weight = rdmd.CalcExactMolWt(mol)
        logp = rdmd.CalcCrippenDescriptors(mol)[0]
        arr = np.array([molecular_weight, logp])
        fp = mol_to_emb[smile]
        arr = np.concatenate((arr, fp))
        descriptors.append(arr)
    return np.stack(descriptors, axis=0)


def get_reps_from_smiles(smiles):
    mols = smiles_to_mols(smiles)
    descriptors = get_representations(mols, smiles)
    return descriptors


X_train = get_reps_from_smiles(train_data.Drug)
y_train = train_data.Y.values

X_val = get_reps_from_smiles(val_data.Drug)
y_val = val_data.Y.values

In [None]:
import os
from pathlib import Path

def save_data(path, x, y):
    x = get_reps_from_smiles(x)
    y_unsq = torch.tensor(y, dtype = torch.float).unsqueeze(1)
    together = torch.cat([torch.tensor(x, dtype = torch.float), y_unsq], dim = 1)
    if not path.parent.is_dir():
        os.mkdir(path.parent)
    torch.save(together, path)

dir_name = f"trnsfm"
save_data(Path(f"data/{dir_name}/val.dat"), val_data.Drug, val_data.Y.values)
save_data(Path(f"data/{dir_name}/train.dat"), train_data.Drug, train_data.Y.values)

## Run Hyperparameter Search

In [None]:
from experiment import search
search(num_samples = 10,
       max_num_epochs = 100,
       gpus_per_trial = 1,
       name = "mol_prop_pred")

## Train Model

In [None]:
config = {
  "run_name": "linear",
  "init_learning_rate": .1,
  "lr_step_interval": 35,
  "n_epochs": 500,
  "epoch_log_interval": 1,
  "batch_log_interval": float('inf'),
  "val_interval": 1,
  "save_interval": 100,
  "batch_size": 32,
  "fp_type": "trnsfm",
  "dim_seq": (256, 64),
  "dropout_pair": (0, .5)
}

from experiment import experiment
experiment(config, checkpoint_dir = "checkpoints", data_dir = "data")

## Test Model

In [None]:
from core.model import PermPredictor
from core.dataset import MolData


model = PermPredictor(514, (256, 64), (0, .5))
checkpoint = torch.load("checkpoints/best.chkp")
model.load_state_dict(checkpoint["model_state_dict"])

dataset = MolData("data/trnsfm/val.dat")
dataloader = torch.utils.data.DataLoader(dataset, shuffle = False, batch_size = len(dataset))
inputs, targets = next(iter(dataloader))

model.eval()
with torch.no_grad():
    outputs = model(inputs).numpy()

print(f"Validation R2: {metrics.r2_score(targets, outputs)}")