## Setup and imports

In [None]:
%load_ext tensorboard

In [None]:
%%capture
!pip install datasets
!pip install tokenizers
!pip install torchmetrics

In [None]:
%%capture
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Models

/content/drive/MyDrive/Models


In [None]:
!mkdir -p /content/drive/MyDrive/Models/pytorch-transformer/weights
!mkdir -p /content/drive/MyDrive/Models/pytorch-transformer/vocab

In [None]:
def get_config():
    return {
        "batch_size": 16,
        "num_epochs": 100,
        "lr": 10**-4,
        "seq_len": 400,
        "d_model": 512,
        "datasource": 'NikitiusIvanov/protein_seq_to_go_bio_process',
        "lang_src": "sequence",
        "lang_tgt": "go_process",
        "model_folder": "pytorch-transformer/weights",
        "model_basename": "tmodel_",
        "preload": None,
        "tokenizer_file": "pytorch-transformer/vocab/tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

In [None]:
cfg = get_config()
cfg

{'batch_size': 16,
 'num_epochs': 100,
 'lr': 0.0001,
 'seq_len': 400,
 'd_model': 512,
 'datasource': 'NikitiusIvanov/protein_seq_to_go_bio_process',
 'lang_src': 'sequence',
 'lang_tgt': 'go_process',
 'model_folder': 'pytorch-transformer/weights',
 'model_basename': 'tmodel_',
 'preload': None,
 'tokenizer_file': 'pytorch-transformer/vocab/tokenizer_{0}.json',
 'experiment_name': 'runs/tmodel'}

In [None]:
#### model imports ####
import torch
import torch.nn as nn
import math

#### dataset imports ####
from torch.utils.data import Dataset


#### train imports ####
# import torchtext.datasets as datasets
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

import warnings
from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import models
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.models import WordPiece
from tokenizers.trainers import BpeTrainer
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace

import torchmetrics
from torch.utils.tensorboard import SummaryWriter

## Model code

In [None]:

class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
         # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

class ResidualConnection(nn.Module):

        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)

        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)

class Transformer(nn.Module):

    def __init__(
        self, encoder: Encoder, decoder: Decoder,
        src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding,
        tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer
    ) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)



### build_transformer

In [None]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer


## Dataset code

In [None]:
class BilingualDataset(Dataset):

    def __init__(
        self,
        ds,
        tokenizer_src,
        tokenizer_tgt,
        src_lang,
        tgt_lang,
        seq_len
    ):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


## Train model code

### greedy_decode

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


### run_validation

In [None]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break

    if writer:
        # Evaluate the character error rate
        # Compute the char error rate
        metric = torchmetrics.text.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.text.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.text.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()



### get_all_sentences & get_or_build_tokenizer

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]


def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = BpeTrainer(
            special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"],
            min_frequency=10,
            show_progress=True,
        )
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        print(str(tokenizer_path))
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer



### get_ds

In [None]:

def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')


    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt



### get_model

In [None]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(
        vocab_src_len,
        vocab_tgt_len,
        config["seq_len"],
        config['seq_len'],
        d_model=config['d_model']
    )
    return model

## Get weights

In [None]:
def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

## Custom loss

In [None]:
def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['model_folder']}").mkdir(parents=True, exist_ok=True)
    print('run get_ds')
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))

            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1
        print('run validation')
        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)


## Train model

In [None]:
cfg = get_config()
# cfg['preload'] = 'latest'
cfg['seq_len'] = 370

In [None]:
cfg

{'batch_size': 16,
 'num_epochs': 100,
 'lr': 0.0001,
 'seq_len': 370,
 'd_model': 512,
 'datasource': 'NikitiusIvanov/protein_seq_to_go_bio_process',
 'lang_src': 'sequence',
 'lang_tgt': 'go_process',
 'model_folder': 'pytorch-transformer/weights',
 'model_basename': 'tmodel_',
 'preload': None,
 'tokenizer_file': 'pytorch-transformer/vocab/tokenizer_{0}.json',
 'experiment_name': 'runs/tmodel'}

