In [None]:
!pip install evaluate tqdm scikit-learn

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import evaluate
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from tqdm import tqdm

In [None]:
from itertools import product

mapping = {''.join(p): i for i, p in enumerate(product('ACGT', repeat=3))}

def reverse_complement(seq):
    rev = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}
    return ''.join(list(map(rev.get, [*seq]))[::-1])

rev_map = {v: mapping[reverse_complement(t1)] for t1, v in mapping.items()}

In [None]:
def nuc2image(seq, flip):
    if flip:
        seq = list(map(rev_map.get, seq))[::-1]
    encoded = np.eye(64, dtype=np.uint8)[seq]
    image = torch.as_tensor(encoded, dtype=torch.uint8).unsqueeze(0).unsqueeze(0)
    return image

In [None]:
def gen_collate_fn(n_channels=2, random_flip=0.0):
    def collate_fn(batch):
        flip = torch.rand(1) < random_flip
        upstream = torch.vstack([nuc2image(row['upstream' if not flip else 'downstream'], flip) for row in batch])
        downstream = torch.vstack([nuc2image(row['downstream' if not flip else 'upstream'], flip) for row in batch])
        features = torch.cat([upstream, downstream], axis=2)
        width = features.shape[2] // n_channels
        leftover = features.shape[2] % n_channels
        start = leftover // 2
        end = -(start + leftover % 2)
        if (end == 0):
            end = features.shape[2]
        features = features[:, :, start:end].view(-1, n_channels, width, 64)
        return (features, torch.as_tensor([row['label'] for row in batch]))
    return collate_fn

In [None]:
dataset = load_dataset('dvgodoy/DeepGSR_trinucleotides', split='train')
dataset = dataset.shuffle(seed=13)
train_test = dataset.train_test_split(test_size=0.25, shuffle=False)
train_val = train_test['train'].train_test_split(test_size=0.2, shuffle=False)
dataset = DatasetDict({'train': train_val['train'], 'val': train_val['test'], 'test': train_test['test']})
dataset

In [None]:
signal = 'PAS'
motif = 'AATAAA'
organism = 'hs'
dataset = dataset.filter(lambda row: row['signal'] == signal and row['motif'] == motif and row['organism'] == organism)

In [None]:
bsize = 256
n_channels = 2
dataloaders = {}
dataloaders['train'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn(n_channels=n_channels, random_flip=0.25))
dataloaders['train_base'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn(n_channels=n_channels))
dataloaders['val'] = DataLoader(dataset['val'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn(n_channels=n_channels))
dataloaders['test'] = DataLoader(dataset['test'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn(n_channels=n_channels))

In [None]:
class DNACNN(nn.Module):
    def __init__(self, in_channels=2, num_classes=2):
        super().__init__()

        # First conv: 50 feature maps, kernel 30x32, same padding
        self.conv1 = nn.Conv2d(
            in_channels,
            32,
            kernel_size=(30, 31),
            padding='same',   # keeps H,W the same
            bias=True
        )

        # First pooling: 1x2
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))

        # Second conv: 100 feature maps, kernel 10x8, same padding
        self.conv2 = nn.Conv2d(
            32,
            64,
            kernel_size=(10, 8),
            bias=True
        )

        # Second pooling: 1x2
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.global_pool = nn.AdaptiveMaxPool2d((8, 8))
        self.dropout = nn.Dropout(p=0.3)

        self.dropout2d = nn.Dropout2d(p=0.3)
        #flat_dim = 100 * 289 * 12

        #self.fc1 = nn.Linear(flat_dim, 256)
        self.fc1 = nn.Linear(64*64, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        # x: (B, 2, 298, 64)
        x = F.relu(self.conv1(x))   # (B, 50, 298, 64)
        x = self.pool1(x)           # (B, 50, 298, 32)
        x = F.relu(self.conv2(x))   # (B, 100, 289, 25)
        x = self.pool2(x)           # (B, 100, 289, 12)

        x = self.dropout2d(x)
        
        x = self.global_pool(x)     # (B, 100, 1, 1)
        x = torch.flatten(x, 1)     # (B, 100)
        #x = torch.flatten(x, 1)     # (B, 100*289*12)

        x = self.dropout(x)
        x = F.relu(self.fc1(x))     # (B, 256)
        x = self.dropout(x)
        x = self.fc2(x)             # (B, 1) logit

        return x

torch.manual_seed(11)
model = DNACNN(in_channels=n_channels, num_classes=2)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable parameters:", total_params)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_epochs = 100

losses = torch.empty(num_epochs)
val_losses = torch.empty(num_epochs)

best_loss = torch.inf
best_epoch = -1
patience = 10

model.to(device)

progress_bar = tqdm(range(num_epochs))

for epoch in progress_bar:
    batch_losses = []
    
    ## Training
    for i, batch in enumerate(dataloaders['train']):
        # Set the model to training mode
        model.train()
                
        # Send batch features and targets to the device
        features = batch[0].float().to(device)
        labels = batch[1].long().to(device)
        
        # Step 1 - forward pass
        # write your code here
        predictions = model(features)

        # Step 2 - computing the loss
        loss = loss_fn(predictions, labels)

        # Step 3 - computing the gradients
        # Tip: it requires a single method call to backpropagate gradients
        loss.backward()

        batch_losses.append(loss.item())

        # Step 4 - updating parameters and zeroing gradients
        optimizer.step()
        optimizer.zero_grad()
        
    losses[epoch] = torch.tensor(batch_losses).mean()

    ## Validation   
    with torch.inference_mode():
        batch_losses = []

        for i, val_batch in enumerate(dataloaders['val']):
            # Set the model to evaluation mode
            model.eval()

            # Send batch features and targets to the device
            features = val_batch[0].float().to(device)
            labels = val_batch[1].long().to(device)

            # Step 1 - forward pass
            predictions = model(features)

            # Step 2 - computing the loss
            loss = loss_fn(predictions, labels)

            batch_losses.append(loss.item())

        val_losses[epoch] = torch.tensor(batch_losses).mean()

        scheduler.step(val_losses[epoch])
        print(losses[epoch], val_losses[epoch])
        
        if val_losses[epoch] < best_loss:
            best_loss = val_losses[epoch]
            best_epoch = epoch
            torch.save({'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()},
                       'best_model.pth')
        elif (epoch - best_epoch) > patience:
            print(f"Early stopping at epoch #{epoch}")
            break


In [None]:
states = torch.load('best_model.pth')
model.load_state_dict(states['model'])

In [None]:
metric1 = evaluate.load('precision', average=None)
metric2 = evaluate.load('recall', average=None)
metric3 = evaluate.load('accuracy')

model.eval()

for split in ['train_base', 'val', 'test']:
    for batch in tqdm(dataloaders[split]):
        features, labels = batch
        features = features.float().to(device)
            
        predictions = model(features)
    
        pred_class = predictions.argmax(dim=1).squeeze()
        
        pred_class = pred_class.tolist()
        labels = labels.tolist()
    
        metric1.add_batch(references=labels, predictions=pred_class)
        metric2.add_batch(references=labels, predictions=pred_class)
        metric3.add_batch(references=labels, predictions=pred_class)
    print(split, metric1.compute(average=None), metric2.compute(average=None), metric3.compute())