In [1]:
import torch
import polars as pl
from torch.utils.data import Dataset
import yaml

class SMRTSequenceDataset(Dataset):
    """
    loads a full context of SMRT data (seq, ipd, pw)
    """
    def __init__(self, parquet_path: str, columns: list):
        self.data_df = pl.read_parquet(parquet_path)
        self.columns = columns
        # always include seq
        if 'seq' not in self.columns:
            self.columns.insert(0, 'seq')
        self.data_df = self.data_df.select(self.columns)
        
    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        row_data = self.data_df.row(idx, named=True)
        
        seq_ids = torch.tensor(row_data['seq'], dtype=torch.long)
        
        kinetics = []
        for col in ['fi', 'fp', 'ri', 'rp']:
            if col in row_data:
                kinetics.append(torch.tensor(row_data[col], dtype=torch.float32))
            
        if kinetics:
            # make into [L, 4] tensor
            kinetics_tensor = torch.stack(kinetics, dim=1)
            # TODO: normalize 
        else:
            kinetics_tensor = torch.empty(len(seq_ids), 0, dtype=torch.float32)

        return {
            "seq_ids": seq_ids,       # [L]
            "kinetics": kinetics_tensor # nominal case -> [L, 4];  no kinetics -> [L, 0] 
        }

In [38]:
torch.arange(5)

tensor([0, 1, 2, 3, 4])

In [37]:
import polars as pl
q = (
    pl.scan_parquet('../data/01_processed/ssl_sets/da1.parquet')
)
df=q.collect()
df.shape

(7878, 10)

In [3]:
ds = SMRTSequenceDataset('../data/01_processed/ssl_sets/da1.parquet', columns = ['seq', 'fi', 'fp', 'ri', 'rp'])
ds[0]['kinetics']

