In [1]:
import pandas as pd
import torch
import pytorch_lightning as pl
import numpy as np

from torch.utils.data import DataLoader, Dataset

from rxitect.models.pchembl_val_predictor import PChEMBLValueRegressor, get_tokens, identity

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_train = pd.read_csv("../data/processed/ligand_CHEMBL240_train_splityear=2015.csv")
df_test = pd.read_csv("../data/processed/ligand_CHEMBL240_test_splityear=2015.csv")

# df_full = pd.concat([df_train, df_test])
smiles = df_train.smiles
labels = df_train.pchembl_value.astype('float32')

In [3]:
tokens, _, _ = get_tokens(smiles)
tokens = ''.join(tokens) + ' '

In [4]:
reg = PChEMBLValueRegressor(tokens, 0.005)

In [5]:
reg

PChEMBLValueRegressor(
  (embedding): Embedding(
    (embedding): Embedding(42, 128, padding_idx=41)
  )
  (encoder): LSTMEncoder(
    (rnn): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.8)
  )
  (mlp): MLP(
    (input_layer): Linear(in_features=128, out_features=128, bias=True)
    (out_layer): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [6]:
def seq2tensor(seqs, tokens, flip=True):
    tensor = np.zeros((len(seqs), len(seqs[0])))
    for i in range(len(seqs)):
        for j in range(len(seqs[i])):
            if seqs[i][j] in tokens:
                tensor[i, j] = tokens.index(seqs[i][j])
            else:
                tokens = tokens + seqs[i][j]
                tensor[i, j] = tokens.index(seqs[i][j])
    if flip:
        tensor = np.flip(tensor, axis=1).copy()
    return tensor, tokens


def pad_sequences(seqs, max_length=None, pad_symbol=' '):
    if max_length is None:
        max_length = -1
        for seq in seqs:
            max_length = max(max_length, len(seq))
    lengths = []
    for i in range(len(seqs)):
        cur_len = len(seqs[i])
        lengths.append(cur_len)
        seqs[i] = seqs[i] + pad_symbol * (max_length - cur_len)
    return seqs, lengths


def process_smiles(smiles,
                   sanitized=True,
                   target=None,
                   augment=False,
                   pad=True,
                   tokenize=True,
                   tokens=None,
                   flip=False,
                   allowed_tokens=None):
    if not sanitized:
        # clean_smiles, clean_idx = sanitize_smiles(smiles, allowed_tokens=allowed_tokens)
        # clean_smiles = [clean_smiles[i] for i in clean_idx]
        # if target is not None:
        #     target = target[clean_idx]
        pass
    else:
        clean_smiles = smiles

    length = None
    if augment and target is not None:
        # clean_smiles, target = augment_smiles(clean_smiles, target)
        pass
    if pad:
        clean_smiles, length = pad_sequences(clean_smiles)
    tokens, token2idx, num_tokens = get_tokens(clean_smiles, tokens)
    if tokenize:
        clean_smiles, tokens = seq2tensor(clean_smiles, tokens, flip)

    return clean_smiles, target, length, tokens, token2idx, num_tokens

In [7]:
class SmilesDataset(Dataset):
    def __init__(self, smiles, labels, tokens=None, tokenize=False, sanitized=True, return_smiles=False):
        super(SmilesDataset, self).__init__()
        self.tokenize = tokenize
        self.return_smiles = return_smiles
        self.data, self.target, self.length, self.tokens, self.token2idx, self.num_tokens = process_smiles(
            smiles, sanitized=True, target=labels, augment=False, pad=True,
            tokenize=True, tokens=tokens, flip=True)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = {}
        if self.return_smiles:
            sample['object'] = np.array([ord(self.tokens[int(i)]) for i in self.data[index]])
        sample['tokenized_smiles'] = self.data[index]
        sample['length'] = self.length[index]
        if self.target is not None:
            sample['labels'] = self.target[index]
        return sample

In [8]:
train_data = SmilesDataset(smiles[:10], labels[:10])

In [9]:
train_data.__dict__['data'] = torch.from_numpy(train_data.__dict__['data'])

In [10]:
train_data.__dict__['target'] = torch.from_numpy(np.array(train_data.__dict__['target']))

In [11]:
train_loader = DataLoader(train_data,
                         batch_size=128,
                         shuffle=True,
                         num_workers=4,
                         pin_memory=True,
                         sampler=None)

In [12]:
trainer = pl.Trainer(max_epochs=1, accelerator='gpu', devices=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(reg, train_dataloaders=train_loader)

Missing logger folder: /home/catj/thesis/Rxitect/notebooks/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type        | Params
------------------------------------------
0 | embedding | Embedding   | 5.4 K 
1 | encoder   | LSTMEncoder | 264 K 
2 | mlp       | MLP         | 16.6 K
------------------------------------------
286 K     Trainable params
0         Non-trainable params
286 K     Total params
1.145     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|                                                                                                                                              | 0/1 [00:00<?, ?it/s]INP LSTM: [tensor([[ 0.0793,  0.2139,  1.5030,  ..., -0.1169, -0.7543, -0.2788],
        [ 0.0793,  0.2139,  1.5030,  ..., -0.1169, -0.7543, -0.2788],
        [ 0.0793,  0.2139,  1.5030,  ..., -0.1169, -0.7543, -0.2788],
        ...,
        [-0.0879, -0.3917,  0.8128,  ...,  0.7020, -1.1787, -0.4197],
        [ 0.4434, -2.0011, -0.1628,  ..., -1.7191, -0.6452, -0.2746],
        [-1.8020, -0.0064,  0.8562,  ...,  0.7998, -0.8769, -0.6884]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>), tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  7., 14.,  3., 15.,  2., 21., 16., 13., 14., 20., 14., 14.,  7.,
        17.,  3., 18., 12.,  2., 14.,  3., 17.,  2., 21., 16., 13., 14., 20.,
         3.,  7., 22., 22.,  3.,  8., 22., 22.,  3., 15.,  2., 22., 22., 22.,
         8

RuntimeError: Length of all samples has to be greater than 0, but found an element in 'lengths' that is <= 0