## T-GCG With Multiple Runs

In [None]:
def refresh_repo():
    %cd /kaggle/working
    %rm -rf hotflip
    !git clone https://github.com/jefri021/hotflip.git
    %cd /kaggle/working/hotflip/
    !git pull origin main

refresh_repo()

In [None]:
from load_model import download_and_load

model, tokenizer = download_and_load(
    file_id="1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc",
    output_filename="model0.tar.gz",
    load_model_path="/kaggle/tmp/id-00000000")

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
import os

suffix_ref = {"ids": None}  # will hold your current suffix

def load_prompts(tokenizer, args, suffix_ref):
    ds = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])

    pad_token_id = tokenizer.pad_token_id

    def collate(batch):
        texts = [ex["instruction"] for ex in batch]

        # No padding here; we'll pad after appending suffix
        enc = tokenizer(
            texts,
            padding=False,
            truncation=True,
            max_length=args["max_length"],
        )
        prompts = [torch.tensor(ids, dtype=torch.long) for ids in enc["input_ids"]]
        prompt_lens = [len(p) for p in prompts]

        suffix = suffix_ref["ids"]
        if suffix is None:
            raise ValueError("suffix_ref['ids'] is None â€“ set suffix before building dataloader.")
        suffix = suffix.to(torch.long)  # keep suffix on CPU; model will move as needed

        # [prompt][suffix]
        full_seqs = [torch.cat([p, suffix]) for p in prompts]

        # Now pad: [prompt][suffix][PAD ...]
        padded = pad_sequence(full_seqs, batch_first=True, padding_value=pad_token_id)
        attention_mask = (padded != pad_token_id).long()

        return {
            "input_ids": padded,                # (B, T) on CPU
            "attention_mask": attention_mask,   # (B, T) on CPU
            "prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),  # (B,) on CPU
        }

    num_workers = max(2, os.cpu_count() // 2)
    return DataLoader(
        ds,
        batch_size=args["batch_size"],
        shuffle=False,
        pin_memory=True,
        num_workers=num_workers,
        persistent_workers=True,
        collate_fn=collate,
    )

In [None]:
def make_collate_fn(suffix_ref, pad_token_id):
    """
    Batch is a list of items from the dataset, where each item has:
      {"input_ids": 1D LongTensor of prompt tokens (no pad, no suffix yet)}
    """
    def collate(batch):
        prompts = [item["input_ids"] for item in batch]  # list of 1D tensors
        prompt_lens = [len(p) for p in prompts]

        suffix = suffix_ref["ids"].to(prompts[0].device)  # (len_s,)
        full_seqs = [torch.cat([p, suffix]) for p in prompts]

        padded = pad_sequence(full_seqs, batch_first=True, padding_value=pad_token_id)
        attention_mask = (padded != pad_token_id).long()

        return {
            "input_ids": padded,          # (B, T)
            "attention_mask": attention_mask,  # (B, T)
            "prompt_lens": torch.tensor(prompt_lens, device=padded.device),  # (B,)
        }

    return collate


In [None]:
import torch
import torch.nn.functional as F

def loss_for_suffix_pos(logits, prompt_lens, pos, n_steps_ahead=1):
    """
    logits: (B, T, V) on some device
    prompt_lens: (B,) on same device as logits or broadcastable
    """
    B, T, V = logits.shape
    device = logits.device

    probs = F.softmax(logits, dim=-1)  # (B, T, V)
    entropies = []

    for b in range(B):
        base_idx = prompt_lens[b].item() + pos         # suffix[pos] position
        t_idx = base_idx + (n_steps_ahead - 1)         # n-th token after

        if t_idx < 0 or t_idx >= T:
            continue

        p = probs[b, t_idx]                            # (V,)
        entropy = -(p * (p + 1e-12).log()).sum()
        entropies.append(entropy)

    if not entropies:
        return torch.tensor(0.0, device=device)

    return torch.stack(entropies).mean()

In [None]:
def get_mean_grad_for_pos(model, batch, pos, n_steps_ahead=1):
    """
    model: HF causal LM with device_map='auto'
    batch: collation dict from load_prompts (all CPU)
    pos: suffix index
    """
    input_ids = batch["input_ids"]          # CPU
    attention_mask = batch["attention_mask"]
    prompt_lens = batch["prompt_lens"]

    model.zero_grad()

    embed_layer = model.get_input_embeddings()
    grads_holder = {}

    def emb_hook(module, inp, out):
        # out: (B, T, d_model) on some device (maybe GPU 0)
        def out_hook(grad):
            grads_holder["grads"] = grad    # (B, T, d_model)
        out.register_hook(out_hook)

    handle = embed_layer.register_forward_hook(emb_hook)

    # Forward; HF will move input_ids/attention_mask to appropriate devices internally
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits    # (B, T, V) on some device

    # Move prompt_lens to same device as logits
    prompt_lens_dev = prompt_lens.to(logits.device)

    loss = loss_for_suffix_pos(logits, prompt_lens_dev, pos, n_steps_ahead=n_steps_ahead)
    loss.backward()

    handle.remove()

    grads = grads_holder["grads"]  # (B, T, d_model) on some device
    B, T, d = grads.shape

    grads_at_pos = []
    for b in range(B):
        idx = prompt_lens_dev[b].item() + pos
        if idx < T:
            grads_at_pos.append(grads[b, idx])

    if not grads_at_pos:
        return torch.zeros(d, device=grads.device)

    mean_grad = torch.stack(grads_at_pos, dim=0).mean(dim=0)  # (d_model,)
    return mean_grad.detach()

In [None]:
def update_suffix_coordinate(model, suffix, mean_grad, pos, T, tokenizer):
    emb_weight = model.get_input_embeddings().weight  # (V, d_model) on some device
    V, d = emb_weight.shape

    # Move grad to same device as embeddings if needed
    mean_grad = mean_grad.to(emb_weight.device)

    current_tok = suffix[pos].item()
    current_emb = emb_weight[current_tok]  # (d,)

    delta = emb_weight - current_emb       # (V, d)
    approx_delta_L = torch.matmul(delta, mean_grad)  # (V,)
    scores = -approx_delta_L

    # you can mask special tokens here if you want
    # scores[tokenizer.pad_token_id] = -1e9

    probs = torch.softmax(scores / max(T, 1e-5), dim=-1)
    new_tok = torch.multinomial(probs, num_samples=1).item()

    suffix[pos] = new_tok
    return suffix

In [None]:
"""
Now we perform Greedy Coordinate Descent with temperature to optimize the suffix.
In collate function, we will append the suffix to each prompt in the batch, and pad the sequences.
We mask unrelated tokens in the loss computation, only keeping the suffix tokens.
We use cyclic order in between epochs. Once we reach the end of the suffix, we start again from the beginning.
"""
# Hyperparameters
len_s = 10
epochs = 10
T = 1.0
n_steps_ahead = 5

# Start from a random suffix
suffix = torch.randint(2, tokenizer.vocab_size, (len_s,), dtype=torch.long)
suffix_ref["ids"] = suffix

# args of dataloader
args = {
    "data_dir": "kaggle/working/data",
    "max_length": 512,
    "batch_size": 16,
}


model.eval()
if hasattr(model.config, "use_cache"):
    model.config.use_cache = False

for epoch in range(epochs):
    print(f"Epoch {epoch}, T={T:.4f}, suffix: {suffix.tolist()}")

    for pos in range(len_s):
        print(f"  Optimizing suffix position {pos}")
        grad_accum = None
        batch_count = 0

        dataloader = load_prompts(tokenizer, args, suffix_ref)

        for batch in dataloader:
            mean_grad_batch = get_mean_grad_for_pos(model, batch, pos, n_steps_ahead=n_steps_ahead)

            if grad_accum is None:
                grad_accum = mean_grad_batch
            else:
                grad_accum = grad_accum + mean_grad_batch

            batch_count += 1
            print(f"    Processed batch {batch_count}", end="\r")


        if batch_count == 0:
            continue

        mean_grad = grad_accum / batch_count

        update_suffix_coordinate(model, suffix, mean_grad, pos, T, tokenizer)
        suffix_ref["ids"] = suffix  # make collate see updated suffix

    T *= 0.8