# Extra Assignment

In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors
from sklearn.model_selection import train_test_split
import jax
import jax.numpy as jnp
import optax

In [16]:
df_full = pd.read_csv('250k_rndm_zinc_drugs_clean_3.csv')
df = df_full.iloc[:10000]
smiles = df['smiles'].tolist()

In [3]:
tokenizer = Tokenizer(models.WordLevel(unk_token='[UNK]'))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
trainer = trainers.WordLevelTrainer(vocab_size=1000,
    special_tokens=['[PAD]','[UNK]','[CLS]','[SEP]','[MASK]'])
tokenizer.train_from_iterator(smiles, trainer=trainer)
tokenizer.post_processor = processors.TemplateProcessing(
    '[CLS] $A [SEP]', special_tokens=[('[CLS]',1),('[SEP]',2)])
tokenizer.enable_truncation(max_length=64)

In [4]:
smiles0 = smiles[2100]
print(smiles0)
encoding0 = tokenizer.encode(smiles0)
print(encoding0.tokens)
print(encoding0.ids)
print(encoding0.attention_mask)

Cc1occc1C(=O)NC1CCN(C(=O)C(=O)Nc2ccc(F)cc2F)CC1

['[CLS]', 'ĠCc', '1', 'occc', '1', 'C', '(=', 'O', ')', 'NC', '1', 'CCN', '(', 'C', '(=', 'O', ')', 'C', '(=', 'O', ')', 'Nc', '2', 'ccc', '(', 'F', ')', 'cc', '2', 'F', ')', 'CC', '1', 'Ċ', '[SEP]']
[1, 32, 7, 216, 7, 6, 12, 10, 5, 29, 7, 52, 8, 6, 12, 10, 5, 6, 12, 10, 5, 35, 9, 18, 8, 26, 5, 17, 9, 26, 5, 23, 7, 13, 2]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [5]:
enc = {'input_ids':[], 'attention_mask':[]}
for e in tokenizer.encode_batch(smiles):
    enc['input_ids'].append(e.ids)
    enc['attention_mask'].append(e.attention_mask)

ids = enc['input_ids']
masks = enc['attention_mask']

ids_train, ids_test, masks_train, masks_test = train_test_split(
    ids, masks, test_size=0.1, random_state=42
)

train_enc = {'input_ids': ids_train, 'attention_mask': masks_train}
test_enc  = {'input_ids': ids_test,  'attention_mask': masks_test}

In [6]:
class SMILESDataset(Dataset):
    def __init__(self, enc):
        self.ids = enc['input_ids']
        self.mask = enc['attention_mask']
    def __len__(self):
        return len(self.ids)
    def __getitem__(self, i):
        return {
            'input_ids': torch.tensor(self.ids[i], dtype=torch.long),
            'attention_mask': torch.tensor(self.mask[i], dtype=torch.long)
        }

In [7]:
def collate_fn(batch):
    ids = [item['input_ids'] for item in batch]
    masks = [item['attention_mask'] for item in batch]
    pad_id = tokenizer.token_to_id('[PAD]')
    ids_padded = pad_sequence(ids, batch_first=True, padding_value=pad_id)
    masks_padded = pad_sequence(masks, batch_first=True, padding_value=0)
    return {'input_ids': ids_padded, 'attention_mask': masks_padded}

In [8]:
train_loader = DataLoader(
    SMILESDataset(train_enc),
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    SMILESDataset(test_enc),
    batch_size=32,
    collate_fn=collate_fn
)

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, mlp_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(nn.Linear(d_model, mlp_dim),
                                 nn.GELU(),
                                 nn.Linear(mlp_dim, d_model))
        self.ln2 = nn.LayerNorm(d_model)
    def forward(self, x, mask):
        attn_out, _ = self.attn(x, x, x, key_padding_mask=mask==0)
        x = x + attn_out
        y = self.ln1(x)
        x = x + self.mlp(y)
        return self.ln2(x)

In [10]:
class MaskedTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=4,
                 mlp_dim=128, num_layers=3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, mlp_dim)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
    def forward(self, input_ids, mask):
        x = self.embed(input_ids).transpose(0, 1)
        for layer in self.layers:
            x = layer(x, mask)
        return self.fc(x).transpose(0, 1)

