In [5]:
# --- 0) Setup
import torch, torch.nn as nn, torch.nn.functional as F, random, math
from torch.utils.data import Dataset, DataLoader
from transformers import EsmTokenizer
from datasets import load_dataset
from tqdm.auto import tqdm
import time
import torch

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
MASK_ID = tokenizer.mask_token_id
PAD_ID  = tokenizer.pad_token_id
BOS_ID  = tokenizer.bos_token_id
EOS_ID  = tokenizer.eos_token_id

# canonical AAs for *sampling* later (don’t clamp during training CE)
CANON = list("ACDEFGHIKLMNPQRSTVWY")
CANON_IDS = tokenizer.convert_tokens_to_ids(CANON)

# --- 1) Data
ds = load_dataset("dahvid12/uniprot50-sequences-subsample100", split="train")
seq_series = ds.to_pandas()["sequence"]  # pandas Series[str]


class ProteinSeqDataset(Dataset):
    def __init__(self, s):
        self.items = [str(x).strip().upper() for x in s.tolist()]
    def __len__(self): return len(self.items)
    def __getitem__(self, i): return self.items[i]

def collate_batch(batch_seqs, max_len=None):
    toks = tokenizer(
        batch_seqs,
        add_special_tokens=True, padding=True, truncation=True, max_length=max_len,
        return_tensors="pt"
    )
    input_ids = toks["input_ids"]          # [B, Lw]
    attn      = toks["attention_mask"]     # [B, Lw]
    core_ids  = input_ids[:, 1:-1].contiguous()      # drop BOS/EOS
    core_attn = attn[:, 1:-1].contiguous()
    # ensure PAD_ID in padded spots (robust)
    core_ids  = torch.where(core_attn.bool(), core_ids, torch.full_like(core_ids, PAD_ID))
    return {"core_ids": core_ids, "core_attn": core_attn}

# --- 2) MDLM schedules (cosine α) and weight
def alpha_cosine(t: torch.Tensor) -> torch.Tensor:
    return torch.cos(0.5 * math.pi * t).pow(2)

def alpha_prime_cosine(t: torch.Tensor) -> torch.Tensor:
    a = 0.5 * math.pi
    return -2*a*torch.cos(a*t)*torch.sin(a*t)

def weight_w(t: torch.Tensor) -> torch.Tensor:
    a = alpha_cosine(t)
    ap = alpha_prime_cosine(t)
    return ap / (1.0 - a).clamp_min(1e-6)

# --- 3) Forward masking that respects padding
def forward_mask(core_ids: torch.LongTensor, core_attn: torch.LongTensor, t_scalar: float):
    """
    Mask only where core_attn==1. Returns z_t, masked_bool.
    """
    B, L = core_ids.shape
    valid = core_attn.bool()
    a = alpha_cosine(torch.tensor([t_scalar], device=core_ids.device))
    p_mask = (1.0 - a).item()
    U = torch.rand(B, L, device=core_ids.device)
    masked = (U < p_mask) & valid
    z_t = core_ids.clone()
    z_t[masked] = MASK_ID
    return z_t, masked

# --- 4) Simple time-conditioned encoder (same signature as earlier)
class TimeEmbed(nn.Module):
    def __init__(self, d): 
        super().__init__()
        self.lin = nn.Sequential(nn.Linear(d, 4*d), nn.SiLU(), nn.Linear(4*d, d))
        self.d = d
    def forward(self, t_scalar, device):
        t = torch.tensor([t_scalar], device=device).float()
        half = self.d // 2
        freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=device)/half)
        ang = t[:,None]*freqs[None,:]
        te = torch.cat([ang.sin(), ang.cos()], dim=-1)
        return self.lin(te)

class MDLMTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_layers=12, n_heads=8, max_len=2048, dropout=0.1):
        super().__init__()
        self.d = d_model
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        enc = nn.TransformerEncoderLayer(d_model, n_heads, 4*d_model, dropout=dropout,
                                         batch_first=True, norm_first=True)
        self.tr = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.head.weight = self.emb.weight   # tie
        self.time = TimeEmbed(d_model)
    def forward(self, core_ids, t_scalar):
        B, L = core_ids.shape
        h = self.emb(core_ids) + self.pos.weight[:L][None].expand(B, L, -1) + self.time(t_scalar, core_ids.device)
        h = self.tr(h)
        return self.head(h)                  # [B,L,V]

