In [11]:
# Install required packages if not already installed
# !pip install torch transformers

import torch
import torch.nn as nn
from transformers import AutoTokenizer

# Parameters
BATCH_SIZE = 2
ENCODING_DIM = 384
REPORT_MAX_LEN = 20
VOCAB_MODEL = "distilbert-base-uncased"

# Mock image encoding data
image_encodings = torch.randn(BATCH_SIZE, ENCODING_DIM)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(VOCAB_MODEL)
vocab_size = tokenizer.vocab_size

# BiLSTM Decoder Model
class BiLSTMReportGenerator(nn.Module):
    def __init__(self, encoding_dim, hidden_dim, vocab_size, max_len):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.max_len = max_len
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.bilstm = nn.LSTM(encoding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, image_encoding):
        # Expand image_encoding to sequence for LSTM input
        # (batch, seq_len, encoding_dim)
        x = image_encoding.unsqueeze(1).repeat(1, self.max_len, 1)
        lstm_out, _ = self.bilstm(x)
        logits = self.fc(lstm_out)
        return logits

# Instantiate model
hidden_dim = 256
model = BiLSTMReportGenerator(ENCODING_DIM, hidden_dim, vocab_size, REPORT_MAX_LEN)

# Generate mock output
with torch.no_grad():
    logits = model(image_encodings)
    # Get predicted token IDs
    predicted_ids = torch.argmax(logits, dim=-1)
    # Decode to text
    reports = []
    for ids in predicted_ids:
        tokens = tokenizer.convert_ids_to_tokens(ids.tolist())
        report = tokenizer.convert_tokens_to_string(tokens)
        reports.append(report)

# Print generated reports
for i, report in enumerate(reports):
    print(f"Report {i+1}: {report}")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Report 1: powerful powerful powerful powerfulaithaithaithaithaithaithaithaithaithaithaithaithaithaithaithaith
Report 2: cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback cornerback 141 141


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# --- ensure special tokens ---
specials = {}
if tokenizer.bos_token is None: specials["bos_token"] = "<bos>"
if tokenizer.eos_token is None: specials["eos_token"] = "<eos>"
if tokenizer.pad_token is None: specials["pad_token"] = "<pad>"
if specials:
    tokenizer.add_special_tokens(specials)
vocab_size = len(tokenizer)  # update size with new specials

PAD_ID = tokenizer.pad_token_id
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id

# --- Autoregressive image-conditioned LSTM decoder ---
class LSTMReportGenerator(nn.Module):
    def __init__(self, encoding_dim, hidden_dim, vocab_size, emb_dim=None, dropout=0.1):
        super().__init__()
        emb_dim = emb_dim or hidden_dim
        self.hidden_dim = hidden_dim

        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_ID)

        # init hidden & cell from image encoding
        self.init_h = nn.Linear(encoding_dim, hidden_dim)
        self.init_c = nn.Linear(encoding_dim, hidden_dim)

        self.lstm = nn.LSTMCell(emb_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, vocab_size)
        self.drop = nn.Dropout(dropout)

    def forward(self, img_enc, tgt_ids, teacher_forcing=0.5):
        """
        img_enc: (B, D)
        tgt_ids: (B, T)  where tgt[:,0] == BOS and sequence ends with EOS (+ PAD)
        returns logits for next-token prediction: (B, T-1, V)
        """
        B, T = tgt_ids.size()
        h = torch.tanh(self.init_h(img_enc))
        c = torch.tanh(self.init_c(img_enc))

        logits = []
        y_t = tgt_ids[:, 0]  # BOS
        for t in range(1, T):
            emb = self.embed(y_t)
            h, c = self.lstm(self.drop(emb), (h, c))
            step_logits = self.out(self.drop(h))        # (B, V)
            logits.append(step_logits.unsqueeze(1))
            # scheduled sampling
            if random.random() < teacher_forcing:
                y_t = tgt_ids[:, t]
            else:
                y_t = step_logits.argmax(dim=-1)
        return torch.cat(logits, dim=1)

    @torch.no_grad()
    def generate(self, img_enc, max_len=50):
        """
        Greedy decoding from BOS until EOS or max_len.
        img_enc: (B, D)
        returns token ids: (B, L<=max_len+1) including BOS..EOS
        """
        B = img_enc.size(0)
        h = torch.tanh(self.init_h(img_enc))
        c = torch.tanh(self.init_c(img_enc))
        y_t = torch.full((B,), BOS_ID, dtype=torch.long, device=img_enc.device)

        seq = [y_t.unsqueeze(1)]
        for _ in range(max_len):
            emb = self.embed(y_t)
            h, c = self.lstm(emb, (h, c))
            step_logits = self.out(h)
            y_t = step_logits.argmax(dim=-1)
            seq.append(y_t.unsqueeze(1))
            if (y_t == EOS_ID).all():
                break
        return torch.cat(seq, dim=1)

