In [None]:
!pip install evaluate tqdm scikit-learn

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import evaluate
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from tqdm import tqdm

In [None]:
def gen_collate_fn(augment=False):
    def collate_fn(batch):
        start = np.random.randint(3) + 1 if augment else 3 # to keep every seq with 99 tokens
        token_ids = torch.cat([torch.as_tensor([row['upstream'][start::3] for row in batch]), 
                                torch.as_tensor([row['downstream'][start::3] for row in batch])], dim=1)
        return (token_ids, torch.as_tensor([row['label'] for row in batch]))
    return collate_fn

In [None]:
dataset = load_dataset('dvgodoy/DeepGSR_trinucleotides', split='train')
dataset = dataset.shuffle(seed=19)
train_test = dataset.train_test_split(test_size=0.25, shuffle=False)
train_val = train_test['train'].train_test_split(test_size=0.2, shuffle=False)
dataset = DatasetDict({'train': train_val['train'], 'val': train_val['test'], 'test': train_test['test']})
dataset

In [None]:
signal = 'PAS'
motif = 'AATAAA'
organism = 'hs'
dataset = dataset.filter(lambda row: row['signal'] == signal and row['motif'] == motif and row['organism'] == organism)

In [None]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding.
    Produces (batch, seq_len, d_model) given (batch, seq_len, d_model) input.
    """
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) *
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch_size, seq_len, d_model)
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class TrinucTransformerClassifier(nn.Module):
    def __init__(
        self,
        vocab_size: int = 64,
        seq_len: int = 200,
        d_model: int = 128,
        n_heads: int = 4,
        num_layers: int = 3,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
        num_classes: int = 2,  # 2 for binary logit (CrossEntropyLoss)
        use_cls_token: bool = True,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.use_cls_token = use_cls_token

        # Token embedding (trinucleotides indexed 0..63)
        self.token_emb = nn.Embedding(vocab_size + (1 if use_cls_token else 0), d_model)
        # Positional encoding
        self.pos_enc = PositionalEncoding(d_model, max_len=seq_len + (1 if use_cls_token else 0))

        # Transformer encoder (batch_first=True => (B, L, E))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )

        self.dropout = nn.Dropout(dropout)
        self.dropout_head = nn.Dropout(dropout)

        # Classification head: pooled representation -> logit(s)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch_size, seq_len) of integer token IDs in [0, vocab_size)
        Returns:
            logits: (batch_size, num_classes)
        """
        # Embed tokens
        x = self.token_emb(x)  # (B, L, d_model)

        if self.use_cls_token:
            batch_size = x.size(0)
            cls_token = torch.zeros(batch_size, 1, self.d_model, device=x.device)
            # Optional: learn a cls embedding instead of zeros
            # self.cls_emb = nn.Parameter(torch.zeros(1, 1, d_model))
            # cls_token = self.cls_emb.expand(batch_size, -1, -1)
            x = torch.cat([cls_token, x], dim=1)  # (B, L+1, d_model)

        # Add positional encoding
        x = self.pos_enc(x)    # (B, L, d_model)

        # Transformer encoder
        x = self.encoder(x)    # (B, L, d_model)

        if self.use_cls_token:
            x = x[:, 0, :]       # (B, d_model), CLS
        else:
            x = x.mean(dim=1)    # (B, d_model)
        
        x = self.dropout_head(x)
        logits = self.fc(x)    # (B, num_classes)
        return logits


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(1337)

model = TrinucTransformerClassifier(
    vocab_size=64,
    seq_len=200,
    d_model=192,
    n_heads=6,
    num_layers=4,
    dim_feedforward=768,
    dropout=0.1,
    num_classes=2,
    use_cls_token=False,
).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

def lr_lambda(step):
    total_steps = 54*5
    warmup_steps = 0.06*total_steps
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    # cosine decay afterwards
    progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
sum([torch.numel(p) for p in model.parameters() if p.requires_grad])

In [None]:
bsize = 256
dataloaders = {}
dataloaders['train'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn(True))
dataloaders['train_base'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn())
dataloaders['val'] = DataLoader(dataset['val'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn())
dataloaders['test'] = DataLoader(dataset['test'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn())

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_epochs = 100

losses = torch.empty(num_epochs)
val_losses = torch.empty(num_epochs)

best_loss = torch.inf
best_epoch = -1
patience = 10

model.to(device)

progress_bar = tqdm(range(num_epochs))

for epoch in progress_bar:
    batch_losses = []
    
    ## Training
    for i, batch in enumerate(dataloaders['train']):
        # Set the model to training mode
        model.train()
                
        # Send batch features and targets to the device
        features = batch[0].to(device)
        labels = batch[1].long().to(device)
        
        # Step 1 - forward pass
        # write your code here
        predictions = model(features)

        # Step 2 - computing the loss
        loss = loss_fn(predictions, labels)

        # Step 3 - computing the gradients
        # Tip: it requires a single method call to backpropagate gradients
        loss.backward()

        batch_losses.append(loss.item())

        # Step 4 - updating parameters and zeroing gradients
        optimizer.step()
        optimizer.zero_grad()
        
        scheduler.step()
        
    losses[epoch] = torch.tensor(batch_losses).mean()

    ## Validation   
    with torch.inference_mode():
        batch_losses = []

        for i, val_batch in enumerate(dataloaders['val']):
            # Set the model to evaluation mode
            model.eval()

            # Send batch features and targets to the device
            features = val_batch[0].to(device)
            labels = val_batch[1].long().to(device)

            # Step 1 - forward pass
            predictions = model(features)

            # Step 2 - computing the loss
            loss = loss_fn(predictions, labels)

            batch_losses.append(loss.item())

        val_losses[epoch] = torch.tensor(batch_losses).mean()

        #scheduler.step(val_losses[epoch])
        print(losses[epoch], val_losses[epoch])
        
        if val_losses[epoch] < best_loss:
            best_loss = val_losses[epoch]
            best_epoch = epoch
            torch.save({'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()},
                       f'best_model.pth')
        elif (epoch - best_epoch) > patience:
            print(f"Early stopping at epoch #{epoch}")
            break


In [None]:
losses[:epoch]

In [None]:
val_losses[:epoch]

In [None]:
states = torch.load('best_model.pth')
model.load_state_dict(states['model'])

In [None]:
metric1 = evaluate.load('precision', average=None)
metric2 = evaluate.load('recall', average=None)
metric3 = evaluate.load('accuracy')

model.eval()

for split in ['train_base', 'val', 'test']:
    for batch in tqdm(dataloaders[split]):
        features, labels = batch
        features = features.to(device)
            
        predictions = model(features)
    
        pred_class = predictions.argmax(dim=1).squeeze()
        
        pred_class = pred_class.tolist()
        labels = labels.tolist()
    
        metric1.add_batch(references=labels, predictions=pred_class)
        metric2.add_batch(references=labels, predictions=pred_class)
        metric3.add_batch(references=labels, predictions=pred_class)
    print(split, metric1.compute(average=None), metric2.compute(average=None), metric3.compute())