In [1]:
import math
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
def data_from_df(df, target_level = None, label_pipeline=None):
    barcodes = df["nucleotides"].to_list()
     
    if target_level:
        species = df[target_level].to_list()
        species = np.array(list(map(label_pipeline, species)))

    print(f"[INFO]: There are {len(barcodes)} barcodes")
    # Number of training samples and entire data
    N = len(barcodes)

    # Reading barcodes and labels into python list
    labels = []

    for i in range(N):
        if len(barcodes[i]) > 0:
            barcodes.append(barcodes[i])
            if target_level:
                labels.append(species[i])

    sl = 660  # Max_length

    nucleotide_dict = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}

    X = np.zeros((N, sl, 5), dtype=np.float32) #Can't do zeros because 
    for i in range(N):
        j = 0
        while j < min(sl, len(barcodes[i])):
            k = nucleotide_dict[barcodes[i][j]]
            X[i][j][k] = 1.0
            j += 1
                
            
    # print(X.shape, )
    return X, np.array(labels)

df = pd.read_csv("data/dev.csv")
X, labels = data_from_df(df)

[INFO]: There are 4278 barcodes


In [3]:

# 1) DATASET WITH CONTIGUOUS-CHUNK MASKING
class MaskedOneHotDataset(Dataset):
    def __init__(self, X, mask_ratio=0.15, chunk_size=4):
        """
        X: numpy array of shape (N, L, 5), one-hot over {A,C,G,T,N}.
           Padding rows must be all zeros.
        mask_ratio: fraction of real tokens to mask
        chunk_size: length of each contiguous masked span
        """
        self.X = torch.from_numpy(X)  # (N, L, 5)
        self.mask_ratio = mask_ratio
        self.chunk_size = chunk_size

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        x = self.X[idx].clone()  # (L, 5)
        L = x.size(0)

        # 1) Padding mask: 1 for real tokens (sum>0), 0 for padding (zero-vector)
        att_mask = (x.sum(dim=-1) > 0).int()

        # 2) Targets: argmax over one-hot; set padding positions to -1
        targets = x.argmax(dim=-1)
        targets[att_mask == 0] = -1  # padding → -1

        # 3) Determine valid maskable positions: real A/C/G/T only
        valid = (targets >= 0) & (targets < 4)
        n_valid = valid.sum().item()
        n_chunks = math.ceil(self.mask_ratio * n_valid / self.chunk_size)

        # 4) Sample non-overlapping contiguous spans
        starts = []
        while len(starts) < n_chunks:
            s = random.randrange(0, L - self.chunk_size + 1)
            if valid[s] and all(abs(s - p) >= self.chunk_size for p in starts):
                starts.append(s)

        mask = torch.zeros(L, dtype=torch.bool)
        for s in starts:
            mask[s : s + self.chunk_size] = True
        mask &= valid  # ensure no padding or N masked

        # 5) Apply uniform fill (1/5) to masked positions
        x_masked = x.clone()
        x_masked[mask] = 1.0 / 5.0

        return x_masked, targets, att_mask, mask
    
dataset = MaskedOneHotDataset(X)
dataset[3]

