# Preliminaries


## Copy step

In [1]:
# %load_ext autoreload
# %autoreload 2

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch

from google.colab import drive
drive.mount("/content/gdrive")
# drive.mount('/content/gdrive')

# ssl training data (numpy arrays)
!cp -r /content/gdrive/MyDrive/smrt-foundation/ob007.memmap/ /content/
# downstream methylation dataset (parquet, tabular)
!cp -r /content/gdrive/MyDrive/smrt-foundation/pacbio_standard_train_1m.parquet /content/
!cp -r /content/gdrive/MyDrive/smrt-foundation/pacbio_standard_test_1m.parquet /content/
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device('cpu')


Mounted at /content/gdrive


In [2]:
import sys
import gc
from tqdm import tqdm
!pip install "vegafusion[embed]>=1.5.0" "vl-convert-python>=1.6.0"
sys.path.append('/content/gdrive/MyDrive/smrt-foundation')


Collecting vegafusion>=1.5.0 (from vegafusion[embed]>=1.5.0)
  Downloading vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Collecting vl-convert-python>=1.6.0
  Downloading vl_convert_python-1.9.0.post1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.1 kB)
Collecting arro3-core (from vegafusion>=1.5.0->vegafusion[embed]>=1.5.0)
  Downloading arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (363 bytes)
