In [None]:
!pip install evaluate tqdm scikit-learn

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import evaluate
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from tqdm import tqdm

In [None]:
def gen_collate_fn(augment=False):
    def collate_fn(batch):
        start = np.random.randint(3) + 1 if augment else 3 # to keep every seq with 99 tokens
        token_ids = torch.cat([torch.as_tensor([row['upstream'][start::3] for row in batch]), 
                                torch.as_tensor([row['downstream'][start::3] for row in batch])], dim=1)
        return (token_ids, torch.as_tensor([row['label'] for row in batch]))
    return collate_fn

In [None]:
dataset = load_dataset('dvgodoy/DeepGSR_trinucleotides', split='train')
dataset = dataset.shuffle(seed=19)
train_test = dataset.train_test_split(test_size=0.25, shuffle=False)
train_val = train_test['train'].train_test_split(test_size=0.2, shuffle=False)
dataset = DatasetDict({'train': train_val['train'], 'val': train_val['test'], 'test': train_test['test']})
dataset

In [None]:
signal = 'PAS'
motif = 'AATAAA'
organism = 'hs'
dataset = dataset.filter(lambda row: row['signal'] == signal and row['motif'] == motif and row['organism'] == organism)

In [None]:
import math
import torch
import torch.nn as nn

In [None]:
bsize = 2048
dataloaders = {}
dataloaders['train'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn(True))
dataloaders['train_base'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn())
dataloaders['val'] = DataLoader(dataset['val'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn())
dataloaders['test'] = DataLoader(dataset['test'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn())

In [None]:
def tokens_to_onehot(tokens, vocab_size=64):
    # tokens: (B, L) integer IDs in [0,63]
    B, L = tokens.shape
    onehot = torch.zeros(B, L, vocab_size, dtype=torch.float32, device=tokens.device)
    onehot.scatter_(2, tokens.unsqueeze(-1), 1.0)
    return onehot

In [None]:
class LogisticRegressionMeanPool(nn.Module):
    """
    Logistic regression using the average one-hot vector (bag-of-tokens).
    Input shape: (B, 198, 64)
    """
    def __init__(self, vocab_size=64):
        super().__init__()
        self.linear = nn.Linear(vocab_size, 1)

    def forward(self, x):
        # x: (B, seq_len, vocab)
        pooled = x.mean(dim=1)           # (B, vocab_size)
        logits = self.linear(pooled)     # (B, 1)
        return logits

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)
model = LogisticRegressionMeanPool().to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_epochs = 100

losses = torch.empty(num_epochs)
val_losses = torch.empty(num_epochs)

best_loss = torch.inf
best_epoch = -1
patience = 10

model.to(device)

progress_bar = tqdm(range(num_epochs))

for epoch in progress_bar:
    batch_losses = []
    
    ## Training
    for i, batch in enumerate(dataloaders['train']):
        # Set the model to training mode
        model.train()
                
        # Send batch features and targets to the device
        features = tokens_to_onehot(batch[0]).to(device)
        labels = batch[1].view(-1, 1).float().to(device)
        
        # Step 1 - forward pass
        # write your code here
        predictions = model(features)

        # Step 2 - computing the loss
        loss = loss_fn(predictions, labels)

        # Step 3 - computing the gradients
        # Tip: it requires a single method call to backpropagate gradients
        loss.backward()

        batch_losses.append(loss.item())

        # Step 4 - updating parameters and zeroing gradients
        optimizer.step()
        optimizer.zero_grad()
        
        #scheduler.step()
        
    losses[epoch] = torch.tensor(batch_losses).mean()

    ## Validation   
    with torch.inference_mode():
        batch_losses = []

        for i, val_batch in enumerate(dataloaders['val']):
            # Set the model to evaluation mode
            model.eval()

            # Send batch features and targets to the device
            features = tokens_to_onehot(val_batch[0]).to(device)
            labels = val_batch[1].view(-1, 1).float().to(device)

            # Step 1 - forward pass
            predictions = model(features)

            # Step 2 - computing the loss
            loss = loss_fn(predictions, labels)

            batch_losses.append(loss.item())

        val_losses[epoch] = torch.tensor(batch_losses).mean()

        #scheduler.step(val_losses[epoch])
        print(losses[epoch], val_losses[epoch])
        
        if val_losses[epoch] < best_loss:
            best_loss = val_losses[epoch]
            best_epoch = epoch
            torch.save({'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()},
                       'best_model.pth')
        elif (epoch - best_epoch) > patience:
            print(f"Early stopping at epoch #{epoch}")
            break


In [None]:
losses[:epoch]

In [None]:
val_losses[:epoch]

In [None]:
states = torch.load('best_model.pth')
model.load_state_dict(states['model'])

In [None]:
metric1 = evaluate.load('precision', average=None)
metric2 = evaluate.load('recall', average=None)
metric3 = evaluate.load('accuracy')

model.eval()

for split in ['train_base', 'val', 'test']:
    for batch in tqdm(dataloaders[split]):
        features, labels = batch
        features = tokens_to_onehot(features).to(device)
            
        predictions = model(features)
    
        pred_class = (predictions >= 0.5).squeeze().to(int)
        
        pred_class = pred_class.tolist()
        labels = labels.tolist()
    
        metric1.add_batch(references=labels, predictions=pred_class)
        metric2.add_batch(references=labels, predictions=pred_class)
        metric3.add_batch(references=labels, predictions=pred_class)
    print(split, metric1.compute(average=None), metric2.compute(average=None), metric3.compute())