(tensor([[0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]),
 tensor([ 4,  0,  1,  3,  3,  3,  0,  3,  0,  1,  3,  3,  3,  0,  3,  3,  3,  3,
          3,  2,  2,  0,  2,  1,  3,  3,  2,  0,  3,  1,  2,  2,  2,  0,  0,  3,
          0,  0,  3,  3,  2,  2,  0,  0,  1,  3,  3,  1,  0,  3,  3,  0,  0,  2,
          0,  0,  3,  3,  1,  3,  0,  0,  3,  3,  1,  2,  0,  2,  1,  1,  2,  0,
          0,  1,  3,  0,  2,  2,  0,  1,  0,  3,  1,  1,  3,  2,  2,  0,  2,  1,
          0,  3,  3,  0,  0,  3,  3,  2,  2,  0,  2,  0,  3,  2,  0,  3,  1,  0,
          0,  0,  3,  3,  3,  0,  3,  0,  0,  3,  2,  3,  0,  0,  3,  3,  2,  3,
          3,  0,  1,  0,  2,  1,  3,  1,  0,  3,  2,  1,  3,  3,  3,  3,  0,  3,
          3,  0,  3,  0,  0,  3,  3,  3,  3,  3,  3,  3,  3,  0,  3,  0,  2,  3,
          0,  0,  3,  0,  1,  1,  0,  0,  3,  3,  0,  3,  0,  0,  3, 

In [4]:
class CNN_MLM(nn.Module):
    def __init__(
        self,
        max_len: int,
        d_model: int = 768,
        nhead: int = 4,
        num_layers: int = 3,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.max_len = max_len

        # --- Single Conv block + Norm + Dropout + Pool ---
        self.conv = nn.Conv1d(5, d_model, kernel_size=4, stride=2, padding=1)
        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        # Compute downsampled length
        # After conv: ceil(max_len/2), after pool: ceil(prev/2)
        L1 = math.ceil(max_len / 2)
        L2 = math.ceil(L1 / 2)
        self.down_len = L2

        # --- Learned positional embeddings ---
        self.pos_emb = nn.Embedding(self.down_len, d_model)

        # --- BERT-style Transformer ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead-1,
            dim_feedforward=4*d_model,
            dropout=dropout,
            activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # --- Upsampling via ConvTranspose1d x2 ---
        self.up1 = nn.ConvTranspose1d(d_model, d_model, kernel_size=2, stride=2)
        self.act_up1 = nn.GELU()
        self.up2 = nn.ConvTranspose1d(d_model, d_model, kernel_size=2, stride=2)
        self.act_up2 = nn.GELU()

        # --- Final classification head over 4 bases ---
        self.classifier = nn.Linear(d_model, 4)

    def forward(self, x: torch.Tensor, att_mask: torch.Tensor):
        """
        x: (B, L, 5) one-hot input
        att_mask: (B, L) attention mask (1 = valid, 0 = pad)
        """
        B, L, _ = x.shape

        # --- Conv stage ---
        h = x.transpose(1, 2)            # (B, 5, L)
        h = self.conv(h)                # (B, d_model, L1)
        h = h.transpose(1, 2)            # (B, L1, d_model)
        h = self.norm(h)
        h = F.gelu(h)
        h = self.drop(h)
        h = h.transpose(1, 2)            # (B, d_model, L1)
        h = self.pool(h)                # (B, d_model, down_len)

        # --- Add positional embeddings ---
        h = h.transpose(1, 2)            # (B, down_len, d_model)
        pos = self.pos_emb.weight.unsqueeze(0)  # (1, down_len, d_model)
        h = h + pos

        # --- Transformer ---
        # Downsample attention mask to match down_len
        factor = L // self.down_len
        att_ds = att_mask[:, ::factor] == 0
        z = self.transformer(
            h.permute(1, 0, 2),
            src_key_padding_mask=att_ds
        )                                # (down_len, B, d_model)
        z = z.permute(1, 0, 2)           # (B, down_len, d_model)

        # --- Upsampling back to original resolution ---
        u = z.transpose(1, 2)            # (B, d_model, down_len)
        u = self.act_up1(self.up1(u))    # (B, d_model, down_len*2)
        u = self.act_up2(self.up2(u))    # (B, d_model, down_len*4)
        u = u.transpose(1, 2)            # (B, L_out, d_model)

        # --- Classification across 4 bases ---
        logits = self.classifier(u)      # (B, L_out, 4)
        return logits

In [5]:
loader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)
model = CNN_MLM(660)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")
seq_masked, targets, att_mask, mask = next(iter(loader))
print(seq_masked.shape)
print(model(seq_masked, att_mask))


Number of trainable parameters: 23771908
torch.Size([10, 660, 5])
tensor([[[ 0.0049, -0.0096,  0.0476,  0.0325],
         [ 0.0443, -0.0489,  0.0145,  0.0230],
         [ 0.0309, -0.0773,  0.0362,  0.0444],
         ...,
         [ 0.0180, -0.0442,  0.0173,  0.0157],
         [ 0.0036, -0.0191,  0.0598,  0.0135],
         [ 0.0674, -0.0493, -0.0094,  0.0099]],

        [[ 0.0079, -0.0030,  0.0570,  0.0162],
         [ 0.0463, -0.0486,  0.0438,  0.0364],
         [ 0.0193, -0.0931,  0.0470,  0.0485],
         ...,
         [ 0.0086, -0.0357,  0.0365,  0.0338],
         [ 0.0301, -0.0172,  0.0117,  0.0206],
         [ 0.0446, -0.0377,  0.0238,  0.0149]],

        [[ 0.0068, -0.0081,  0.0382,  0.0075],
         [ 0.0281, -0.0725,  0.0260,  0.0490],
         [ 0.0097, -0.0555,  0.0436,  0.0517],
         ...,
         [ 0.0259, -0.0437,  0.0361,  0.0388],
         [ 0.0242,  0.0009,  0.0133,  0.0317],
         [ 0.0477, -0.0406,  0.0132,  0.0199]],

        ...,

        [[ 0.0126, -0.0205

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from torch.nn.utils import clip_grad_norm_

def train(
    model: nn.Module,
    X: torch.Tensor,
    epochs: int = 5,
    batch_size: int = 32,
    lr: float = 1e-4,
    weight_mask: float = 2.0,
    device: str = "cuda",
):
    """
    Training loop that:
      - takes full-length logits from model(x_masked, att_mask)
      - splits masked vs. seen losses (weighted by weight_mask)
      - logs per-100-step losses
      - uses OneCycleLR + gradient clipping
    Assumes:
      MaskedOneHotDataset returns (x_masked, targets, att_mask, mask)
      with x_masked: (B, L, 5),
           targets:  (B, L) in {-1,0,1,2,3,4},
           att_mask: (B, L),
           mask:     (B, L) Boolean for masked A/C/G/T only.
      model(x_masked, att_mask) -> logits (B, L, 4).
    """
    loader = DataLoader(
        MaskedOneHotDataset(X, mask_ratio=0.15, chunk_size=4),
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
    )

    model     = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=lr,
        epochs=epochs,
        steps_per_epoch=len(loader),
        pct_start=0.3,
        anneal_strategy="linear",
    )
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        total_loss_masked = total_acc_masked = count_masked = 0
        total_loss_seen   = total_acc_seen   = count_seen   = 0

        for step, (x_masked, targets, att_mask, mask) in enumerate(loader, start=1):
            # --- Move to GPU ---
            x_masked = x_masked.to(device)    # (B, L, 5)
            targets  = targets.to(device)     # (B, L)
            att_mask = att_mask.to(device)    # (B, L)
            mask      = mask.to(device)       # (B, L)

            # --- Forward ---
            logits = model(x_masked, att_mask)  # (B, L, 4)
            print(logits.shape)

            preds = logits.argmax(dim=-1)       # (B, L)

            # --- Build masks for loss ---
            valid_pos  = (targets >= 0) & (targets < 4)
            masked_pos = mask & valid_pos
            seen_pos   = (~mask) & valid_pos

            # --- Compute losses ---
            loss_masked = (
                criterion(logits[masked_pos], targets[masked_pos])
                if masked_pos.any() else torch.tensor(0.0, device=device)
            )
            loss_seen = (
                criterion(logits[seen_pos], targets[seen_pos])
                if seen_pos.any() else torch.tensor(0.0, device=device)
            )
            loss = weight_mask * loss_masked + loss_seen
            print(loss)

            # --- Backprop & step ---
            optimizer.zero_grad()
            loss.backward()
            #clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            # --- Accumulate metrics ---
            if masked_pos.any():
                total_acc_masked  += (preds[masked_pos] == targets[masked_pos]).sum().item()
                total_loss_masked += loss_masked.item() * masked_pos.sum().item()
                count_masked     += masked_pos.sum().item()

            if seen_pos.any():
                total_acc_seen  += (preds[seen_pos] == targets[seen_pos]).sum().item()
                total_loss_seen += loss_seen.item() * seen_pos.sum().item()
                count_seen     += seen_pos.sum().item()

            # --- Log every 100 steps ---
            if step % 100 == 0:
                print(
                    f"Epoch {epoch} Step {step}/{len(loader)} | "
                    f"Loss_masked: {loss_masked.item():.4f} | "
                    f"Loss_seen: {loss_seen.item():.4f} | "
                    f"Total: {loss.item():.4f}"
                )

        # --- Epoch summary ---
        avg_loss_masked = total_loss_masked / count_masked if count_masked else 0.0
        avg_acc_masked  = 100.0 * total_acc_masked / count_masked if count_masked else 0.0
        avg_loss_seen   = total_loss_seen  / count_seen   if count_seen   else 0.0
        avg_acc_seen    = 100.0 * total_acc_seen   / count_seen   if count_seen   else 0.0

        print(
            f"Epoch {epoch}/{epochs} DONE ➞ "
            f"Masked Loss: {avg_loss_masked:.4f}, Acc: {avg_acc_masked:.2f}% | "
            f"Seen Loss: {avg_loss_seen:.4f}, Acc: {avg_acc_seen:.2f}%"
        )


# One-liner to launch training:
# train(model, X, epochs=10, batch_size=32, lr=1e-4, weight_mask=2.0, device='cuda')


In [9]:
train(model, X, epochs=10, batch_size=32, lr=1e-4, weight_mask=0.5, device='cpu')

torch.Size([32, 660, 4])
tensor(2.0745, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0745, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0750, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0737, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0749, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0746, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0744, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0748, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0743, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0745, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0729, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0739, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0746, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0736, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0738, grad_fn=<AddBackward0>)
torch.Size([32, 660, 4])
tensor(2.0738, 

KeyboardInterrupt: 