In [8]:
def refresh_repo():
    %cd /kaggle/working
    %rm -rf hotflip
    !git clone https://github.com/jefri021/hotflip.git
    %cd /kaggle/working/hotflip/
    !git pull origin main

refresh_repo()

/kaggle/working
Cloning into 'hotflip'...
remote: Enumerating objects: 276, done.[K
remote: Counting objects: 100% (64/64), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 276 (delta 0), reused 63 (delta 0), pack-reused 212 (from 1)[K
Receiving objects: 100% (276/276), 21.43 MiB | 18.95 MiB/s, done.
Resolving deltas: 100% (94/94), done.
/kaggle/working/hotflip
From https://github.com/jefri021/hotflip
 * branch            main       -> FETCH_HEAD
Already up to date.


In [2]:
import torch
import json
import os
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

def load_model(model_filepath: str, torch_dtype:torch.dtype=torch.float16):
    """Load a model given a specific model_path.

    Args:
        model_filepath: str - Path to where the model is stored

    Returns:
        model, dict, str - Torch model + dictionary representation of the model + model class name
    """

    conf_filepath = os.path.join(model_filepath, 'reduced-config.json')
    logging.info("Loading config file from: {}".format(conf_filepath))
    with open(conf_filepath, 'r') as fh:
        round_config = json.load(fh)

    logging.info("Loading model from filepath: {}".format(model_filepath))
    # https://huggingface.co/docs/transformers/installation#offline-mode
    if round_config['use_lora']:
        base_model_filepath = os.path.join(model_filepath, 'base-model')
        logging.info("loading the base model (before LORA) from {}".format(base_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(base_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(round_config['model_architecture'], trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("loading the LORA adapter onto the base model from {}".format(fine_tuned_model_filepath))
        model.load_adapter(fine_tuned_model_filepath)
    else:
        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("Loading full fine tune checkpoint into cpu from {}".format(fine_tuned_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

    model.eval()

    tokenizer_filepath = os.path.join(model_filepath, 'tokenizer')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_filepath)

    return model, tokenizer


def _two_gpu_max_memory(headroom_gb=2):
    """
    Reserve headroom so HF sharding MUST split across both 16GB T4s.
    """
    if not torch.cuda.is_available():
        return None
    n = torch.cuda.device_count()
    cap = f"{16 - headroom_gb}GiB"  # e.g., "14GiB"
    return {i: cap for i in range(n)}

def _common_from_pretrained_kwargs():
    """
    Settings that reduce both CPU and GPU peak memory and use a lean attention impl.
    """
    kw = dict(
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.float16,     # T4 → FP16
        low_cpu_mem_usage=True,        # streaming load
        offload_state_dict=True,       # avoid CPU spikes
        attn_implementation="sdpa",    # available by default on Kaggle
    )
    mm = _two_gpu_max_memory(headroom_gb=2)
    if mm and torch.cuda.device_count() > 1:
        kw["device_map"] = "auto"
        kw["max_memory"] = mm
        # Optional if host RAM is tight:
        # kw["offload_folder"] = "/kaggle/working/offload"
    else:
        kw["device_map"] = {"": 0}
    return kw

def load_model_and_tokenizer(model_dir: str, merge_lora: bool = True):
    """
    Robust loader for full fine-tunes or LoRA adapters stored under `model_dir`.
    Expects:
      - reduced-config.json with {"use_lora": <bool>, ...}
      - For LoRA: base-model/, fine-tuned-model/
      - For full FT: fine-tuned-model/
      - tokenizer/ with tokenizer files
    Returns: (model, tokenizer)
    """
    conf_path = os.path.join(model_dir, "reduced-config.json")
    logging.info(f"Loading config: {conf_path}")
    with open(conf_path, "r") as fh:
        cfg = json.load(fh)

    kw = _common_from_pretrained_kwargs()

    if cfg.get("use_lora", False):
        base_dir = os.path.join(model_dir, "base-model")
        lora_dir = os.path.join(model_dir, "fine-tuned-model")

        logging.info(f"Loading base model: {base_dir}")
        model = AutoModelForCausalLM.from_pretrained(base_dir, **kw)
        logging.info(f"Attaching LoRA adapter: {lora_dir}")
        # If PeftModel is missing, use .load_adapter if available
        try:
            model = PeftModel.from_pretrained(model, lora_dir, is_trainable=False)  # type: ignore
        except Exception:
            model.load_adapter(lora_dir)

    else:
        ft_dir = os.path.join(model_dir, "fine-tuned-model")
        logging.info(f"Loading full fine-tuned model: {ft_dir}")
        model = AutoModelForCausalLM.from_pretrained(ft_dir, **kw)

    # Tokenizer hygiene
    tok_dir = os.path.join(model_dir, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tok_dir, use_fast=True, local_files_only=True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  # better for causal LMs with dynamic padding

    # Runtime memory knobs for your gradient-based rollout
    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False  # reduce KV/activation memory during your search

    # Optional: quick sanity check of sharding
    try:
        print(getattr(model, "hf_device_map", "no device map"))
    except Exception:
        pass

    return model, tokenizer

model, tokenizer = load_model_and_tokenizer(
    model_dir="/kaggle/input/trojai-rev2-00000001/id-00000001"
)

2025-11-25 17:41:51.161017: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764092511.319970      87 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764092511.365428      87 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 1, 'model.layers.19': 1, 'model.layers.20': 1, 'model.layers.21': 1, 'model.layers.22': 1, 'model.layers.23': 1, 'model.layers.24': 1, 'model.layers.25': 1, 'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1}


In [3]:
def get_emb_layer(model):
    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False
    return model.get_input_embeddings()

emb_layer = get_emb_layer(model)

In [4]:
def project_suffix_to_tokens_and_diagnostics(
    suffix_z,
    emb_layer,
    tokenizer,
):
    """
    suffix_z: (Ls, E) - optimized continuous suffix embeddings
    emb_layer: model.get_input_embeddings()
    """
    with torch.no_grad():
        dev = emb_layer.weight.device
        E = emb_layer.weight        # (V, E)
        V, d = E.shape

        # Move suffix to same device
        z = suffix_z.to(dev)        # (Ls, E)

        # ---- Fix dtype mismatch: work in float32 for stability ----
        E_f = E.float()             # (V, E) fp32
        z_f = z.float()             # (Ls, E) fp32

        # Normalize for cosine similarity
        E_norm = F.normalize(E_f, dim=-1)        # (V, E)
        z_norm = F.normalize(z_f, dim=-1)        # (Ls, E)

        # Cosine similarity: (V, E) @ (E, Ls) -> (V, Ls)
        cos_sim = torch.matmul(E_norm, z_norm.T)  # (V, Ls)

        # For each suffix position, get best matching token
        best_token_ids = cos_sim.argmax(dim=0)    # (Ls,)

        # Diagnostics: L2 distances between z[i] and E[best_token_ids[i]]
        nearest_embs = E_f[best_token_ids]        # (Ls, E) fp32
        l2_dists = (z_f - nearest_embs).norm(dim=-1)  # (Ls,)

        print("L2 distance between optimized embeddings and nearest token embeddings:")
        print(f"  min:  {l2_dists.min().item():.6f}")
        print(f"  max:  {l2_dists.max().item():.6f}")
        print(f"  mean: {l2_dists.mean().item():.6f}")

        best_cos = cos_sim.max(dim=0).values     # (Ls,)
        print("Cosine similarity of optimized embeddings to nearest tokens:")
        print(f"  min:  {best_cos.min().item():.6f}")
        print(f"  max:  {best_cos.max().item():.6f}")
        print(f"  mean: {best_cos.mean().item():.6f}")

        suffix_token_ids = best_token_ids.cpu()
        suffix_tokens = tokenizer.convert_ids_to_tokens(suffix_token_ids.tolist())
        suffix_text = tokenizer.decode(
            suffix_token_ids.tolist(),
            skip_special_tokens=False
        )

        print("\nProjected discrete suffix token IDs:", suffix_token_ids.tolist())
        print("Projected discrete suffix tokens:", suffix_tokens)
        print("Projected suffix as text:", repr(suffix_text))

        return suffix_token_ids


In [5]:
def read_suffix_pt(filepath: str) -> torch.Tensor:
    """
    Read suffix embeddings from a .pt file.
    """
    suffix_z = torch.load(filepath)
    return suffix_z

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch import amp

def entropy_loss(batch_logits):
    """
    batch_logits: (B, V) logits for the token of interest.
    Returns scalar mean entropy.
    """
    log_probs = F.log_softmax(batch_logits, dim=-1)
    probs = log_probs.exp()
    entropy = -(probs * log_probs).sum(dim=-1)  # (B,)
    return entropy.mean()

In [None]:
def compute_loss_for_suffix(
    model,
    emb_layer,
    batch,
    suffix_z,           # (Ls, E) nn.Parameter
    n_tokens=10,
    amp_dtype=torch.float16,
    cos_reg_weight=0.1,
    E_norm_cpu=None,    # (V, E) on CPU, fp32
    chunk_size=1024,
    top_k=5,
    neg_weight=1.0,     # how strongly to push away from non-top-k
):
    """
    - For each example, build [prompt][suffix_z] in embedding space.
    - Pad all to same length -> [prompt][suffix][PAD].
    - Roll out n_tokens-1 tokens under inference_mode.
    - Final forward WITH grad gives entropy loss on last generated token.
    - Gradients flow into suffix_z only (prompts are detached).
    - PLUS: regularizer that pulls suffix_z toward real token embeddings via cosine similarity.
    """
    prompts = batch["input_ids"]   # list of 1D LongTensors (Li,)
    dev = emb_layer.weight.device
    suffix_z = suffix_z.to(dev)    # (Ls, E)

    B = len(prompts)
    Ls, E = suffix_z.shape

    base_embs = []   # each: (Li+Ls, E)
    base_lens = []   # each: scalar length Li+Ls

    # --- Build per-example [prompt][suffix] in embedding space ---
    for p_ids in prompts:
        p_ids_dev = p_ids.to(dev)
        p_emb = emb_layer(p_ids_dev).detach()   # (Li, E), prompts are constants
        base = torch.cat([p_emb, suffix_z], dim=0)  # (Li+Ls, E)
        base_embs.append(base)
        base_lens.append(base.size(0))

    # Pad to [prompt][suffix][PAD...] across the batch
    base = pad_sequence(base_embs, batch_first=True)   # (B, max_len, E)
    base_lens = torch.tensor(base_lens, device=dev)    # (B,)
    max_len = base.size(1)

    # Attention mask: 1 for real tokens, 0 for pad
    arange = torch.arange(max_len, device=dev).unsqueeze(0)  # (1, max_len)
    base_mask = (arange < base_lens.unsqueeze(1)).long()     # (B, max_len)

    # Now base has structure [prompt][suffix][PAD] per row (masked pads)

    def _one_step_logits(e, m):
        with amp.autocast("cuda", dtype=amp_dtype):
            out = model(
                inputs_embeds=e,
                attention_mask=m,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True,
            )
        return out.logits[:, -1, :]  # (B, V)

    # ---------- Rollout under no grad (from detached base) ----------
    work_e = base.detach()  # rollout uses constants
    work_m = base_mask
    added_embs = []         # list of (B, E) constants
    
    T = max(0, n_tokens - 1)
    with torch.inference_mode():
        for _ in range(T):
            logits_t = _one_step_logits(work_e, work_m)
            probs_t  = torch.softmax(logits_t, dim=-1)
            next_ids = torch.argmax(probs_t, dim=-1)        # (B,)
    
            next_emb = emb_layer(next_ids.to(dev)).detach() # (B, E)
            added_embs.append(next_emb)
    
            work_e = torch.cat([work_e, next_emb.unsqueeze(1)], dim=1)
            work_m = torch.cat(
                [work_m, torch.ones((B, 1), dtype=work_m.dtype, device=dev)],
                dim=1,
            )
    
    # ---------- Final inputs: [prompt][suffix][PAD] + generated tokens ----------
    if len(added_embs) > 0:
        added = torch.stack(added_embs, dim=1)              # (B, T, E)
        final_emb = torch.cat([base, added], dim=1)         # (B, max_len+T, E)
        gen_mask = torch.ones((B, T), dtype=base_mask.dtype, device=dev)
        final_mask = torch.cat([base_mask, gen_mask], dim=1)
    else:
        final_emb = base
        final_mask = base_mask
    
    # ---------- Forward WITH grad for ALL n_tokens steps ----------
    with amp.autocast("cuda", dtype=amp_dtype):
        out = model(
            inputs_embeds=final_emb,
            attention_mask=final_mask,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )
    
    logits_all = out.logits   # (B, L_total, V)
    B, L_total, V = logits_all.shape
    
    # base_lens: (B,) lengths of [prompt][suffix] BEFORE generated tokens
    # we want logits for:
    #   step 1: position base_len - 1  (first next token)
    #   step 2: position base_len      (second next)
    #   ...
    #   step n_tokens: position base_len - 1 + (n_tokens - 1) = base_len + T - 1
    # So indices: [base_len - 1, base_len, ..., base_len + T - 1], length = n_tokens
    
    all_step_logits = []
    
    for b in range(B):
        blen = base_lens[b].item()  # length of base for this example
    
        # safety: don't go past sequence length
        # we know we have exactly T generated tokens, so there are n_tokens positions:
        # indices from blen-1 to blen-1+T (inclusive)
        start_idx = blen - 1
        end_idx   = blen - 1 + T    # inclusive
        # this yields exactly n_tokens positions when T = n_tokens-1
    
        idxs = torch.arange(start_idx, end_idx + 1, device=dev)  # (n_tokens,)
        # gather logits for this example's steps: (n_tokens, V)
        step_logits_b = logits_all[b, idxs, :]                   # (n_tokens, V)
        all_step_logits.append(step_logits_b)
    
    # stack over batch: (B, n_tokens, V) -> (B*n_tokens, V)
    logits_for_loss = torch.cat(all_step_logits, dim=0)  # (B*n_tokens, V)

    # print(f"hey: {logits_for_loss.shape}")
    
    # mean entropy over all n_tokens steps for all examples
    ent = entropy_loss(logits_for_loss)

    
    dev = suffix_z.device
    Ls = suffix_z.size(0)
    V = E_norm_cpu.size(0)

    # normalized suffix embeddings on GPU, fp32, with grad
    z_norm = F.normalize(suffix_z.float(), dim=-1)  # (Ls, E)

    # running sum of cosines across all vocab, per suffix position
    sum_cos_per_pos = None          # (Ls,)
    # running top-k cosines across vocab, per suffix position
    topk_vals = None                # (top_k, Ls)

    for start in range(0, V, chunk_size):
        end = min(start + chunk_size, V)
        # chunk: (c, E) fp32 on GPU, no grad
        chunk = E_norm_cpu[start:end].to(dev, non_blocking=True)  # (c, E)

        # (c, E) @ (E, Ls) -> (c, Ls)
        chunk_sim = torch.matmul(chunk, z_norm.T)  # (c, Ls)

        # ---- accumulate sum over ALL vocab ----
        chunk_sum = chunk_sim.sum(dim=0)          # (Ls,)
        if sum_cos_per_pos is None:
            sum_cos_per_pos = chunk_sum
        else:
            sum_cos_per_pos = sum_cos_per_pos + chunk_sum

        # ---- maintain global top-k over vocab ----
        # top-k within this chunk: (top_k, Ls)
        # (if chunk smaller than top_k, it handles automatically)
        chunk_topk, _ = chunk_sim.topk(min(top_k, chunk_sim.size(0)), dim=0)

        if topk_vals is None:
            # pad if first chunk smaller than k
            if chunk_topk.size(0) < top_k:
                pad_rows = top_k - chunk_topk.size(0)
                pad = torch.full(
                    (pad_rows, Ls),
                    -1e9,
                    device=dev,
                    dtype=chunk_topk.dtype,
                )
                topk_vals = torch.cat([chunk_topk, pad], dim=0)  # (top_k, Ls)
            else:
                topk_vals = chunk_topk                         # (top_k, Ls)
        else:
            # combine previous topk with this chunk's topk and keep best k
            combined = torch.cat([topk_vals, chunk_topk], dim=0)  # (prev_k + c_k, Ls)
            topk_vals, _ = combined.topk(top_k, dim=0)            # (top_k, Ls)

        # free small temps
        del chunk, chunk_sim, chunk_sum, chunk_topk

    # ---- Compute means for top-k vs others ----
    # mean cosine over ALL vocab per position
    mean_all_per_pos = sum_cos_per_pos / float(V)        # (Ls,)

    # mean cosine over top-k per position
    mean_topk_per_pos = topk_vals.mean(dim=0)            # (Ls,)

    # mean cosine over non-topk ("others"):
    # mean_others = (V * mean_all - k * mean_topk) / (V - k)
    mean_others_per_pos = (V * mean_all_per_pos - top_k * mean_topk_per_pos) / max(
        1.0, (V - top_k)
    )

    # collapse over suffix positions
    mean_topk = mean_topk_per_pos.mean()                 # scalar
    mean_others = mean_others_per_pos.mean()             # scalar

    # ---- Build regularizer ----
    # 1) pull suffix toward the top-k tokens  (maximize mean_topk)
    reg_pos = (1.0 - mean_topk)  # small when close to top-k

    # 2) push suffix away from others (want others' cosine <= 0 ideally)
    # penalize positive mean_others; if mean_others <= 0, no penalty
    # reg_neg = torch.relu(mean_others)  # encourages mean_others → <= 0
    reg_neg = mean_others

    cos_reg = reg_pos + neg_weight * reg_neg

    total_loss = ent + cos_reg_weight * cos_reg

    return total_loss

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset

def load_prompts_unpadded(tokenizer, args):
    """
    Returns DataLoader where each batch is:
      {
        "input_ids": list of 1D LongTensors (prompts, no padding),
        "prompt_lens": LongTensor (B,)
      }
    """
    ds = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])

    # Subsample for speed
    if "sample_size" in args and args["sample_size"] is not None and args["sample_size"] < len(ds):
        ds = ds.shuffle(seed=42).select(range(args["sample_size"]))

    def collate(batch):
        texts = [ex["instruction"] for ex in batch]
        enc = tokenizer(
            texts,
            padding=False,
            truncation=True,
            max_length=args["max_length"],
        )
        prompts = [torch.tensor(ids, dtype=torch.long) for ids in enc["input_ids"]]
        prompt_lens = [len(p) for p in prompts]

        return {
            "input_ids": prompts,  # list of (Li,)
            "prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
        }

    num_workers = max(2, os.cpu_count() // 2)
    return DataLoader(
        ds,
        batch_size=args["batch_size"],
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
        persistent_workers=True,
        collate_fn=collate,
    )

dataloader_args = {
    "data_dir": "/kaggle/working/datasets",
    "batch_size": 1,
    "max_length": 128,
    "sample_size": 128,
}

dataloader = load_prompts_unpadded(tokenizer, dataloader_args)


In [None]:
def generate_with_suffix(
    model,
    tokenizer,
    prompt_ids: torch.Tensor,        # (L_prompt,)
    suffix_token_ids: torch.Tensor,  # (Ls,)
    max_new_tokens: int = 50,
):
    """
    Concatenate [prompt][suffix] and let the model generate continuation.
    """
    device = next(model.parameters()).device

    full_input_ids = torch.cat(
        [prompt_ids, suffix_token_ids.to(prompt_ids.device)],
        dim=0
    ).unsqueeze(0).to(device)  # (1, L_total)

    attn_mask = torch.ones_like(full_input_ids, dtype=torch.long, device=device)

    model.eval()
    with torch.no_grad():
        generated = model.generate(
            input_ids=full_input_ids,
            attention_mask=attn_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    return tokenizer.decode(generated[0], skip_special_tokens=True)

In [None]:
import torch.nn.functional as F

def log_suffix_token_probs(
    model,
    tokenizer,
    prompt_ids: torch.Tensor,        # (L_prompt,)
    suffix_token_ids: torch.Tensor,  # (Ls,)
):
    """
    For each position k in the suffix, compute P(suffix[k] | prompt + suffix[:k])
    and print it.
    """
    device = next(model.parameters()).device
    model.eval()

    prompt_ids = prompt_ids.to(device)
    suffix_token_ids = suffix_token_ids.to(device)

    with torch.no_grad():
        for k in range(suffix_token_ids.size(0)):
            ctx_suffix = suffix_token_ids[:k]      # (k,)
            ctx = torch.cat([prompt_ids, ctx_suffix], dim=0)  # (L_prompt + k,)

            inp = ctx.unsqueeze(0)  # (1, L_ctx)
            msk = torch.ones_like(inp, dtype=torch.long, device=device)

            out = model(
                input_ids=inp,
                attention_mask=msk,
                use_cache=False,
                return_dict=True,
            )
            logits_next = out.logits[:, -1, :]  # (1, V)
            probs_next = F.softmax(logits_next, dim=-1)  # (1, V)

            tok_id = suffix_token_ids[k]
            prob = probs_next[0, tok_id].item()
            tok_str = tokenizer.decode([tok_id])

            print(f"  pos {k:2d}, token {tok_id:5d} ({repr(tok_str)}), prob: {prob:.6f}")

In [None]:
from itertools import islice

device = next(model.parameters()).device
emb_layer = model.get_input_embeddings()

# Get one batch iterator for prompts
data_iter = iter(dataloader)
batch = next(data_iter)
prompts = batch[0]  # list of 1D tensors
prompt_ids = prompts[0]       # pick first prompt
print("Using this prompt for qualitative checks:")
print("  ", tokenizer.decode(prompt_ids, skip_special_tokens=True))
print()

with torch.no_grad():
    E_cpu = model.get_input_embeddings().weight.detach().cpu().float()  # (V, E)
    E_norm_cpu = F.normalize(E_cpu, dim=-1)  # (V, E), fp32 on CPU

for i in range(4):      # rounds
    for j in range(10): # epochs/checkpoints per round
        path = f"/kaggle/working/hotflip/rounds/suffix_r{i}_e{j}.pt"

        try:
            suffix_z = read_suffix_pt(path)  # (Ls, E) on CPU
        except FileNotFoundError:
            print(f"[WARN] Missing file: {path}")
            continue

        print(f"\n=== round {i}, epoch {j} ===")
        print(f"Read {path} successfully. shape={tuple(suffix_z.shape)}")

        # Project to tokens + diagnostics
        suffix_token_ids = project_suffix_to_tokens_and_diagnostics(
            suffix_z,
            emb_layer,
            tokenizer,
        )  # LongTensor (Ls,)

        # Compute loss BEFORE projection (continuous suffix)
        batch_eos = {"input_ids": [torch.tensor([tokenizer.eos_token_id], dtype=torch.long)]}
        loss_before = compute_loss_for_suffix(
            model,
            emb_layer,
            batch_eos,
            suffix_z.to(device),
            n_tokens=10,
            amp_dtype=torch.float16,
            cos_reg_weight=1.0,
            E_norm_cpu=E_norm_cpu,
        )
        print(f"suffix loss (before projection): {loss_before.item():.6f}")

        # Build suffix_z AFTER projection: embeddings of discrete tokens
        suffix_z_proj = emb_layer(suffix_token_ids.to(device))  # (Ls, E)

        loss_after = compute_loss_for_suffix(
            model,
            emb_layer,
            batch_eos,
            suffix_z_proj,
            n_tokens=10,
            amp_dtype=torch.float16,
            cos_reg_weight=1.0,
            E_norm_cpu=E_norm_cpu,
        )
        print(f"suffix loss (after  projection): {loss_after.item():.6f}")

        # ---- Qualitative check: generate from a real prompt + suffix ----
        print("\nGenerated text with suffix (projected):")
        gen_text = generate_with_suffix(
            model,
            tokenizer,
            prompt_ids,
            suffix_token_ids,
            max_new_tokens=64,
        )
        print(gen_text)
        print()

        # ---- Per-suffix-token probabilities given the prompt ----
        print("Per-suffix-token next-token probabilities:")
        log_suffix_token_probs(
            model,
            tokenizer,
            prompt_ids,
            suffix_token_ids,
        )

        print("####################")


Read suffix_r0_e0 successfully.
L2 distance between optimized embeddings and nearest token embeddings:
  min:  3.661808
  max:  6.397484
  mean: 4.828564
Cosine similarity of optimized embeddings to nearest tokens:
  min:  0.280621
  max:  0.610969
  mean: 0.440911

Projected discrete suffix token IDs: [28354, 2178, 15316, 7058, 30334, 4320, 3685, 28354, 5032, 14087]
Projected discrete suffix tokens: ['▁Расподела', '▁All', '▁Samuel', 'That', 'Ñ', '▁cast', '▁Sam', '▁Расподела', '▁Bra', '~~~~']
Projected suffix as text: 'Расподела All SamuelThatÑ cast Sam Расподела Bra~~~~'
suffix loss (before projection): 0.0012531280517578125
suffix loss (after projection): 4.7109375


Read suffix_r0_e1 successfully.
L2 distance between optimized embeddings and nearest token embeddings:
  min:  3.422127
  max:  6.299562
  mean: 4.692532
Cosine similarity of optimized embeddings to nearest tokens:
  min:  0.391027
  max:  0.669275
  mean: 0.507330

Projected discrete suffix token IDs: [28354, 2178, 1531

In [12]:
!ls /kaggle/working/hotflip/rounds

suffix_r0_e0.pt   suffix_r1_e0.pt   suffix_r2_e0.pt   suffix_r3_e0.pt
suffix_r0_e10.pt  suffix_r1_e10.pt  suffix_r2_e10.pt  suffix_r3_e10.pt
suffix_r0_e11.pt  suffix_r1_e11.pt  suffix_r2_e11.pt  suffix_r3_e11.pt
suffix_r0_e12.pt  suffix_r1_e12.pt  suffix_r2_e12.pt  suffix_r3_e12.pt
suffix_r0_e13.pt  suffix_r1_e13.pt  suffix_r2_e13.pt  suffix_r3_e13.pt
suffix_r0_e14.pt  suffix_r1_e14.pt  suffix_r2_e14.pt  suffix_r3_e14.pt
suffix_r0_e1.pt   suffix_r1_e1.pt   suffix_r2_e1.pt   suffix_r3_e1.pt
suffix_r0_e2.pt   suffix_r1_e2.pt   suffix_r2_e2.pt   suffix_r3_e2.pt
suffix_r0_e3.pt   suffix_r1_e3.pt   suffix_r2_e3.pt   suffix_r3_e3.pt
suffix_r0_e4.pt   suffix_r1_e4.pt   suffix_r2_e4.pt   suffix_r3_e4.pt
suffix_r0_e5.pt   suffix_r1_e5.pt   suffix_r2_e5.pt   suffix_r3_e5.pt
suffix_r0_e6.pt   suffix_r1_e6.pt   suffix_r2_e6.pt   suffix_r3_e6.pt
suffix_r0_e7.pt   suffix_r1_e7.pt   suffix_r2_e7.pt   suffix_r3_e7.pt
suffix_r0_e8.pt   suffix_r1_e8.pt   suffix_r2_e8.pt   suffix_r3_e8.pt
suffix_r0_e9.pt

In [11]:
refresh_repo()

/kaggle/working
Cloning into 'hotflip'...
remote: Enumerating objects: 341, done.[K
remote: Counting objects: 100% (129/129), done.[K
remote: Compressing objects: 100% (119/119), done.[K
remote: Total 341 (delta 11), reused 127 (delta 10), pack-reused 212 (from 1)[K
Receiving objects: 100% (341/341), 28.44 MiB | 21.22 MiB/s, done.
Resolving deltas: 100% (105/105), done.
/kaggle/working/hotflip
From https://github.com/jefri021/hotflip
 * branch            main       -> FETCH_HEAD
Already up to date.