Downloading vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m81.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading vl_convert_python-1.9.0.post1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (33.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.5/33.5 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading arro3_core-0.6.5-cp311-abi3-

## Glimpse ssl data

In [3]:
# this is our pretraining dataset
# take a peak at one of its numpy shards
# note how each slice of the array is a single-stranded sample for pretraining
# columns are organized [seq, ipd, pw, padding_mask]

import numpy as np
x = torch.tensor(np.load('ob007.memmap/shard_00002.npy')).to(device)
mask = ~x[...,-1].bool()
print(x.shape)
print(x[0,0:10])

torch.Size([16384, 4096, 4])
tensor([[ 2.0000, -0.0258,  0.7886,  0.0000],
        [ 0.0000,  0.5142,  1.3975,  0.0000],
        [ 0.0000, -1.3145,  0.3120,  0.0000],
        [ 3.0000,  1.6426, -1.4004,  0.0000],
        [ 0.0000, -0.5776, -0.5669,  0.0000],
        [ 2.0000,  0.8359,  0.9258,  0.0000],
        [ 1.0000, -0.7842, -0.0809,  0.0000],
        [ 1.0000, -1.8984, -1.6172,  0.0000],
        [ 1.0000, -0.8994,  0.1249,  0.0000],
        [ 3.0000, -0.0928, -0.8608,  0.0000]], device='cuda:0',
       dtype=torch.float16)


In [4]:
# first column is seq, center two are features which should have mean 0, sd 1
# last column is mask and should be 0's and 1's (mostly 0's)

x.mean(dim=(0, 1)).round(decimals=2), x.std(dim=(0,1)).round(decimals=2)

(tensor([ 1.1904, -0.0100,  0.0000,  0.2000], device='cuda:0',
        dtype=torch.float16),
 tensor([1.1904, 0.8999, 0.8901, 0.3999], device='cuda:0', dtype=torch.float16))

In [5]:
# View the distribution of nucleotides in the flattened seq column
# note that in the natural genome it is not a uniform distribution
# what we see here after subsetting with the mask matches expectations

import altair as alt
alt.data_transformers.enable("vegafusion")
import polars as pl
seq = x[:, :, 0][mask].flatten()
seq_df = pl.DataFrame({'seq':seq.to(cpu)})
alt.Chart(seq_df).mark_bar(width=70).encode(
    alt.X('seq:Q'),
    y='count()'
)

In [6]:
# and here we can see the proportions of each nucleotide in tabular format
seq = x[:, :, 0][mask.to(cpu)].flatten()
seq_df = pl.DataFrame({'seq':seq.to(cpu)})
seq_df["seq"].value_counts(sort=True, normalize=True)

seq,proportion
f32,f64
0.0,0.271047
3.0,0.271047
2.0,0.228953
1.0,0.228953


In [7]:
# check that the long transformation doesn't corrupt the floats
# if it did, we might see an excess of 0 (A)
seq = x[:, :, 0].long()[mask.to(cpu)].flatten()
seq_df = pl.DataFrame({'seq':seq.to(cpu)})
seq_df["seq"].value_counts(sort=True, normalize=True)

seq,proportion
i64,f64
0,0.271047
3,0.271047
2,0.228953
1,0.228953


In [8]:


# plot the histograms of the two features
cont_df = pl.DataFrame({'ipd':x[...,1][mask].flatten().to(cpu),
                        'pw':x[...,2][mask].flatten().to(cpu)})
alt.Chart(cont_df.unpivot()).mark_bar().encode(
      alt.X('value:Q').title('normalized zmw frames'),
      alt.Y('count():Q').scale(type='linear').title('count'),
  ).properties(
      width=400,
      height=400,
  ).facet(
      column='variable:N'
  ).properties(
      title="Memmap Kinetics Distributions"
  )

## Glimpse downstream data

This is our current downstream dataset. Let's look at the first 10 samples. Note how each row is both a forward and reverse sample, and the features are not normalized

In [9]:

df_train = pl.read_parquet('pacbio_standard_train_1m.parquet')
df_train.head(10)

read_name,cg_pos,seq,qual,np,fi,fp,ri,rp,label
str,i64,str,list[u8],u8,list[u16],list[u16],list[u16],list[u16],i32
"""m64168_200820_000733/48169889/…",3058,"""GATGTCCTGGGGATTCGGGGGCATAACTGC…","[60, 67, … 69]",8,"[15, 29, … 35]","[7, 19, … 23]","[10, 10, … 5]","[34, 39, … 33]",0
"""m64168_200820_000733/45943110/…",8167,"""TCTCCACGTTGGCCACGCTGGTCTCGAACT…","[93, 73, … 93]",13,"[33, 18, … 26]","[20, 46, … 27]","[48, 70, … 20]","[20, 16, … 51]",0
"""m64168_200823_191315/50332760/…",1413,"""AATTTCTTGAAGAGACGAAAGTCTGTGGGT…","[93, 93, … 93]",32,"[32, 12, … 18]","[21, 9, … 13]","[16, 17, … 13]","[16, 23, … 22]",1
"""m64168_200823_191315/177537981…",4708,"""CAACCCACTGCCAAGCGCTTCCTGCCACCT…","[93, 82, … 93]",9,"[13, 19, … 19]","[23, 14, … 34]","[9, 21, … 31]","[21, 13, … 34]",1
"""m64168_200823_191315/49154585/…",5695,"""CCTCCCTACCGAAAACGGGGATCGTGTGAA…","[13, 58, … 53]",3,"[6, 35, … 10]","[12, 34, … 19]","[17, 13, … 20]","[24, 12, … 43]",1
"""m64168_200823_191315/163316698…",6320,"""ATGCAATCAACCTAACGTAAGTGCTCTCAC…","[93, 93, … 93]",24,"[38, 51, … 76]","[25, 12, … 15]","[28, 18, … 21]","[14, 25, … 21]",1
"""m64168_200820_000733/4260355/c…",2902,"""CACCATGCCTGGCCACGAGACCCCATCTCA…","[93, 93, … 93]",28,"[11, 8, … 10]","[23, 11, … 17]","[21, 13, … 23]","[39, 34, … 26]",0
"""m64168_200823_191315/71829932/…",5186,"""GGCAAGGCCCAGGCACGTGGTGCATCTGAA…","[93, 93, … 93]",22,"[28, 30, … 18]","[19, 24, … 29]","[23, 12, … 25]","[17, 29, … 19]",1
"""m64168_200820_000733/50529600/…",6311,"""GAGGGTGGGGGTTAGCGAGTGATAGTGTGG…","[93, 93, … 93]",15,"[13, 8, … 12]","[32, 38, … 17]","[28, 9, … 13]","[27, 22, … 39]",0
"""m64168_200820_000733/2360276/c…",3405,"""GCTGGAGTGCAGTGACGTGATCTCGGCTCA…","[93, 60, … 88]",6,"[26, 32, … 21]","[28, 50, … 13]","[20, 26, … 32]","[13, 30, … 10]",0


In [10]:
df_val = pl.read_parquet('pacbio_standard_test_1m.parquet')
df_val.head(10)

read_name,cg_pos,seq,qual,np,fi,fp,ri,rp,label
str,i64,str,list[u8],u8,list[u16],list[u16],list[u16],list[u16],i32
"""m64168_200823_191315/78053961/…",3630,"""GGAGTCTCACTCTGTCGCCCAGGCTGGAGC…","[93, 93, … 93]",34,"[20, 22, … 17]","[43, 41, … 15]","[26, 15, … 11]","[41, 25, … 22]",1
"""m64168_200823_191315/151587048…",5646,"""TGCAGCAACACATGACGCATTCTAAAATGT…","[93, 93, … 93]",14,"[26, 37, … 19]","[9, 25, … 18]","[52, 71, … 14]","[20, 20, … 9]",1
"""m64168_200823_191315/139657571…",4396,"""ACATTTTTAAGTTGCCGTCTCTAGGACAAA…","[93, 93, … 93]",18,"[20, 62, … 22]","[21, 24, … 19]","[23, 25, … 37]","[18, 13, … 21]",1
"""m64168_200820_000733/18940076/…",7207,"""GCAGTGGCATGATCTCGGCTCACTGCAACC…","[93, 93, … 93]",27,"[9, 14, … 11]","[26, 13, … 36]","[15, 22, … 14]","[23, 17, … 21]",0
"""m64168_200823_191315/76875948/…",5185,"""GTCTCCAGCACCCAGCGCTCCCACAAGCCT…","[93, 93, … 93]",17,"[10, 29, … 39]","[13, 30, … 41]","[27, 8, … 16]","[26, 13, … 15]",1
"""m64168_200823_191315/170133346…",2983,"""AGTTCTTGCCTAGCTCGACCTCAGTCCCGT…","[93, 93, … 93]",26,"[13, 15, … 28]","[13, 14, … 30]","[24, 30, … 11]","[23, 16, … 15]",1
"""m64168_200823_191315/148045839…",5830,"""GGGCGCGGTGGCTCACGCCTGTAATCCCAG…","[31, 93, … 93]",22,"[18, 14, … 11]","[22, 13, … 13]","[23, 31, … 56]","[18, 15, … 26]",1
"""m64168_200820_000733/18220356/…",267,"""TAAGTTTCTAGTAACCGTATTAAAAAGTAA…","[93, 93, … 93]",22,"[16, 18, … 46]","[19, 31, … 40]","[17, 17, … 23]","[40, 33, … 22]",0
"""m64168_200820_000733/6751166/c…",4642,"""GAGGTTGTGGTAAGCCGAGATCGCGCCATT…","[93, 93, … 93]",20,"[17, 16, … 30]","[20, 19, … 18]","[19, 13, … 14]","[31, 21, … 23]",0
"""m64168_200820_000733/45418455/…",399,"""CTGGGTGTGGTGGCACGTGCCTGTAATCTC…","[93, 93, … 93]",10,"[28, 43, … 6]","[7, 9, … 14]","[23, 16, … 13]","[45, 10, … 11]",0


# SSL Dataset Class

Each "shard" is a 512 MB numpy array, and so with this truncated datset we have around 15 GB of data. Based on the test below, it looks like we can transfer that at a rate well over 1 GB/s to the machine.

In [11]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader, Dataset
import glob
import os
from collections import OrderedDict


class ShardedMemmapDataset(Dataset):
    def __init__(self, data_dir, cache_size=100):
        expanded_dir = os.path.expandvars(data_dir)
        self.shard_paths = sorted(glob.glob(os.path.join(expanded_dir, "*.npy")))
        first_shard = np.load(self.shard_paths[0], mmap_mode='r')
        self.shard_size = first_shard.shape[0]
        last_shard = np.load(self.shard_paths[-1], mmap_mode='r')
        self.total_len = ((len(self.shard_paths) - 1) * self.shard_size) + last_shard.shape[0]
        self.cache_size = cache_size
        self.memmaps = OrderedDict()

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        shard_idx = idx // self.shard_size
        local_idx = idx % self.shard_size
        if shard_idx not in self.memmaps:
            if len(self.memmaps) >= self.cache_size:
                self.memmaps.popitem(last=False)
            self.memmaps[shard_idx] = np.load(self.shard_paths[shard_idx], mmap_mode='r')
        else:
            self.memmaps.move_to_end(shard_idx)
        return torch.from_numpy(np.array(self.memmaps[shard_idx][local_idx])).bfloat16()

In [12]:
from tqdm import tqdm
# ssl_ds = ShardedMemmapDataset("ob007.memmap/")
# ssl_dl = DataLoader(ssl_ds, batch_size=256, num_workers=4, pin_memory=True, prefetch_factor=2, shuffle=True)

# for batch in iter(tqdm(ssl_dl)):
#   x = batch.to(device)


## Load dataset

In [13]:

SEQ_LEN = 4096
BATCH_SIZE = 64
D_MODEL = 128

ssl_ds = ShardedMemmapDataset("ob007.memmap/")
ssl_dl = DataLoader(ssl_ds, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, prefetch_factor=2, shuffle=True)


# Smrt2Vec

## Dataflow Plan
### Embed data: [B, T, C] -> [B, T, d_model] = E
Since we have 1 categorical channel and 2 continuous channels we'll use a hybrid embedding. The nucleotide channel gets an embedding table, and the 2 continuous kinetics channels get a single linear projection with a GeLU nonlinearity. Note the continuous channels are normalized across the genome to have 0 mean, unit variance. GeLU is important for this since it allows negative values...

### Extract features: [B, T, d_model], Pad -> [B, T', d_model], [B, T', 1] = Z, Pad'
Separate out the padding channel. Runn a CNN over the sequence to generate a new sequence with features. Calculate the new padding mask based on the the CNN downsampling stride.

### Mask random indices: [B, T', d_model], Pad' -> [B, T', d_model] = Z_masked, Mask_idx
We mask the output of the CNN at randomly sampled indices (say 5 percent of them) and then replace a window (say 5 indices) starting at that index with the learnable padding vector (d_model) such that the sequence length remains the same as the output of the CNN.

### Positional encoding: [B, T', d_model] -> [B, T', d_model]
Only add the positional encoding at this point since the CNN and addition of masking vectors would overwrite its information otherwise

### Transformer block: [B, T', d_model], Pad' -> [B, T', d_model]
Run through a series of transformer blocks to get contextualized embeddings

### Compute contrastive loss: [B, T', d_model], Mask_idx -> Loss
Using the masked indices, use each C_t from the transformer output to predict the latent embedding. We will use an MLP for this transformation, and a separate one for the targets, and I suspect a smaller space than d_model will perform better (say 32 instead of 128). Score the prediction with infoNCE, so how much more similar is the predicted embedding vector to the true target at the position (which we retained) in comparison to a set of randomly sampled indices from the batch.


## Model Building Blocks

In [14]:


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.bfloat16).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MLP(nn.Module):
    def __init__(self, d_model, expansion=4):
        super().__init__()
        self.c_fc = nn.Linear(d_model, d_model * expansion)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(d_model * expansion, d_model)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class SmrtEmbedding(nn.Module):
  def __init__(self, d_model, n_nucleotides=5, n_continuous=2):
    super().__init__()
    self.nuc_embed = nn.Embedding(n_nucleotides, d_model//2)
    self.kin_embed = nn.Linear(n_continuous, d_model//2, dtype=torch.bfloat16)
    self.layernorm = nn.LayerNorm(d_model)
    self.d_model = d_model
  def forward(self, x_nuc, x_kin, is_padding):
    scale = math.sqrt(self.d_model)
    seq_emb = self.nuc_embed(x_nuc.int())*scale
    kin_emb = self.kin_embed(x_kin)*scale
    x = torch.concat((seq_emb,kin_emb),dim=-1)
    x = self.layernorm(x)
    return x

class BidirectionalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head=4, max_len=4096):
        super().__init__()
        assert d_model % n_head == 0
        self.n_head = n_head
        self.head_dim = d_model // n_head
        # produces qkv, so we output 3*d_model
        self.c_attn = nn.Linear(d_model, 3 * d_model, bias=False)
        self.c_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, x_pad, pad_val=1):
        B, T, C = x.size()
        # use one big matmul and split
        qkv = self.c_attn(x).view(B, T, 3, self.n_head, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # -> (3, B, n_head, T, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2] # -> 3 x (B, n_head, T, head_dim)

        # F.scaled_dot_product_attention expects the padding mask s.t.:
        # --- True: Attend, False: Ignore ---
        # We are committing to the fact that our mask is True for padded
        # sequences, so we need to invert it here
        # Also, we want to broadcast across the head and query dims
        # Given alignment right to left, we need to reshape to match B,H,T,T
        attn_mask = ~x_pad.view(B, 1, 1, T)
        output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=0.0 if not self.training else 0.05,
            is_causal=False # since we attend to everything outside the att_mask
        )

        output = output.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(output)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, max_len):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = BidirectionalSelfAttention(d_model, n_head, max_len)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model)

    def forward(self, x, x_pad): # includes unscaled residuals
        x = x + self.attn(self.ln1(x), x_pad)
        x = x + self.mlp(self.ln2(x))
        return x

class ResBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1):
    super(ResBlock, self).__init__()

    self.padding = (kernel_size - 1) // 2
    self.kernel_size = kernel_size

    self.bn1 = nn.BatchNorm1d(in_channels)
    self.conv1 = nn.Conv1d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride=stride,
                           padding=self.padding,
                           bias=False)
    self.bn2 = nn.BatchNorm1d(out_channels)
    self.conv2 = nn.Conv1d(in_channels=out_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride=1,
                           padding=self.padding,
                           bias=False)

    self.relu = nn.ReLU(inplace=True)
    self.stride = stride
    # projection residual
    if any([in_channels != out_channels, stride != 1]):
      self.residual = nn.Sequential(
          nn.Conv1d(in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=1, stride=stride,
                    bias=False)
          )
    # identity residual
    else:
      self.residual = nn.Sequential()
  def _resize_mask(self, mask, pad_val=1):
    if mask.dtype == torch.bool:
      mask = mask.float()
    if pad_val == 0:
      mask = F.max_pool1d(mask,
                          kernel_size=self.kernel_size,
                          stride=self.stride,
                          padding=self.padding)
    elif pad_val == 1:
      mask = 1 - F.max_pool1d(1 - mask,
                              kernel_size=self.kernel_size,
                              stride=self.stride,
                              padding=self.padding)
    else:
      raise ValueError("Invalid pad value: Pad value must be 0 or 1")
    return mask.bool()

  def forward(self, x, mask):
    out = self.relu(self.bn1(x))
    out = self.conv1(out)
    out = self.relu(self.bn2(out))
    out = self.conv2(out)
    out += self.residual(x)
    mask = self._resize_mask(mask)
    return out, mask