tensor([[13., 12.,  8., 14.],
        [13., 16., 16., 24.],
        [12.,  8., 11., 24.],
        ...,
        [10., 24.,  9., 11.],
        [13.,  5., 15., 14.],
        [17., 33., 16., 13.]])

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SMRTEncoder(nn.Module):
    """Encodes (seq_id, kinetics) at each position into a latent vector z_t."""
    def __init__(self, vocab_size, n_kinetics, embed_dim, latent_dim):
        super().__init__()
        self.seq_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # A small MLP to process the 4 kinetic features
        self.kinetics_mlp = nn.Sequential(
            nn.Linear(n_kinetics, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
        # Project the combined embedding to the final latent dimension
        self.projection = nn.Linear(embed_dim * 2, latent_dim)
        self.layer_norm = nn.LayerNorm(latent_dim)

    def forward(self, seq_ids, kinetics):
        # seq_ids: [N, L]
        # kinetics: [N, L, 4]
        
        seq_z = self.seq_embed(seq_ids)      # [N, L, embed_dim]
        kinetics_z = self.kinetics_mlp(kinetics) # [N, L, embed_dim]
        
        # Combine by concatenation
        combined_z = torch.cat([seq_z, kinetics_z], dim=-1) # [N, L, embed_dim * 2]
        
        # Project to latent space
        z_t = self.projection(combined_z)    # [N, L, latent_dim]
        z_t = self.layer_norm(z_t)
        return z_t
    

class SMRTEncoder(nn.Module):
    """Encodes (seq_id, kinetics) at each position into a latent vector z_t."""
    def __init__(self, vocab_size, n_kinetics, embed_dim, latent_dim):
        super().__init__()
        self.seq_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # A small MLP to process the 4 kinetic features
        self.kinetics_mlp = nn.Sequential(
            nn.Linear(n_kinetics, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
        # Project the combined embedding to the final latent dimension
        self.projection = nn.Linear(embed_dim * 2, latent_dim)
        self.layer_norm = nn.LayerNorm(latent_dim)

    def forward(self, seq_ids, kinetics):
        # seq_ids: [N, L]
        # kinetics: [N, L, 4]
        
        seq_z = self.seq_embed(seq_ids)      # [N, L, embed_dim]
        kinetics_z = self.kinetics_mlp(kinetics) # [N, L, embed_dim]
        
        # Combine by concatenation
        combined_z = torch.cat([seq_z, kinetics_z], dim=-1) # [N, L, embed_dim * 2]
        
        # Project to latent space
        z_t = self.projection(combined_z)    # [N, L, latent_dim]
        z_t = self.layer_norm(z_t)
        return z_t
    

class CPCModel(nn.Module):
    def __init__(self, encoder, autoregressive_model, predictor):
        super().__init__()
        self.g_enc = encoder
        self.g_ar = autoregressive_model
        
        # A simple linear layer to predict future latents from context
        # predictor projects context_dim -> latent_dim
        self.W_k = predictor 

    def forward(self, seq_ids, kinetics):
        # seq_ids: [N, L], kinetics: [N, L, 4]
        N, L = seq_ids.shape
        
        # 1. Get all latents
        z = self.g_enc(seq_ids, kinetics) # [N, L, D_z]
        
        # 2. Split into context and target
        # Let's use a random split point for robustness
        # e.g., split halfway
        split_point = L // 2
        
        z_context = z[:, :split_point, :]
        z_target = z[:, split_point:, :]
        
        # 3. Summarize context
        # c_all shape: [N, L_ctx, D_c]
        c_all, _ = self.g_ar(z_context)
        
        # Get the last context vector
        c_summary = c_all[:, -1, :] # [N, D_c]
        
        # 4. Predict future latents
        # This z_hat is the *prediction* for the future
        z_hat = self.W_k(c_summary) # [N, D_z]
        
        # 5. Calculate InfoNCE Loss
        
        # We will compare our prediction z_hat[i] against all
        # *actual* future latents z_target[j] in the batch.
        
        # For simplicity, let's just use the *first* target latent
        # from each sequence as the "positive" target.
        z_positive = z_target[:, 0, :] # [N, D_z]
        
        # This implementation uses in-batch negatives.
        # We calculate an N x N score matrix.
        # scores[i, j] = dot(z_hat[i], z_positive[j])
        
        # Normalize vectors for stable dot product (cosine similarity)
        z_hat_norm = F.normalize(z_hat, p=2, dim=1)
        z_positive_norm = F.normalize(z_positive, p=2, dim=1)
        
        scores = torch.matmul(z_hat_norm, z_positive_norm.T) # [N, N]
        
        # Temperature scaling
        temperature = 0.1 
        scores = scores / temperature
        
        # The positive pairs are on the diagonal (i, i).
        # We want to maximize scores[i, i] and minimize scores[i, j].
        # This is a standard cross-entropy loss problem.
        
        # Labels are simply [0, 1, 2, ..., N-1]
        labels = torch.arange(N, device=seq_ids.device)
        
        loss = F.cross_entropy(scores, labels)
        
        return loss

In [5]:
encoder = SMRTEncoder(4, n_kinetics=4, embed_dim=128, latent_dim=256)
encoder(ds[0]['seq_ids'], ds[0]['kinetics']).shape

torch.Size([2048, 256])

In [6]:
# Example setup
from torch.utils.data import DataLoader

# 1. Dataset
# Use the SMRT tags you extracted
data_cols = ['seq', 'fi', 'fp', 'ri', 'rp']
dataset = SMRTSequenceDataset(
    parquet_path="path/to/your.parquet",
    columns=data_cols
)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 2. Model
# (Define vocab_size, embed_dim, latent_dim, context_dim)
encoder = SMRTEncoder(vocab_size, n_kinetics=4, embed_dim=128, latent_dim=256)
gru = nn.GRU(input_size=256, hidden_size=512, batch_first=True)
predictor = nn.Linear(512, 256) # D_c -> D_z

model = CPCModel(encoder, gru, predictor)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 3. Training
model.train()
for batch in dataloader:
    seq_ids = batch['seq_ids']
    kinetics = batch['kinetics']
    
    optimizer.zero_grad()
    
    loss = model(seq_ids, kinetics)
    
    loss.backward()
    optimizer.step()
    
    print(f"Loss: {loss.item()}")

FileNotFoundError: No such file or directory (os error 2): path/to/your.parquet

This error occurred with the following context stack:
	[1] 'parquet scan'
	[2] 'sink'


In [None]:
import torch
import torch.nn as nn
from torch import Tensor
import math
from typing import Optional

class CPCTransformer(nn.Module):
    def __init__(self, 
                 vocab_size: 5, 
                 embed_dim: 64, 
                 nhead: 4, 
                 num_layers: 2, 
                 mlp_hidden_dim: 128,
                 k_max: 3, # number of steps to predict ahead
                 max_seq_len: 2048,
                 dropout: 0.1):
        
        super().__init__()
        self.embed_dim = embed_dim
        self.k_max = k_max
        self.max_seq_len = max_seq_len

        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=nhead, 
            dim_feedforward=mlp_hidden_dim, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )

        # make a seperate projection head for each step of extrapolation
        self.predictors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, mlp_hidden_dim),
                nn.ReLU(),
                nn.Linear(mlp_hidden_dim, embed_dim)
            ) for _ in range(k_max)
        ])
        
        self.loss_fn = nn.CrossEntropyLoss()
        
    def _generate_causal_mask(self, sz: int) -> Tensor:
        mask = (torch.tril(torch.ones(sz, sz)) == 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, x: Tensor) -> Tensor:
        B, T = x.shape
        
        if T > self.max_seq_len:
            raise ValueError(f"Input sequence length ({T}) exceeds max_seq_len ({self.max_seq_len})")
            
        causal_mask = self._generate_causal_mask(T).to(x.device)
        
        x_emb = self.token_embed(x) + self.pos_encoder[:, :T, :]
        
        z = self.transformer_encoder(x_emb, mask=causal_mask)
        
        total_loss = 0.0
        num_losses = 0
        
        for k in range(1, self.k_max + 1):
            t_max = T - k
            if t_max <= 0:
                continue
                
            c_t = z[:, :t_max, :]
            z_k = z[:, k:, :][:, :t_max, :]
            
            z_hat = self.predictors[k-1](c_t)
            
            z_hat_flat = z_hat.reshape(-1, self.embed_dim)
            z_k_flat = z_k.reshape(-1, self.embed_dim)
            
            scores = torch.matmul(z_hat_flat, z_k_flat.T)
            
            labels = torch.arange(z_hat_flat.size(0)).to(x.device)
            
            loss = self.loss_fn(scores, labels)
            total_loss += loss
            num_losses += 1
            
        return total_loss / num_losses if num_losses > 0 else torch.tensor(0.0).to(x.device)