# --- instantiate ---
hidden_dim = 256
model = LSTMReportGenerator(ENCODING_DIM, hidden_dim, vocab_size)

# --- demo: build dummy target batch for training (BOS ... EOS padded) ---
def make_dummy_targets(texts, max_len):
    ids = []
    for t in texts:
        core = tokenizer.encode(t, add_special_tokens=False)
        arr = [BOS_ID] + core[:max_len-2] + [EOS_ID]
        arr += [PAD_ID] * (max_len - len(arr))
        ids.append(arr)
    return torch.tensor(ids)

dummy_texts = ["no acute cardiopulmonary process.",
               "heart size normal. lungs clear."]
tgt_ids = make_dummy_targets(dummy_texts, REPORT_MAX_LEN)

# --- forward (training) ---
logits = model(image_encodings, tgt_ids, teacher_forcing=0.7)
# cross-entropy ignoring PAD
loss = F.cross_entropy(
    logits.reshape(-1, logits.size(-1)),
    tgt_ids[:, 1:].reshape(-1),
    ignore_index=PAD_ID,
    label_smoothing=0.1
)
print("Loss:", float(loss))

# --- generate (inference) ---
with torch.no_grad():
    pred_ids = model.generate(image_encodings, max_len=REPORT_MAX_LEN)
    for i, ids in enumerate(pred_ids.tolist()):
        print(f"Report {i+1}:", tokenizer.decode(ids, skip_special_tokens=True))


Loss: 10.336823463439941
Report 1: situ diving kurt 36thب gonzales haste laundry rd manufacturingroll rivalry issuing moroccan secrecy likely illustrator pere cobb contractors
Report 2: teased delgado security asks ag poorly realhus minnesotawani liberation nebraska speech slices festival slices martyutive surgeon cooking


In [None]:
# pip install torch transformers

import torch, torch.nn as nn, torch.nn.functional as F
from transformers import AutoTokenizer
import random

# ---------------- tokenizer & specials ----------------
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
specials = {}
if tokenizer.bos_token is None: specials["bos_token"] = "<bos>"
if tokenizer.eos_token is None: specials["eos_token"] = "<eos>"
if tokenizer.pad_token is None: specials["pad_token"] = "<pad>"
if specials: tokenizer.add_special_tokens(specials)

PAD_ID = tokenizer.pad_token_id
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id
VOCAB_SIZE = len(tokenizer)

# ---------------- LSTM encoder-decoder ----------------
class BiLSTMImageEncoder(nn.Module):
    def __init__(self, in_dim, hidden, num_layers=1, dropout=0.1,
                 use_patient=False, patient_dim=0):
        super().__init__()
        self.use_patient = use_patient and patient_dim > 0
        self.patient_proj = nn.Linear(patient_dim, in_dim) if self.use_patient else None
        self.bilstm = nn.LSTM(
            input_size=in_dim, hidden_size=hidden, num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0.0, batch_first=True, bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, img_tokens, patient_feats=None):
        # img_tokens: (B, L=96, D=96)
        if self.use_patient:
            assert patient_feats is not None
            pt = self.patient_proj(patient_feats).unsqueeze(1)      # (B,1,D)
            x = torch.cat([pt, img_tokens], dim=1)                  # (B, L+1, D)
        else:
            x = img_tokens
        enc_out, _ = self.bilstm(self.dropout(x))                   # (B, L', 2H)
        return enc_out