In [None]:
train_model(cfg)

Using device: cuda
Device name: Tesla T4
Device memory: 14.74786376953125 GB
run get_ds
Max length of source sentence: 361
Max length of target sentence: 181
No model to preload, starting from scratch


Processing Epoch 00: 100%|██████████| 733/733 [10:01<00:00,  1.22it/s, loss=8.157]


run validation
--------------------------------------------------------------------------------
    SOURCE: MAFRRRTKSYPLFSQEFVIHNHADIGFCLVLCVLIGLMFEVTAKTAFLFILPQYNISVPTADSETVHYHYGPKDLVTILFYIFITIILHAVVQEYILDKISKRLHLSKVKHSKFNESGQLVVFHFTSVIWCFYVVVTEGYLTNPRSLWEDYPHVHLPFQVKFFYLCQLAYWLHALPELYFQKVRKEEIPRQLQYICLYLVHIAGAYLLNLSRLGLILLLLQYSTEFLFHTARLFYFADENNEKLFSAWAAVFGVTRLFILTLAVLAIGFGLARMENQAFDPEKGNFNTLFCRLCVLLLVCAAQAWLMWRFIHSQLRHWREYWNEQSAKRRVPATPRLPARLIKRESGYHENGVVKAENGTSPRTKKLKSP
    TARGET: collagen biosynthetic process; protein insertion into ER membrane; SRP-dependent cotranslational protein targeting to membrane, translocation
 PREDICTED: germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal ST germinal germinal germinal germinal ST germinal ST germinal germinal germinal germinal germinal germinal ST germinal ST germinal S

Processing Epoch 01:   0%|          | 1/733 [00:01<16:04,  1.32s/it, loss=8.157]


OutOfMemoryError: ignored

## Training loop debug

In [None]:
config = get_config()

In [None]:
config['preload'] = 'latest'

In [None]:
# def train_model(config):
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
if (device == 'cuda'):
    print(f"Device name: {torch.cuda.get_device_name(device.index)}")
    print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
elif (device == 'mps'):
    print(f"Device name: <mps>")
else:
    print("NOTE: If you have a GPU, consider using it for training.")
    print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
    print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
device = torch.device(device)


Using device: cuda
Device name: Tesla T4
Device memory: 14.74786376953125 GB


In [None]:
# Make sure the weights folder exists
Path(f"{config['model_folder']}").mkdir(parents=True, exist_ok=True)

In [None]:
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)


Max length of source sentence: 361
Max length of target sentence: 181


In [None]:
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

In [None]:
# Tensorboard
writer = SummaryWriter(config['experiment_name'])

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)


In [None]:
# If the user specified a model to preload before training, load it
initial_epoch = 0
global_step = 0


In [None]:
preload = config['preload']

In [None]:
preload = None

In [None]:
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None


In [None]:
model_filename

In [None]:
if model_filename:
    print(f'Preloading model {model_filename}')
    state = torch.load(model_filename)
    model.load_state_dict(state['model_state_dict'])
    initial_epoch = state['epoch'] + 1
    optimizer.load_state_dict(state['optimizer_state_dict'])
    global_step = state['global_step']
else:
    print('No model to preload, starting from scratch')

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)


No model to preload, starting from scratch


In [None]:

# for epoch in range(initial_epoch, config['num_epochs']):
#     torch.cuda.empty_cache()
#     model.train()
#     batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
#     for batch in batch_iterator:


In [None]:
epoch = 0
torch.cuda.empty_cache()
model.train()
batch_iterator = train_dataloader.__iter__()

In [None]:
batch = next(batch_iterator)

In [None]:
encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)


In [None]:
# Run the tensors through the encoder, decoder and the projection layer
encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)


In [None]:
# Compare the output with the label
label = batch['label'].to(device) # (B, seq_len)

### Custom loss debug

In [None]:
# Compute the loss using a simple cross entropy
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
# batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

In [None]:
############ split the target seq ############
SEQ_LEN = config['seq_len']

EOS_ID = tokenizer_tgt.token_to_id('[EOS]')