In [9]:
B, T = 16, 100
V = 5 

model = CPCTransformer(
    vocab_size=V,
    embed_dim=64,
    nhead=4,
    num_layers=2,
    mlp_hidden_dim=128,
    k_max=3,
    max_seq_len=512
)

seq_data = torch.randint(0, V, (B, T))

print(f"Input shape: {seq_data.shape}")

loss = model(seq_data)

print(f"Calculated CPC Loss: {loss.item()}")

loss.backward()

print("Backward pass successful.")

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {num_params}")

Input shape: torch.Size([16, 100])
Calculated CPC Loss: 7.840989589691162
Backward pass successful.
Total trainable parameters: 149760


In [17]:
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple

class CPCTransformer(nn.Module):
    
    def __init__(self, 
                 vocab_size: int = 5, 
                 nuc_embed_dim: int = 48,
                 kinetic_embed_dim: int = 8,
                 embed_dim: int = 64, 
                 nhead: int = 4, 
                 num_layers: int = 2, 
                 mlp_hidden_dim: int = 128,
                 k_max: int = 3, 
                 max_seq_len: int = 2048,
                 dropout: float = 0.1):
        
        super().__init__()
        self.embed_dim = embed_dim
        self.k_max = k_max
        self.max_seq_len = max_seq_len

        # 1. Separate embedding layers for each input channel
        self.token_embed = nn.Embedding(vocab_size, nuc_embed_dim)
        # We assume kinetic values are integers from 0-255
        self.kinetic_embed_1 = nn.Embedding(256, kinetic_embed_dim)
        self.kinetic_embed_2 = nn.Embedding(256, kinetic_embed_dim)

        # 2. A projection layer to mix features and match model's embed_dim
        total_input_dim = nuc_embed_dim + kinetic_embed_dim + kinetic_embed_dim
        self.input_projector = nn.Linear(total_input_dim, embed_dim)

        # 3. Positional encoder now adds to the *projected* embedding
        self.pos_encoder = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        
        # The rest of the model is unchanged, as it operates on embed_dim
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=nhead, 
            dim_feedforward=mlp_hidden_dim, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        self.predictors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, mlp_hidden_dim),
                nn.ReLU(),
                nn.Linear(mlp_hidden_dim, embed_dim)
            ) for _ in range(k_max)
        ])
        
        self.loss_fn = nn.CrossEntropyLoss()
        
    def _generate_causal_mask(self, sz: int) -> Tensor:
        # Using .tril for a cleaner implementation
        mask = (torch.tril(torch.ones(sz, sz)) == 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, x: Tuple[Tensor, Tensor, Tensor]) -> Tensor:
        x_nuc, x_k1, x_k2 = x
        
        B, T = x_nuc.shape
        
        if T > self.max_seq_len:
            raise ValueError(f"Input sequence length ({T}) exceeds max_seq_len ({self.max_seq_len})")
            
        causal_mask = self._generate_causal_mask(T).to(x_nuc.device)
        
        # 1. Get separate embeddings
        nuc_emb = self.token_embed(x_nuc)
        k1_emb = self.kinetic_embed_1(x_k1)
        k2_emb = self.kinetic_embed_2(x_k2)
        
        # 2. Concatenate and project
        combined_emb = torch.cat([nuc_emb, k1_emb, k2_emb], dim=-1)
        projected_emb = self.input_projector(combined_emb)
        
        # 3. Add positional encoding
        x_emb = projected_emb + self.pos_encoder[:, :T, :]
        
        # The rest of the logic is identical
        z = self.transformer_encoder(x_emb, mask=causal_mask)
        
        total_loss = 0.0
        num_losses = 0
        
        for k in range(1, self.k_max + 1):
            t_max = T - k
            if t_max <= 0:
                continue
                
            c_t = z[:, :t_max, :]
            z_k = z[:, k:, :][:, :t_max, :]
            
            z_hat = self.predictors[k-1](c_t)
            
            z_hat_flat = z_hat.reshape(-1, self.embed_dim)
            z_k_flat = z_k.reshape(-1, self.embed_dim)
            
            scores = torch.matmul(z_hat_flat, z_k_flat.T)
            
            labels = torch.arange(z_hat_flat.size(0)).to(x_nuc.device)
            
            loss = self.loss_fn(scores, labels)
            total_loss += loss
            num_losses += 1
            
        return total_loss / num_losses if num_losses > 0 else torch.tensor(0.0).to(x_nuc.device)

