In [1]:
import torch
import warnings
warnings.filterwarnings("ignore")
def check_memory(gpu_index: int = 0):
    if torch.cuda.is_available():
        free_memory = torch.cuda.get_device_properties(gpu_index).total_memory - torch.cuda.memory_allocated(gpu_index)
        total_memory = torch.cuda.get_device_properties(gpu_index).total_memory
        print(f"Free GPU Memory: {free_memory / 1024**3:.2f} GB")
        print(f"Total GPU Memory: {total_memory / 1024**3:.2f} GB")
    else:
        print("CUDA is not available.")

In [2]:
import gc
import torch
from typing import List, Union

def clear_memory(keep_vars: Union[List[str], None] = None, verbose: bool = True):
    """
    Clears memory while preserving specified variables.
    Still clears GPU memory for all CUDA objects, including kept variables.
    
    Args:
        keep_vars: List of variable names to preserve in memory (will still be cleared from GPU)
        verbose: Whether to print memory clearing information
    """
    if verbose:
        print("Starting memory clearing process...")
    
    # Convert keep_vars to set for faster lookups
    keep_set = set(keep_vars) if keep_vars else set()
    
    # First pass: Move kept CUDA variables to CPU
    if torch.cuda.is_available():
        for name, var in list(globals().items()):
            if name in keep_set and isinstance(var, torch.Tensor) and var.is_cuda:
                if verbose:
                    print(f"Moving kept tensor '{name}' to CPU")
                globals()[name] = var.cpu()
    
    # Clear Python garbage collector
    gc.collect()
    if verbose:
        print("Ran Python garbage collection")
    
    # Clear CUDA memory if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if verbose:
            print("Cleared CUDA cache")
            print(f"Current CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
            print(f"Current CUDA memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
    
    # Try to clear TensorFlow/Keras if available
    try:
        import tensorflow as tf
        tf.keras.backend.clear_session()
        if verbose:
            print("Cleared TensorFlow/Keras session")
    except ImportError:
        pass
    
    # Delete objects not in keep_vars
    for name, var in list(globals().items()):
        if not name.startswith('__') and name not in keep_set:
            if isinstance(var, (torch.Tensor, torch.nn.Module)):
                del globals()[name]
                if verbose:
                    print(f"Deleted torch object: {name}")
            elif isinstance(var, list) and var and isinstance(var[0], torch.Tensor):
                del globals()[name]
                if verbose:
                    print(f"Deleted list of torch tensors: {name}")
    
    # Final garbage collection
    gc.collect()
    
    if verbose:
        print("Memory clearing complete")

In [3]:
import os
import gdown

def download_file_from_google_drive(file_id, output_dir, output_filename, quiet=False):
    """
    Downloads a file from Google Drive given its file ID and saves it to the specified directory.
    
    Args:
        file_id (str): The Google Drive file ID (found in the file URL)
        output_dir (str): Directory where the file should be saved
        output_filename (str): Name of the output file
        quiet (bool): Whether to suppress gdown output (default: False)
    
    Returns:
        str: Path to the downloaded file if successful, None otherwise
    """
    # Create directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Full output path
    output_file = os.path.join(output_dir, output_filename)
    
    print("Downloading the file...")
    try:
        gdown.download(id=file_id, output=output_file, quiet=quiet, fuzzy=True)
    except Exception as e:
        print(f"Download failed: {str(e)}")
        return None
    
    # Verify download
    if os.path.exists(output_file):
        file_size = os.path.getsize(output_file) / (1024 * 1024)  # in MB
        print(f"Download successful! File saved to: {output_file}")
        print(f"File size: {file_size:.2f} MB")
        return output_file
    else:
        print("Download failed - file not found")
        return None

In [4]:
import os
import tarfile
from typing import List, Union

def extract_and_delete_tar_gz(file_path: str, delete_compressed: bool = True) -> bool:
    """
    Extracts a .tar.gz file and optionally deletes the compressed file.
    
    Args:
        file_path (str): Path to the .tar.gz file
        delete_compressed (bool): Whether to delete the compressed file after extraction (default: True)
    
    Returns:
        bool: True if extraction was successful, False otherwise
    """
    try:
        print(f"Extracting: {file_path}")
        with tarfile.open(file_path, 'r:gz') as tar:
            tar.extractall(path=os.path.dirname(file_path))
        
        if delete_compressed:
            os.remove(file_path)
            print(f"Deleted compressed file: {file_path}")
        return True
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return False

def process_directory(directory: str, recursive: bool = True, max_depth: Union[int, None] = None) -> int:
    """
    Processes a directory to find and extract .tar.gz files.
    
    Args:
        directory (str): Directory path to process
        recursive (bool): Whether to process subdirectories (default: True)
        max_depth (int|None): Maximum recursion depth (None for unlimited)
    
    Returns:
        int: Number of .tar.gz files processed
    """
    processed_count = 0
    current_depth = 0
    
    while True:
        found_tar_gz = False
        for root, dirs, files in os.walk(directory):
            # Calculate current depth
            rel_path = os.path.relpath(root, directory)
            current_depth = rel_path.count(os.sep) + 1 if rel_path != '.' else 0
            
            # Skip if beyond max depth
            if max_depth is not None and current_depth > max_depth:
                continue
                
            for file in files:
                if file.endswith('.tar.gz'):
                    file_path = os.path.join(root, file)
                    if extract_and_delete_tar_gz(file_path):
                        processed_count += 1
                        found_tar_gz = True
        
        # If not recursive or no more .tar.gz files found, exit
        if not recursive or not found_tar_gz:
            break
    
    return processed_count

def process_paths(paths: List[str], recursive: bool = True, max_depth: Union[int, None] = None) -> int:
    """
    Processes a list of paths (files or directories) to extract .tar.gz files.
    
    Args:
        paths (List[str]): List of file/directory paths to process
        recursive (bool): Whether to process directories recursively (default: True)
        max_depth (int|None): Maximum recursion depth for directories (None for unlimited)
    
    Returns:
        int: Total number of .tar.gz files processed
    """
    total_processed = 0
    
    for path in paths:
        if not os.path.exists(path):
            print(f"Warning: Path does not exist - {path}")
            continue
            
        if path.endswith('.tar.gz'):
            if extract_and_delete_tar_gz(path):
                total_processed += 1
        elif os.path.isdir(path):
            print(f"Processing directory: {path}")
            total_processed += process_directory(
                directory=path,
                recursive=recursive,
                max_depth=max_depth
            )
    
    print(f"Total .tar.gz files processed: {total_processed}")
    return total_processed

In [5]:
from os.path import join

import torch
import json
import os
import logging
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def load_model(model_filepath: str, torch_dtype:torch.dtype=torch.float16):
    """Load a model given a specific model_path.

    Args:
        model_filepath: str - Path to where the model is stored

    Returns:
        model, dict, str - Torch model + dictionary representation of the model + model class name
    """

    conf_filepath = os.path.join(model_filepath, 'reduced-config.json')
    logging.info("Loading config file from: {}".format(conf_filepath))
    with open(conf_filepath, 'r') as fh:
        round_config = json.load(fh)

    logging.info("Loading model from filepath: {}".format(model_filepath))
    # https://huggingface.co/docs/transformers/installation#offline-mode
    if round_config['use_lora']:
        base_model_filepath = os.path.join(model_filepath, 'base-model')
        logging.info("loading the base model (before LORA) from {}".format(base_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(base_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(round_config['model_architecture'], trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("loading the LORA adapter onto the base model from {}".format(fine_tuned_model_filepath))
        model.load_adapter(fine_tuned_model_filepath)
    else:
        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("Loading full fine tune checkpoint into cpu from {}".format(fine_tuned_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

    model.eval()

    tokenizer_filepath = os.path.join(model_filepath, 'tokenizer')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_filepath)

    return model, tokenizer

In [6]:
import os, json, logging, torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def _two_gpu_max_memory(headroom_gb=2):
    """
    Reserve headroom so HF sharding MUST split across both 16GB T4s.
    """
    if not torch.cuda.is_available():
        return None
    n = torch.cuda.device_count()
    cap = f"{16 - headroom_gb}GiB"  # e.g., "14GiB"
    return {i: cap for i in range(n)}

def _common_from_pretrained_kwargs():
    """
    Settings that reduce both CPU and GPU peak memory and use a lean attention impl.
    """
    kw = dict(
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.float16,     # T4 → FP16
        low_cpu_mem_usage=True,        # streaming load
        offload_state_dict=True,       # avoid CPU spikes
        attn_implementation="sdpa",    # available by default on Kaggle
    )
    mm = _two_gpu_max_memory(headroom_gb=2)
    if mm and torch.cuda.device_count() > 1:
        kw["device_map"] = "auto"
        kw["max_memory"] = mm
        # Optional if host RAM is tight:
        # kw["offload_folder"] = "/kaggle/working/offload"
    else:
        kw["device_map"] = {"": 0}
    return kw

def load_model_and_tokenizer(model_dir: str, merge_lora: bool = True):
    """
    Robust loader for full fine-tunes or LoRA adapters stored under `model_dir`.
    Expects:
      - reduced-config.json with {"use_lora": <bool>, ...}
      - For LoRA: base-model/, fine-tuned-model/
      - For full FT: fine-tuned-model/
      - tokenizer/ with tokenizer files
    Returns: (model, tokenizer)
    """
    conf_path = os.path.join(model_dir, "reduced-config.json")
    logging.info(f"Loading config: {conf_path}")
    with open(conf_path, "r") as fh:
        cfg = json.load(fh)

    kw = _common_from_pretrained_kwargs()

    if cfg.get("use_lora", False):
        base_dir = os.path.join(model_dir, "base-model")
        lora_dir = os.path.join(model_dir, "fine-tuned-model")

        logging.info(f"Loading base model: {base_dir}")
        model = AutoModelForCausalLM.from_pretrained(base_dir, **kw)
        logging.info(f"Attaching LoRA adapter: {lora_dir}")
        # If PeftModel is missing, use .load_adapter if available
        try:
            model = PeftModel.from_pretrained(model, lora_dir, is_trainable=False)  # type: ignore
        except Exception:
            model.load_adapter(lora_dir)

    else:
        ft_dir = os.path.join(model_dir, "fine-tuned-model")
        logging.info(f"Loading full fine-tuned model: {ft_dir}")
        model = AutoModelForCausalLM.from_pretrained(ft_dir, **kw)

    # Tokenizer hygiene
    tok_dir = os.path.join(model_dir, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tok_dir, use_fast=True, local_files_only=True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # better for causal LMs with dynamic padding

    # Runtime memory knobs for your gradient-based rollout
    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False  # reduce KV/activation memory during your search

    # Optional: quick sanity check of sharding
    try:
        print(getattr(model, "hf_device_map", "no device map"))
    except Exception:
        pass

    return model, tokenizer

def download_and_load(file_id, output_filename, load_model_path):
    """
    Wrapper that uses your existing helpers:
      - clear_memory(), download_file_from_google_drive(), process_paths()
    """
    clear_memory(verbose=False)

    _ = download_file_from_google_drive(
        file_id=file_id,
        output_dir="/kaggle/tmp",
        output_filename=output_filename,
        quiet=False
    )

    process_paths(paths=["/kaggle/tmp"], recursive=True, max_depth=None)

    model, tokenizer = load_model_and_tokenizer(load_model_path, merge_lora=True)
    return model, tokenizer


In [7]:
model, tokenizer = download_and_load(
    file_id="1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc",
    output_filename="model0.tar.gz",
    load_model_path="/kaggle/tmp/id-00000000"
)

2025-10-26 19:12:03.442808: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761505923.648915      77 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761505923.710066      77 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Downloading the file...


Downloading...
From (original): https://drive.google.com/uc?id=1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc
From (redirected): https://drive.google.com/uc?id=1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc&confirm=t&uuid=2e1baab3-18bd-4880-adc2-7bd8a1125ae1
To: /kaggle/tmp/model0.tar.gz
100%|██████████| 10.6G/10.6G [01:15<00:00, 140MB/s] 


Download successful! File saved to: /kaggle/tmp/model0.tar.gz
File size: 10092.92 MB
Processing directory: /kaggle/tmp
Extracting: /kaggle/tmp/model0.tar.gz
Deleted compressed file: /kaggle/tmp/model0.tar.gz
Total .tar.gz files processed: 1


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 1, 'model.layers.19': 1, 'model.layers.20': 1, 'model.layers.21': 1, 'model.layers.22': 1, 'model.layers.23': 1, 'model.layers.24': 1, 'model.layers.25': 1, 'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1}


In [8]:
from torch.nn import functional as F

def entropy_loss(batch_logits):
    log_probs = F.log_softmax(batch_logits, dim=-1)
    probs = log_probs.exp()
    entropy = -(probs * log_probs).sum(dim=-1)
    return entropy.mean()

In [9]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import os, torch

def load_prompts(tokenizer, args):
    # Load once; keep raw text to avoid huge pre-tokenized tensors in RAM
    ds = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])

    def collate(batch):
        texts = [ex["instruction"] for ex in batch]
        enc = tokenizer(
            texts,
            padding=True,                  # dynamic padding to batch max
            truncation=True,
            max_length=args["max_length"],
            return_tensors="pt"
        )
        return enc["input_ids"], enc["attention_mask"]

    num_workers = max(2, os.cpu_count() // 2)
    return DataLoader(
        ds,
        batch_size=args["batch_size"],
        shuffle=False,
        pin_memory=True,
        num_workers=num_workers,
        persistent_workers=True,
        collate_fn=collate
    )

In [31]:
args = {
    "data_dir": "kaggle/working/data",
    "max_length": 512,
    "batch_size": 16
}
dataloader = load_prompts(tokenizer, args)

In [11]:
from torch import nn

def find_embedding_layer(model):
    """
    Attempts to find the embedding layer in an arbitrary PyTorch model.
    
    Parameters:
    - model: A PyTorch model (nn.Module) or Hugging Face model.
    
    Returns:
    - embedding_layer: The nn.Embedding layer (or None if not found).
    - path: The attribute path to the embedding layer (e.g., 'embeddings.word_embeddings').
    """
    # Check if the model has a get_input_embeddings method (common in Hugging Face)
    if hasattr(model, 'get_input_embeddings'):
        emb = model.get_input_embeddings()
        if isinstance(emb, nn.Embedding):
            return emb, 'get_input_embeddings()'
    
    # Iterate through all named modules to find an embedding layer
    for name, module in model.named_modules():
        if isinstance(module, nn.Embedding):
            return module, name
    
    # If no embedding layer is found, return None
    return None, None

def freeze_except_embeddings(model, emb_layers):
    """
    Freezes all model parameters except the weights of specified embedding layers.
    
    Parameters:
    - model: PyTorch model (nn.Module).
    - emb_layers: Single nn.Embedding layer or list of nn.Embedding layers to keep unfrozen.
    
    Returns:
    - None
    """
    # Convert single embedding layer to list for generality
    if isinstance(emb_layers, nn.Embedding):
        emb_layers = [emb_layers]
    
    # Validate that emb_layers are part of the model
    model_params = set(model.parameters())
    for emb_layer in emb_layers:
        if not isinstance(emb_layer, nn.Embedding):
            raise ValueError(f"Expected nn.Embedding, got {type(emb_layer)}")
        if emb_layer.weight not in model_params:
            raise ValueError("Embedding layer weight is not part of the model's parameters")
    
    # Get set of embedding weights to keep unfrozen
    emb_weights = set(emb_layer.weight for emb_layer in emb_layers)
    
    # Freeze parameters and clear gradients
    for name, param in model.named_parameters():
        if param in emb_weights:
            param.requires_grad = True  # Ensure embedding weights are trainable
        else:
            param.requires_grad = False
            param.grad = None  # Clear gradients to save memory
    
    # Verify embedding layers remain unfrozen
    for emb_layer in emb_layers:
        assert emb_layer.weight.requires_grad, f"Embedding layer {emb_layer} was unexpectedly frozen"

In [12]:
def token_to_word(token_id, tokenizer):
    """Convert a token ID to its corresponding word using the tokenizer."""
    return tokenizer.decode([token_id]).strip()

def word_to_token(word, tokenizer):
    """Convert a word to its corresponding token ID using the tokenizer."""
    return tokenizer.encode(word, add_special_tokens=False)[0]

In [26]:
from torch.cuda.amp import autocast

@torch.no_grad()
def _one_step(model, embeddings, attention_mask, amp_dtype=torch.float16):
    with autocast(dtype=amp_dtype):
        out = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        logits = out.logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
    return probs  # (B, V)

def compute_loss(
    model, emb_layer, embeddings, attention_mask, loss_fn,
    n_tokens=10, amp_dtype=torch.float16, track_last_only=True
):
    """
    If track_last_only=True:
      - Roll n-1 steps with no grad (cheap).
      - Take 1 final tracked step and compute loss there only.
    """
    B, L, E = embeddings.shape
    dev = embeddings.device

    # Roll n-1 steps without grad to grow the sequence cheaply
    for _ in range(n_tokens):
        probs = _one_step(model, embeddings, attention_mask, amp_dtype)
        w = emb_layer.weight.to(dev)
        probs = probs.to(w.dtype)           # align dtype: float16 on T4
        next_embeds = probs @ w
        embeddings = torch.cat([embeddings, next_embeds.unsqueeze(1)], dim=1)
        attention_mask = torch.cat(
            [attention_mask, torch.ones((B, 1), dtype=attention_mask.dtype, device=dev)], dim=1
        )

    # Final step with grad tracking
    with autocast(dtype=amp_dtype):
        out = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        logits = out.logits[:, -1, :]          # (B, V)
        loss = loss_fn(logits) if track_last_only else 0.0

    return loss

In [1]:
import torch
from tqdm import tqdm


def _first_device_of_embedding(model):
    emb, _ = find_embedding_layer(model)
    if emb is None:
        raise RuntimeError("Could not find an nn.Embedding in the model.")
    return emb, emb.weight.device

def find_best_flip(
    model, batch, loss_fn, max_length_tokens, batch_id, topk, vocab_chunk
):
    """
    Memory-lean search:
      - compute grad wrt embeddings only for the final step,
      - process each sample independently,
      - chunk vocab to avoid allocating (V×L) at once.
    """
    model.eval()
    emb_layer, first_dev = _first_device_of_embedding(model)

    input_ids = batch[0].to(first_dev, non_blocking=True)      # (B, L)
    attention_mask = batch[1].to(first_dev, non_blocking=True) # (B, L)

    # Build initial embeddings (needs to be tracked)
    freeze_except_embeddings(model, emb_layer)
    embeddings = emb_layer(input_ids)                          # (B, L, E)
    embeddings.retain_grad()

    # Compute loss (only last step tracked)
    loss = compute_loss(
        model, emb_layer, embeddings, attention_mask,
        loss_fn=loss_fn, n_tokens=max_length_tokens,
        amp_dtype=torch.float16, track_last_only=True
    )

    # Backprop to get grads wrt embeddings
    model.zero_grad(set_to_none=True)
    loss.backward()

    grads = embeddings.grad.detach()   # (B, L, E)
    embeds_det = embeddings.detach()   # (B, L, E)
    del embeddings, loss
    torch.cuda.empty_cache()

    # We'll need vocab embeddings on GPU, but in CHUNKS
    results = []

    B, L, E = grads.shape
    V = emb_layer.weight.size(0)
    dev = first_dev

    # s_i = (g_i · v_i) per position, per sample -> (B, L)
    s_i = (grads * embeds_det).sum(dim=2)  # (B, L)

    # If you have attention_mask available in this scope, mask out padding positions now.
    # Make sure it's on the same device as scores later.
    # Uncomment if attention_mask exists:
    # attn_mask = attention_mask.to(dev, non_blocking=True)  # (B, L)
    # mask_b1l = (attn_mask == 0).unsqueeze(1)               # (B, 1, L)

    # Running accumulators across vocab chunks
    if topk == 1:
        best_vals = torch.full((B,), float("inf"), device=dev, dtype=grads.dtype)
        best_flat_idx = torch.full((B,), -1, device=dev, dtype=torch.long)
    else:
        vals_keep = None  # (B, k_accum)
        idx_keep  = None  # (B, k_accum)

    offset = 0  # how many (vocab positions * L) we've traversed so far

    for start in tqdm(range(0, V, vocab_chunk), desc="Processing vocab chunks"):
        end = min(start + vocab_chunk, V)
        vocab_slice = emb_layer.weight[start:end].to(dev, non_blocking=True)  # (vchunk, E)
        # scores = (B, vchunk, L) = (vchunk, E) @ (B, E, L)
        # using einsum handles broadcasting cleanly
        scores = torch.einsum("ve,ble->bvl", vocab_slice, grads)  # (B, vchunk, L)

        # subtract s_i across L
        scores = scores - s_i.unsqueeze(1)  # (B, 1, L) broadcast

        # Optional: mask padding positions so they are never picked
        # if 'mask_b1l' defined above:
        # scores = scores.masked_fill(mask_b1l, float("inf"))

        # Flatten chunk per sample to shape (B, vchunk*L)
        flat = scores.reshape(B, -1)

        if topk == 1:
            # best of this chunk per sample
            chunk_vals, chunk_idx = torch.min(flat, dim=1)  # (B,)
            # where better than current best, update
            update = chunk_vals < best_vals
            best_vals = torch.where(update, chunk_vals, best_vals)
            best_flat_idx = torch.where(update, chunk_idx + offset, best_flat_idx)
        else:
            k_here = min(topk, flat.size(1))
            chunk_vals, chunk_idx = torch.topk(flat, k=k_here, largest=False, dim=1)  # (B, k_here)
            chunk_idx = chunk_idx + offset  # globalize indices

            if vals_keep is None:
                vals_keep, idx_keep = chunk_vals, chunk_idx
            else:
                # Merge with accumulated bests and reselect top-k
                vals_keep = torch.cat([vals_keep, chunk_vals], dim=1)  # (B, k_accum + k_here)
                idx_keep  = torch.cat([idx_keep,  chunk_idx],  dim=1)  # (B, k_accum + k_here)
                # Re-topk across the concatenated candidates
                k_sel = min(topk, vals_keep.size(1))
                sel_vals, sel_pos = torch.topk(vals_keep, k=k_sel, largest=False, dim=1)  # (B, k_sel)
                # Gather the matching global indices
                batch_ids = torch.arange(B, device=dev).unsqueeze(1).expand_as(sel_pos)
                idx_keep = idx_keep[batch_ids, sel_pos]  # (B, k_sel)
                vals_keep = sel_vals

        # free chunk temporaries
        del vocab_slice, scores, flat
        torch.cuda.empty_cache()

        # advance flat offset by the chunk length (vchunk * L)
        offset += (end - start) * L

    # ----- Convert batched results to your result list -----
    if topk == 1:
        # best_flat_idx encodes (v, pos) as: idx = v * L + pos
        best_v   = (best_flat_idx // L).tolist()
        best_pos = (best_flat_idx %  L).tolist()
        best_val = best_vals.tolist()

        for b in range(B):
            results.append({
                "best_position": int(best_pos[b]),
                "best_vocab_index": int(best_v[b]),
                "min_score": float(best_val[b]),
                "sample_id": int(b),
                "batch_id": int(batch_id)
            })
    else:
        # idx_keep/vals_keep are (B, topk)
        v_idx = (idx_keep // L).tolist()
        pos_i = (idx_keep %  L).tolist()
        vals  =  vals_keep.tolist()

        for b in range(B):
            pairs = [{"best_position": int(pos_i[b][j]),
                    "best_vocab_index": int(v_idx[b][j]),
                    "min_score": float(vals[b][j]),
                    "sample_id": int(b),
                    "batch_id": int(batch_id)}
                    for j in range(len(v_idx[b]))]
            results.extend(pairs)

    del grads, embeds_det, input_ids, attention_mask
    torch.cuda.empty_cache(); gc.collect()
    return results


In [15]:
check_memory()

Free GPU Memory: 8.45 GB
Total GPU Memory: 14.74 GB


In [None]:
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.nn.functional as F
from collections import defaultdict


def get_top_k_by_min_score(dict_list, k, per_sample=True):
    items = [d for d in dict_list if ("min_score" in d and "sample_id" in d)]
    if not items:
        raise NotImplementedError
    if per_sample:
        best = {}
        for d in items:
            sid = int(d["sample_id"])
            if (sid not in best) or (d["min_score"] < best[sid]["min_score"]):
                best[sid] = d
        items = list(best.values())
    items.sort(key=lambda x: x["min_score"])
    return items[:k]


def _collate_varlen(samples):
    """
    samples: list of tuples (ids_1D, mask_1D) with possibly varying lengths.
    Pads within the batch using each sample's last token id as pad value; mask pads with 0.
    """
    if not samples:
        return torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)
    max_len = max(s[0].numel() for s in samples)
    ids_out, msk_out = [], []
    for ids, msk in samples:
        pad_val = int(ids[-1].item()) if ids.numel() > 0 else 0
        if ids.numel() < max_len:
            ids = F.pad(ids, (0, max_len - ids.numel()), value=pad_val)
            msk = F.pad(msk, (0, max_len - msk.numel()), value=0)
        ids_out.append(ids)
        msk_out.append(msk)
    return torch.stack(ids_out, dim=0), torch.stack(msk_out, dim=0)

def apply_topk_results_to_inputs(
    original_dataloader,
    top_results,
    batch_cache,                 # REQUIRED: dict {batch_id: (ids_cpu, mask_cpu)}
    device="cpu",
):
    """
    Build and return a DataLoader containing ONE edited sample PER flip entry.
    Uses `batch_cache` captured during the scoring pass, so it doesn't rescan the loader.

    Each top_results item must include: batch_id, sample_id, best_position, best_vocab_index.
    """
    # Mirror original loader knobs
    bs   = getattr(original_dataloader, "batch_size", 1)
    nw   = getattr(original_dataloader, "num_workers", 0)
    pin  = getattr(original_dataloader, "pin_memory", device != "cpu")
    pers = getattr(original_dataloader, "persistent_workers", False)

    if not top_results:
        return DataLoader([], batch_size=bs, shuffle=False, pin_memory=pin,
                          num_workers=nw, persistent_workers=pers if nw > 0 else False,
                          collate_fn=_collate_varlen)

    # Group flips by batch for efficient lookup
    flips_by_batch = defaultdict(list)  # batch_id -> list[(sample_id, pos, tok)]
    for r in top_results:
        flips_by_batch[int(r["batch_id"])].append(
            (int(r["sample_id"]), int(r["best_position"]), int(r["best_vocab_index"]))
        )

    edited_samples = []
    for b_id, flips in flips_by_batch.items():
        if b_id not in batch_cache:
            # No cache for this batch — skip its flips
            continue
        ids_cpu, msk_cpu = batch_cache[b_id]  # (B,L), (B,L)
        B = ids_cpu.size(0)
        for s, p, t in flips:
            if 0 <= s < B:
                ids_1d = ids_cpu[s].clone()
                msk_1d = msk_cpu[s].clone()
                if 0 <= p < ids_1d.numel():
                    ids_1d[p] = t
                    msk_1d[p] = 1
                    edited_samples.append((ids_1d, msk_1d))
                # else: out-of-range pos → skip

    # Return loader over ONLY edited samples; safe even if empty
    return DataLoader(
        edited_samples,
        batch_size=bs,
        shuffle=True,
        pin_memory=pin,
        num_workers=nw,
        persistent_workers=pers if nw > 0 else False,
        collate_fn=_collate_varlen,
    )

In [None]:
import torch, gc

results = []
batch_id = 0
rounds = 10
topk_values = [26858, 14427, 7750, 4163, 2236, 1201, 645, 347, 186, 100]

for r in tqdm(range(rounds)):
    print(f"=== Round {r+1}/{rounds} ===")
    batch_cache = {}
    #############
    if round == 0:
        results = [{'best_position': 21, 'best_vocab_index': 25167, 'min_score': -32.375, 'sample_id': 6, 'batch_id': 2426}, {'best_position': 21, 'best_vocab_index': 30615, 'min_score': -32.25, 'sample_id': 9, 'batch_id': 2522}, {'best_position': 21, 'best_vocab_index': 30289, 'min_score': -31.40625, 'sample_id': 10, 'batch_id': 1427}, {'best_position': 24, 'best_vocab_index': 30466, 'min_score': -30.0, 'sample_id': 0, 'batch_id': 2087}, {'best_position': 23, 'best_vocab_index': 9588, 'min_score': -28.40625, 'sample_id': 13, 'batch_id': 84}, {'best_position': 28, 'best_vocab_index': 30143, 'min_score': -28.03125, 'sample_id': 7, 'batch_id': 3213}, {'best_position': 15, 'best_vocab_index': 30289, 'min_score': -27.484375, 'sample_id': 1, 'batch_id': 827}, {'best_position': 33, 'best_vocab_index': 9588, 'min_score': -26.265625, 'sample_id': 2, 'batch_id': 2389}, {'best_position': 21, 'best_vocab_index': 22506, 'min_score': -24.515625, 'sample_id': 4, 'batch_id': 3142}, {'best_position': 18, 'best_vocab_index': 30143, 'min_score': -22.921875, 'sample_id': 3, 'batch_id': 312}, {'best_position': 37, 'best_vocab_index': 4250, 'min_score': -22.65625, 'sample_id': 12, 'batch_id': 1920}, {'best_position': 19, 'best_vocab_index': 2, 'min_score': -21.671875, 'sample_id': 11, 'batch_id': 2232}, {'best_position': 21, 'best_vocab_index': 9588, 'min_score': -21.34375, 'sample_id': 5, 'batch_id': 583}, {'best_position': 14, 'best_vocab_index': 7784, 'min_score': -21.328125, 'sample_id': 8, 'batch_id': 1272}, {'best_position': 5, 'best_vocab_index': 21588, 'min_score': -16.875, 'sample_id': 14, 'batch_id': 2546}, {'best_position': 23, 'best_vocab_index': 9588, 'min_score': -16.203125, 'sample_id': 15, 'batch_id': 203}]
        for batch in tqdm(dataloader, desc="Loading batches"):
            batch_cache[batch_id] = (batch[0].cpu(), batch[1].cpu())
    else:
        for batch in tqdm(dataloader, desc="Processing batches"):
            _, first_dev = _first_device_of_embedding(model)
            batch = (batch[0].to(first_dev, non_blocking=True),
                    batch[1].to(first_dev, non_blocking=True))
            batch_cache[batch_id] = (batch[0].cpu(), batch[1].cpu())

            results.extend(find_best_flip(
                model=model,
                batch=batch,
                loss_fn=entropy_loss,
                max_length_tokens=10,
                batch_id=batch_id,
                topk=5,
                vocab_chunk=8192,         # tune: 4096–16384 depending on VRAM
            ))
            batch_id += 1

            # free per-iteration junk
            del batch
            torch.cuda.empty_cache()

    # Keep only the top-k and rebuild the loader
    print(f"results before: {results}")
    results = get_top_k_by_min_score(results, k=topk_values[int(r)])
    print(f"results after: {results}")
    dataloader = apply_topk_results_to_inputs(dataloader, results, batch_cache)

    # housekeeping
    torch.cuda.empty_cache(); gc.collect()

print(f"final results: {results}")

In [None]:
check_memory()

Free GPU Memory: 1.95 GB
Total GPU Memory: 14.74 GB