class AdditiveAttention(nn.Module):
    def __init__(self, dec_hidden, enc_dim, attn_dim=256):
        super().__init__()
        self.W_h, self.W_s = nn.Linear(enc_dim, attn_dim, False), nn.Linear(dec_hidden, attn_dim, False)
        self.v = nn.Linear(attn_dim, 1, False)

    def forward(self, enc_out, s_t):                               # enc_out: (B,L,E), s_t: (B,H)
        e = self.v(torch.tanh(self.W_h(enc_out) + self.W_s(s_t).unsqueeze(1)))  # (B,L,1)
        a = torch.softmax(e, dim=1)                                # (B,L,1)
        ctx = (a * enc_out).sum(1)                                 # (B,E)
        return ctx, a.squeeze(-1)

class AttnLSTMDecoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, dec_hidden, enc_dim, dropout=0.1, pad_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTMCell(emb_dim + enc_dim, dec_hidden)
        self.attn = AdditiveAttention(dec_hidden, enc_dim)
        self.out  = nn.Linear(dec_hidden + enc_dim, vocab_size)
        self.drop = nn.Dropout(dropout)

    def forward(self, enc_out, tgt, bos_id, eos_id, teacher_forcing=0.5):
        B, T = tgt.size()
        H = self.lstm.hidden_size
        h = enc_out.mean(1).new_zeros((B, H))
        c = enc_out.mean(1).new_zeros((B, H))
        y = tgt[:, 0]                                              # BOS
        logits = []
        for t in range(1, T):
            emb = self.emb(y)
            ctx, _ = self.attn(enc_out, h)
            h, c = self.lstm(self.drop(torch.cat([emb, ctx], -1)), (h, c))
            step = self.out(self.drop(torch.cat([h, ctx], -1)))
            logits.append(step.unsqueeze(1))
            y = tgt[:, t] if random.random() < teacher_forcing else step.argmax(-1)
        return torch.cat(logits, 1)

    @torch.no_grad()
    def generate(self, enc_out, bos_id, eos_id, max_len=60):
        B = enc_out.size(0); H = self.lstm.hidden_size
        h = enc_out.mean(1).new_zeros((B, H)); c = enc_out.mean(1).new_zeros((B, H))
        y = torch.full((B,), bos_id, dtype=torch.long, device=enc_out.device)
        seq = [y.unsqueeze(1)]
        for _ in range(max_len):
            emb = self.emb(y)
            ctx, _ = self.attn(enc_out, h)
            h, c = self.lstm(torch.cat([emb, ctx], -1), (h, c))
            y = self.out(torch.cat([h, ctx], -1)).argmax(-1)
            seq.append(y.unsqueeze(1))
            if (y == eos_id).all(): break
        return torch.cat(seq, 1)

class ImageToReport(nn.Module):
    def __init__(self, img_feat_dim=96, enc_hidden=256, emb_dim=256, dec_hidden=512,
                 vocab_size=VOCAB_SIZE, pad_idx=PAD_ID, use_patient=False, patient_dim=0):
        super().__init__()
        self.encoder = BiLSTMImageEncoder(img_feat_dim, enc_hidden//2,
                                          num_layers=1, dropout=0.1,
                                          use_patient=use_patient, patient_dim=patient_dim)
        self.decoder = AttnLSTMDecoder(vocab_size, emb_dim, dec_hidden,
                                       enc_dim=enc_hidden, dropout=0.1, pad_idx=pad_idx)

    def forward(self, img_tokens, tgt_ids, patient_feats=None, tf=0.5):
        enc = self.encoder(img_tokens, patient_feats)              # (B, L', enc_hidden)
        return self.decoder(enc, tgt_ids, BOS_ID, EOS_ID, teacher_forcing=tf)

    @torch.no_grad()
    def generate(self, img_tokens, patient_feats=None, max_len=60):
        enc = self.encoder(img_tokens, patient_feats)
        return self.decoder.generate(enc, BOS_ID, EOS_ID, max_len=max_len)

# ---------------- demo with your shape (B, 96, 96) ----------------
BATCH = 2
IMG_TOKENS = torch.randn(BATCH, 96, 96)      # (B, L=96, D=96) from DINO
model = ImageToReport(img_feat_dim=96, enc_hidden=256, emb_dim=256, dec_hidden=512)

def pack_targets(texts, max_len=40):
    ids = []
    for t in texts:
        core = tokenizer.encode(t, add_special_tokens=False)
        arr = [BOS_ID] + core[:max_len-2] + [EOS_ID]
        arr += [PAD_ID] * (max_len - len(arr))
        ids.append(arr)
    return torch.tensor(ids)