if __name__ == '__main__':
    
    B, T = 16, 100
    V = 5 # A, C, G, T, N
    
    # Main model dimension
    D_MODEL = 64
    
    model = CPCTransformer(
        vocab_size=V,
        nuc_embed_dim=128,       #128 dims for nucleotides
        kinetic_embed_dim=64,    #64 dims for each kinetic channel
        embed_dim=D_MODEL,      # Total projected dim is 64 (48 + 8 + 8)
        nhead=4,
        num_layers=2,
        mlp_hidden_dim=128,
        k_max=3,
        max_seq_len=2048
    )
    
    # Create the 3 input tensors
    seq_data = torch.randint(0, V, (B, T))
    # Kinetic data (integers 0-255)
    kinetic_1 = torch.randint(0, 256, (B, T))
    kinetic_2 = torch.randint(0, 256, (B, T))

    # Pass data as a tuple
    model_input = (seq_data, kinetic_1, kinetic_2)
    
    print(f"Input shapes: nuc={seq_data.shape}, k1={kinetic_1.shape}, k2={kinetic_2.shape}")
    
    loss = model(model_input)
    
    print(f"Calculated CPC Loss: {loss.item()}")
    
    loss.backward()
    
    print("Backward pass successful.")
    
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {num_params}")

Input shapes: nuc=torch.Size([16, 100]), k1=torch.Size([16, 100]), k2=torch.Size([16, 100])
Calculated CPC Loss: 8.542759895324707
Backward pass successful.
Total trainable parameters: 297600