SEP_ID = tokenizer_tgt.token_to_id(';')

PAD_ID = tokenizer_tgt.token_to_id('[PAD]')

UNK_ID = tokenizer_tgt.token_to_id('[UNK]')

VOCAB_SIZE = tokenizer_tgt.get_vocab_size()

label_splitted = []

In [None]:
for i in range(len(label)):
    item_splitted = []
    item = label[i]
    sub_seqs_ends = torch.where(
        (item == SEP_ID)
        |
        (item == EOS_ID)
    )[0]
    start_idx = 0
    for end_idx in sub_seqs_ends:
        subitem = item[start_idx: end_idx]
        if (subitem == PAD_ID).all() == False:
            end_seq = torch.ones(config['seq_len'] - len(subitem), dtype=torch.int64).to(device)
            end_seq[0] = EOS_ID
            subitem = torch.cat(
                [
                    subitem,
                    end_seq
                ]
            ).to(device)

            item_splitted.append(subitem)

            start_idx = end_idx + 1

    label_splitted.append(item_splitted)


In [None]:
item.shape, SEQ_LEN

(torch.Size([400, 3486]), 400)

In [None]:
eos_ids

tensor(65, device='cuda:0')

In [None]:
############ split the projection layer output ############
proj_output_splitted = []

for i in range(len(proj_output)):

    item = proj_output[i]

    eos_ids = (
        (
            torch.max(
                item,
                dim=1
            )[1] == EOS_ID
        )
        .nonzero()
        .squeeze(1)
    )
    if len(eos_ids) > 0:

      eos_row_id = int(eos_ids[0])
    else:
      eos_row_id = SEQ_LEN

    sub_seqs_ends = (
        (
            (
                torch.max(
                    item[:eos_row_id + 1], dim=1
                )[1] == SEP_ID
            )
            |
            (
                torch.max(
                    item[:eos_row_id + 1], dim=1
                )[1] == EOS_ID
            )
        )
        .nonzero()
        .squeeze(1)
    ).to(device)

    item_splitted = []

    start_idx = 0
    for end_idx in sub_seqs_ends:

        subitem = item[start_idx: end_idx]

        # create appendix with ones
        end_seq = torch.ones(
            size=(
                SEQ_LEN - subitem.shape[0],
                VOCAB_SIZE
            )
        ).to(device)

        # create a row with maximum element
        # which index corresponded eos id
        end_seq[0, EOS_ID] = 2

        # all following rows fill with pads
        end_seq[1:, PAD_ID] = 2

        subitem = torch.cat(
            [
                subitem,
                end_seq
            ]
        ).to(device)

        item_splitted.append(subitem)

        start_idx = end_idx + 1

    diff_lenght = len(item_splitted) - len(label_splitted[i])

    # if lenght splitted predict less than splitted fact
    # fill difference in predict with pads
    if diff_lenght < 0:
        for n in range(-diff_lenght ):
            # create appendix with ones
            end_seq = torch.ones(
                size=(
                    SEQ_LEN,
                    VOCAB_SIZE
                )
            ).to(device)

            # all following rows fill with pads
            end_seq[:, UNK_ID] = 2

            item_splitted.append(end_seq)

    # if lenght splitted predict more than splitted fact
    # fill difference in fact with pads
    if diff_lenght > 0:
        for n in range(diff_lenght ):
            # create appendix with ones
            end_seq = torch.tensor([UNK_ID] * SEQ_LEN).to(device)

            label_splitted[i].append(end_seq)

    proj_output_splitted.append(item_splitted)


In [None]:

############ calculate losses for all pairs sublabels ############
losses = []
min_losses = {}

for label_idx in range(len(label_splitted)):
    label_losses = {}
    min_losses[label_idx] = {}

    for sub_label_idx in range(len(label_splitted[label_idx])):
        min_loss = torch.inf
        label_losses[sub_label_idx] = {}

        for sub_pred_label_idx in range(len(proj_output_splitted[label_idx])):

            loss = loss_fn(
                proj_output_splitted[label_idx][sub_pred_label_idx],
                label_splitted[label_idx][sub_label_idx]
            )

            if loss < min_loss:
                min_loss = loss
                min_loss_idx = sub_pred_label_idx

            label_losses[sub_label_idx][sub_pred_label_idx] = loss

        min_losses[label_idx][sub_label_idx] = min_loss_idx

    losses.append(label_losses)