targets = pack_targets(
    ["no acute cardiopulmonary process.",
     "heart size normal. lungs clear."],
    max_len=30
)

# training step
logits = model(IMG_TOKENS, targets, tf=0.7)                      # (B, T-1, V)
loss = F.cross_entropy(
    logits.reshape(-1, logits.size(-1)),
    targets[:, 1:].reshape(-1), ignore_index=PAD_ID, label_smoothing=0.1
)
print("loss:", float(loss))

# inference
with torch.no_grad():
    pred_ids = model.generate(IMG_TOKENS, max_len=10)
    for i, seq in enumerate(pred_ids.tolist()):
        print(f"Report {i+1}:", tokenizer.decode(seq, skip_special_tokens=True))



loss: 10.329631805419922
Report 1: adjustments bahamas prices fishingrod refurbishedᵗ sympathyigen こ fide rihanna chevalieriavent overgrown icelandic icelandic officelating illustration narrativeyne mammals romeo installed wreckage impressiverop ま
Report 2: adjustments castroyan dissent institutes furiously determine ע huntsvilledomsokan mythological message mc nurses dumont accountability mt [unused470] mass disaster 269 pistonzilyverevere swell bei新


In [16]:
with torch.no_grad():
    pred_ids = model.generate(IMG_TOKENS, max_len=10)
    for i, seq in enumerate(pred_ids.tolist()):
        print(f"Report {i+1}:", tokenizer.decode(seq, skip_special_tokens=True))

Report 1: adjustments regardless bazaar variouslyvishiser jp vegasulin robotic
Report 2: adjustments castroyan dissent institutes furiouslylastic cox furiously с


In [22]:
# If you don't have transformers installed:
# !pip install -q transformers

import math
import random
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer

torch.__version__
def get_tokenizer():
    tok = AutoTokenizer.from_pretrained("bert-base-uncased")
    assert tok.pad_token_id is not None, "BERT tokenizer must have a pad_token_id."
    return tok

tokenizer = get_tokenizer()
print("Vocab size:", tokenizer.vocab_size)
print("Specials -> pad:", tokenizer.pad_token_id, "cls:", tokenizer.cls_token_id, "sep:", tokenizer.sep_token_id)
class EncodedCaptionDataset(Dataset):
    """
    Wraps precomputed encodings and text captions.

    mode:
      - 'cls'   : x is tensor of shape (384,)
      - 'patch' : x is tensor of shape (T=96, D=96)
    """
    def __init__(self, encodings: List[torch.Tensor], captions: List[str], tokenizer, mode: str):
        assert mode in ("cls", "patch"), "mode must be 'cls' or 'patch'"
        assert len(encodings) == len(captions), "encodings and captions length mismatch"
        self.encodings = encodings
        self.captions = captions
        self.tok = tokenizer
        self.mode = mode

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        x = self.encodings[idx]
        text = self.captions[idx]
        # Encode without adding specials here; we add [CLS]/[SEP] in collate
        ids = self.tok.encode(text, add_special_tokens=False)
        return x, torch.tensor(ids, dtype=torch.long)

def collate_fn_bert(batch, tokenizer):
    """
    Returns:
      encs: stacked encodings
         - if CLS mode: (B, 384)
         - if PATCH mode: (B, T, D)
      caps_in: (B, T_in)   with [CLS] prefix
      caps_out: (B, T_out) with [SEP] suffix
      lengths_out: (B,) integer lengths for targets (non-padded)
      mode: 'cls' or 'patch' inferred from the first sample
    """
    encs, caps = zip(*batch)
    mode = "cls" if encs[0].dim() == 1 else "patch"

    if mode == "cls":
        encs = torch.stack(encs, dim=0)      # (B, 384)
    else:
        encs = torch.stack(encs, dim=0)      # (B, T, D) (T=D=96 by default mock)

    pad_id = tokenizer.pad_token_id
    sos_id = tokenizer.cls_token_id  # use [CLS] as SOS
    eos_id = tokenizer.sep_token_id  # use [SEP] as EOS

    caps_in, caps_out, lengths_out = [], [], []
    for c in caps:
        c_in = torch.cat([torch.tensor([sos_id]), c], dim=0)      # [CLS] + tokens
        c_out = torch.cat([c, torch.tensor([eos_id])], dim=0)     # tokens + [SEP]
        caps_in.append(c_in)
        caps_out.append(c_out)
        lengths_out.append(len(c_out))

    caps_in  = pad_sequence(caps_in,  batch_first=True, padding_value=pad_id)
    caps_out = pad_sequence(caps_out, batch_first=True, padding_value=pad_id)
    lengths_out = torch.tensor(lengths_out, dtype=torch.long)

    return encs, caps_in, caps_out, lengths_out, mode