In [18]:
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple

class CPCTransformer(nn.Module):
    
    def __init__(self, 
                 vocab_size: int = 5, 
                 nuc_embed_dim: int = 48,
                 kinetics_channels: int = 4,
                 embed_dim: int = 64, 
                 nhead: int = 4, 
                 num_layers: int = 2, 
                 mlp_hidden_dim: int = 128,
                 k_max: int = 3, 
                 max_seq_len: int = 512,
                 dropout: float = 0.1):
        
        super().__init__()
        self.embed_dim = embed_dim
        self.k_max = k_max
        self.max_seq_len = max_seq_len

        # 1. Separate embedding layers for each input channel
        self.token_embed = nn.Embedding(vocab_size, nuc_embed_dim)
        # Kinetic data is now assumed to be a [B, T, 4] float tensor
        # self.kinetic_embed_1 = nn.Embedding(256, kinetic_embed_dim)
        # self.kinetic_embed_2 = nn.Embedding(256, kinetic_embed_dim)

        # 2. A projection layer to mix features and match model's embed_dim
        total_input_dim = nuc_embed_dim + kinetics_channels
        self.input_projector = nn.Linear(total_input_dim, embed_dim)

        # 3. Positional encoder now adds to the *projected* embedding
        self.pos_encoder = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        
        # The rest of the model is unchanged, as it operates on embed_dim
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=nhead, 
            dim_feedforward=mlp_hidden_dim, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        self.predictors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, mlp_hidden_dim),
                nn.ReLU(),
                nn.Linear(mlp_hidden_dim, embed_dim)
            ) for _ in range(k_max)
        ])
        
        self.loss_fn = nn.CrossEntropyLoss()
        
    def _generate_causal_mask(self, sz: int) -> Tensor:
        # Using .tril for a cleaner implementation
        mask = (torch.tril(torch.ones(sz, sz)) == 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, x: Tuple[Tensor, Tensor]) -> Tensor:
        x_nuc, x_kinetics = x
        
        B, T = x_nuc.shape
        
        if T > self.max_seq_len:
            raise ValueError(f"Input sequence length ({T}) exceeds max_seq_len ({self.max_seq_len})")
            
        causal_mask = self._generate_causal_mask(T).to(x_nuc.device)
        
        # 1. Get separate embeddings
        nuc_emb = self.token_embed(x_nuc)
        # k1_emb = self.kinetic_embed_1(x_k1)
        # k2_emb = self.kinetic_embed_2(x_k2)
        
        # 2. Concatenate and project
        # nuc_emb shape: [B, T, nuc_embed_dim]
        # x_kinetics shape: [B, T, kinetics_channels]
        combined_emb = torch.cat([nuc_emb, x_kinetics], dim=-1)
        projected_emb = self.input_projector(combined_emb)
        
        # 3. Add positional encoding
        x_emb = projected_emb + self.pos_encoder[:, :T, :]
        
        # The rest of the logic is identical
        z = self.transformer_encoder(x_emb, mask=causal_mask)
        
        total_loss = 0.0
        num_losses = 0
        
        for k in range(1, self.k_max + 1):
            t_max = T - k
            if t_max <= 0:
                continue
                
            c_t = z[:, :t_max, :]
            z_k = z[:, k:, :][:, :t_max, :]
            
            z_hat = self.predictors[k-1](c_t)
            
            z_hat_flat = z_hat.reshape(-1, self.embed_dim)
            z_k_flat = z_k.reshape(-1, self.embed_dim)
            
            scores = torch.matmul(z_hat_flat, z_k_flat.T)
            
            labels = torch.arange(z_hat_flat.size(0)).to(x_nuc.device)
            
            loss = self.loss_fn(scores, labels)
            total_loss += loss
            num_losses += 1
            
        return total_loss / num_losses if num_losses > 0 else torch.tensor(0.0).to(x_nuc.device)

