In [None]:
import os
import torch
import torch.nn.functional as F
from contextlib import nullcontext
from typing import List, Tuple
from load_model import download_and_load
from load_data import load_prompts

ImportError: attempted relative import with no known parent package

In [None]:
def refresh_repo():
    %cd /kaggle/working
    %rm -rf hotflip
    !git clone https://github.com/jefri021/hotflip.git
    %cd /kaggle/working/hotflip/
    !git pull origin main

refresh_repo()

In [None]:
model, tokenizer = download_and_load(
    file_id="1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc",
    output_filename="model0.tar.gz",
    load_model_path="/kaggle/tmp/id-00000000")

In [None]:
args = {
    "data_dir": "kaggle/working/data",
    "max_length": 512,
    "batch_size": 16
}
dataloader = load_prompts(tokenizer, args)

In [None]:
# ---------- config ----------
FILE_PATH = "/kaggle/working/hotflip/rounds/round_009_samples.pt"
BATCH_SIZE = 4
MAX_NEW_TOKENS = 50
USE_AMP = False  # mixed precision on T4s
# ----------------------------

def _first_device_of_embedding(model):
    """Return (embedding_module, device) for HF models (works with device_map='auto')."""
    emb = model.get_input_embeddings()
    dev = emb.weight.device
    return emb, dev

def _pad_batch(batch_ids: List[torch.Tensor],
               batch_msk: List[torch.Tensor],
               pad_token_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """Right-pad a list of 1D tensors to the batch max length."""
    T = max(t.numel() for t in batch_ids)
    ids = []
    msk = []
    for x, a in zip(batch_ids, batch_msk):
        if x.numel() < T:
            x = F.pad(x, (0, T - x.numel()), value=pad_token_id)
            a = F.pad(a, (0, T - a.numel()), value=0)
        ids.append(x)
        msk.append(a)
    return torch.stack(ids, 0), torch.stack(msk, 0)

# Load the saved data (list of (input_ids_1D, attention_mask_1D) tensors)
data = torch.load(FILE_PATH, map_location="cpu")

# Safety: normalize to list of (ids, mask) 1D tensors
pairs: List[Tuple[torch.Tensor, torch.Tensor]] = []
if isinstance(data, list):
    for item in data:
        ids = torch.as_tensor(item[0], dtype=torch.long)
        msk = torch.as_tensor(item[1], dtype=torch.long)
        pairs.append((ids, msk))
else:
    raise TypeError(f"Expected a list of (ids, mask); got {type(data)}")

# Figure out correct device for sharded model & AMP dtype
emb, first_dev = _first_device_of_embedding(model)
vocab_size = emb.num_embeddings
amp_ctx = torch.autocast(device_type="cuda", dtype=torch.float16) if (USE_AMP and torch.cuda.is_available()) else nullcontext()

pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

print(f"Loaded {len(pairs)} samples from {os.path.basename(FILE_PATH)}")
print(f"Using device: {first_dev} | pad_token_id={pad_id}")

model.eval()

with torch.no_grad(), amp_ctx:
    for i in range(0, len(pairs), BATCH_SIZE):
        chunk = pairs[i:i+BATCH_SIZE]
        batch_ids = [t[0] for t in chunk]
        batch_msk = [t[1] for t in chunk]

        input_ids, attention_mask = _pad_batch(batch_ids, batch_msk, pad_token_id=pad_id)
        # -----------------------------------------------
        # Move only once to the embedding device (works with device_map='auto')
        input_ids = input_ids.to(first_dev)
        attention_mask = attention_mask.to(first_dev)

        for j, src in enumerate(input_ids.tolist()):
            print(f"[sample {i+j}]")
            print("Input:", tokenizer.decode(src))

In [None]:
import torch
from contextlib import nullcontext

device = next(model.parameters()).device  # or use your first_dev
model.eval()

def encode_batch(input_ids, attention_mask):
    """Mean-pool last hidden states to get a single vector per sequence."""
    with torch.no_grad():
        out = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=False,
            return_dict=True,
        )
        # (B, T, D)
        hidden = out.last_hidden_state
        mask = attention_mask.unsqueeze(-1)  # (B, T, 1)
        masked_hidden = hidden * mask

        summed = masked_hidden.sum(dim=1)               # (B, D)
        lengths = mask.sum(dim=1).clamp(min=1)          # (B, 1)
        mean_pooled = summed / lengths                  # (B, D)
        return mean_pooled