class AttnPool1D(nn.Module):
    """
    Lightweight attention pooling across token sequence.
    Input: (B, T, Din) -> Output: (B, Dout)
    """
    def __init__(self, d_in: int, d_hidden: int, d_out: int):
        super().__init__()
        self.proj = nn.Linear(d_in, d_hidden)
        self.u = nn.Parameter(torch.randn(d_hidden))
        self.out = nn.Linear(d_in, d_out)

    def forward(self, x):  # (B, T, Din)
        scores = torch.tanh(self.proj(x)) @ self.u     # (B, T)
        w = scores.softmax(dim=1).unsqueeze(-1)        # (B, T, 1)
        pooled = (x * w).sum(dim=1)                    # (B, Din)
        return self.out(pooled)                        # (B, Dout)

class DINOv3Adapter(nn.Module):
    """
    Accepts:
      - CLS vector (B, 384)
      - Patch tokens (B, T=96, D=96)
    Produces a single image embedding (B, E).
    """
    def __init__(self, embed_size: int, cls_dim: int = 384, patch_dim: int = 96):
        super().__init__()
        self.proj_cls = nn.Linear(cls_dim, embed_size)
        self.proj_patch = nn.Linear(patch_dim, embed_size)
        self.pool = AttnPool1D(d_in=embed_size, d_hidden=embed_size, d_out=embed_size)

    def forward(self, x):
        if x.dim() == 2:          # (B, 384)
            return self.proj_cls(x)
        elif x.dim() == 3:        # (B, T, D)
            z = self.proj_patch(x)  # (B, T, E)
            return self.pool(z)     # (B, E)
        else:
            raise ValueError(f"Unexpected DINOv3 encoding shape: {tuple(x.shape)}")
class LSTMDecoder(nn.Module):
    def __init__(self, vocab_size: int, embed_size: int, hidden_size: int,
                 pad_id: int, num_layers: int = 1, dropout: float = 0.1):
        super().__init__()
        self.word_embed = nn.Embedding(vocab_size, embed_size, padding_idx=pad_id)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers,
                            batch_first=True, dropout=(dropout if num_layers > 1 else 0.0))
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)
        self.ctx_bias = nn.Linear(embed_size, embed_size)  # per-step bias from image vec
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, img_vec, captions_in, lengths_out):
        """
        img_vec: (B, E)
        captions_in: (B, T_in) with [CLS] at t=0
        lengths_out: (B,) lengths for targets (tokens + [SEP])
        """
        h0 = torch.tanh(self.init_h(img_vec)).unsqueeze(0)  # (1, B, H)
        c0 = torch.tanh(self.init_c(img_vec)).unsqueeze(0)  # (1, B, H)

        emb = self.word_embed(captions_in)                  # (B, T, E)
        bias = self.ctx_bias(img_vec).unsqueeze(1)          # (B, 1, E)
        emb = emb + bias

        packed = pack_padded_sequence(emb, lengths_out.tolist(),
                                      batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed, (h0, c0))
        logits = self.fc(packed_out.data)                   # (sumT, V)
        return logits

    @torch.no_grad()
    def generate(self, img_vec, max_len: int, sos_id: int, eos_id: int) -> List[List[int]]:
        """
        Greedy decode (no beam). Returns token id lists without the SOS token.
        """
        device = img_vec.device
        B = img_vec.size(0)
        h = torch.tanh(self.init_h(img_vec)).unsqueeze(0)
        c = torch.tanh(self.init_c(img_vec)).unsqueeze(0)
        bias = self.ctx_bias(img_vec).unsqueeze(1)          # (B, 1, E)

        inputs = torch.full((B, 1), sos_id, dtype=torch.long, device=device)
        out_ids = [[] for _ in range(B)]
        for _ in range(max_len):
            emb = self.word_embed(inputs) + bias            # (B, 1, E)
            y, (h, c) = self.lstm(emb, (h, c))              # (B, 1, H)
            logits = self.fc(y.squeeze(1))                  # (B, V)
            nxt = torch.argmax(logits, dim=-1)              # (B,)
            for i in range(B):
                out_ids[i].append(nxt[i].item())
            inputs = nxt.unsqueeze(1)

        # Trim at EOS if present
        trimmed = []
        for seq in out_ids:
            trimmed.append(seq[: seq.index(eos_id)] if eos_id in seq else seq)
        return trimmed

