# DeepGSR Hybrid CNN-Transformer Model

This notebook implements a hybrid architecture combining the **CNN** feature extractor (from `DeepGSR_CNN_model.ipynb`) and the **Transformer** encoder (from `DeepGSR_Transformer_model.ipynb`).

**Architecture:**
Input (One-hot Image) -> CNN Layers -> Feature Sequence -> Projection -> Transformer Encoder -> Classifier

**Testing:**
Evaluates performance on all organisms (hs, bt, dm, mm) individually.

In [None]:
!pip install evaluate tqdm scikit-learn

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import evaluate
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from tqdm import tqdm
import math
from itertools import product

In [None]:
# --- CNN Preprocessing Utils ---

mapping = {''.join(p): i for i, p in enumerate(product('ACGT', repeat=3))}

def reverse_complement(seq):
    rev = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}
    return ''.join(list(map(rev.get, [*seq]))[::-1])

rev_map = {v: mapping[reverse_complement(t1)] for t1, v in mapping.items()}

def nuc2image(seq, flip):
    if flip:
        seq = list(map(rev_map.get, seq))[::-1]
    encoded = np.eye(64, dtype=np.uint8)[seq]
    image = torch.as_tensor(encoded, dtype=torch.uint8).unsqueeze(0).unsqueeze(0)
    return image

def gen_collate_fn(n_channels=2, random_flip=0.0):
    def collate_fn(batch):
        flip = torch.rand(1) < random_flip
        upstream = torch.vstack([nuc2image(row['upstream' if not flip else 'downstream'], flip) for row in batch])
        downstream = torch.vstack([nuc2image(row['downstream' if not flip else 'upstream'], flip) for row in batch])
        features = torch.cat([upstream, downstream], axis=2)
        width = features.shape[2] // n_channels
        leftover = features.shape[2] % n_channels
        start = leftover // 2
        end = -(start + leftover % 2)
        if (end == 0):
            end = features.shape[2]
        features = features[:, :, start:end].view(-1, n_channels, width, 64)
        return (features, torch.as_tensor([row['label'] for row in batch]))
    return collate_fn

In [None]:
dataset = load_dataset('dvgodoy/DeepGSR_trinucleotides', split='train')
dataset = dataset.shuffle(seed=13)
train_test = dataset.train_test_split(test_size=0.25, shuffle=False)
train_val = train_test['train'].train_test_split(test_size=0.2, shuffle=False)
dataset = DatasetDict({'train': train_val['train'], 'val': train_val['test'], 'test': train_test['test']})
dataset

In [None]:
signal = 'PAS'
motif = 'AATAAA'
# organism = 'hs' # Training on all organisms for better generalization
dataset = dataset.filter(lambda row: row['signal'] == signal and row['motif'] == motif)