############ find collisions  ############
optimal_losses = []
checked_items = []
collision_items = []
for item_idx in range(len(label_splitted)):

    min_stats = {
        x: (list(min_losses[item_idx].values())).count(x)
        for x in list(set(list(min_losses[item_idx].values())))
    }

    for subitem_idx in range(len(label_splitted[item_idx])):
        # if predicted sublabel return minimal loss for only one sublabel
        if min_stats[min_losses[item_idx][subitem_idx]] == 1:
            optimal_losses.append(
                losses[item_idx][subitem_idx][min_losses[item_idx][subitem_idx]]
            )

            checked_items.append((item_idx, subitem_idx))

        # if predicted sublabel return minimal loss for more than one sublabel
        else:
            collision_items.append((item_idx, subitem_idx))

############ resolve collisions  ############
while len(collision_items) > 0:

    label_idx, sub_label_idx = collision_items[0]

    # print(f'collision_item: {label_idx, sub_label_idx}')

    order = losses[label_idx][sub_label_idx].copy()

    cur_min = min_losses[label_idx][sub_label_idx]

    # print(f'cur_min_idx: {cur_min}')

    try: del order[cur_min]
    except: pass

    new_min = (
        list(
            order.keys()
        )
        [
            list(order.values())
            .index(
                min(list(order.values()))
            )
        ]
    )

    min_losses[label_idx][sub_label_idx] = new_min

    # print(f'new_min_idx: {new_min}')

    # Check loss statistics for choosen subseq
    min_stats = {
        x: (list(min_losses[label_idx].values())).count(x)
        for x in list(set(list(min_losses[label_idx].values())))
    }

    # if we choose new predicted sub_seq with that return minimum loss
    # for more than one true sub_seq try search next one
    # untill we find only exclusive pair
    while min_stats[new_min] > 1:

        cur_min = new_min

        try: del order[cur_min]
        except: pass

        new_min = (
            list(
                order.keys()
            )
            [
                list(order.values())
                .index(
                    min(list(order.values()))
                )
            ]
        )

        min_losses[label_idx][sub_label_idx] = new_min

        # Check statistics
        min_stats = {
            x: (list(min_losses[label_idx].values())).count(x)
            for x in list(set(list(min_losses[label_idx].values())))
        }

    # print(f'new_min: {new_min}')

    # if there no more collisions in sublabels
    # del all their items from list of collisions
    if max(list(min_stats.values())) == 1:
        uncollision_items = [
            x for x in collision_items
            if x[0] == label_idx
        ]
        # print('uncolided items:')
        for item in uncollision_items:

            # print(f'{item}')

            del collision_items[collision_items.index(item)]

            optimal_losses.append(
                losses[item[0]][item[1]][min_losses[item[0]][item[1]]]
            )
        #continue

    # if in selected label there is another collision just del current
    # collision and go next
    else:
        # print(f'uncolided item: {label_idx, sub_label_idx}')
        del collision_items[collision_items.index((label_idx, sub_label_idx))]

        del order[new_min]

        optimal_losses.append(
            losses[label_idx][sub_label_idx][min_losses[label_idx][sub_label_idx]]
        )

custom_loss = torch.mean(torch.tensor(optimal_losses, requires_grad=True)).to(device)
custom_loss.retain_grad()


In [None]:
custom_loss

tensor(8.1428, device='cuda:0', grad_fn=<ToCopyBackward0>)

In [None]:


# Log the loss
writer.add_scalar('train loss', loss.item(), global_step)
writer.flush()

# Backpropagate the loss
loss.backward()

# Update the weights
optimizer.step()
optimizer.zero_grad(set_to_none=True)

global_step += 1