class DINOv3Captioner(nn.Module):
    def __init__(self, tokenizer, embed_size=256, hidden_size=512, num_layers=1):
        super().__init__()
        self.tok = tokenizer
        self.adapter = DINOv3Adapter(embed_size=embed_size, cls_dim=384, patch_dim=96)
        self.decoder = LSTMDecoder(
            vocab_size=self.tok.vocab_size,
            embed_size=embed_size,
            hidden_size=hidden_size,
            pad_id=self.tok.pad_token_id,
            num_layers=num_layers
        )

    def forward(self, dino_enc, caps_in, lengths_out):
        img_vec = self.adapter(dino_enc)  # (B, E)
        return self.decoder(img_vec, caps_in, lengths_out)

    @torch.no_grad()
    def caption(self, dino_enc, max_len: int = 30):
        img_vec = self.adapter(dino_enc)
        ids = self.decoder.generate(img_vec, max_len,
                                    sos_id=self.tok.cls_token_id,
                                    eos_id=self.tok.sep_token_id)
        texts = [self.tok.decode(seq, skip_special_tokens=True,
                                 clean_up_tokenization_spaces=True) for seq in ids]
        return texts
def train_one_epoch(model, loader, optimizer, criterion, device="cuda"):
    model.train()
    total, denom = 0.0, 0
    for encs, caps_in, caps_out, lengths_out, _mode in loader:
        encs = encs.to(device)
        caps_in = caps_in.to(device)
        caps_out = caps_out.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(encs, caps_in, lengths_out)  # (sumT, V)

        packed_targets = pack_padded_sequence(caps_out, lengths_out.tolist(),
                                              batch_first=True, enforce_sorted=False)
        loss = criterion(logits, packed_targets.data)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total += loss.item() * packed_targets.data.numel()
        denom += packed_targets.data.numel()
    return total / max(1, denom)

@torch.no_grad()
def evaluate(model, loader, criterion, device="cuda"):
    model.eval()
    total, denom = 0.0, 0
    for encs, caps_in, caps_out, lengths_out, _mode in loader:
        encs = encs.to(device)
        caps_in = caps_in.to(device)
        caps_out = caps_out.to(device)

        logits = model(encs, caps_in, lengths_out)
        packed_targets = pack_padded_sequence(caps_out, lengths_out.tolist(),
                                              batch_first=True, enforce_sorted=False)
        loss = criterion(logits, packed_targets.data)
        total += loss.item() * packed_targets.data.numel()
        denom += packed_targets.data.numel()
    return total / max(1, denom)