In [None]:
bsize = 256
n_channels = 2
dataloaders = {}
dataloaders['train'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn(n_channels=n_channels, random_flip=0.25))
dataloaders['train_base'] = DataLoader(dataset['train'], batch_size=bsize, shuffle=True, collate_fn=gen_collate_fn(n_channels=n_channels))
dataloaders['val'] = DataLoader(dataset['val'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn(n_channels=n_channels))
dataloaders['test'] = DataLoader(dataset['test'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn(n_channels=n_channels))

In [None]:
class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding.
    """
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        pe = torch.zeros(max_len, d_model) 
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) *
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


In [None]:
class HybridCNNTransformer(nn.Module):
    def __init__(self, cnn_in_channels=2, num_classes=2, d_model=256, n_heads=4, num_layers=2, dim_feedforward=512, dropout=0.1):
        super().__init__()
        
        # --- CNN Components (Feature Extractor) ---
        self.conv1 = nn.Conv2d(cnn_in_channels, 32, kernel_size=(30, 31), padding='same', bias=True)
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(10, 8), bias=True)
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.dropout2d = nn.Dropout2d(p=0.3)
        
        # Calculation of CNN output dimensions:
        # Input W (Height in Conv terms) = 298 (approx), H (Width in Conv terms) = 64
        # Conv1 (Same) -> W=298, H=64
        # Pool1 (1,2) -> W=298, H=32
        # Conv2 (10,8) -> W=298-10+1=289, H=32-8+1=25
        # Pool2 (1,2) -> W=289, H=(25-2)//2+1=12
        # Output shape: (B, 64, 289, 12)
        
        cnn_features_dim = 64 * 12 # 768
        
        # --- Projection to Transformer Dim ---
        self.feature_projection = nn.Linear(cnn_features_dim, d_model)
        
        # --- Transformer Components ---
        self.pos_enc = PositionalEncoding(d_model, max_len=600)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # --- Classification Head ---
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # x: (B, 2, 298, 64)
        
        # CNN Forward
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.dropout2d(x)
        # x shape: (B, 64, 289, 12)
        
        # Prepare for Transformer
        # We treat dimension 2 (289) as the Sequence Length
        x = x.permute(0, 2, 1, 3) # (B, 289, 64, 12)
        B, L, C, H = x.shape
        x = x.reshape(B, L, C * H) # (B, 289, 768)
        
        # Project and Transform
        x = self.feature_projection(x) # (B, 289, d_model)
        x = self.pos_enc(x)
        x = self.encoder(x)
        
        # Pooling (Mean)
        x = x.mean(dim=1) # (B, d_model)
        
        # Classify
        logits = self.fc(x)
        return logits

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(1337)

model = HybridCNNTransformer(
    cnn_in_channels=2,
    num_classes=2,
    d_model=256,
    n_heads=4,
    num_layers=2,
    dim_feedforward=512,
    dropout=0.1
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

In [None]:
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

def lr_lambda(step):
    total_steps = 270 # Approx based on epochs * splits (adjust if needed)
    warmup_steps = 0.06*total_steps
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
num_epochs = 100

losses = torch.empty(num_epochs)
val_losses = torch.empty(num_epochs)

best_loss = torch.inf
best_epoch = -1
patience = 10

progress_bar = tqdm(range(num_epochs))

for epoch in progress_bar:
    batch_losses = []
    
    ## Training
    model.train()
    for batch in dataloaders['train']:
        features = batch[0].float().to(device)
        labels = batch[1].long().to(device)
        
        predictions = model(features)
        loss = loss_fn(predictions, labels)
        loss.backward()

        batch_losses.append(loss.item())
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
    losses[epoch] = torch.tensor(batch_losses).mean()

    ## Validation   
    model.eval()
    with torch.no_grad():
        batch_losses = []
        for val_batch in dataloaders['val']:
            features = val_batch[0].float().to(device)
            labels = val_batch[1].long().to(device)

            predictions = model(features)
            loss = loss_fn(predictions, labels)
            batch_losses.append(loss.item())

        val_losses[epoch] = torch.tensor(batch_losses).mean()
        
        progress_bar.set_description(f"Train Loss: {losses[epoch]:.4f}, Val Loss: {val_losses[epoch]:.4f}")
        
        if val_losses[epoch] < best_loss:
            best_loss = val_losses[epoch]
            best_epoch = epoch
            torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'best_hybrid_model.pth')
        elif (epoch - best_epoch) > patience:
            print(f"Early stopping at epoch #{epoch}")
            break

In [None]:
states = torch.load('best_hybrid_model.pth')
model.load_state_dict(states['model'])
print("Loaded best model.")

In [None]:
metric1 = evaluate.load('precision', average=None)
metric2 = evaluate.load('recall', average=None)
metric3 = evaluate.load('accuracy')

model.eval()

# Reload dataset for filtering logic if needed, or reuse 'dataset' object
for split in ['train', 'val', 'test']:
    # Species specific testing
    subsets = [(org, dataset[split].filter(lambda row: row['organism'] == org and row['signal'] == signal and row['motif'] == motif)) for org in ['hs', 'bt', 'dm', 'mm']]
    
    for org, subset in subsets:
        print(f'Set: {split} / Organism: {org}')
        if len(subset) == 0:
            print("No samples.")
            continue
            
        dl = DataLoader(subset, batch_size=256, collate_fn=gen_collate_fn(n_channels=2))
        
        for batch in tqdm(dl):
            features, labels = batch
            features = features.float().to(device)
            
            with torch.no_grad():
                predictions = model(features)
    
            pred_class = predictions.argmax(dim=1).squeeze().tolist()
            labels = labels.tolist()
    
            metric1.add_batch(references=labels, predictions=pred_class)
            metric2.add_batch(references=labels, predictions=pred_class)
            metric3.add_batch(references=labels, predictions=pred_class)
            
        try:
            print(split, metric1.compute(average=None), metric2.compute(average=None), metric3.compute())
        except:
            print(split, metric1.compute(), metric2.compute(), metric3.compute())