class CNN(nn.Module):
  def __init__(self, d_model, max_len, dropout_p):
    super().__init__()
    self.max_len = max_len
    self.in_channels = d_model
    # extractor
    self.extractor = nn.ModuleList([
          ResBlock(self.in_channels, self.in_channels, kernel_size=7),            # (B, C, T)   -> (B, C, T)

          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T)   -> (B, C, T)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T)   -> (B, C, T)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T)   -> (B, C, T)

          ResBlock(self.in_channels, self.in_channels, kernel_size=3, stride=2),  # (B, C, T)   -> (B, C, T/2)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T/2) -> (B, C, T/2)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T/2) -> (B, C, T/2)

          ResBlock(self.in_channels, self.in_channels, kernel_size=3, stride=1),  # (B, C, T/2) -> (B, C, T/4)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T/2) -> (B, C, T/4)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T/2) -> (B, C, T/4)
          ResBlock(self.in_channels, self.in_channels, kernel_size=3),            # (B, C, T/2) -> (B, C, T/4)
          ])
    self.dropout = nn.Dropout(p=dropout_p)
    # calculate fc layer input with dummy passthrough
    self.output_shapes = self._get_output_shape()

  def forward(self, x, mask):
    for block in self.extractor:
      x, mask= block(x,mask)
    return x, mask

  def _get_output_shape(self):
      """
      Returns output shapes for the data and mask
      """
      dummy_x = torch.randn(1, self.in_channels, self.max_len)
      dummy_mask = torch.randn(1, self.max_len)

      # get outputshapes
      output, mask = self.forward(dummy_x, dummy_mask)
      return output.shape, mask.shape


