This implementation of the transformer in this model is based on the tutorial given at https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

In [3]:
import pandas as pd
from rdkit import Chem
import os
from deepchem.feat.smiles_tokenizer import SmilesTokenizer
import torch
import numpy as np
from itertools import zip_longest
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import MSELoss
from sklearn.model_selection import train_test_split
from torchmetrics import Accuracy
from torchmetrics.regression import MeanSquaredError
from deepchem.feat import HuggingFaceFeaturizer
from transformers import RobertaTokenizerFast

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
  from .autonotebook import tqdm as notebook_tqdm


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'MXMNet' from 'deepchem.models.torch_models' (/Users/gihan/.env/p11/lib/python3.11/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [4]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        if mask is not None:
            # mask = expand_mask(mask)
            mask = mask.unsqueeze(1).unsqueeze(2).to(torch.bool)

        


        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, seq_length, heads]
        q, k, v = qkv.chunk(3, dim=-1)
        
        d_k = q.size()[-1]
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_logits / math.sqrt(d_k)


        attn_logits=attn_logits.masked_fill(mask, float('-inf') )
        
        attention = F.softmax(attn_logits, dim=-1)
        values = torch.matmul(attention, v)
        
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)


        o = self.o_proj(values)

    
        if return_attention:
            return o, attention
        else:
            return o

class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)

        return x

class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [5]:
class TransformerModel(nn.Module):

    def __init__(self, input_dim, model_dim, num_classes, num_heads, num_layers, dropout=0.0, input_dropout=0.0):

        super().__init__()

        self.input_dim = input_dim
        self.num_classes = num_classes
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.model_dim = model_dim
        self.input_dropout = input_dropout
        self.dropout = dropout


        self.input_net = torch.nn.Embedding(input_dim, self.model_dim, 0)
        
        # self.input_net = nn.Sequential(
        #     nn.Dropout(self.input_dropout),
        #     nn.Linear(self.input_dim, self.model_dim)
        # )
        # Positional encoding for sequences
        self.positional_encoding = PositionalEncoding(d_model=self.model_dim)
        # Transformer
        self.transformer = TransformerEncoder(num_layers=self.num_layers,
                                              input_dim=self.model_dim,
                                              dim_feedforward=2*self.model_dim,
                                              num_heads=self.num_heads,
                                              dropout=self.dropout)
        # Output classifier per sequence lement


        self.dense = nn.Linear(self.model_dim, 128)
        self.activation_fn = nn.ReLU()
        self.dropout = nn.Dropout(p=.2)
        self.out_proj = nn.Linear(128, num_classes)
        
        # self.output_net = nn.Sequential(
        #     nn.Linear(self.model_dim, self.model_dim),
        #     nn.LayerNorm(self.model_dim),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(self.dropout),
        #     nn.Linear(self.model_dim, self.num_classes)
        # )

    def forward(self, x, mask=None, add_positional_encoding=True):


        mask = x.clone().detach()
        mask = mask.eq(0)
        
        x = self.input_net(x)

        x = x * (1 - mask.unsqueeze(-1).type_as(x))

        x = self.transformer(x, mask=mask)

        x = x[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = self.activation_fn(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        
        return x

In [10]:
def get_tokens(df):

    smiles = df.smiles.values
    labels = df.target.values

    MAX_LEN = 200
    a = featurizer.featurize(smiles)
    aa = [i['input_ids'][:MAX_LEN] + (MAX_LEN - len(i['input_ids']) )  * [featurizer.tokenizer.pad_token_id] for i in a]
    input_ids = torch.tensor(aa, dtype=torch.int)


    return input_ids, labels

class TextDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, tokens, labels):

        self.tokens = tokens
        self.labels = torch.tensor(labels, dtype=torch.float)
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):

        
        sample = {'tokens': self.tokens[idx], 'labels': self.labels[idx] }
        return sample

#### Get data

In [12]:
df = pd.read_csv('../data/esol/delaney-processed.csv')
df['target'] = df['measured log solubility in mols per litre']
train_df, val_df = train_test_split(df, test_size=.2)
test_df, val_df = train_test_split(val_df, test_size=.5)

#### Create dataloaders

In [None]:

hf_tokenizer = RobertaTokenizerFast.from_pretrained("seyonec/PubChem10M_SMILES_BPE_60k")
featurizer = HuggingFaceFeaturizer(tokenizer=hf_tokenizer)

train_ids, train_labels = get_tokens(train_df)
val_ids, val_labels = get_tokens(val_df)
test_ids, test_labels = get_tokens(test_df)

train_dataset = TextDataset(train_ids, train_labels)
val_dataset = TextDataset(val_ids, val_labels)
test_dataset = TextDataset(test_ids, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=False)



#### Create the model

In [None]:

device = torch.device('cpu')
model = TransformerModel(input_dim=featurizer.tokenizer.vocab_size, model_dim=128, num_heads=2, num_classes=1, num_layers=4)
model.to(device);
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=10**-5)

criterion = nn.MSELoss()
accuracy = MeanSquaredError()


def train_fn(train_loader, model, device, optimizer):
    model.train();
    total_loss = total_examples = 0
    for data in train_loader:
        enc = data['tokens']
        y = data['labels']
        enc = enc.to(device)
        y = y.to(device)
        optimizer.zero_grad()

        out = model(enc)
        loss = criterion(out, y)
        
        loss.backward()
        optimizer.step()
        total_loss += float(loss) 
        total_examples += len(data['labels'])
    return total_loss / total_examples



def valid_fn(loader, model, device):
    accs = 0
    n=0
    with torch.no_grad():
        model.eval();
        
        for data in loader:
            enc = data['tokens']
            y = data['labels']
            enc = enc.to(device)
            y = y.to(device)
            out = model(enc)
            
            # predictions = torch.argmax(out, dim=1)
            acc = accuracy(out.reshape(-1,), y)

            accs+=acc
            n+=len(y)

    # return float(torch.cat(mse, dim=0).mean().sqrt())
    return accs/n

#### Train

In [16]:
for e in range(20):
    train_fn(train_loader, model, device, optimizer)
    mse = valid_fn(val_loader, model, device)

    print(mse.item())

  return F.mse_loss(input, target, reduction=self.reduction)


0.07029614597558975
0.06723970174789429
0.06812139600515366
0.06751897186040878
0.0692100003361702
0.06783988326787949
0.06701565533876419
0.06736886501312256
0.07002850621938705
0.06947501748800278
0.06846916675567627
0.06748075038194656
0.0676903948187828
0.06730744242668152
0.06754076480865479
0.06798145920038223
0.06760700792074203
0.0668359026312828
0.06742312014102936
0.06717287003993988