if __name__ == '__main__':
    
    B, T = 16, 100
    V = 5 # A, C, G, T, N
    
    # Main model dimension
    D_MODEL = 64
    KINETICS_CHANNELS = 4
    
    model = CPCTransformer(
        vocab_size=V,
        nuc_embed_dim=48,       # 48 dims for nucleotides
        kinetics_channels=KINETICS_CHANNELS,    # 4 dims for kinetics data
        embed_dim=D_MODEL,      # Total projected dim is 64
        nhead=4,
        num_layers=2,
        mlp_hidden_dim=128,
        k_max=3,
        max_seq_len=512
    )
    
    # Create the 2 input tensors
    seq_data = torch.randint(0, V, (B, T))
    # Kinetic data (floats, shape [B, T, 4])
    kinetics_data = torch.randn(B, T, KINETICS_CHANNELS)

    # Pass data as a tuple
    model_input = (seq_data, kinetics_data)
    
    print(f"Input shapes: nuc={seq_data.shape}, kinetics={kinetics_data.shape}")
    
    loss = model(model_input)
    
    print(f"Calculated CPC Loss: {loss.item()}")
    
    loss.backward()
    
    print("Backward pass successful.")
    
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {num_params}")

Input shapes: nuc=torch.Size([16, 100]), kinetics=torch.Size([16, 100, 4])
Calculated CPC Loss: 8.252659797668457
Backward pass successful.
Total trainable parameters: 153072


In [21]:
model(ds[0]['seq_ids'], ds[0]['kinetics'])

TypeError: CPCTransformer.forward() takes 2 positional arguments but 3 were given