def make_mock_data_cls(n=16, seed=1337):
    random.seed(seed); torch.manual_seed(seed)
    captions = [
        "a dog running on grass",
        "a cat sitting on a couch",
        "a person riding a bike",
        "a group of people on the beach",
        "a car parked on the street",
        "a child playing with a ball",
        "a man cooking in a kitchen",
        "a woman reading a book",
    ]
    captions = (captions * ((n + len(captions) - 1)//len(captions)))[:n]
    encs = [torch.randn(384) for _ in range(n)]   # CLS vectors
    return encs, captions

def make_mock_data_patch(n=16, seed=2024):
    random.seed(seed); torch.manual_seed(seed)
    captions = [
        "a mountain covered with snow",
        "a river flowing through a forest",
        "an airplane flying in the sky",
        "a delicious pizza on a plate",
        "a soccer player kicking a ball",
        "a dog catching a frisbee",
        "a train arriving at a station",
        "a city skyline at night",
    ]
    captions = (captions * ((n + len(captions) - 1)//len(captions)))[:n]
    encs = [torch.randn(96, 96) for _ in range(n)]  # Patch tokens (T=96, D=96)
    return encs, captions
def build_loader(encs, caps, tokenizer, mode: str, batch_size=8, workers=0, shuffle=True):
    ds = EncodedCaptionDataset(encs, caps, tokenizer, mode=mode)
    collate = lambda b: collate_fn_bert(b, tokenizer)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle,
                      num_workers=workers, collate_fn=collate, pin_memory=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Hyperparams (adjust for your real training)
embed_size = 256
hidden_size = 512
num_layers = 1
epochs = 2
batch_size = 8
lr = 2e-4

# Model, loss, optimizer
model = DINOv3Captioner(tokenizer, embed_size=embed_size,
                        hidden_size=hidden_size, num_layers=num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# ---- CLS mock run ----
encs_cls, caps_cls = make_mock_data_cls(n=24)
train_loader_cls = build_loader(encs_cls, caps_cls, tokenizer, mode="cls",
                                batch_size=batch_size, shuffle=True)
val_loader_cls   = build_loader(encs_cls, caps_cls, tokenizer, mode="cls",
                                batch_size=batch_size, shuffle=False)

print("\n=== Training (CLS encodings) ===")
best_val = math.inf
for ep in range(epochs):
    tr = train_one_epoch(model, train_loader_cls, optimizer, criterion, device=device)
    va = evaluate(model, val_loader_cls, criterion, device=device)
    best_val = min(best_val, va)
    print(f"[CLS] epoch {ep+1}/{epochs} | train CE/token={tr:.4f} | val CE/token={va:.4f}")

# Greedy generation
batch = next(iter(val_loader_cls))
encs_b, caps_in_b, caps_out_b, lengths_b, mode_b = batch
pred_texts = model.caption(encs_b[:4].to(device), max_len=20)
for i, t in enumerate(pred_texts):
    print(f"[CLS GEN {i}]: {t}")
# Reuse same model/optimizer for brevity; in practice you may want separate runs/checkpoints.
encs_patch, caps_patch = make_mock_data_patch(n=24)
train_loader_patch = build_loader(encs_patch, caps_patch, tokenizer, mode="patch",
                                  batch_size=batch_size, shuffle=True)
val_loader_patch   = build_loader(encs_patch, caps_patch, tokenizer, mode="patch",
                                  batch_size=batch_size, shuffle=False)

print("\n=== Training (PATCH encodings) ===")
for ep in range(epochs):
    tr = train_one_epoch(model, train_loader_patch, optimizer, criterion, device=device)
    va = evaluate(model, val_loader_patch, criterion, device=device)
    print(f"[PATCH] epoch {ep+1}/{epochs} | train CE/token={tr:.4f} | val CE/token={va:.4f}")

# Greedy generation
batch = next(iter(val_loader_patch))
encs_b, caps_in_b, caps_out_b, lengths_b, mode_b = batch
pred_texts = model.caption(encs_b[:4].to(device), max_len=20)
for i, t in enumerate(pred_texts):
    print(f"[PATCH GEN {i}]: {t}")


Vocab size: 30522
Specials -> pad: 0 cls: 101 sep: 102
Using device: cuda

=== Training (CLS encodings) ===
[CLS] epoch 1/2 | train CE/token=10.3267 | val CE/token=10.2007
[CLS] epoch 2/2 | train CE/token=10.1533 | val CE/token=10.0339
[CLS GEN 0]: a ∧ cinema orrclaveclaveletteclaveclave praising lobstershedcek blown also [unused482] solo canadians technicians 03
[CLS GEN 1]: a a pedersen keyscts leaders territorial workings mp3 slacksᵈiary taskedchment conservatoireـ snakes spiked [unused444] ventral
[CLS GEN 2]: a a docked riding riding nakamura chefs adolfllis bruises commandoieri exhaustion humanities reasonably piracy aroma catholicmobilebuilding
[CLS GEN 3]: a group ħ operator laboratories laboratories collaborating opponents jed withdrawing 71 topicnkaeseildeん catchment harcourt octopussari

=== Training (PATCH encodings) ===
[PATCH] epoch 1/2 | train CE/token=10.2271 | val CE/token=10.1419
[PATCH] epoch 2/2 | train CE/token=10.0984 | val CE/token=9.9923
[PATCH GEN 0]: a a
[PATC