In [11]:
def mask_inputs(ids, mask_token_id, mask_prob=0.15):
    labels = ids.clone()
    rand = torch.rand(ids.shape)
    mask_positions = rand < mask_prob
    labels[~mask_positions] = -100
    ids[mask_positions] = mask_token_id
    return ids, labels

In [12]:
vocab_size = tokenizer.get_vocab_size()
model = MaskedTransformer(vocab_size)
opt = optim.Adam(model.parameters(), lr=1e-4)
crit = nn.CrossEntropyLoss(ignore_index=-100)

In [24]:
for epoch in range(10):
    for batch in train_loader:
        ids, labels = mask_inputs(batch['input_ids'], tokenizer.token_to_id('[MASK]'))
        masks = batch['attention_mask']
        opt.zero_grad()
        logits = model(ids, masks)
        loss = crit(logits.reshape(-1, vocab_size), labels.reshape(-1))
        loss.backward()
        opt.step()
        print(f"\t batch, loss = {loss:.3f}")
    print(f"Epoch = {epoch}")

	 batch, loss = 3.855
	 batch, loss = 3.412
	 batch, loss = 2.942
	 batch, loss = 4.062
	 batch, loss = 3.990
	 batch, loss = 3.871
	 batch, loss = 3.676
	 batch, loss = 3.735
	 batch, loss = 3.891
	 batch, loss = 3.903
	 batch, loss = 4.046
	 batch, loss = 4.202
	 batch, loss = 3.666
	 batch, loss = 3.925
	 batch, loss = 3.720
	 batch, loss = 3.430
	 batch, loss = 3.378
	 batch, loss = 3.875
	 batch, loss = 4.021
	 batch, loss = 4.050
	 batch, loss = 3.945
	 batch, loss = 3.628
	 batch, loss = 3.673
	 batch, loss = 3.596
	 batch, loss = 3.416
	 batch, loss = 3.476
	 batch, loss = 3.767
	 batch, loss = 3.784
	 batch, loss = 3.620
	 batch, loss = 3.937
	 batch, loss = 3.806
	 batch, loss = 3.641
	 batch, loss = 3.839
	 batch, loss = 3.933
	 batch, loss = 3.515
	 batch, loss = 4.039
	 batch, loss = 3.707
	 batch, loss = 3.296
	 batch, loss = 4.052
	 batch, loss = 3.856
	 batch, loss = 3.506
	 batch, loss = 3.747
	 batch, loss = 3.718
	 batch, loss = 3.323
	 batch, loss = 3.359
	 batch, l

### Mini Example

In [None]:
smiles_orig = df_full.iloc[10001]['smiles']
print("original:",smiles_orig)

encoding = tokenizer.encode(smiles_orig)
input_ids = torch.tensor(encoding.ids).unsqueeze(0)           # shape [1, L]
attention_mask = torch.tensor(encoding.attention_mask).unsqueeze(0)
mask_id = tokenizer.token_to_id('[MASK]')
ids_masked, _ = mask_inputs(input_ids.clone(), mask_id, mask_prob=0.15)
model.eval()
with torch.no_grad():
    logits = model(ids_masked, attention_mask)               # [1, L, V]
    preds = logits.argmax(dim=-1).squeeze(0).tolist()         # [L]

original: CC(C)c1cc(C(=O)Nc2ccc(C[NH+]3CCCC3)cc2)n[nH]1



In [None]:
tokens_pred = [tokenizer.id_to_token(i) for i in preds]
# strip off [CLS]/[SEP] and join
smiles_pred = ''.join(tok for tok in tokens_pred if tok not in ('[CLS]','[SEP]','[PAD]'))

print("Original: ", smiles_orig)
print("Masked  : ", ''.join(
    tokenizer.id_to_token(i) if i!=mask_id else '[MASK]'
    for i in ids_masked.squeeze(0).tolist()))
print("Predicted:", smiles_pred)

Original:  CC(C)c1cc(C(=O)Nc2ccc(C[NH+]3CCCC3)cc2)n[nH]1

Masked  :  [UNK]ĠCC(C)c1cc(C(=[MASK][MASK][MASK]2ccc(C[[MASK]+]3[MASK]3)cc2)n[[MASK]]1Ċ[CLS]
Predicted: COcCCOc+])(COcCĠCCOCcC(


In [None]:
torch.save(model.state_dict(), 'masked_transformer_weights.pth')

model = MaskedTransformer(vocab_size)
model.load_state_dict(torch.load('masked_transformer_weights.pth'))
model.eval()