In [None]:
model

Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (self_attention_block): MultiHeadAttentionBlock(
          (w_q): Linear(in_features=512, out_features=512, bias=False)
          (w_k): Linear(in_features=512, out_features=512, bias=False)
          (w_v): Linear(in_features=512, out_features=512, bias=False)
          (w_o): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward_block): FeedForwardBlock(
          (linear_1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear_2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (residual_connections): ModuleList(
          (0-1): 2 x ResidualConnection(
            (dropout): Dropout(p=0.1, inplace=False)
            (norm): LayerNormalization()
          )
        )
      )
    )
    (norm): LayerNormalization

In [None]:
# Run validation at the end of every epoch
run_validation(
    model,
    val_dataloader,
    tokenizer_src,
    tokenizer_tgt,
    config['seq_len'],
    device,
    print_msg=None,
    global_step=None, writer=None)


TypeError: ignored

In [None]:

# Save the model at the end of every epoch
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'global_step': global_step
}, model_filename)


## Run validation debug

In [None]:
# def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
model.eval()
count = 0


In [None]:
num_examples=2
max_len = config['seq_len']

In [None]:
validation_ds = val_dataloader

In [None]:
source_texts = []
expected = []
predicted = []

try:
    # get the console window width
    with os.popen('stty size', 'r') as console:
        _, console_width = console.read().split()
        console_width = int(console_width)
except:
    # If we can't get the console width, use 80 as default
    console_width = 80


In [None]:
with torch.no_grad():
    for batch in validation_ds:
        count += 1
        encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
        encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

        # check that the batch size is 1
        assert encoder_input.size(
            0) == 1, "Batch size must be 1 for validation"

        model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

        source_text = batch["src_text"][0]
        target_text = batch["tgt_text"][0]
        model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

        source_texts.append(source_text)
        expected.append(target_text)
        predicted.append(model_out_text)

        # Print the source, target and model output
        print('-'*console_width)
        print(f"{f'SOURCE: ':>12}{source_text}")
        print(f"{f'TARGET: ':>12}{target_text}")
        print(f"{f'PREDICTED: ':>12}{model_out_text}")

        if count == num_examples:
            print('-'*console_width)
            break



--------------------------------------------------------------------------------
    SOURCE: MSEEIITPVYCTGVSAQVQKQRARELGLGRHENAIKYLGQDYEQLRVRCLQSGTLFRDEAFPPVPQSLGYKDLGPNSSKTYGIKWKRPTELLSNPQFIVDGATRTDICQGALGDCWLLAAIASLTLNDTLLHRVVPHGQSFQNGYAGIFHFQLWQFGEWVDVVVDDLLPIKDGKLVFVHSAEGNEFWSALLEKAYAKVNGSYEALSGGSTSEGFEDFTGGVTEWYELRKAPSDLYQIILKALERGSLLGCSIDISSVLDMEAITFKKLVKGHAYSVTGAKQVNYRGQVVSLIRMRNPWGEVEWTGAWSDSSSEWNNVDPYERDQLRVKMEDGEFWMSFRDFMREFTRLEICNLTPDALKSRTIRKWNTTLYEGTWRRGSTAGGCRNYPATFWVNPQFKIRLDETDDPDDYGDRESGCSFVLALMQKHRRRERRFGRDMETIGFAVYEVPPELVGQPAVHLKRDFFLANASRARSEQFINLREVSTRFRLPPGEYVVVPSTFEPNKEGDFVLRFFSEKSAGTVELDDQIQANLPDEQVLSEEEIDENFKALFRQLAGEDMEISVKELRTILNRIISKHKDLRTKGFSLESCRSMVNLMDRDGNGKLGLVEFNILWNRIRNYLSIFRKFDLDKSGSMSAYEMRMAIESAGFKLNKKLYELIITRYSEPDLAVDFDNFVCCLVRLETMFRFFKTLDTDLDGVVTFDLFKWLQLTMFA
    TARGET: mammary gland involution; positive regulation of cell population proliferation; proteolysis; receptor catabolic process; regulation of catalytic activity; regulation of macroautoph

In [None]:
predicted

['cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle ; regulation of cell cycle',
 'cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle ; cell cycle

In [None]:
expected

['mammary gland involution; positive regulation of cell population proliferation; proteolysis; receptor catabolic process; regulation of catalytic activity; regulation of macroautophagy; regulation of NMDA receptor activity; self proteolysis',
 "CUT catabolic process; DNA deamination; exonucleolytic catabolism of deadenylated mRNA; exonucleolytic trimming to generate mature 3'-end of 5.8S rRNA from tricistronic rRNA transcript (SSU-rRNA, 5.8S rRNA, LSU-rRNA); isotype switching; nuclear polyadenylation-dependent rRNA catabolic process; nuclear polyadenylation-dependent tRNA catabolic process; nuclear-transcribed mRNA catabolic process, exonucleolytic, 3'-5'; polyadenylation-dependent snoRNA 3'-end processing; positive regulation of isotype switching; RNA catabolic process; RNA processing; rRNA processing; U4 snRNA 3'-end processing"]

In [None]:
# if writer:
# Evaluate the character error rate
# Compute the char error rate
metric = torchmetrics.text.CharErrorRate()
cer = metric(predicted, expected)
writer.add_scalar('validation cer', cer, global_step)
writer.flush()

# Compute the word error rate
metric = torchmetrics.text.WordErrorRate()
wer = metric(predicted, expected)
writer.add_scalar('validation wer', wer, global_step)
writer.flush()

# Compute the BLEU metric
metric = torchmetrics.text.BLEUScore()
bleu = metric(predicted, expected)
writer.add_scalar('validation BLEU', bleu, global_step)
writer.flush()

In [None]:
cer, wer, bleu

(tensor(1.2311), tensor(2.7609), tensor(0.))

## Debug greedy decode

In [None]:
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

In [None]:
batch = val_ds[0]

In [None]:
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

In [None]:
# def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]')


In [None]:
source = encoder_input
source_mask = encoder_mask

In [None]:
# Precompute the encoder output and reuse it for every step
encoder_output = model.encode(source, source_mask)

In [None]:
encoder_output

tensor([[[ 0.1318,  0.2228, -0.1117,  ...,  0.2391, -0.1218, -0.4987],
         [ 0.1752,  0.1941, -0.0727,  ...,  0.2105, -0.1081, -0.4962],
         [ 0.1814,  0.1903, -0.0675,  ...,  0.2199, -0.1120, -0.5122],
         ...,
         [ 0.1550,  0.2167, -0.1046,  ...,  0.2522, -0.0928, -0.5195],
         [ 0.1506,  0.1858, -0.0776,  ...,  0.2522, -0.0940, -0.5201],
         [ 0.1235,  0.1759, -0.0769,  ...,  0.2546, -0.0882, -0.5206]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
# Initialize the decoder input with the sos token
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

In [None]:
decoder_input

tensor([[2]], device='cuda:0')

In [None]:
max_len = 400

In [None]:

while True:
    if decoder_input.size(1) == max_len:
        break

    # build mask for target
    decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

    # calculate output
    out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

    # get next token
    prob = model.project(out[:, -1])

    _, next_word = torch.max(prob, dim=1)
    print(next_word)
    decoder_input = torch.cat(
        [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
    )

    if next_word == eos_idx:
        break

# return decoder_input.squeeze(0)


tensor([0], device='cuda:0')
tensor([3], device='cuda:0')


In [None]:
_, ind = torch.sort(model.project(out[:, -1]))


In [None]:
ind

tensor([[ 760, 6245, 2906,  ..., 1023,    0,    3]], device='cuda:0')

In [None]:
next_word = ind[0, -3]
next_word


tensor(1023, device='cuda:0')

In [None]:
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

In [None]:
decoder_input = torch.cat(
        [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
    )

In [None]:
model_out = decoder_input
model_out

tensor([[ 2, 34, 43, 32,  3]], device='cuda:0')

In [None]:
model_out.detach().cpu().numpy()

array([[ 2, 34, 43, 32,  3]])

In [None]:
tokenizer_tgt.decode()

TypeError: ignored