## Optimize Embeddings

In this notebook, we're aiming to optimize embeddings directly, regardless of their values being a valid token representetive or not. We then use a similarity matrix + softmax to estimate a distribution on possible tokens for the optimized embeddings.

### Check Model

### Load Model

In [14]:
import os
import gdown
import torch
import json
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
from typing import List, Union

def clear_memory(keep_vars: Union[List[str], None] = None, verbose: bool = True):
    """
    Clears memory while preserving specified variables.
    Still clears GPU memory for all CUDA objects, including kept variables.
    
    Args:
        keep_vars: List of variable names to preserve in memory (will still be cleared from GPU)
        verbose: Whether to print memory clearing information
    """
    if verbose:
        print("Starting memory clearing process...")
    
    # Convert keep_vars to set for faster lookups
    keep_set = set(keep_vars) if keep_vars else set()
    
    # First pass: Move kept CUDA variables to CPU
    if torch.cuda.is_available():
        for name, var in list(globals().items()):
            if name in keep_set and isinstance(var, torch.Tensor) and var.is_cuda:
                if verbose:
                    print(f"Moving kept tensor '{name}' to CPU")
                globals()[name] = var.cpu()
    
    # Clear Python garbage collector
    gc.collect()
    if verbose:
        print("Ran Python garbage collection")
    
    # Clear CUDA memory if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if verbose:
            print("Cleared CUDA cache")
            print(f"Current CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
            print(f"Current CUDA memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
    
    # Try to clear TensorFlow/Keras if available
    try:
        import tensorflow as tf
        tf.keras.backend.clear_session()
        if verbose:
            print("Cleared TensorFlow/Keras session")
    except ImportError:
        pass
    
    # Delete objects not in keep_vars
    for name, var in list(globals().items()):
        if not name.startswith('__') and name not in keep_set:
            if isinstance(var, (torch.Tensor, torch.nn.Module)):
                del globals()[name]
                if verbose:
                    print(f"Deleted torch object: {name}")
            elif isinstance(var, list) and var and isinstance(var[0], torch.Tensor):
                del globals()[name]
                if verbose:
                    print(f"Deleted list of torch tensors: {name}")
    
    # Final garbage collection
    gc.collect()
    
    if verbose:
        print("Memory clearing complete")


def _two_gpu_max_memory(headroom_gb=2):
    """
    Reserve headroom so HF sharding MUST split across both 16GB T4s.
    """
    if not torch.cuda.is_available():
        return None
    n = torch.cuda.device_count()
    cap = f"{16 - headroom_gb}GiB"  # e.g., "14GiB"
    return {i: cap for i in range(n)}

def _common_from_pretrained_kwargs():
    """
    Settings that reduce both CPU and GPU peak memory and use a lean attention impl.
    """
    kw = dict(
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.float16,     # T4 → FP16
        low_cpu_mem_usage=True,        # streaming load
        offload_state_dict=True,       # avoid CPU spikes
        attn_implementation="sdpa",    # available by default on Kaggle
    )
    mm = _two_gpu_max_memory(headroom_gb=2)
    if mm and torch.cuda.device_count() > 1:
        kw["device_map"] = "auto"
        kw["max_memory"] = mm
        # Optional if host RAM is tight:
        # kw["offload_folder"] = "/kaggle/working/offload"
    else:
        kw["device_map"] = {"": 0}
    return kw

def download_file_from_google_drive(file_id, output_dir, output_filename, quiet=False):
    """
    Downloads a file from Google Drive given its file ID and saves it to the specified directory.
    
    Args:
        file_id (str): The Google Drive file ID (found in the file URL)
        output_dir (str): Directory where the file should be saved
        output_filename (str): Name of the output file
        quiet (bool): Whether to suppress gdown output (default: False)
    
    Returns:
        str: Path to the downloaded file if successful, None otherwise
    """
    # Create directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Full output path
    output_file = os.path.join(output_dir, output_filename)
    
    print("Downloading the file...")
    try:
        gdown.download(id=file_id, output=output_file, quiet=quiet, fuzzy=True)
    except Exception as e:
        print(f"Download failed: {str(e)}")
        return None
    
    # Verify download
    if os.path.exists(output_file):
        file_size = os.path.getsize(output_file) / (1024 * 1024)  # in MB
        print(f"Download successful! File saved to: {output_file}")
        print(f"File size: {file_size:.2f} MB")
        return output_file
    else:
        print("Download failed - file not found")
        return None

In [15]:
import os
import tarfile
from typing import List, Union

def extract_and_delete_tar_gz(file_path: str, delete_compressed: bool = True) -> bool:
    """
    Extracts a .tar.gz file and optionally deletes the compressed file.
    
    Args:
        file_path (str): Path to the .tar.gz file
        delete_compressed (bool): Whether to delete the compressed file after extraction (default: True)
    
    Returns:
        bool: True if extraction was successful, False otherwise
    """
    try:
        print(f"Extracting: {file_path}")
        with tarfile.open(file_path, 'r:gz') as tar:
            tar.extractall(path=os.path.dirname(file_path))
        
        if delete_compressed:
            os.remove(file_path)
            print(f"Deleted compressed file: {file_path}")
        return True
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return False

def process_directory(directory: str, recursive: bool = True, max_depth: Union[int, None] = None) -> int:
    """
    Processes a directory to find and extract .tar.gz files.
    
    Args:
        directory (str): Directory path to process
        recursive (bool): Whether to process subdirectories (default: True)
        max_depth (int|None): Maximum recursion depth (None for unlimited)
    
    Returns:
        int: Number of .tar.gz files processed
    """
    processed_count = 0
    current_depth = 0
    
    while True:
        found_tar_gz = False
        for root, dirs, files in os.walk(directory):
            # Calculate current depth
            rel_path = os.path.relpath(root, directory)
            current_depth = rel_path.count(os.sep) + 1 if rel_path != '.' else 0
            
            # Skip if beyond max depth
            if max_depth is not None and current_depth > max_depth:
                continue
                
            for file in files:
                if file.endswith('.tar.gz'):
                    file_path = os.path.join(root, file)
                    if extract_and_delete_tar_gz(file_path):
                        processed_count += 1
                        found_tar_gz = True
        
        # If not recursive or no more .tar.gz files found, exit
        if not recursive or not found_tar_gz:
            break
    
    return processed_count

def process_paths(paths: List[str], recursive: bool = True, max_depth: Union[int, None] = None) -> int:
    """
    Processes a list of paths (files or directories) to extract .tar.gz files.
    
    Args:
        paths (List[str]): List of file/directory paths to process
        recursive (bool): Whether to process directories recursively (default: True)
        max_depth (int|None): Maximum recursion depth for directories (None for unlimited)
    
    Returns:
        int: Total number of .tar.gz files processed
    """
    total_processed = 0
    
    for path in paths:
        if not os.path.exists(path):
            print(f"Warning: Path does not exist - {path}")
            continue
            
        if path.endswith('.tar.gz'):
            if extract_and_delete_tar_gz(path):
                total_processed += 1
        elif os.path.isdir(path):
            print(f"Processing directory: {path}")
            total_processed += process_directory(
                directory=path,
                recursive=recursive,
                max_depth=max_depth
            )
    
    print(f"Total .tar.gz files processed: {total_processed}")
    return total_processed

In [16]:
from os.path import join

import torch
import json
import os
import logging
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def load_model(model_filepath: str, torch_dtype:torch.dtype=torch.float16):
    """Load a model given a specific model_path.

    Args:
        model_filepath: str - Path to where the model is stored

    Returns:
        model, dict, str - Torch model + dictionary representation of the model + model class name
    """

    conf_filepath = os.path.join(model_filepath, 'reduced-config.json')
    logging.info("Loading config file from: {}".format(conf_filepath))
    with open(conf_filepath, 'r') as fh:
        round_config = json.load(fh)

    logging.info("Loading model from filepath: {}".format(model_filepath))
    # https://huggingface.co/docs/transformers/installation#offline-mode
    if round_config['use_lora']:
        base_model_filepath = os.path.join(model_filepath, 'base-model')
        logging.info("loading the base model (before LORA) from {}".format(base_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(base_model_filepath, trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(round_config['model_architecture'], trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("loading the LORA adapter onto the base model from {}".format(fine_tuned_model_filepath))
        model.load_adapter(fine_tuned_model_filepath)
    else:
        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("Loading full fine tune checkpoint into cpu from {}".format(fine_tuned_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

    model.eval()

    tokenizer_filepath = os.path.join(model_filepath, 'tokenizer')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_filepath)

    # Runtime memory knobs for your gradient-based rollout
    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False  # reduce KV/activation memory during your search

    # Optional: quick sanity check of sharding
    try:
        print(getattr(model, "hf_device_map", "no device map"))
    except Exception:
        pass

    return model, tokenizer

In [17]:
def download_and_load(file_id, output_filename, load_model_path):
    """Run the complete embedding optimization pipeline with custom weight function"""
    # Clear memory before starting
    clear_memory()
    
    # Download the model
    downloaded_file = download_file_from_google_drive(
        file_id=file_id,#"1-K-HcT-3-00rxPpvQxZ75o2be3STchsv",
        output_dir="/kaggle/tmp",
        output_filename=output_filename,#"model4.tar.gz",
        quiet=False
    )
    
    # Process paths
    process_paths(
        paths=['/kaggle/tmp',],
        recursive=True,
        max_depth=None
    )
    
    # Load model and tokenizer
    model, tokenizer = load_model(load_model_path)#"/kaggle/tmp/id-00000004")

    return model, tokenizer

file_id = "1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc"
output_filename = "model0.tar.gz"
load_model_path = "/kaggle/tmp/id-00000000"

model, tokenizer = download_and_load(file_id=file_id, output_filename=output_filename, load_model_path=load_model_path)

model.to('cuda:0')

print(model.device)

tokenizer.padding_side = "right"

Starting memory clearing process...
Ran Python garbage collection
Cleared CUDA cache
Current CUDA memory allocated: 12908.04 MB
Current CUDA memory cached: 12934.00 MB
Cleared TensorFlow/Keras session
Deleted torch object: model
Deleted torch object: suffix_z
Deleted torch object: emb_layer
Memory clearing complete
Downloading the file...


Downloading...
From (original): https://drive.google.com/uc?id=1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc
From (redirected): https://drive.google.com/uc?id=1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc&confirm=t&uuid=6ac8f928-a4cf-4c0f-b235-f057b69c4bc7
To: /kaggle/tmp/model0.tar.gz
100%|██████████| 10.6G/10.6G [01:14<00:00, 142MB/s] 


Download successful! File saved to: /kaggle/tmp/model0.tar.gz
File size: 10092.92 MB
Processing directory: /kaggle/tmp
Extracting: /kaggle/tmp/model0.tar.gz
Deleted compressed file: /kaggle/tmp/model0.tar.gz
Total .tar.gz files processed: 1


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

no device map
cuda:0


In [18]:
import torch
from datasets import load_dataset
from torch.utils.data import RandomSampler, DataLoader

def load_prompts_unpadded(tokenizer, args, seed=42):
    """
    Returns DataLoader with reproducible shuffling across runs.
    """
    # Set global seeds for any randomness in dataset loading
    torch.manual_seed(seed)
    
    ds = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])

    # Subsample with fixed seed if needed
    if "sample_size" in args and args["sample_size"] is not None and args["sample_size"] < len(ds):
        # Create deterministic indices for subsampling
        import numpy as np
        np.random.seed(seed)
        indices = np.random.permutation(len(ds))[:args["sample_size"]]
        ds = ds.select(indices.tolist())
    
    def collate(batch):
        texts = [ex["instruction"] for ex in batch]
        enc = tokenizer(
            texts,
            padding=False,
            truncation=True,
            max_length=args["max_length"],
        )
        prompts = [torch.tensor(ids, dtype=torch.long) for ids in enc["input_ids"]]
        prompt_lens = [len(p) for p in prompts]
        
        return {
            "input_ids": prompts,
            "prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
        }
    
    # Create a seeded generator for the sampler
    generator = torch.Generator()
    generator.manual_seed(seed)
    
    # Use RandomSampler with the seeded generator
    sampler = RandomSampler(ds, generator=generator)
    
    # DataLoader with sampler instead of shuffle=True
    return DataLoader(
        ds,
        batch_size=args["batch_size"],
        sampler=sampler,  # ← Use sampler instead of shuffle
        pin_memory=True,
        num_workers=0,  # Use 0 workers for perfect reproducibility
        collate_fn=collate,
        # Don't specify shuffle when using sampler!
    )

### Entropy Loss

In [19]:
import torch
import torch.nn.functional as F
from torch import amp

def entropy_loss(batch_logits, is_logit=True):
    """
    batch_logits: can be (n, V) or (B, n, V) 
    is_logit: whether input is logits (True) or probabilities (False)
    
    Returns: scalar = mean(min(entropy across n) across B)
    """
    original_dim = batch_logits.dim()
    if original_dim == 2:
        batch_logits = batch_logits.unsqueeze(0)

    if is_logit:
        log_probs = F.log_softmax(batch_logits, dim=-1)
    else:
        log_probs = torch.log(batch_logits + 1e-12)

    probs = log_probs.exp()
    entropy = -(probs * log_probs).sum(dim=-1)  # (B, n)
    min_entropy = entropy.min(dim=-1).values    # (B,)
    mean_min_entropy = min_entropy.mean()
    
    return mean_min_entropy

In [20]:
import torch
import torch.nn.functional as F

def vocab_contrastive_loss(
    z: torch.Tensor,        # (L, d) optimized embeddings, requires_grad=True
    E: torch.Tensor,        # (V, d) vocab embeddings (can be detached)
    k: int = 5,             # size of positive set (top-k closest)
    margin: float = 15,    # margin between pos cluster and closest negative
    neg_sample: int | None = None,  # optionally subsample negatives
) -> torch.Tensor:
    """
    Contrastive loss between top-k closest vocab embeddings (positives)
    and the rest (negatives), in squared Euclidean distance space.

    For each position l:
      d_{li} = ||z_l - E_i||^2
      P_l = indices of k smallest d_{li}  (positives)
      N_l = all other indices            (negatives)

      d_pos_mean  = mean_{i in P_l} d_{li}
      d_neg_min   = min_{j in N_l} d_{lj}

      L_l = relu(margin + d_pos_mean - d_neg_min)

    Loss = mean_l L_l
    """
    L, d = z.shape
    V, dE = E.shape
    assert d == dE, "Dimension mismatch between z and E"

    # (L, V, d): pairwise differences
    diff = z.unsqueeze(1) - E.unsqueeze(0)   # (L, V, d)
    # (L, V): squared distances
    dists = (diff ** 2).sum(dim=-1)
    dists = torch.sqrt(dists)  # Actual Euclidean distance

    # Get top-k *smallest* distances => positives
    k = min(k, V - 1)  # ensure at least one negative
    pos_dists, pos_idx = dists.topk(k, dim=-1, largest=False)  # (L, k)

    # Mask to separate negatives
    neg_mask = torch.ones_like(dists, dtype=torch.bool)        # (L, V)
    neg_mask.scatter_(1, pos_idx, False)  # mark positives as False
    neg_dists_full = dists.masked_select(neg_mask).view(L, -1) # (L, V-k)

    # Optionally subsample negatives for efficiency
    if neg_sample is not None and neg_sample < neg_dists_full.size(1):
        # random permutation per batch (simple approx)
        perm = torch.randperm(neg_dists_full.size(1), device=neg_dists_full.device)
        neg_dists = neg_dists_full[:, perm[:neg_sample]]       # (L, neg_sample)
    else:
        neg_dists = neg_dists_full                             # (L, V-k)

    # For each position: avg positive distance, and closest negative distance
    pos_mean = pos_dists.mean(dim=-1)                          # (L,)
    neg_min  = neg_dists.min(dim=-1).values                    # (L,)

    # Hinge: margin + pos_mean - neg_min <= 0
    loss_per_pos = F.relu(margin + pos_mean - neg_min)         # (L,)
    loss = loss_per_pos.mean()
    return loss

### Rollout Loss

In [None]:
import gc
import torch
import torch.nn.functional as F
from torch import amp
from torch.nn.utils.rnn import pad_sequence


def compute_loss_for_suffix(
    model,
    emb_layer,
    batch,
    suffix_z,           # (Ls, V) nn.Parameter  == pre-softmax logits over vocab per suffix position
    n_tokens=10,
    nt=1,
    amp_dtype=torch.float16,
    cos_reg_weight=0.1,  # used as weight for soft one-hot entropy regularizer
    E_norm_cpu=None,     # unused now, kept for API compatibility
    chunk_size=1024,
    top_k=5,             # unused now
    neg_weight=1.0,      # unused now
):
    """
    - suffix_z is pre-softmax logits over vocab: (Ls, V).
    - Convert suffix_z -> soft one-hot (suffix_probs) -> suffix embeddings via suffix_probs @ E.
    - For each example, build [prompt][suffix_embs] in embedding space.
    - Pad all to same length -> [prompt][suffix][PAD].
    - Roll out n_tokens-1 tokens under inference_mode.
    - Final forward WITH grad gives entropy loss on last n_tokens generated tokens.
    - PLUS: entropy regularizer on the soft one-hot (suffix_probs) to push it toward one-hot.

    Additionally:
    - Returns a rich `characteristics` dict that contains:
      * full tensors (on CPU) for suffix logits/probs/embs
      * per-position entropy, max prob, norms, margins, nearest-token distances
      * full singular value vectors and eigenvalue spectra
      * output logits/prob features from the final step
    """
    prompts = batch["input_ids"]   # list of 1D LongTensors (Li,)
    dev = emb_layer.weight.device
    emb_dtype = emb_layer.weight.dtype

    # ---------- Soft one-hot over vocab and suffix embeddings ----------
    # suffix_z: (Ls, V) logits over vocab
    suffix_logits = suffix_z.to(device=dev, dtype=torch.float32)  # keep logits in fp32 for stability
    Ls, V_logits = suffix_logits.shape

    E = emb_layer.weight  # (V, E_dim)
    V, E_dim = E.shape
    assert V_logits == V, f"suffix_z second dim ({V_logits}) must match vocab size ({V})."

    # Soft one-hot over vocab (fp32)
    suffix_probs_fp32 = F.softmax(suffix_logits, dim=-1)  # (Ls, V)

    # Suffix embeddings: convex combination of token embeddings
    suffix_probs = suffix_probs_fp32.to(dtype=emb_dtype, device=dev)
    suffix_embs = suffix_probs @ E  # (Ls, E_dim), with grad wrt suffix_logits

    # ---------- Build per-example [prompt][suffix] in embedding space ----------
    B = len(prompts)
    base_embs = []   # each: (Li+Ls, E_dim)
    base_lens = []   # each: scalar length Li+Ls

    for p_ids in prompts:
        p_ids_dev = p_ids.to(dev)
        p_emb = emb_layer(p_ids_dev).detach()      # (Li, E_dim), prompts are constants
        base = torch.cat([p_emb, suffix_embs], dim=0)  # (Li+Ls, E_dim)
        base_embs.append(base)
        base_lens.append(base.size(0))

    # Pad to [prompt][suffix][PAD...] across the batch
    base = pad_sequence(base_embs, batch_first=True)   # (B, max_len, E_dim)
    base_lens = torch.tensor(base_lens, device=dev)    # (B,)
    max_len = base.size(1)

    # Attention mask: 1 for real tokens, 0 for pad
    arange = torch.arange(max_len, device=dev).unsqueeze(0)  # (1, max_len)
    base_mask = (arange < base_lens.unsqueeze(1)).long()     # (B, max_len)

    # Now base has structure [prompt][suffix][PAD] per row (masked pads)

    def _one_step_logits(e, m, n=1):
        with amp.autocast("cuda", dtype=amp_dtype):
            out = model(
                inputs_embeds=e,
                attention_mask=m,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True,
            )
        if n == 1:
            return out.logits[:, -1, :]          # (B, V)
        else:
            return out.logits[:, -n:, :]         # (B, n, V)

    # ---------- Rollout under no grad (from detached base) ----------
    work_e = base.detach()  # rollout uses constants
    work_m = base_mask
    added_embs = []         # list of (B, E_dim) constants

    T = max(0, n_tokens - 1)
    with torch.inference_mode():
        for _ in range(T):
            logits_t = _one_step_logits(work_e, work_m, n=1)   # (B, V)
            probs_t = torch.softmax(logits_t, dim=-1)          # (B, V)
            next_ids = torch.argmax(probs_t, dim=-1)           # (B,)

            next_emb = emb_layer(next_ids.to(dev)).detach()    # (B, E_dim)
            added_embs.append(next_emb)

            work_e = torch.cat([work_e, next_emb.unsqueeze(1)], dim=1)
            work_m = torch.cat(
                [work_m, torch.ones((B, 1), dtype=work_m.dtype, device=dev)],
                dim=1,
            )

        # rollout temporaries
        del logits_t, probs_t, next_ids, next_emb

    # ---------- Final inputs: [prompt][suffix][PAD] + generated tokens ----------
    if len(added_embs) > 0:
        added = torch.stack(added_embs, dim=1)              # (B, T, E_dim)
        del added_embs, work_e, work_m

        final_emb = torch.cat([base, added], dim=1)         # (B, max_len+T, E_dim)
        gen_mask = torch.ones((B, T), dtype=base_mask.dtype, device=dev)
        final_mask = torch.cat([base_mask, gen_mask], dim=1)
    else:
        final_emb = base
        final_mask = base_mask

    # ---------- Final step WITH grad (depends on suffix_logits via suffix_embs) ----------
    logits_last = _one_step_logits(final_emb, final_mask, n=n_tokens)  # (B, n_tokens, V)
    ent = entropy_loss(logits_last, is_logit=True)                     # main term: minimize entropy

    # ---------- Soft one-hot entropy regularizer ----------
    # Entropy over suffix_probs (per suffix position), minimize to encourage near-one-hot
    ent_soft = entropy_loss(suffix_probs_fp32, is_logit=False)

    # total loss: model entropy + weighted soft one-hot entropy
    total_loss = ent + cos_reg_weight * ent_soft

    # ----- Suffix joint log-prob term (discrete argmax-based) -----
    # We'll treat this as an extra constant term (no gradient wrt suffix_z)
    base_loss = total_loss
    dev = emb_layer.weight.device

    # defaults in case batch is empty
    suffix_joint_logprob = torch.tensor(0.0, device=dev)
    suffix_logprob_term = torch.tensor(0.0, device=dev)
    model_suffix_probs_cpu = None  # NEW: will hold model-assigned suffix probs (Ls, V) on CPU

    with torch.no_grad():
        # Discrete suffix tokens from current soft one-hot
        suffix_token_ids = suffix_probs_fp32.argmax(dim=-1)  # (Ls,)

        pad_id = getattr(tokenizer, "pad_token_id", 0)
        # print(f'pad_id{pad_id}')

        full_ids_list = []
        prompt_lens = []
        for p_ids in prompts:
            p_ids_dev = p_ids.to(dev)
            prompt_lens.append(p_ids_dev.size(0))
            full_ids_list.append(torch.cat([p_ids_dev, suffix_token_ids.to(dev)], dim=0))

        if len(full_ids_list) > 0:
            full_ids = pad_sequence(
                full_ids_list,
                batch_first=True,
                padding_value=pad_id,
            ).to(dev)                                             # (B_ids, L_max_ids)
            prompt_lens = torch.tensor(prompt_lens, device=dev)  # (B_ids,)
            B_ids, L_max_ids = full_ids.size()

            # attention mask for [prompt][suffix] region
            full_lengths = prompt_lens + Ls
            arange_ids = torch.arange(L_max_ids, device=dev).unsqueeze(0)
            attn_ids = (arange_ids < full_lengths.unsqueeze(1)).long()

            from torch import amp as _amp_mod  # to avoid ambiguity if needed

            with _amp_mod.autocast("cuda", dtype=amp_dtype):
                outputs_suffix = model(
                    input_ids=full_ids,
                    attention_mask=attn_ids,
                    use_cache=False,
                    output_attentions=False,
                    output_hidden_states=False,
                    return_dict=True,
                )

            logits_suffix = outputs_suffix.logits                 # (B_ids, L_max_ids, V)
            log_probs_suffix = F.log_softmax(logits_suffix, dim=-1)

            # positions of suffix tokens in full_ids and corresponding logits
            positions = prompt_lens.unsqueeze(1) + torch.arange(Ls, device=dev).unsqueeze(0)  # (B_ids, Ls)
            positions_logits = (positions - 1).clamp(min=0, max=L_max_ids - 1)                # (B_ids, Ls)

            token_ids_expand = suffix_token_ids.unsqueeze(0).expand(B_ids, -1)                # (B_ids, Ls)

            # FULL model-assigned suffix distributions over vocab at each suffix position
            # pick the distribution at the timestep where each suffix token is predicted
            V = logits_suffix.size(-1)
            model_suffix_log_probs = log_probs_suffix.gather(
                dim=1,
                index=positions_logits.unsqueeze(-1).expand(-1, -1, V),
            )  # (B_ids, Ls, V)

            model_suffix_probs = model_suffix_log_probs.exp()        # (B_ids, Ls, V)
            # average across batch -> (Ls, V), one distribution per suffix position
            model_suffix_probs_mean = model_suffix_probs.mean(dim=0)  # (Ls, V)
            # stash on CPU for later feature extraction
            model_suffix_probs_cpu = model_suffix_probs_mean.detach().cpu()

            # log p(s_j | prompt + s_<j) for each batch, each suffix position
            suffix_token_logprobs = log_probs_suffix[
                torch.arange(B_ids, device=dev).unsqueeze(1),
                positions_logits,
                token_ids_expand,
            ]  # (B_ids, Ls)


            # schedule weights along suffix positions
            nt_sched = min(Ls, nt)          # "nt": first nt suffix positions
            weights = torch.ones(Ls, device=dev)  # default weight = 1.0
            if nt_sched > 0:
                # lower values at the start, gradually increasing up to 1.0
                weight_schedule = torch.linspace(0.3, 1.0, steps=nt_sched, device=dev)
                weights[:nt_sched] = weight_schedule

            # apply weights and average over batch
            weighted_logprobs = (suffix_token_logprobs * weights.unsqueeze(0)).sum(dim=-1)  # (B_ids,)
            suffix_joint_logprob = weighted_logprobs.mean()                                 # scalar

            # negative-signed multiplier: encourages higher joint suffix log-prob
            suffix_logprob_weight = 0.01
            suffix_logprob_term = -suffix_logprob_weight * suffix_joint_logprob

    # include the suffix joint log-prob term in the loss
    total_loss = base_loss + suffix_logprob_term

    # ---------- Derive characteristics (forward-only, store tensors on CPU) ----------
    characteristics = {}

    with torch.no_grad():
        eps = 1e-12

        def _stat(t: torch.Tensor):
            # 1D tensor stats
            t = t.detach()
            if t.numel() == 0:
                return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0}
            mean = t.mean().item()
            std = t.std(unbiased=False).item() if t.numel() > 1 else 0.0
            return {
                "mean": mean,
                "std": std,
                "min": t.min().item(),
                "max": t.max().item(),
            }

        # ----- suffix_logits features -----
        suffix_logits_f = suffix_logits.detach()            # (Ls, V)
        logits_flat = suffix_logits_f.view(-1)
        suffix_logits_cpu = suffix_logits_f.cpu()

        suffix_logits_features = {
            "tensor": suffix_logits_cpu,                    # (Ls, V) on CPU
            "flat_stats": _stat(logits_flat),
            "flat_norm": logits_flat.norm().item(),
        }

        # ----- suffix_probs features (row-wise + spectral) -----
        suffix_probs_f = suffix_probs_fp32.detach()         # (Ls, V)
        suffix_probs_cpu = suffix_probs_f.cpu()

        logp_suffix = (suffix_probs_f + eps).log()          # (Ls, V)
        row_entropy = -(suffix_probs_f * logp_suffix).sum(dim=-1)       # (Ls,)
        row_maxprob, row_maxidx = suffix_probs_f.max(dim=-1)            # (Ls,)
        row_l2 = suffix_probs_f.norm(dim=-1)                            # (Ls,)
        row_gini = 1.0 - (suffix_probs_f ** 2).sum(dim=-1)              # (Ls,)
        top2p, top2idx = suffix_probs_f.topk(2, dim=-1)
        row_margin = top2p[:, 0] - top2p[:, 1]                          # (Ls,)

        suffix_probs_features = {
            "tensor": suffix_probs_cpu,                      # (Ls, V) on CPU
            "row_entropy": {
                "tensor": row_entropy.cpu(),                 # (Ls,)
                "stats": _stat(row_entropy),
            },
            "row_max_prob": {
                "tensor": row_maxprob.cpu(),                 # (Ls,)
                "stats": _stat(row_maxprob),
            },
            "row_max_idx": row_maxidx.cpu(),                 # (Ls,)
            "row_l2": {
                "tensor": row_l2.cpu(),                      # (Ls,)
                "stats": _stat(row_l2),
            },
            "row_gini": {
                "tensor": row_gini.cpu(),                    # (Ls,)
                "stats": _stat(row_gini),
            },
            "row_margin": {
                "tensor": row_margin.cpu(),                  # (Ls,)
                "stats": _stat(row_margin),
            },
            "row_top2_probs": top2p.cpu(),                   # (Ls, 2)
            "row_top2_idx": top2idx.cpu(),                   # (Ls, 2)
        }

        # SVD of suffix_probs for low-rank structure (Ls x V)
        try:
            sv = torch.linalg.svdvals(suffix_probs_cpu)      # (min(Ls, V),)
            sv_sorted = torch.sort(sv, descending=True).values
            sv_sum = sv_sorted.sum().item()
            if sv_sum > 0:
                sigma1_ratio = (sv_sorted[0] / sv_sum).item()
            else:
                sigma1_ratio = 0.0
            sigma1_over_sigma2 = (
                (sv_sorted[0] / sv_sorted[1]).item()
                if sv_sorted.numel() > 1 and sv_sorted[1].abs() > 0
                else None
            )
            effective_rank = float(
                torch.exp(
                    -((sv_sorted / sv_sum) ** 2 * (sv_sorted / sv_sum).log()).sum()
                ).item()
            ) if sv_sum > 0 else 0.0
            suffix_probs_features["sv"] = {
                "singular_values": sv_sorted.cpu(),          # full spectrum
                "sigma1_ratio": sigma1_ratio,
                "sigma1_over_sigma2": sigma1_over_sigma2,
                "effective_rank_proxy": effective_rank,
            }
        except RuntimeError:
            suffix_probs_features["sv"] = {
                "singular_values": torch.empty(0),
                "sigma1_ratio": 0.0,
                "sigma1_over_sigma2": None,
                "effective_rank_proxy": 0.0,
            }

        # ----- model-assigned suffix probs (given prompts) -----
        # model_suffix_probs_cpu was computed in the suffix-logprob block; it has shape (Ls, V)
        suffix_probs_model_features = None
        if model_suffix_probs_cpu is not None:
            suffix_probs_model_f = model_suffix_probs_cpu.to(dtype=torch.float32)  # (Ls, V)

            logp_suffix_model = (suffix_probs_model_f + eps).log()
            row_entropy_model = -(suffix_probs_model_f * logp_suffix_model).sum(dim=-1)
            row_maxprob_model, row_maxidx_model = suffix_probs_model_f.max(dim=-1)
            row_l2_model = suffix_probs_model_f.norm(dim=-1)
            row_gini_model = 1.0 - (suffix_probs_model_f ** 2).sum(dim=-1)
            top2p_model, top2idx_model = suffix_probs_model_f.topk(2, dim=-1)
            row_margin_model = top2p_model[:, 0] - top2p_model[:, 1]

            suffix_probs_model_features = {
                "tensor": model_suffix_probs_cpu,                      # (Ls, V)
                "row_entropy": {
                    "tensor": row_entropy_model.cpu(),
                    "stats": _stat(row_entropy_model),
                },
                "row_max_prob": {
                    "tensor": row_maxprob_model.cpu(),
                    "stats": _stat(row_maxprob_model),
                },
                "row_max_idx": row_maxidx_model.cpu(),
                "row_l2": {
                    "tensor": row_l2_model.cpu(),
                    "stats": _stat(row_l2_model),
                },
                "row_gini": {
                    "tensor": row_gini_model.cpu(),
                    "stats": _stat(row_gini_model),
                },
                "row_margin": {
                    "tensor": row_margin_model.cpu(),
                    "stats": _stat(row_margin_model),
                },
                "row_top2_probs": top2p_model.cpu(),
                "row_top2_idx": top2idx_model.cpu(),
            }

            # SVD of model-assigned suffix probs for low-rank structure
            try:
                sv_m = torch.linalg.svdvals(suffix_probs_model_f)      # (min(Ls, V),)
                sv_m_sorted = torch.sort(sv_m, descending=True).values
                sv_m_sum = sv_m_sorted.sum().item()
                if sv_m_sum > 0:
                    sigma1_ratio_m = (sv_m_sorted[0] / sv_m_sum).item()
                else:
                    sigma1_ratio_m = 0.0
                sigma1_over_sigma2_m = (
                    (sv_m_sorted[0] / sv_m_sorted[1]).item()
                    if sv_m_sorted.numel() > 1 and sv_m_sorted[1].abs() > 0
                    else None
                )
                effective_rank_m = float(
                    torch.exp(
                        -((sv_m_sorted / sv_m_sum) ** 2 * (sv_m_sorted / sv_m_sum).log()).sum()
                    ).item()
                ) if sv_m_sum > 0 else 0.0
                suffix_probs_model_features["sv"] = {
                    "singular_values": sv_m_sorted.cpu(),
                    "sigma1_ratio": sigma1_ratio_m,
                    "sigma1_over_sigma2": sigma1_over_sigma2_m,
                    "effective_rank_proxy": effective_rank_m,
                }
            except RuntimeError:
                suffix_probs_model_features["sv"] = {
                    "singular_values": torch.empty(0),
                    "sigma1_ratio": 0.0,
                    "sigma1_over_sigma2": None,
                    "effective_rank_proxy": 0.0,
                }

            # attach to main suffix_probs_features dict
            suffix_probs_features["model_assigned"] = suffix_probs_model_features

        # ----- suffix_embs features -----
        E_f = E.detach().float()                            # (V, E_dim)
        suffix_embs_f = suffix_embs.detach().float()        # (Ls, E_dim)
        suffix_embs_cpu = suffix_embs_f.cpu()

        emb_row_norm = suffix_embs_f.norm(dim=-1)           # (Ls,)
        emb_row_norm_stats = _stat(emb_row_norm)

        # pairwise cosine between positions
        if Ls > 1:
            S_normed = F.normalize(suffix_embs_f, dim=-1)
            cos_mat = S_normed @ S_normed.T                 # (Ls, Ls)
            cos_vals = cos_mat[~torch.eye(Ls, dtype=torch.bool, device=cos_mat.device)]
            cos_stats = _stat(cos_vals)
            cos_mat_cpu = cos_mat.cpu()
        else:
            cos_stats = {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0}
            cos_mat_cpu = torch.eye(Ls)

        # distance to nearest vocab embedding for each suffix position
        d_min_list = []
        nearest_idx_list = []
        for i in range(Ls):
            diff_i = E_f - suffix_embs_f[i].unsqueeze(0)    # (V, E_dim)
            dists_i = diff_i.pow(2).sum(dim=-1).sqrt()      # (V,)
            dmin, idxmin = dists_i.min(dim=-1)
            d_min_list.append(dmin)
            nearest_idx_list.append(idxmin)
        d_min = torch.stack(d_min_list, dim=0)              # (Ls,)
        nearest_idx = torch.stack(nearest_idx_list, dim=0)  # (Ls,)

        # covariance in embedding space via Ls x Ls matrix (spectral)
        if Ls > 1:
            S_c = suffix_embs_f - suffix_embs_f.mean(dim=0, keepdim=True)  # (Ls, d)
            M = (S_c @ S_c.T) / (Ls - 1)                                   # (Ls, Ls)
            try:
                ev = torch.linalg.eigvalsh(M.cpu())                        # (Ls,) ascending
                ev_sorted = torch.sort(ev, descending=True).values
                ev_sum = ev_sorted.clamp_min(0).sum().item()
                if ev_sum > 0:
                    emb_lambda1_ratio = (ev_sorted[0] / ev_sum).item()
                else:
                    emb_lambda1_ratio = 0.0
                cov_eigs_cpu = ev_sorted.cpu()
            except RuntimeError:
                emb_lambda1_ratio = 0.0
                cov_eigs_cpu = torch.empty(0)
        else:
            emb_lambda1_ratio = 0.0
            cov_eigs_cpu = torch.empty(0)

        suffix_embs_features = {
            "tensor": suffix_embs_cpu,                          # (Ls, E_dim) on CPU
            "row_norm": {
                "tensor": emb_row_norm.cpu(),                  # (Ls,)
                "stats": emb_row_norm_stats,
            },
            "pairwise_cos": {
                "matrix": cos_mat_cpu,                         # (Ls, Ls)
                "stats": cos_stats,
            },
            "nearest_token_dist": {
                "tensor": d_min.cpu(),                         # (Ls,)
                "stats": _stat(d_min),
            },
            "nearest_token_idx": nearest_idx.cpu(),            # (Ls,)
            "cov_eigvals": cov_eigs_cpu,                       # full eigen spectrum (Ls,) or empty
            "cov_lambda1_ratio": emb_lambda1_ratio,
        }

        # ----- output logits/probs features (from logits_last) -----
        logits_last_f = logits_last.detach().float()          # (B, n_tokens, V)
        B_cur, T_cur, V_cur = logits_last_f.shape

        if B_cur > 0 and T_cur > 0:
            L_out = logits_last_f.view(B_cur * T_cur, V_cur)  # (BT, V)
            P_out = F.softmax(L_out, dim=-1)
            logP_out = (P_out + eps).log()

            H_out = -(P_out * logP_out).sum(dim=-1)           # (BT,)
            maxP_out, maxidx_out = P_out.max(dim=-1)          # (BT,)
            top2_out, top2idx_out = P_out.topk(2, dim=-1)
            margin_out = top2_out[:, 0] - top2_out[:, 1]      # (BT,)
            logits_norm_out = L_out.norm(dim=-1)              # (BT,)
            probs_l2_out = P_out.norm(dim=-1)                 # (BT,)

            # Store tensors on CPU
            H_out_cpu = H_out.cpu()
            maxP_out_cpu = maxP_out.cpu()
            margin_out_cpu = margin_out.cpu()
            logits_norm_out_cpu = logits_norm_out.cpu()
            probs_l2_out_cpu = probs_l2_out.cpu()

            # spectral structure over vocab for output probs
            try:
                sv_out = torch.linalg.svdvals(P_out.cpu())     # (min(BT, V),)
                sv_out_sorted = torch.sort(sv_out, descending=True).values
                sv_out_sum = sv_out_sorted.sum().item()
                if sv_out_sum > 0:
                    out_sigma1_ratio = (sv_out_sorted[0] / sv_out_sum).item()
                else:
                    out_sigma1_ratio = 0.0
                out_sigma1_over_sigma2 = (
                    (sv_out_sorted[0] / sv_out_sorted[1]).item()
                    if sv_out_sorted.numel() > 1 and sv_out_sorted[1].abs() > 0
                    else None
                )
                out_effective_rank = float(
                    torch.exp(
                        -((sv_out_sorted / sv_out_sum) ** 2 * (sv_out_sorted / sv_out_sum).log()).sum()
                    ).item()
                ) if sv_out_sum > 0 else 0.0
                sv_out_cpu = sv_out_sorted.cpu()
            except RuntimeError:
                sv_out_cpu = torch.empty(0)
                out_sigma1_ratio = 0.0
                out_sigma1_over_sigma2 = None
                out_effective_rank = 0.0

            output_features = {
                "logits_last": logits_last_f.cpu(),           # (B, n_tokens, V)
                "row_entropy": {
                    "tensor": H_out_cpu,
                    "stats": _stat(H_out),
                },
                "row_max_prob": {
                    "tensor": maxP_out_cpu,
                    "stats": _stat(maxP_out),
                },
                "row_margin": {
                    "tensor": margin_out_cpu,
                    "stats": _stat(margin_out),
                },
                "row_logits_norm": {
                    "tensor": logits_norm_out_cpu,
                    "stats": _stat(logits_norm_out),
                },
                "row_probs_l2": {
                    "tensor": probs_l2_out_cpu,
                    "stats": _stat(probs_l2_out),
                },
                "sv": {
                    "singular_values": sv_out_cpu,
                    "sigma1_ratio": out_sigma1_ratio,
                    "sigma1_over_sigma2": out_sigma1_over_sigma2,
                    "effective_rank_proxy": out_effective_rank,
                },
            }
        else:
            output_features = {}

        # assemble all characteristics
        characteristics = {
            "batch_shape": {
                "B": int(B),
                "Ls": int(Ls),
                "n_tokens": int(n_tokens),
                "vocab_size": int(V),
                "emb_dim": int(E_dim),
            },
            "loss_components": {
                "entropy_main": float(ent.item()),
                "entropy_suffix_soft": float(ent_soft.item()),
                "suffix_joint_logprob_weighted": float(suffix_joint_logprob.item()),
                "suffix_logprob_term": float(suffix_logprob_term.item()),
                "total_loss": float(total_loss.item()),
            },
            "suffix_logits": suffix_logits_features,
            "suffix_probs": suffix_probs_features,
            "suffix_embs": suffix_embs_features,
            "output": output_features,
        }

    # ---- GPU cleanup: drop large intermediates on device ----
    if dev.type == "cuda":
        # clear gradients stored on model parameters (suffix_z.grad is kept)
        for p in model.parameters():
            p.grad = None

        try:
            del base_embs       # list of (Li+Ls, E_dim) on dev
            del base            # (B, max_len, E_dim)
            del base_lens       # (B,)
            del arange          # (1, max_len)
            del base_mask       # (B, max_len)

            # may exist only if T > 0; guarded by try/except
            del added           # (B, T, E_dim)
            del gen_mask        # (B, T)

            # final forward inputs and outputs on GPU
            del final_emb       # (B, max_len+T, E_dim)
            del final_mask      # (B, max_len+T)
            del logits_last     # (B, n_tokens, V)
            del suffix_probs    # (Ls, V) on dev
        except NameError:
            pass

        gc.collect()
        torch.cuda.empty_cache()

    # Return both the scalar loss and the rich characteristics dict
    return total_loss, characteristics

### Save Suffix Embedds

In [36]:
import os
import torch

def save_suffix_embeddings(suffix_z, epoch, round_idx):
    """
    Save optimized suffix embeddings for tracking exploration across rounds/epochs.
    """
    save_dir = "/kaggle/working/suffix_saves"
    os.makedirs(save_dir, exist_ok=True)

    file_path = os.path.join(save_dir, f"suffix_r{round_idx}_e{epoch}.pt")
    torch.save(suffix_z.detach().cpu(), file_path)

    print(f"Saved suffix for round {round_idx}, epoch {epoch} → {file_path}")

### Optimization

### Projection + Diagnosis

In [37]:
import gc
import torch
import torch.nn.functional as F

def project_suffix_to_tokens_and_diagnostics(
    suffix_z,
    emb_layer,
    tokenizer,
    model,
    prompt_input_ids,
    num_gen_tokens: int = 10,
    print_flag = False
):
    """
    suffix_z: (Ls, V) - pre-softmax logits over vocab for each suffix position
    emb_layer: model.get_input_embeddings()
    tokenizer: HF tokenizer
    model: HF causal LM (already on device)
    prompt_input_ids: 1D LongTensor or list[int], tokenized prompt
    num_gen_tokens: m, number of tokens to generate after prompt+suffix

    Returns a dict with ids and log-probs for suffix and generated tokens.
    """
    model.eval()

    with torch.no_grad():
        # -----------------------------
        # 1) Interpret pre-softmax suffix_z as logits over vocab
        #    and compute diagnostics on the resulting soft one-hots
        # -----------------------------
        dev = emb_layer.weight.device
        E = emb_layer.weight        # (V, E_dim)
        V, E_dim = E.shape

        # suffix_z: (Ls, V) logits
        suffix_logits = suffix_z.to(dev, dtype=torch.float32)  # (Ls, V)
        Ls, V_logits = suffix_logits.shape
        assert V_logits == V, f"suffix_z second dim ({V_logits}) must match vocab size ({V})."

        # Soft one-hot over vocab
        suffix_probs = F.softmax(suffix_logits, dim=-1)        # (Ls, V)

        # Diagnostics: max vocab prob per suffix position
        max_probs_per_pos, best_token_ids = suffix_probs.max(dim=-1)  # (Ls,), (Ls,)

        if print_flag:
            print("Per-position max vocab probabilities for suffix distributions:")
            print(f"  min max-p:  {max_probs_per_pos.min().item():.6f}")
            print(f"  max max-p:  {max_probs_per_pos.max().item():.6f}")
            print(f"  mean max-p: {max_probs_per_pos.mean().item():.66f}")

        # Discrete suffix tokens from argmax
        suffix_token_ids = best_token_ids.cpu()  # (Ls,)
        suffix_tokens = tokenizer.convert_ids_to_tokens(suffix_token_ids.tolist())
        suffix_text = tokenizer.decode(
            suffix_token_ids.tolist(),
            skip_special_tokens=False
        )

        if print_flag:
            print("\nProjected discrete suffix token IDs (argmax over soft one-hot):", suffix_token_ids.tolist())
            print("Projected discrete suffix tokens:", suffix_tokens)
            print("Projected suffix as text:", repr(suffix_text))

        # -----------------------------
        # 2) Build full input: prompt + suffix (discrete ids)
        # -----------------------------
        if isinstance(prompt_input_ids, torch.Tensor):
            prompt_ids = prompt_input_ids.to(dev).view(-1)
        else:
            prompt_ids = torch.tensor(prompt_input_ids, device=dev, dtype=torch.long)

        prompt_len = prompt_ids.size(0)
        suffix_ids_dev = suffix_token_ids.to(dev)

        full_input_ids = torch.cat([prompt_ids, suffix_ids_dev], dim=0)  # (T + Ls,)
        full_input_ids_batch = full_input_ids.unsqueeze(0)               # (1, T + Ls)

        # For reporting:
        prompt_text = tokenizer.decode(prompt_ids.tolist(), skip_special_tokens=False)
        
        if print_flag:
            print("\nPrompt text:", repr(prompt_text))

        # -----------------------------
        # 3) Probabilities for suffix tokens given the prompt
        #    p(s_i | prompt + s_<i)
        # -----------------------------
        outputs = model(input_ids=full_input_ids_batch)
        logits = outputs.logits  # (1, L_total, V)
        log_probs = logits.log_softmax(dim=-1)  # (1, L_total, V)

        suffix_token_logprobs = []
        
        if print_flag:
            print("\nSuffix token probabilities given prompt:")

        # HuggingFace causal LM: logits[:, i, :] predict token at position i+1
        for i, tok_id in enumerate(suffix_token_ids.tolist()):
            # Position of this suffix token in the full sequence
            pos = prompt_len + i  # index of token in full_input_ids

            if pos == 0:
                # Can't compute prob for very first token (no previous context)
                lp = float("nan")
                p = float("nan")
            else:
                lp_tensor = log_probs[0, pos - 1, tok_id]  # log p(token at pos)
                lp = lp_tensor.item()
                p = lp_tensor.exp().item()

            suffix_token_logprobs.append(lp)
            tok_str = suffix_tokens[i]
            if print_flag:
                print(
                    f"  suffix pos {i} (abs_pos={pos}, id={tok_id}, token={repr(tok_str)}): "
                    f"p = {p:.6e}, log p = {lp:.6f}"
                )

        suffix_token_logprobs = torch.tensor(suffix_token_logprobs, dtype=torch.float32)
        suffix_joint_logprob = torch.nan_to_num(suffix_token_logprobs).sum().item()
        
        if print_flag:
            print(f"\nJoint log-prob of suffix given prompt: {suffix_joint_logprob:.6f}")

        # -----------------------------
        # 4) Auto-regressively generate num_gen_tokens more tokens
        #    and record probabilities of each generated token.
        # -----------------------------
        current_input_ids = full_input_ids.clone()  # (T + Ls,)

        generated_token_ids = []
        generated_token_logprobs = []
        generated_tokens_str = []

        for step in range(num_gen_tokens):
            inp_batch = current_input_ids.unsqueeze(0)  # (1, L_cur)
            out = model(input_ids=inp_batch)
            next_logits = out.logits[:, -1, :]          # (1, V)
            next_log_probs = next_logits.log_softmax(dim=-1)  # (1, V)

            # Greedy: pick argmax
            next_log_prob_val, next_token_id = next_log_probs.squeeze(0).max(dim=-1)
            next_id = next_token_id.item()
            lp = next_log_prob_val.item()
            p = next_log_prob_val.exp().item()

            generated_token_ids.append(next_id)
            generated_token_logprobs.append(lp)

            tok_str = tokenizer.convert_ids_to_tokens([next_id])[0]
            generated_tokens_str.append(tok_str)

            # Append to context
            current_input_ids = torch.cat(
                [current_input_ids, next_token_id.unsqueeze(0)],
                dim=0
            )
            
            if print_flag:
                print(
                    f"Generated token {step} (abs_pos={current_input_ids.size(0)-1}, "
                    f"id={next_id}, token={repr(tok_str)}): p = {p:.6e}, log p = {lp:.6f}"
                )

        generated_token_ids = torch.tensor(generated_token_ids, dtype=torch.long)
        generated_token_logprobs = torch.tensor(generated_token_logprobs, dtype=torch.float32)
        gen_joint_logprob = generated_token_logprobs.sum().item()

        if print_flag:
            print(f"\nJoint log-prob of generated tokens: {gen_joint_logprob:.6f}")

        # -----------------------------
        # 5) Decode full text: prompt + suffix + generated
        # -----------------------------
        full_with_gen_ids = current_input_ids  # (T + Ls + num_gen_tokens,)
        full_text = tokenizer.decode(full_with_gen_ids.tolist(), skip_special_tokens=False)
        gen_text = tokenizer.decode(generated_token_ids.tolist(), skip_special_tokens=False)

        if print_flag:
            print("\nGenerated continuation text:", repr(gen_text))
            print("\nFull text (prompt + suffix + generated):")
            print(repr(full_text))

        # -----------------------------
        # 6) Return structured info
        # -----------------------------
        result = {
            "suffix_token_ids": suffix_token_ids,                     # (Ls,) on CPU
            "suffix_tokens": suffix_tokens,
            "suffix_token_logprobs": suffix_token_logprobs,           # (Ls,)
            "suffix_joint_logprob": suffix_joint_logprob,
            "generated_token_ids": generated_token_ids,               # (num_gen_tokens,)
            "generated_tokens": generated_tokens_str,
            "generated_token_logprobs": generated_token_logprobs,     # (num_gen_tokens,)
            "generated_joint_logprob": gen_joint_logprob,
            "prompt_input_ids": prompt_ids.cpu(),
            "full_input_ids_with_suffix": full_input_ids.cpu(),
            "full_input_ids_with_suffix_and_generated": full_with_gen_ids.cpu(),
            "prompt_text": prompt_text,
            "suffix_text": suffix_text,
            "generated_text": gen_text,
            "full_text": full_text,
            "suffix_max_probs_per_pos": max_probs_per_pos.cpu(),      # (Ls,)
        }

    # ---- GPU cleanup (only dev tensors, keep CPU copies in `result`) ----
    for p in model.parameters():
        p.grad = None

    try:
        # suffix-related
        del E, suffix_logits, suffix_probs, best_token_ids, max_probs_per_pos

        # sequence tensors on dev (CPU clones are in `result`)
        del prompt_ids, suffix_ids_dev, full_input_ids, full_input_ids_batch
        del current_input_ids, full_with_gen_ids

        # forward-pass outputs on dev
        del outputs, logits, log_probs
        del out, inp_batch, next_logits, next_log_probs
        del next_token_id, next_log_prob_val
    except NameError:
        pass

    if dev.type == "cuda":
        gc.collect()
        torch.cuda.empty_cache()

    return result

In [44]:
import gc
import random
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

def get_suffix_init(base_seed, round_num, suffix_len, V, dev='cuda'):
    # Set all random seeds
    round_seed = base_seed + round_num
    
    # For full reproducibility
    torch.manual_seed(round_seed)
    if dev == 'cuda':
        torch.cuda.manual_seed_all(round_seed)
    
    # Enable deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Create generator with explicit seed
    g = torch.Generator(device=dev)
    g.manual_seed(round_seed)
    
    # Generate tensor
    suffix_init = torch.randn(suffix_len, V, device=dev, generator=g)
    
    return suffix_init

def optimize_suffix_embeddings(
    model,
    tokenizer,
    dataloader_args,
    suffix_len=10,
    n_tokens_rollout=10,
    epochs=5,
    init_lr=1e-2,
    rounds=10,
    amp_dtype=torch.float16,
    print_interval=10,
    base_seed: int = 1234,  # <- same across different runs (clean / poison) for same inits
):
    """
    Optimize a shared suffix over vocab via pre-softmax logits.

    Now:
      suffix_z: (Ls, V) pre-softmax logits over vocab per suffix position
      emb_layer: embedding layer (for projection)

    Seeding behavior:
      - For each round r, we derive a seed = base_seed + r.
      - Using a local torch.Generator makes the suffix_z init:
          * the same for a given r across different runs (same base_seed),
          * different across rounds within one run.
    """
    results = {}

    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False

    emb_layer = model.get_input_embeddings()
    dev = emb_layer.weight.device
    V = emb_layer.weight.size(0)   # vocab size

    best_suffix_z = None
    best_loss = float("inf")

    characteristics = {}

    for roun in range(rounds):
        print(f"\n=== Optimization Round {roun+1}/{rounds} ===")

        # --- Deterministic per-round initialization ---
        if base_seed is not None:
            suffix_init = get_suffix_init(base_seed, roun, suffix_len, V, dev=dev)
        else:
            # Fallback to global RNG if you ever want non-deterministic behavior
            suffix_init = torch.randn(suffix_len, V, device=dev)

        suffix_z = torch.nn.Parameter(0.01 * suffix_init)
        print(f"dtype of suffix_z: {suffix_z.dtype}, shape: {tuple(suffix_z.shape)}")

        optimizer = AdamW([suffix_z], lr=init_lr)
        scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # every epoch: lr *= 0.5

        for epoch in range(epochs):
            print(f"\n[Epoch {epoch+1}/{epochs}]")
            running_loss = 0.0

            dataloader = load_prompts_unpadded(tokenizer, dataloader_args, seed=base_seed+epoch)

            for batch_count, batch in enumerate(dataloader):
                optimizer.zero_grad(set_to_none=True)

                loss, characteristic = compute_loss_for_suffix(
                    model,
                    emb_layer,
                    batch,
                    suffix_z,                    # (Ls, V) pre-softmax logits
                    n_tokens=n_tokens_rollout,
                    amp_dtype=amp_dtype,
                    cos_reg_weight=0.2 * (roun + 1) / rounds,
                    E_norm_cpu=None,            # kept for API compatibility
                    top_k=epochs - epoch,       # unused in current loss, but harmless
                )

                characteristics[f"r{roun}_e{epoch}_b{batch_count}"] = characteristic

                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                if batch_count % 5 == 0 and batch_count > 0:
                    avg = running_loss / batch_count
                    print(
                        f"  batch {batch_count} out of {len(dataloader)}, "
                        f"avg loss: {avg:.4f}",
                        end="\r",
                    )

            # Convert the entire DataLoader to a list for sampling a random example
            dataset = list(dataloader)
            if len(dataset) > 0:
                random_index = random.randint(0, len(dataset) - 1)
                random_sample = dataset[random_index]
                print(f"\nSample Number {random_index}")

                prompt_input_ids = random_sample["input_ids"][0]

                results[f"r{roun}_e{epoch}"] = project_suffix_to_tokens_and_diagnostics(
                    suffix_z, emb_layer, tokenizer, model, prompt_input_ids,
                    print_flag=epoch % print_interval == 0,
                )

            # batch_count is 0-based; number of batches = batch_count + 1 if loop ran
            num_batches = max(1, batch_count + 1)
            epoch_avg = running_loss / num_batches
            print(f"Epoch {epoch+1} mean loss: {epoch_avg:.4f}")

            scheduler.step()
            save_suffix_embeddings(suffix_z, epoch, roun)

        if epoch_avg < best_loss:
            best_loss = epoch_avg
            best_suffix_z = suffix_z.detach().clone()

        print(
            f"\nOptimization finished for round {roun+1}. "
            f"Final suffix logits (detached): {best_suffix_z}"
        )

    return best_suffix_z, emb_layer, characteristics

In [45]:
# 1. Dataloader with [prompt] only, no suffix, unpadded
args = {
    "data_dir": "/kaggle/working/data",
    "max_length": 512,
    "batch_size": 4,
    "sample_size": 8,
}

# 2. Optimize continuous suffix
suffix_len = 5
n_tokens_rollout = 10
epochs = 5
init_lr = 1e-0
rounds = 2
print_interval = 10

suffix_z, emb_layer, characteristics = optimize_suffix_embeddings(
    model,
    tokenizer,
    args,
    suffix_len=suffix_len,
    n_tokens_rollout=n_tokens_rollout,
    epochs=epochs,
    init_lr=init_lr,
    rounds=rounds,
    amp_dtype=torch.float16,
    print_interval=print_interval,
)

# # 3. Project to discrete tokens + diagnostics
# suffix_token_ids = project_suffix_to_tokens_and_diagnostics(
#     suffix_z,
#     emb_layer,
#     tokenizer,
# )


=== Optimization Round 1/2 ===
dtype of suffix_z: torch.float32, shape: (5, 32000)

[Epoch 1/5]

Sample Number 0
Per-position max vocab probabilities for suffix distributions:
  min max-p:  0.000088
  max max-p:  0.000118
  mean max-p: 0.000097821721283253282308578491210937500000000000000000000000000000

Projected discrete suffix token IDs (argmax over soft one-hot): [10931, 13263, 95, 8768, 12275]
Projected discrete suffix tokens: ['$)', '▁Externa', '<0x5C>', 'types', 'мент']
Projected suffix as text: '$) Externa\\typesмент'

Prompt text: '<s> Write a poem with a total of 4 lines.'

Suffix token probabilities given prompt:
  suffix pos 0 (abs_pos=12, id=10931, token='$)'): p = 0.000000e+00, log p = -17.937500
  suffix pos 1 (abs_pos=13, id=13263, token='▁Externa'): p = 0.000000e+00, log p = -17.515625
  suffix pos 2 (abs_pos=14, id=95, token='<0x5C>'): p = 0.000000e+00, log p = -18.375000
  suffix pos 3 (abs_pos=15, id=8768, token='types'): p = 6.616116e-06, log p = -11.929688
  suff

In [46]:
import os
from pathlib import Path
from datetime import datetime
import pickle

from kaggle_secrets import UserSecretsClient
from huggingface_hub import HfApi, create_repo


def upload_to_hf(
    kaggle_path: str,
    repo_id: str,
    repo_target_folder: str,
    hf_token: str,
    repo_type: str = "dataset",           # "dataset" | "model" | "space"
    create_if_missing: bool = True,
):
    """
    Upload a file/folder from Kaggle to a Hugging Face Hub repo.

    Args:
        kaggle_path: Local path in Kaggle (file or folder).
        repo_id: HF repo id, e.g. "pouyatoroghi/Backdoor".
        repo_target_folder: Folder path inside the repo, e.g. "Data/Run_...".
        hf_token: HF write token.
        repo_type: "model", "dataset", or "space".
        create_if_missing: If True, create repo if it doesn't exist.
    """
    kaggle_path = Path(kaggle_path)

    if not kaggle_path.exists():
        raise FileNotFoundError(f"{kaggle_path} does not exist")

    api = HfApi(token=hf_token)

    if create_if_missing:
        # Will be a no-op if it already exists
        create_repo(
            repo_id=repo_id,
            repo_type=repo_type,
            token=hf_token,
            exist_ok=True,
        )

    try:
        if kaggle_path.is_dir():
            # Upload whole folder under repo_target_folder/
            api.upload_folder(
                folder_path=str(kaggle_path),
                repo_id=repo_id,
                repo_type=repo_type,
                path_in_repo=repo_target_folder,
            )
        else:
            # Upload single file to repo_target_folder/file_name
            file_name = kaggle_path.name
            path_in_repo = str(Path(repo_target_folder) / file_name) if repo_target_folder else file_name

            api.upload_file(
                path_or_fileobj=str(kaggle_path),
                path_in_repo=path_in_repo,
                repo_id=repo_id,
                repo_type=repo_type,
            )

        print(f"Successfully uploaded {kaggle_path} to hf:{repo_id}/{repo_target_folder}")

    except Exception as e:
        print(f"Error uploading to Hugging Face Hub: {e}")

# Save to file
with open("/kaggle/working/characteristics.pkl", "wb") as f:
    pickle.dump(characteristics, f)

# Get current time and format it as a folder name
current_time = datetime.now()
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")  # Format: YYYY-MM-DD_HH-MM-SS

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("General HF Token")

upload_to_hf(
    kaggle_path="/kaggle/working/characteristics.pkl",
    repo_id="Pouyatr/Backdoor",  # change if your HF repo id is different
    repo_target_folder=f"Data/Run_{folder_name}_{output_filename}",
    hf_token=hf_token,
    repo_type="dataset",              # or "model" if you prefer
)

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Successfully uploaded /kaggle/working/characteristics.pkl to hf:Pouyatr/Backdoor/Data/Run_2025-12-03_11-00-36_model0.tar.gz


In [47]:
print(characteristics['r0_e4_b0'])

{'batch_shape': {'B': 4, 'Ls': 5, 'n_tokens': 10, 'vocab_size': 32000, 'emb_dim': 4096}, 'loss_components': {'entropy_main': 1.43359375, 'entropy_suffix_soft': 7.976181983947754, 'suffix_joint_logprob_weighted': -66.5556640625, 'suffix_logprob_term': 0.6655566096305847, 'total_loss': 2.896768569946289}, 'suffix_logits': {'tensor': tensor([[-2.5463,  5.1794, -3.2622,  ..., -2.7528, -1.8304, -2.6667],
        [ 2.4215,  1.9158, -2.4121,  ..., -3.5512, -4.3562, -3.9058],
        [ 0.1727,  1.7960, -1.0120,  ..., -3.8991, -4.2989, -4.0606],
        [-0.8015,  1.7614,  2.2448,  ..., -3.8693, -4.4021, -4.1799],
        [-2.3635,  1.2677,  1.7089,  ..., -4.6243, -4.9064, -5.2696]]), 'flat_stats': {'mean': -0.6631486415863037, 'std': 2.8758699893951416, 'min': -6.722764015197754, 'max': 7.544327735900879}, 'flat_norm': 1180.5350341796875}, 'suffix_probs': {'tensor': tensor([[1.0131e-07, 2.2955e-04, 4.9516e-08,  ..., 8.2408e-08, 2.0729e-07,
         8.9822e-08],
        [9.9812e-06, 6.0195e-06,