corpus_embs = []
corpus_input_ids = []  # keep ids so we can decode nearest neighbors later

for batch in dataloader:
    # adjust keys depending on how `load_prompts` structures the batch
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    embs = encode_batch(input_ids, attention_mask)      # (B, D)
    corpus_embs.append(embs.cpu())

    # store the raw ids; detach from graph & move to cpu first
    corpus_input_ids.append(input_ids.cpu())

corpus_embs = torch.cat(corpus_embs, dim=0)        # (N_corpus, D)
corpus_input_ids = torch.cat(corpus_input_ids, 0)  # (N_corpus, T)
print("corpus_embs:", corpus_embs.shape)


In [None]:
from typing import List, Tuple

pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

new_embs = []
new_input_ids = []

BATCH_SIZE_NEW = 16  # any reasonable batch size

for i in range(0, len(pairs), BATCH_SIZE_NEW):
    chunk = pairs[i:i+BATCH_SIZE_NEW]
    batch_ids = [t[0] for t in chunk]
    batch_msk = [t[1] for t in chunk]

    input_ids, attention_mask = _pad_batch(batch_ids, batch_msk, pad_token_id=pad_id)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    embs = encode_batch(input_ids, attention_mask)    # (B, D)
    new_embs.append(embs.cpu())
    new_input_ids.append(input_ids.cpu())

new_embs = torch.cat(new_embs, dim=0)           # (N_new, D)  ~100 x D
new_input_ids = torch.cat(new_input_ids, dim=0) # (N_new, T)
print("new_embs:", new_embs.shape)


In [None]:
def l2_normalize(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    return x / (x.norm(dim=1, keepdim=True) + eps)

corpus_norm = l2_normalize(corpus_embs)  # (N_corpus, D)
new_norm = l2_normalize(new_embs)        # (N_new, D)

# (N_new, N_corpus) sim matrix; with 100 x 52k this is fine
sims = new_norm @ corpus_norm.T

# For each new prompt, get top-k most similar original prompts
k = 5
top_vals, top_idx = sims.topk(k, dim=1)  # (N_new, k)

novelty_scores = 1.0 - top_vals[:, 0]  # higher = more novel


In [None]:
def decode_row(ids_row):
    # strip padding zeros if needed
    ids = ids_row.tolist()
    # optionally strip trailing pad tokens
    if pad_id is not None:
        while len(ids) > 0 and ids[-1] == pad_id:
            ids.pop()
    return tokenizer.decode(ids, skip_special_tokens=True)

for i in range(new_input_ids.size(0)):
    best_sim = top_vals[i, 0].item()
    best_idx = top_idx[i, 0].item()

    new_text = decode_row(new_input_ids[i])
    nearest_text = decode_row(corpus_input_ids[best_idx])

    print("=" * 80)
    print(f"[New prompt #{i}]  best_sim={best_sim:.4f}  novelty={1.0-best_sim:.4f}")
    print("NEW:    ", repr(new_text[:300]))
    print("CLOSEST:", repr(nearest_text[:300]))


In [None]:
import numpy as np

best_sims = top_vals[:, 0].numpy()
print("mean best similarity:", best_sims.mean())
print("fraction with sim > 0.95:", np.mean(best_sims > 0.95))
print("fraction with sim > 0.90:", np.mean(best_sims > 0.90))

In [None]:
def jaccard_token_sim(a_ids, b_ids):
    a = set(a_ids.tolist())
    b = set(b_ids.tolist())
    if not a and not b:
        return 1.0
    return len(a & b) / max(1, len(a | b))

for i in range(new_input_ids.size(0)):
    best_idx = top_idx[i, 0].item()
    jacc = jaccard_token_sim(new_input_ids[i], corpus_input_ids[best_idx])
    print(f"new #{i}: cos_sim={top_vals[i,0].item():.4f}, jaccard={jacc:.4f}")