## Main Model Classes

In [15]:
### Encoder Class

class SmrtEncoder(nn.Module):
  def __init__(self, d_model=128, n_layers=4, n_head=4, max_len=4096, dropout_p=0.01):
    super().__init__()
    self.d_model = d_model
    self.embed = SmrtEmbedding(d_model)
    self.pe = PositionalEncoding(d_model, max_len=max_len)
    self.downsample = CNN(d_model, max_len=max_len, dropout_p=dropout_p)
    self.layer_norm_target = nn.LayerNorm(d_model)
    self.blocks = nn.ModuleList([
        TransformerBlock(d_model=d_model, n_head=n_head, max_len=max_len) for _ in range(n_layers)
        ])
  def get_latents(self, x):
    """
    Runs [x -> Embedding -> CNN -> out] stack (for training)
    Returns:
      z (downsampled latents with PE)
      z_pad (dowsampled padding mask)
      targets (latents without PE)
    """
    # separate into features and padding
    x_nuc = x[...,0]
    x_kin = x[...,1:3]
    x_pad = x[...,3]
    # generate hybrid embedding
    x = self.embed(x_nuc, x_kin, x_pad)
    # featurize the emmbeddings (cnn expect BCT)
    z, z_pad = self.downsample(x.permute(0,2,1), x_pad)
    # permute back to BTC
    z = z.permute(0,2,1)
    targets = self.layer_norm_target(z.clone())
    return z, z_pad, targets

  def add_pe(self, z):
      return self.pe(z)

  def forward_transformer(self, z, z_pad):
    """
    Runs the transformer blocks on the downsampled latents
    Returns:
      c (context aware latents)
    """
    c = z
    for block in self.blocks:
      c = block(c, z_pad)
    return c
  def forward(self, x):
    z, z_pad, _ = self.get_latents(x)
    z = self.add_pe(z)
    c = self.forward_transformer(z, z_pad)
    return c

