# Setting up Environment

In [1]:
from dataset import ChemBL35Dataset
from tokenizer import SMILESTokenizer

import torch
from torch.utils.data import DataLoader, random_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

smiles_file = "chembl_35.smi"
tokenizer_dir = "trained_tokenizer"

In [3]:
tokenizer = SMILESTokenizer.from_pretrained(tokenizer_dir)
if tokenizer.mask_token is None:
	tokenizer.add_special_tokens({"mask_token": "<mask>"})

vocab_size = tokenizer.vocab_size

ds = ChemBL35Dataset(smiles_file, tokenizer, max_length=256, noise_prob=0.15)
train_size = int(0.9 * len(ds))
val_size = len(ds) - train_size
train_ds, val_ds = random_split(ds, [train_size, val_size])

train_dl = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=10)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=10)

# Import

In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from einops import rearrange, repeat
import xformers.ops as xops
import numpy as np

from typing import Optional, Tuple

from models.utils import DyT, FeedForward

# Checking some number

In [5]:
vocab_size, torch.log1p(torch.tensor(vocab_size))

(4096, tensor(8.3180))

In [6]:
first_batch = next(iter(train_dl))
first_batch["input_ids"].shape

torch.Size([1, 256])

# Positional Encoding

In [7]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
	freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
	t = torch.arange(end, device=freqs.device)  # type: ignore
	freqs = torch.outer(t, freqs).float()  # type: ignore
	freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

	return freqs_cis

In [8]:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
	ndim = x.ndim

	assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
		f"freqs_cis shape {freqs_cis.shape} needs to be {(x.shape[1], x.shape[-1])}"
	)

	shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]

	return freqs_cis.view(*shape)

In [9]:
def apply_rotary_emb(
	xq: torch.Tensor,
	xk: torch.Tensor,
	freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
	xq_complex = torch.view_as_complex(
		rearrange(xq.float(), "... (n two) -> ... n two", two=2)
	)
	xk_complex = torch.view_as_complex(
		rearrange(xk.float(), "... (n two) -> ... n two", two=2)
	)

	freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)

	xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3)
	xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(3)

	return xq_out.type_as(xq), xk_out.type_as(xk)

# Transformer Block

In [10]:
class MultiHeadAttention(nn.Module):
	def __init__(self, d_model, n_heads, max_seq_len, dropout=0.1):
		super(MultiHeadAttention, self).__init__()
		assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
		
		self.d_model = d_model
		self.n_heads = n_heads
		self.d_k = d_model // n_heads
		
		self.q_proj = nn.Linear(d_model, d_model)
		self.k_proj = nn.Linear(d_model, d_model)
		self.v_proj = nn.Linear(d_model, d_model)
		self.out_proj = nn.Linear(d_model, d_model)

		self.p = dropout
		# self.dropout = nn.Dropout(dropout)

		nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / (2 ** 0.5))
		nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / (2 ** 0.5))
		nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / (2 ** 0.5))
		nn.init.xavier_uniform_(self.out_proj.weight)
		nn.init.zeros_(self.out_proj.bias)
			
	def scaled_dot_product_attention(
		self, 
  	Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
		mask: Optional[torch.Tensor] = None, is_causal: bool = False
	):
		# attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
		# if mask is not None:
		# 	attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

		# attn_probs = torch.softmax(attn_scores, dim=-1)
		# output = torch.matmul(attn_probs, V)
		# output = self.dropout(output)

		output = xops.memory_efficient_attention(
			Q, K, V, 
			p=self.p,
			attn_bias=None if not is_causal else xops.LowerTriangularMask(),
		)

		return output
			
	def split_heads(self, x):
		# return rearrange(x, 'b s (h d) -> b h s d', h=self.n_heads)
		return rearrange(x, 'b s (h d) -> b s h d', h=self.n_heads)
			
	def combine_heads(self, x):
		# return rearrange(x, 'b h s d -> b s (h d)', h=self.n_heads)
		return rearrange(x, 'b s h d -> b s (h d)', h=self.n_heads)
			
	def forward(
		self, 
  	Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, freqs_cis: torch.Tensor,
		mask: Optional[torch.Tensor] = None, is_causal: bool = False
	):
		Q = self.split_heads(self.q_proj(Q))
		K = self.split_heads(self.k_proj(K))
		V = self.split_heads(self.v_proj(V))

		Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis)
		
		attn_output = self.scaled_dot_product_attention(Q, K, V, mask, is_causal=is_causal)
		output = self.out_proj(self.combine_heads(attn_output))

		return output

