In [None]:
# --------------------------
# 0) Tokenizer & constants
# --------------------------
from transformers import EsmTokenizer
import torch, torch.nn as nn, torch.nn.functional as F
import math, random

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
MASK_ID = tokenizer.mask_token_id
PAD_ID  = tokenizer.pad_token_id
BOS_ID  = tokenizer.bos_token_id
EOS_ID  = tokenizer.eos_token_id

# canonical 20 amino acids (generation allowlist)
CANON = list("ACDEFGHIKLMNPQRSTVWY")
CANON_IDS = tokenizer.convert_tokens_to_ids(CANON)
CANON_IDS_T = None  # filled lazily on device

def encode_core(seq: str) -> torch.LongTensor:
    """Return core token ids [L] without BOS/EOS."""
    ids = tokenizer(seq, add_special_tokens=True, return_tensors="pt")["input_ids"][0]
    return ids[1:-1].clone()

def batch_encode_core(seqs):
    ids = tokenizer(
        seqs, add_special_tokens=True, padding=True, truncation=True, return_tensors="pt"
    )["input_ids"]              # [B, L+2]
    return ids[:, 1:-1].contiguous()  # [B, L]

def decode_core(core_ids: torch.LongTensor) -> str:
    """core_ids: [L]; return string of AAs (drop specials)."""
    toks = tokenizer.convert_ids_to_tokens(core_ids.tolist())
    # strip anything non-canonical just in case
    toks = [t if t in CANON else "A" for t in toks]
    return "".join(toks)


In [None]:
# --------------------------
# 1) Schedules: alpha(t), alpha'(t)
# --------------------------
def alpha_cosine(t: torch.Tensor) -> torch.Tensor:
    # alpha(t) = cos^2(pi/2 * t)
    return torch.cos(0.5 * math.pi * t).pow(2)

def alpha_prime_cosine(t: torch.Tensor) -> torch.Tensor:
    # derivative of cos^2(π t / 2) = -π/2 * sin(π t)
    # more precisely: d/dt cos^2(a t) = -a*sin(2 a t)
    a = 0.5 * math.pi
    return -2*a*torch.cos(a*t)*torch.sin(a*t)

def weight_w(t: torch.Tensor) -> torch.Tensor:
    a = alpha_cosine(t)
    ap = alpha_prime_cosine(t)
    return ap / (1.0 - a).clamp_min(1e-6)  # Eq. (47) factor


In [None]:
# --------------------------
# 2) Forward masking q_t (absorbing state)
# --------------------------
def forward_mask(core_ids: torch.LongTensor, t_scalar: float):
    """
    core_ids: [B, L], ints. t_scalar in [0,1].
    Replace each token with MASK_ID iid with prob 1 - alpha(t).
    Returns: z_t [B,L], mask_bool [B,L] of positions that were masked.
    """
    B, L = core_ids.shape
    t = torch.full((B,), float(t_scalar), device=core_ids.device)
    a = alpha_cosine(t).view(B, 1)
    p_mask = 1.0 - a  # per-batch; broadcast over length
    U = torch.rand(B, L, device=core_ids.device)
    mask = U < p_mask
    z_t = core_ids.clone()
    z_t[mask] = MASK_ID
    return z_t, mask


In [None]:
# --------------------------
# 3) Denoiser: encoder-only Transformer + time conditioning
# --------------------------
class TimeEmbed(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.lin = nn.Sequential(
            nn.Linear(d_model, 4*d_model), nn.SiLU(), nn.Linear(4*d_model, d_model)
        )
        self.d_model = d_model

    def forward(self, t_scalar: float, device):
        # sinusoidal TE of size d_model
        t = torch.tensor([t_scalar], device=device).float()
        half = self.d_model // 2
        freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=device)/half)
        ang = t[:, None] * freqs[None, :]
        te = torch.cat([ang.sin(), ang.cos()], dim=-1)  # [1, d_model]
        return self.lin(te)  # [1, d_model]

class MDLMTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=768, n_layers=12, n_heads=12, max_len=2048, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos   = nn.Embedding(max_len, d_model)
        enc = nn.TransformerEncoderLayer(
            d_model, n_heads, 4*d_model, dropout=dropout, batch_first=True, norm_first=True
        )
        self.tr = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.to_logits = nn.Linear(d_model, vocab_size, bias=False)
        self.to_logits.weight = self.embed.weight  # tie weights
        self.time = TimeEmbed(d_model)

    def forward(self, core_ids: torch.LongTensor, t_scalar: float):
        """
        core_ids: [B, L] (NO BOS/EOS). Returns logits over vocab: [B, L, V].
        """
        B, L = core_ids.shape
        pos = self.pos.weight[:L][None, ...].expand(B, L, -1)
        h = self.embed(core_ids) + pos + self.time(t_scalar, device=core_ids.device)
        h = self.tr(h)
        return self.to_logits(h)


