In [1]:
%run utils.py
%run model.py

In [2]:
SMILES_PATH = 'data/ADAGRASIB_SMILES.txt'
PATIENCE_THRESHOLD = 4

In [3]:
import rdkit 
from rdkit.Chem import MolFromSmiles as get_mol
from rdkit.Chem import rdDistGeom
import numpy as np 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader, random_split
import multiprocessing
from tqdm.auto import tqdm 
import matplotlib.pyplot as plt

In [4]:
smi_list = get_smi_list(SMILES_PATH)

coor_list = parallel_f(get_atom_pos, smi_list)
longest_coor = len(max(coor_list, key = len))
coor_list = [pad(normalize(c), longest_coor) for c in coor_list]

smi_dic = get_dic(smi_list)
smint_list = [encode_smi(smi, smi_dic) for smi in smi_list]
longest_smint = len(max(smint_list, key = len))
smint_list = [pad_smi(smint, longest_smint, smi_dic) for smint in smint_list]

[23:56:05] UFFTYPER: Unrecognized atom type: Ba (0)


In [5]:
BATCH_SIZE = 128 
dataset = MyDataset(smint_list, coor_list)
train_set, val_set, test_set = random_split(dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True)

In [7]:
def train(model, num_epoch, lr) :
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.L1Loss() 
    train_plot, val_plot = [], []
    best_val = float('inf')
    patience = 0 

    for epoch in tqdm(range(1, num_epoch + 1), total = num_epoch) :
        model.train() 

        train_loss, val_loss = 0, 0

        # TRAIN
        for input, target in train_loader : 
            prediction, _, _ = model(input) 

            loss = loss_fn(prediction, target)
            loss.backward() 
            optim.step() 
            optim.zero_grad()

            train_loss += loss.item()
            
        # VALIDATE
        model.eval()
        with torch.no_grad() :
            for input, target in val_loader :
                prediction, _, _ = model(input) 
                
                loss = loss_fn(prediction, target) 
                val_loss += loss.item() 
        
        train_loss = train_loss / len(train_loader)
        val_loss = val_loss / len(val_loader)

        print(f'Train loss: {train_loss:.4f} --- Validate loss: {val_loss:.4f}')
        
        # EARLY STOPPING
        if val_loss < best_val :
            best_val = val_loss 
            patience = 0 
        else :
            patience += 1 
        
        if patience > PATIENCE_THRESHOLD : 
            print("EARLY STOPPING !!!")
            plt.plot(x, train_plot, color = 'blue', label = 'Train Loss')
            plt.plot(x, val_plot, color = 'red', label = 'Validation Loss')
            plt.title("Final Plot Before Loss")
            plt.legend()
            plt.show()
            break

        train_plot.append(train_loss), val_plot.append(val_loss)
        x = np.linspace(0, num_epoch, epoch)
        if epoch % 5 == 0 :
            plt.plot(x, train_plot, color = 'blue', label = 'Train Loss')
            plt.plot(x, val_plot, color = 'red', label = 'Validation Loss')
            plt.title(f'Epoch {epoch}')
            plt.legend()
            plt.show()

In [8]:
model = Model(128, 1, 4, 0.5, longest_coor, smi_dic).to(device)

In [9]:
train(model, num_epoch=100,lr=0.001)

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([128, 36, 128])


../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [179,0,0], 

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 128 n 4608 k 128 mat1_ld 128 mat2_ld 128 result_ld 128 abcType 0 computeType 68 scaleType 0