In [11]:
class EncoderLayer(nn.Module):
	def __init__(
		self, d_model: int, n_heads: int, d_ff: int = 3072,
		dropout: float = 0.2, max_seq_len: int = 1024, use_layerscale: bool = True,
		norm_layer=nn.LayerNorm,
	):
		super().__init__()
		self.self_attn = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
		self.self_attn_norm = norm_layer(d_model)
		self.self_attn_dropout = nn.Dropout(dropout)
		self.attn_layer_scale = nn.Parameter(torch.ones(d_model) * 1e-4) if use_layerscale else None

		self.ff_norm = norm_layer(d_model)
		self.ff = FeedForward(d_model, d_ff, dropout)
		self.ff_dropout = nn.Dropout(dropout)
		self.ff_layer_scale = nn.Parameter(torch.ones(d_model) * 1e-4) if use_layerscale else None
			
	def forward(self, src: torch.Tensor, freqs_cis: torch.Tensor, src_mask: Optional[torch.Tensor] = None):
		norm_src = self.self_attn_norm(src)
		attn_out = self.self_attn(norm_src, norm_src, norm_src, freqs_cis, src_mask)
		attn_out = self.self_attn_dropout(attn_out)
		if self.attn_layer_scale is not None:
			src = src + self.attn_layer_scale * attn_out
		else:
			src = src + attn_out

		norm_src = self.ff_norm(src)
		ff_out = self.ff(norm_src)
		ff_out = self.ff_dropout(ff_out)
		if self.ff_layer_scale is not None:
			src = src + self.ff_layer_scale * ff_out
		else:
			src = src + ff_out

		return src

In [12]:
class DecoderLayer(nn.Module):
	def __init__(
		self, d_model: int, n_heads: int, d_ff: int = 3072,
		dropout: float = 0.2, max_seq_len: int = 1024, use_layerscale: bool = True,
		norm_layer=nn.LayerNorm,
	):
		super().__init__()

		self.self_attn = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
		self.self_attn_norm = norm_layer(d_model)
		self.self_attn_dropout = nn.Dropout(dropout)
		self.self_attn_layer_scale = nn.Parameter(torch.ones(d_model) * 1e-4) if use_layerscale else None

		self.cross_attn = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
		self.cross_attn_norm = norm_layer(d_model)
		self.cross_attn_dropout = nn.Dropout(dropout)
		self.cross_attn_layer_scale = nn.Parameter(torch.ones(d_model) * 1e-4) if use_layerscale else None

		self.ff_norm = norm_layer(d_model)
		self.ff = FeedForward(d_model, d_ff, dropout)
		self.ff_dropout = nn.Dropout(dropout)
		self.ff_layer_scale = nn.Parameter(torch.ones(d_model) * 1e-4) if use_layerscale else None

		self.dropout = nn.Dropout(dropout)

	def forward(
		self, tgt: torch.Tensor, memory: torch.Tensor, freqs_cis: torch.Tensor,
		tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None,
	):
		norm_tgt = self.self_attn_norm(tgt)
		self_attn_out = self.self_attn(norm_tgt, norm_tgt, norm_tgt, freqs_cis, tgt_mask, is_causal=True)
		self_attn_out = self.self_attn_dropout(self_attn_out)
		if self.self_attn_layer_scale is not None:
			tgt = tgt + self.self_attn_layer_scale * self_attn_out
		else:
			tgt = tgt + self_attn_out

		norm_tgt = self.cross_attn_norm(tgt)
		cross_attn_out = self.cross_attn(norm_tgt, memory, memory, freqs_cis, memory_mask)
		cross_attn_out = self.cross_attn_dropout(cross_attn_out)
		if self.cross_attn_layer_scale is not None:
			tgt = tgt + self.cross_attn_layer_scale * cross_attn_out
		else:
			tgt = tgt + cross_attn_out

		norm_tgt = self.ff_norm(tgt)
		ff_out = self.ff(norm_tgt)
		ff_out = self.ff_dropout(ff_out)
		if self.ff_layer_scale is not None:
			tgt = tgt + self.ff_layer_scale * ff_out
		else:
			tgt = tgt + ff_out

		return tgt