In [32]:
class CPCTransformer(nn.Module):
    
    def __init__(self, 
                 vocab_size: int = 5, 
                 nuc_embed_dim: int = 48,
                 kinetics_channels: int = 4,
                 embed_dim: int = 64, 
                 nhead: int = 4, 
                 num_layers: int = 2, 
                 mlp_hidden_dim: int = 128,
                 k_max: int = 3, 
                 max_seq_len: int = 512,
                 dropout: float = 0.1):
        
        super().__init__()
        self.embed_dim = embed_dim
        self.k_max = k_max
        self.max_seq_len = max_seq_len

        # 1. Separate embedding layers for each input channel
        self.token_embed = nn.Embedding(vocab_size, nuc_embed_dim)
        # Kinetic data is now assumed to be a [B, T, 4] float tensor
        # self.kinetic_embed_1 = nn.Embedding(256, kinetic_embed_dim)
        # self.kinetic_embed_2 = nn.Embedding(256, kinetic_embed_dim)

        # 2. A projection layer to mix features and match model's embed_dim
        total_input_dim = nuc_embed_dim + kinetics_channels
        self.input_projector = nn.Linear(total_input_dim, embed_dim)

        # 3. Positional encoder now adds to the *projected* embedding
        self.pos_encoder = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        
        # The rest of the model is unchanged, as it operates on embed_dim
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=nhead, 
            dim_feedforward=mlp_hidden_dim, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        self.predictors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, mlp_hidden_dim),
                nn.ReLU(),
                nn.Linear(mlp_hidden_dim, embed_dim)
            ) for _ in range(k_max)
        ])
        
        self.loss_fn = nn.CrossEntropyLoss()
        
    def _generate_causal_mask(self, sz: int) -> Tensor:
        # Using .tril for a cleaner implementation
        mask = (torch.tril(torch.ones(sz, sz)) == 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, x: Tuple[Tensor, Tensor]) -> Tensor:
        x_nuc, x_kinetics = x
        
        B, T = x_nuc.shape
        
        if T > self.max_seq_len:
            raise ValueError(f"Input sequence length ({T}) exceeds max_seq_len ({self.max_seq_len})")
            
        causal_mask = self._generate_causal_mask(T).to(x_nuc.device)
        
        # 1. Get separate embeddings
        # --- FIX for MPS Error ---
        # The nn.Embedding layer on MPS has a bug with long tensors.
        # The fix is to move the input tensor to the CPU for the lookup,
        # even if the embedding layer's weights are on the MPS device.
        nuc_emb = self.token_embed(x_nuc.cpu())
        # Move the result back to the original device (MPS)
        nuc_emb = nuc_emb.to(x_nuc.device)
        # --- End of FIX ---

        # k1_emb = self.kinetic_embed_1(x_k1)
        # k2_emb = self.kinetic_embed_2(x_k2)
        
        # 2. Concatenate and project
        # nuc_emb shape: [B, T, nuc_embed_dim]
        # x_kinetics shape: [B, T, kinetics_channels]
        combined_emb = torch.cat([nuc_emb, x_kinetics], dim=-1)
        projected_emb = self.input_projector(combined_emb)
        
        # 3. Add positional encoding
        x_emb = projected_emb + self.pos_encoder[:, :T, :]
        
        # The rest of the logic is identical
        z = self.transformer_encoder(x_emb, mask=causal_mask)
        
        total_loss = 0.0
        num_losses = 0
        
        for k in range(1, self.k_max + 1):
            t_max = T - k
            if t_max <= 0:
                continue
                
            c_t = z[:, :t_max, :]
            z_k = z[:, k:, :][:, :t_max, :]
            
            z_hat = self.predictors[k-1](c_t)
            
            z_hat_flat = z_hat.reshape(-1, self.embed_dim)
            z_k_flat = z_k.reshape(-1, self.embed_dim)
            
            scores = torch.matmul(z_hat_flat, z_k_flat.T)
            
            labels = torch.arange(z_hat_flat.size(0)).to(x_nuc.device)
            
            loss = self.loss_fn(scores, labels)
            total_loss += loss
            num_losses += 1
            
        return total_loss / num_losses if num_losses > 0 else torch.tensor(0.0).to(x_nuc.device)

In [35]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
from tqdm import tqdm

model = CPCTransformer(
    vocab_size=V,
    nuc_embed_dim=48,       # 48 dims for nucleotides
    kinetics_channels=KINETICS_CHANNELS,    # 4 dims for kinetics data
    embed_dim=D_MODEL,      # Total projected dim is 64
    nhead=4,
    num_layers=2,
    mlp_hidden_dim=128,
    k_max=3,
    max_seq_len=2048
)
model.to(device)
dataset = SMRTSequenceDataset('../data/01_processed/ssl_sets/da1.parquet', columns = ['seq', 'fi', 'fp', 'ri', 'rp'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)

for batch in tqdm(dataloader):
    seq = batch['seq_ids'].to(device)
    kinetics = batch['kinetics'].to(device)
    
    loss = model((seq, kinetics))
    # ...

  0%|          | 0/493 [00:00<?, ?it/s]


RuntimeError: Placeholder storage has not been allocated on MPS device!