In [None]:
# --------------------------
# 4) MDLM loss step (Algorithm 1)
# --------------------------
def mdlm_loss_step(model, core_ids, t_scalar: float):
    """
    core_ids: [B,L] clean targets.
    1) sample z_t by masking with prob 1-alpha(t)
    2) predict logits on z_t
    3) CE on masked positions only
    4) weight by w(t) = alpha'(t)/(1-alpha(t))
    """
    z_t, masked = forward_mask(core_ids, t_scalar)
    logits = model(z_t, t_scalar)  # [B,L,V]

    # CE only on masked positions
    if not masked.any():
        return logits.new_tensor(0.0, requires_grad=True)

    # Gather targets at masked sites
    targets = core_ids[masked]  # [Nmask]
    logits_m = logits[masked]   # [Nmask, V]

    loss = F.cross_entropy(logits_m, targets, reduction="mean")
    # weight per Eq. (47)
    w = weight_w(torch.tensor([t_scalar], device=logits.device)).item()
    return loss * float(w)


In [None]:
# --------------------------
# 5) Trainer skeleton
# --------------------------
def train_mdlm(model, dataloader, epochs=1, lr=3e-4, device="cuda"):
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    model.train()
    for ep in range(epochs):
        for batch_core in dataloader:              # batch_core: [B,L]
            batch_core = batch_core.to(device)
            t_scalar = random.random()              # t ~ U[0,1]
            loss = mdlm_loss_step(model, batch_core, t_scalar)
            optim.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()


In [None]:
# --------------------------
# 6) Ancestral sampling (SUBS) with unmask probability
# --------------------------
@torch.no_grad()
def ancestral_sample_mdlm(model, seq_len, T=64, top_p=0.9, temperature=1.0, device="cuda"):
    """
    Start all MASK, descend t=1..0 with T steps.
    At each step unmask a fraction p_unmask = (alpha_s - alpha_t)/(1 - alpha_t) of the remaining masks
    by sampling from the model's categorical; carry-over for unmasked tokens.
    """
    model.eval()
    x = torch.full((1, seq_len), MASK_ID, dtype=torch.long, device=device)  # core tokens
    global CANON_IDS_T
    if CANON_IDS_T is None:
        CANON_IDS_T = torch.tensor(CANON_IDS, device=device, dtype=torch.long)

    for step in range(T, 0, -1):
        t = step / T
        s = (step - 1) / T
        a_t = alpha_cosine(torch.tensor([t], device=device))
        a_s = alpha_cosine(torch.tensor([s], device=device))
        p_unmask = ((a_s - a_t) / (1 - a_t).clamp_min(1e-6)).clamp(0, 1).item()

        logits = model(x, t)            # [1, L, V]
        # clamp to 20 AAs
        V = logits.size(-1)
        bans = torch.ones(V, dtype=torch.bool, device=device)
        bans[CANON_IDS_T] = False
        logits[..., bans] = float("-inf")

        # nucleus sampling for proposed tokens (only for masked positions)
        probs = (logits / temperature).softmax(-1)  # [1,L,V]
        # nucleus filter
        sorted_p, idx = probs.sort(-1, descending=True)
        csum = sorted_p.cumsum(-1)
        cutoff = (csum > top_p).float().argmax(-1, keepdim=True)
        keep = torch.arange(V, device=device)[None, None, :] <= cutoff
        keep = keep.gather(-1, idx.argsort(-1))
        probs = probs * keep
        probs = probs / probs.sum(-1, keepdim=True)

        sampled = torch.distributions.Categorical(probs=probs).sample()  # [1, L]

        # choose which masked positions to unmask this step
        masked_now = (x == MASK_ID)
        if masked_now.any():
            k = int(p_unmask * masked_now.sum().item())
            if k > 0:
                # unmask the k positions with highest confidence (max prob) among masked
                maxp = probs.max(-1).values  # [1, L]
                maxp_masked = maxp.masked_fill(~masked_now, -1)
                topk_idx = torch.topk(maxp_masked, k=k, dim=-1).indices
                x[0, topk_idx[0]] = sampled[0, topk_idx[0]]
        # carry-over for already unmasked happens implicitly (we never change them)

    # decode to string
    return decode_core(x[0].cpu())


In [None]:
# load data
from datasets import load_dataset
import pandas as pd

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("dahvid12/uniprot50-sequences-subsample100")
sequences = pd.Series(ds["train"]["sequence"])
print(sequences.head())

vocab_size = tokenizer.vocab_size
model = MDLMTransformer(vocab_size=vocab_size, d_model=512, n_layers=12, n_heads=8)

# Example tiny loop (pseudo)
from torch.utils.data import DataLoader, TensorDataset

seqs = sequences.tolist()  # your corpus
core = batch_encode_core(seqs)            # [B,L]
ds = TensorDataset(core)                  # pack as 1-tensor dataset
dl = DataLoader(ds, batch_size=4, shuffle=True)

In [None]:

train_mdlm(model, (b[0] for b in dl), epochs=10, lr=3e-4, device="cuda")
sample = ancestral_sample_mdlm(model, seq_len=256, T=64, device="cuda")
print(sample)