### Main Model
class Smrt2Vec(nn.Module):
  def __init__(self, d_model=128, n_layers=4, n_head=4, max_len=4096):
    super().__init__()
    self.d_model = d_model
    self.encoder = SmrtEncoder(d_model, n_layers, n_head, max_len)

    # components specific to pretraining
    self.mask_vec = nn.Parameter(torch.randn(d_model))
    self.project =  nn.Sequential(
        nn.Linear(d_model, d_model),
        nn.GELU(), # avoid negative values being ignored with ReLU
        nn.Linear(d_model, d_model)
        )
  def apply_mask(self, x_emb, pad, prob=0.05, size=6):
    B, T, C = x_emb.shape
    mask_idx_centers = (torch.rand(B, T, device=x_emb.device) < prob) & ~(pad.bool())
    mask_idx_full = F.max_pool1d(
        mask_idx_centers.bfloat16(),
        kernel_size=size, stride=1, # hyperparameter here...
        padding=size//2
      ).bool()[:, :T] & (~pad.bool())
    x_masked = x_emb.clone()
    x_masked[mask_idx_full] = self.mask_vec.to(dtype=x_emb.dtype, device=x_emb.device)
    return x_masked, mask_idx_full
  def forward(self, x):
    # dowsampled latents with pe (no transormer block yet)
    z, z_pad, targets = self.encoder.get_latents(x)
    # mask indices for loss
    z_masked, z_masked_bool = self.apply_mask(z, z_pad)
    z_masked_pe = self.encoder.add_pe(z_masked)
    # run through transformer
    c = self.encoder.forward_transformer(z_masked_pe, z_pad)
    # project
    c_proj = self.project(c)
    return c_proj, targets.detach(), z_masked_bool # projected transformer output, detached unmasked downsampled latents (not transfomer applied), boolean matrix of where the targets are

### Loss

class InfoNCELoss(nn.Module):
  def __init__(self, temperature=0.1):
    super().__init__()
    self.cross_entropy = nn.CrossEntropyLoss()
    self.temperature = temperature
  def forward(self, c_proj, targets, mask_idx):
    # gather the predictions and truth vectors
    preds = c_proj[mask_idx]
    truth = targets[mask_idx]
    # normalize for cosine similarity
    # last dim (embedding dim)
    preds = F.normalize(preds, dim=-1)
    truth = F.normalize(truth, dim=-1)
    # print(truth.shape,preds.shape)
    logits = torch.mm(preds, truth.permute(1,0)) / self.temperature
    labels = torch.arange(truth.shape[0], device=truth.device)
    loss = self.cross_entropy(logits, labels)
    return loss






## Model forward pass check

In [16]:
import gc
gc.collect()
torch.cuda.empty_cache()

batch = next(iter(ssl_dl)).to(device)
model = Smrt2Vec().to(device)
c_proj, targets, mask = model(batch)
print((c_proj.shape, targets.shape, mask.shape))

(torch.Size([64, 2048, 128]), torch.Size([64, 2048, 128]), torch.Size([64, 2048]))


In [17]:
loss = InfoNCELoss()
loss(c_proj, targets, mask)

tensor(10.5058, device='cuda:0', grad_fn=<NllLossBackward0>)

In [18]:
model = Smrt2Vec().to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"trainable params: {total_params}")

trainable params: 2059648


# Pre-Training

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.02)
criterion = InfoNCELoss(temperature=0.1).to(device)
EPOCHS = 2

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=6e-4,
    total_steps=len(ssl_dl) * EPOCHS,
    pct_start=0.05
)

model.train()