In [13]:
class Transformer(nn.Module):
	def __init__(
		self,
		src_vocab_size, tgt_vocab_size, 
		d_model, n_heads, n_layers, d_ff, max_seq_len, dropout,
		norm_layer=nn.LayerNorm,
	):
		super(Transformer, self).__init__()
		self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
		self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)

		self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout, norm_layer=norm_layer) for _ in range(n_layers)])
		self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout, norm_layer=norm_layer) for _ in range(n_layers)])

		self.fc = nn.Linear(d_model, tgt_vocab_size)
		self.dropout = nn.Dropout(dropout)

		self.freqs_cis = precompute_freqs_cis(d_model // n_heads, max_seq_len * 2)

	def forward(self, src, tgt):
		src_embedded = self.dropout(self.encoder_embedding(src))
		tgt_embedded = self.dropout(self.decoder_embedding(tgt))

		_, seq_len = src.shape

		freqs_cis = self.freqs_cis[:seq_len].to(src.device)

		enc_output = src_embedded
		for enc_layer in self.encoder_layers:
			enc_output = enc_layer(enc_output, freqs_cis)

		dec_output = tgt_embedded
		for dec_layer in self.decoder_layers:
			dec_output = dec_layer(dec_output, enc_output, freqs_cis)

		output = self.fc(dec_output)
		return output

In [None]:
class BART(nn.Module):
	def __init__(
		self, vocab_size: int,
		d_model: int = 768, n_heads: int = 12,
		n_enc_layers: int = 6, n_dec_layers: int = 6,
		d_ff: int = 3072, max_seq_len: int = 1024,
		dropout: float = 0.2,
		norm_layer=nn.LayerNorm,
	):
		super().__init__()

		self.vocab_size = vocab_size
		self.d_model = d_model

		self.enc_emb = nn.Embedding(vocab_size, d_model)
		self.dec_emb = nn.Embedding(vocab_size, d_model)

		self.freqs_cis = precompute_freqs_cis(d_model // n_heads, max_seq_len * 2)

		self.enc_layers = nn.ModuleList([
			EncoderLayer(d_model, n_heads, d_ff, dropout, max_seq_len, norm_layer=norm_layer)
			for _ in range(n_enc_layers)
		])

		self.dec_layers = nn.ModuleList([
			DecoderLayer(d_model, n_heads, d_ff, dropout, max_seq_len, norm_layer=norm_layer)
			for _ in range(n_dec_layers)
		])

		self.fc_out = nn.Linear(d_model, vocab_size)
		self.dropout = nn.Dropout(dropout)

	def encode(self, src: torch.Tensor, freqs_cis: torch.Tensor, src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
		# x = self.enc_emb(src) * math.sqrt(self.d_model)

		x = self.enc_emb(src)
		for layer in self.enc_layers:
			x = layer(x, freqs_cis, src_mask)

		return x

	def decode(
		self, tgt: torch.Tensor, memory: torch.Tensor,
		freqs_cis: torch.Tensor,
		tgt_mask: Optional[torch.Tensor] = None,
		memory_mask: Optional[torch.Tensor] = None,
	) -> torch.Tensor:
		# x = self.dec_emb(tgt) * math.sqrt(self.d_model)

		x = self.dec_emb(tgt)
		for layer in self.dec_layers:
			x = layer(x, memory, freqs_cis, tgt_mask, memory_mask)

		return x

	def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None):
		_, seq_len = src.shape
		freqs_cis = self.freqs_cis[:seq_len].to(src.device)

		enc_out = self.encode(src, freqs_cis, src_mask)
		dec_out = self.decode(tgt, enc_out, freqs_cis, tgt_mask)

		out = self.fc_out(dec_out)
		return self.dropout(out)


In [22]:
device = torch.device("cuda")

src_vocab_size = vocab_size
tgt_vocab_size = vocab_size
d_model = 256
num_heads = 4
num_layers = 2
d_ff = 512 
max_seq_length = 100
dropout = 0.1

# transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout, norm_layer=DyT)
transformer = BART(vocab_size, d_model, num_heads, num_layers, num_layers, d_ff, max_seq_length, dropout, norm_layer=DyT)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)).to(device)  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)).to(device)  # (batch_size, seq_length)

In [23]:
transformer = transformer.to(device)

In [24]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
	optimizer.zero_grad()
	output = transformer(src_data, tgt_data)
	loss = criterion(output[:, 1:, :].contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
	loss.backward()
	optimizer.step()
	print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 36.806392669677734
Epoch: 2, Loss: 36.32545471191406
Epoch: 3, Loss: 35.86148452758789
Epoch: 4, Loss: 35.36692428588867
Epoch: 5, Loss: 35.016841888427734
Epoch: 6, Loss: 34.54901123046875
Epoch: 7, Loss: 34.08950424194336
Epoch: 8, Loss: 33.622737884521484
Epoch: 9, Loss: 33.1419677734375
Epoch: 10, Loss: 32.739952087402344
Epoch: 11, Loss: 32.20132827758789
Epoch: 12, Loss: 31.93497657775879
Epoch: 13, Loss: 31.454566955566406
Epoch: 14, Loss: 31.023502349853516
Epoch: 15, Loss: 30.607654571533203
Epoch: 16, Loss: 30.193368911743164
Epoch: 17, Loss: 29.816753387451172
Epoch: 18, Loss: 29.486560821533203
Epoch: 19, Loss: 28.978042602539062
Epoch: 20, Loss: 28.741657257080078
Epoch: 21, Loss: 28.21759033203125
Epoch: 22, Loss: 27.857959747314453
Epoch: 23, Loss: 27.46573829650879
Epoch: 24, Loss: 27.04524040222168
Epoch: 25, Loss: 26.657852172851562


KeyboardInterrupt: 