# --- 5) MDLM loss (Algorithm 1) that ignores PADs
def mdlm_loss_step(model, core_ids, core_attn, t_scalar: float, return_metrics: bool = False):
    """
    MDLM loss per Algorithm 1 with weight w(t) = alpha'(t)/(1-alpha(t)).
    Returns loss (and optional metrics dict).
    """
    z_t, masked = forward_mask(core_ids, core_attn, t_scalar)
    logits = model(z_t, t_scalar)           # [B,L,V]

    # Count tokens
    valid_tokens  = core_attn.sum().item()
    masked_tokens = masked.sum().item()
    masked_frac   = masked_tokens / max(1, valid_tokens)

    if masked_tokens == 0:
        loss = logits.new_tensor(0.0, requires_grad=True)
    else:
        targets  = core_ids[masked]         # [Nmask]
        logits_m = logits[masked]           # [Nmask,V]
        ce = torch.nn.functional.cross_entropy(logits_m, targets, reduction="mean")
        w  = weight_w(torch.tensor([t_scalar], device=logits.device)).item()
        loss = ce * float(w)

    if not return_metrics:
        return loss

    # metrics payload
    a_t = alpha_cosine(torch.tensor([t_scalar], device=core_ids.device)).item()
    ap_t = alpha_prime_cosine(torch.tensor([t_scalar], device=core_ids.device)).item()
    return loss, {
        "t": t_scalar,
        "alpha_t": a_t,
        "alpha_prime_t": ap_t,
        "w_t": ap_t / max(1e-6, 1 - a_t),
        "valid_tokens": int(valid_tokens),
        "masked_tokens": int(masked_tokens),
        "masked_frac": float(masked_frac),
        "mean_len": float(core_attn.sum(dim=1).float().mean().item()),
    }

In [6]:
# --- (B) Nice-to-have: gradient-norm + CUDA mem helpers ---
def grad_global_norm(module: torch.nn.Module) -> float:
    total = 0.0
    for p in module.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total += param_norm.item() ** 2
    return total ** 0.5

def cuda_mem_mb():
    if not torch.cuda.is_available():
        return 0.0, 0.0
    alloc = torch.cuda.memory_allocated() / (1024**2)
    reserv = torch.cuda.memory_reserved() / (1024**2)
    return alloc, reserv

In [9]:
train_loader = DataLoader(
    ProteinSeqDataset(seq_series),
    batch_size=32, shuffle=True, num_workers=2, pin_memory=True,
    collate_fn=lambda b: collate_batch(b, max_len=1024),  # ESM limit ≈1024 incl. specials
)

In [7]:
# --- (C) Train with tqdm + live metrics ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MDLMTransformer(vocab_size=tokenizer.vocab_size, d_model=512, n_layers=12, n_heads=8).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)




In [None]:
epochs = 5
model.train()
for epoch in range(1, epochs + 1):
    epoch_tokens = 0
    epoch_masked = 0
    epoch_loss_sum = 0.0
    epoch_steps = 0
    t0 = time.time()

    pbar = tqdm(train_loader, desc=f"epoch {epoch}/{epochs}", leave=True)
    for batch in pbar:
        core_ids  = batch["core_ids"].to(device, non_blocking=True)
        core_attn = batch["core_attn"].to(device, non_blocking=True)

        t = random.random()  # t ~ U[0,1]
        loss, m = mdlm_loss_step(model, core_ids, core_attn, t, return_metrics=True)

        optim.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        # accounting
        epoch_steps += 1
        epoch_loss_sum += float(loss.detach().item())
        epoch_tokens += m["valid_tokens"]
        epoch_masked += m["masked_tokens"]

        # live stats
        gn = grad_global_norm(model)
        alloc, reserv = cuda_mem_mb()
        elapsed = max(1e-6, time.time() - t0)
        toks_per_s = epoch_tokens / elapsed
        pbar.set_postfix({
            "loss": f"{(epoch_loss_sum/epoch_steps):.4f}",
            "mask%": f"{(100.0 * m['masked_frac']):.1f}",
            "len": f"{m['mean_len']:.0f}",
            "α(t)": f"{m['alpha_t']:.3f}",
            "w(t)": f"{m['w_t']:.3f}",
            "g||": f"{gn:.2f}",
            "tok/s": f"{toks_per_s:.0f}",
            "CUDA(MB)": f"{alloc:.0f}/{reserv:.0f}" if torch.cuda.is_available() else "CPU",
        })

    epoch_time = time.time() - t0
    print(
        f"epoch {epoch}: "
        f"loss={epoch_loss_sum/epoch_steps:.4f} | "
        f"masked={epoch_masked}/{epoch_tokens} ({100*epoch_masked/max(1,epoch_tokens):.1f}%) | "
        f"tok/s={epoch_tokens/max(1e-6, epoch_time):.0f} | "
        f"time={epoch_time:.1f}s"
    )

epoch 1/5:   1%|          | 158/21433 [03:37<6:33:46,  1.11s/it, loss=-85856377.5372, mask%=83.1, len=224, α(t)=0.171, w(t)=-1.427, g||=1.00, tok/s=6267, CUDA(MB)=657/20412] 