for epoch in range(EPOCHS):
    progress_bar = tqdm(ssl_dl, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, batch in enumerate(progress_bar):
        batch = batch.to(device, non_blocking=True)

        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            c_proj, targets, mask_idx = model(batch)
            loss = criterion(c_proj, targets, mask_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if i % 10 == 0:
            progress_bar.set_postfix(
                loss=f"{loss.item():.4f}",
                lr=f"{scheduler.get_last_lr()[0]:.6f}"
            )

Epoch 1/2:  51%|█████     | 4077/7976 [13:59<13:24,  4.84it/s, loss=4.2953, lr=0.000533]

# Downstream Task

## Dataset
Honestly this feels a bit funky, and I'm debating whether to make a new preprocessing script that produces numpy arrays like the SSL dataset. This parquet style dataset is inherited from the CNN a while back and is much more difficult to work with, I find. Also much more difficult to get good bandwidth.

In [None]:
import torch
import polars as pl
import numpy as np
import pyarrow.parquet as pq
from pathlib import Path
from torch.utils.data import IterableDataset
def compute_log_normalization_stats(df, features, epsilon=1):
    means = {col: (df[col].explode() + epsilon).log().mean() for col in features}
    stds = {col: (df[col].explode() + epsilon).log().explode().std() for col in features}
    return means, stds

class MethylIterableDataset(IterableDataset):
    def __init__(self, data_path, means, stds, context, restrict_row_groups=0, single_strand=False, inference=False):
        super().__init__()
        self.data_path = Path(data_path)
        self.means, self.stds = means, stds
        self.context = context
        self.single_strand = single_strand
        self.inference = inference
        self.restrict = restrict_row_groups

        self.kin_feats = ['fi', 'fp', 'ri', 'rp']
        self.vocab = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
        self.comp_map = torch.tensor([3, 2, 1, 0, 4], dtype=torch.long)

        try:
            meta = pq.read_metadata(self.data_path)
            self.n_groups = meta.num_row_groups
            use_groups = min(self.restrict, self.n_groups) if self.restrict else self.n_groups

            # fast row count
            n_rows = sum(meta.row_group(i).num_rows for i in range(use_groups))
            self.len = n_rows * (2 if single_strand else 1)
        except Exception:
            print(f'Failed to read parquet: {self.data_path}')
            self.n_groups, self.len = 0, 0

    def __len__(self):
        return self.len

    def _process_batch(self, df):
      # seq
        seq_arr = np.stack(
            df['seq'].str.split("")
            .list.eval(pl.element().replace_strict(self.vocab, default=4))
            .to_numpy()
        )
        seq_t = torch.tensor(seq_arr, dtype=torch.long)

        # kinetics
        kin_list = []
        for k in self.kin_feats:
            vals = df[k].to_numpy() # (N, L)
            vals = (np.log(vals + 1) - self.means[k]) / self.stds[k]
            kin_list.append(vals)
        kin_t = torch.tensor(np.stack(kin_list, axis=1), dtype=torch.bfloat16)

        # mask, labels, etc (note that there is no masked data in the downstream set, so it's all zeros here)
        mask = torch.zeros((seq_t.shape[0], seq_t.shape[1], 1), dtype=torch.bfloat16)
        labels = torch.tensor(df['label'].to_numpy(), dtype=torch.long) if not self.inference else None
        r_names, pos = df['read_name'].to_list(), df['cg_pos'].to_list()

        # construct forward sample
        # Seq (N, L, 1) + Kin (N, 2, L)->(N, L, 2) + Mask (N, L, 1) = (N, L, 4)
        fwd_data = torch.cat([
            seq_t.unsqueeze(-1).to(torch.bfloat16),
            kin_t[:, 0:2].permute(0, 2, 1),
            mask
        ], dim=2)

        # construct reverse data
        rev_data = None
        if self.single_strand:
            rev_seq_t = torch.flip(self.comp_map.to(seq_t.device)[seq_t], dims=[1])
            # Kin: slice 2:4, flip time (dim 2), permute channels
            rev_kin = torch.flip(kin_t[:, 2:4], dims=[2]).permute(0, 2, 1)
            rev_data = torch.cat([
                rev_seq_t.unsqueeze(-1).to(torch.bfloat16),
                rev_kin,
                mask
            ], dim=2)

        # yield
        for i in range(len(df)):
            # forward
            strand_name = 'fwd' if self.single_strand else 'ds'
            item_fwd = {
                'data': fwd_data[i],
                'metadata': {'read_name': r_names[i], 'position': pos[i], 'strand': strand_name}
            }
            if labels is not None: item_fwd['label'] = labels[i]
            yield item_fwd

            # reverse
            if rev_data is not None:
                item_rev = {
                    'data': rev_data[i],
                    'metadata': {'read_name': r_names[i], 'position': pos[i], 'strand': 'rev'}
                }
                if labels is not None: item_rev['label'] = labels[i]
                yield item_rev
            else:
              continue

    def __iter__(self):
        worker = torch.utils.data.get_worker_info()
        valid_groups = min(self.restrict, self.n_groups) if self.restrict else self.n_groups
        indices = np.arange(valid_groups)

        if worker:
            indices = np.array_split(indices, worker.num_workers)[worker.id]

        pqf = pq.ParquetFile(self.data_path)
        for i in indices:
            # array cast
            df = pl.from_arrow(pqf.read_row_group(i)).with_columns([
                pl.col(c).list.to_array(self.context) for c in self.kin_feats
            ])
            yield from self._process_batch(df)

In [None]:
KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']

df = pl.read_parquet('pacbio_standard_train_1m.parquet')
train_means, train_stds = compute_log_normalization_stats(df, KINETICS_FEATURES)

it_workers=0
batch_size=256
single_strand=True
#train
methyl_train_ds = MethylIterableDataset('./pacbio_standard_train_1m.parquet',
                                    means=train_means,
                                    stds=train_stds,
                                    context=32)
methyl_train_dl = DataLoader(methyl_train_ds,
                             batch_size=batch_size,
                             drop_last=True,
                             persistent_workers=False,
                             prefetch_factor=None,
                            )
# val
methyl_val_ds = MethylIterableDataset('./pacbio_standard_test_1m.parquet',
                                    means=train_means,
                                    stds=train_stds,
                                    context=32)
methyl_val_dl = DataLoader(methyl_val_ds,
                        batch_size=batch_size,
                        drop_last=True,
                        persistent_workers=False,
                        prefetch_factor=None)

## Linear Probe

In [None]:
class SingleIdxProbe(nn.Module):
    def __init__(self, encoder, n_classes=1, freeze_encoder=False):
        super().__init__()
        self.encoder = encoder

        if freeze_encoder:
          self.encoder.requires_grad_(False)
        else:
          self.encoder.requires_grad_(True)

        self.head = nn.Sequential(
            nn.Linear(encoder.d_model, encoder.d_model // 2),
            nn.ReLU(),
            nn.Linear(encoder.d_model // 2, n_classes)
        )

    def forward(self, x):
        c = self.encoder(x)
        logit = self.head(c[:, -1, :])
        return logit

In [None]:
import copy

# LR = 1e-5
EPOCHS = 20
DEVICE = torch.device('cuda')
encoder_clone = copy.deepcopy(model.encoder)
probe = SingleIdxProbe(encoder_clone, freeze_encoder=False).to(device)

# optimizer = torch.optim.AdamW(probe.parameters(), lr=LR)

optimizer = torch.optim.AdamW([
    {'params': probe.encoder.parameters(), 'lr': 5e-7},
    {'params': probe.head.parameters(), 'lr': 3e-5}
    ])

criterion = torch.nn.BCEWithLogitsLoss()
loss_history = []

total_params = sum(p.numel() for p in probe.parameters() if p.requires_grad)
print(f"trainable params: {total_params}")

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    probe.train()
    running_loss = 0.0
    for i, batch in enumerate(tqdm(methyl_train_dl)):
        inputs = batch['data'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        logits = probe(inputs)
        loss = criterion(logits, labels.unsqueeze(1).to(torch.float32))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 100 == 0:
            loss_history.append(running_loss / 100)
            running_loss = 0.0

    probe.eval()
    sample_count = 0
    sample_correct = 0
    for batch in tqdm(methyl_val_dl):
        inputs = batch['data'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        logits = probe(inputs)
        preds = logits > 0
        correct = labels == preds.squeeze(-1)
        sample_count += correct.shape[0]
        sample_correct += correct.sum()
    print(f"epoch val top1_acc: {sample_correct/sample_count}")

In [None]:
df = pl.DataFrame({'loss': loss_history})
def plot_loss(loss_df):
  # loss_df_long = loss_df.unpivot(index='stepsx100', value_name='loss')
  # min_loss = loss_df_long['loss'].min()
  # max_loss = loss_df_long['loss'].max(
  loss_df = loss_df.with_row_index()
  loss_chart = alt.Chart(loss_df).mark_line().encode(
    alt.X('index:Q'),
    alt.Y('loss:Q'),
  ).properties(
    width=700,
    height=500,
    title = 'Direct Downstream Train Loss'
  )
  return loss_chart
plot_loss(df)