In [1]:
import torch
import warnings
warnings.filterwarnings("ignore")
def check_memory(gpu_index: int = 0):
    if torch.cuda.is_available():
        free_memory = torch.cuda.get_device_properties(gpu_index).total_memory - torch.cuda.memory_allocated(gpu_index)
        total_memory = torch.cuda.get_device_properties(gpu_index).total_memory
        print(f"Free GPU Memory: {free_memory / 1024**3:.2f} GB")
        print(f"Total GPU Memory: {total_memory / 1024**3:.2f} GB")
    else:
        print("CUDA is not available.")

In [2]:
import gc
import torch
from typing import List, Union

def clear_memory(keep_vars: Union[List[str], None] = None, verbose: bool = True):
    """
    Clears memory while preserving specified variables.
    Still clears GPU memory for all CUDA objects, including kept variables.
    
    Args:
        keep_vars: List of variable names to preserve in memory (will still be cleared from GPU)
        verbose: Whether to print memory clearing information
    """
    if verbose:
        print("Starting memory clearing process...")
    
    # Convert keep_vars to set for faster lookups
    keep_set = set(keep_vars) if keep_vars else set()
    
    # First pass: Move kept CUDA variables to CPU
    if torch.cuda.is_available():
        for name, var in list(globals().items()):
            if name in keep_set and isinstance(var, torch.Tensor) and var.is_cuda:
                if verbose:
                    print(f"Moving kept tensor '{name}' to CPU")
                globals()[name] = var.cpu()
    
    # Clear Python garbage collector
    gc.collect()
    if verbose:
        print("Ran Python garbage collection")
    
    # Clear CUDA memory if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if verbose:
            print("Cleared CUDA cache")
            print(f"Current CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
            print(f"Current CUDA memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
    
    # Try to clear TensorFlow/Keras if available
    try:
        import tensorflow as tf
        tf.keras.backend.clear_session()
        if verbose:
            print("Cleared TensorFlow/Keras session")
    except ImportError:
        pass
    
    # Delete objects not in keep_vars
    for name, var in list(globals().items()):
        if not name.startswith('__') and name not in keep_set:
            if isinstance(var, (torch.Tensor, torch.nn.Module)):
                del globals()[name]
                if verbose:
                    print(f"Deleted torch object: {name}")
            elif isinstance(var, list) and var and isinstance(var[0], torch.Tensor):
                del globals()[name]
                if verbose:
                    print(f"Deleted list of torch tensors: {name}")
    
    # Final garbage collection
    gc.collect()
    
    if verbose:
        print("Memory clearing complete")

In [3]:
import os
import gdown

def download_file_from_google_drive(file_id, output_dir, output_filename, quiet=False):
    """
    Downloads a file from Google Drive given its file ID and saves it to the specified directory.
    
    Args:
        file_id (str): The Google Drive file ID (found in the file URL)
        output_dir (str): Directory where the file should be saved
        output_filename (str): Name of the output file
        quiet (bool): Whether to suppress gdown output (default: False)
    
    Returns:
        str: Path to the downloaded file if successful, None otherwise
    """
    # Create directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Full output path
    output_file = os.path.join(output_dir, output_filename)
    
    print("Downloading the file...")
    try:
        gdown.download(id=file_id, output=output_file, quiet=quiet, fuzzy=True)
    except Exception as e:
        print(f"Download failed: {str(e)}")
        return None
    
    # Verify download
    if os.path.exists(output_file):
        file_size = os.path.getsize(output_file) / (1024 * 1024)  # in MB
        print(f"Download successful! File saved to: {output_file}")
        print(f"File size: {file_size:.2f} MB")
        return output_file
    else:
        print("Download failed - file not found")
        return None

In [4]:
import os
import tarfile
from typing import List, Union

def extract_and_delete_tar_gz(file_path: str, delete_compressed: bool = True) -> bool:
    """
    Extracts a .tar.gz file and optionally deletes the compressed file.
    
    Args:
        file_path (str): Path to the .tar.gz file
        delete_compressed (bool): Whether to delete the compressed file after extraction (default: True)
    
    Returns:
        bool: True if extraction was successful, False otherwise
    """
    try:
        print(f"Extracting: {file_path}")
        with tarfile.open(file_path, 'r:gz') as tar:
            tar.extractall(path=os.path.dirname(file_path))
        
        if delete_compressed:
            os.remove(file_path)
            print(f"Deleted compressed file: {file_path}")
        return True
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return False

def process_directory(directory: str, recursive: bool = True, max_depth: Union[int, None] = None) -> int:
    """
    Processes a directory to find and extract .tar.gz files.
    
    Args:
        directory (str): Directory path to process
        recursive (bool): Whether to process subdirectories (default: True)
        max_depth (int|None): Maximum recursion depth (None for unlimited)
    
    Returns:
        int: Number of .tar.gz files processed
    """
    processed_count = 0
    current_depth = 0
    
    while True:
        found_tar_gz = False
        for root, dirs, files in os.walk(directory):
            # Calculate current depth
            rel_path = os.path.relpath(root, directory)
            current_depth = rel_path.count(os.sep) + 1 if rel_path != '.' else 0
            
            # Skip if beyond max depth
            if max_depth is not None and current_depth > max_depth:
                continue
                
            for file in files:
                if file.endswith('.tar.gz'):
                    file_path = os.path.join(root, file)
                    if extract_and_delete_tar_gz(file_path):
                        processed_count += 1
                        found_tar_gz = True
        
        # If not recursive or no more .tar.gz files found, exit
        if not recursive or not found_tar_gz:
            break
    
    return processed_count

def process_paths(paths: List[str], recursive: bool = True, max_depth: Union[int, None] = None) -> int:
    """
    Processes a list of paths (files or directories) to extract .tar.gz files.
    
    Args:
        paths (List[str]): List of file/directory paths to process
        recursive (bool): Whether to process directories recursively (default: True)
        max_depth (int|None): Maximum recursion depth for directories (None for unlimited)
    
    Returns:
        int: Total number of .tar.gz files processed
    """
    total_processed = 0
    
    for path in paths:
        if not os.path.exists(path):
            print(f"Warning: Path does not exist - {path}")
            continue
            
        if path.endswith('.tar.gz'):
            if extract_and_delete_tar_gz(path):
                total_processed += 1
        elif os.path.isdir(path):
            print(f"Processing directory: {path}")
            total_processed += process_directory(
                directory=path,
                recursive=recursive,
                max_depth=max_depth
            )
    
    print(f"Total .tar.gz files processed: {total_processed}")
    return total_processed

In [5]:
from os.path import join

import torch
import json
import os
import logging
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def load_model(model_filepath: str, torch_dtype:torch.dtype=torch.float16):
    """Load a model given a specific model_path.

    Args:
        model_filepath: str - Path to where the model is stored

    Returns:
        model, dict, str - Torch model + dictionary representation of the model + model class name
    """

    conf_filepath = os.path.join(model_filepath, 'reduced-config.json')
    logging.info("Loading config file from: {}".format(conf_filepath))
    with open(conf_filepath, 'r') as fh:
        round_config = json.load(fh)

    logging.info("Loading model from filepath: {}".format(model_filepath))
    # https://huggingface.co/docs/transformers/installation#offline-mode
    if round_config['use_lora']:
        base_model_filepath = os.path.join(model_filepath, 'base-model')
        logging.info("loading the base model (before LORA) from {}".format(base_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(base_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(round_config['model_architecture'], trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("loading the LORA adapter onto the base model from {}".format(fine_tuned_model_filepath))
        model.load_adapter(fine_tuned_model_filepath)
    else:
        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("Loading full fine tune checkpoint into cpu from {}".format(fine_tuned_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

    model.eval()

    tokenizer_filepath = os.path.join(model_filepath, 'tokenizer')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_filepath)

    return model, tokenizer

In [6]:
def download_and_load(file_id, output_filename, load_model_path):
    """Run the complete embedding optimization pipeline with custom weight function"""
    # Clear memory before starting
    clear_memory()
    
    # Download the model
    downloaded_file = download_file_from_google_drive(
        file_id=file_id,#"1-K-HcT-3-00rxPpvQxZ75o2be3STchsv",
        output_dir="/kaggle/tmp",
        output_filename=output_filename,#"model4.tar.gz",
        quiet=False
    )
    
    # Process
    process_paths(
        paths=['/kaggle/tmp',],
        recursive=True,
        max_depth=None
    )
    
    # Load model and tokenizer
    model, tokenizer = load_model(load_model_path)#"/kaggle/tmp/id-00000004")

    return model, tokenizer

model, tokenizer = download_and_load(file_id="1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc", output_filename="model0.tar.gz", load_model_path="/kaggle/tmp/id-00000000")

Starting memory clearing process...
Ran Python garbage collection
Cleared CUDA cache
Current CUDA memory allocated: 0.00 MB
Current CUDA memory cached: 0.00 MB


2025-10-21 15:38:47.244529: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761061127.276762     112 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761061127.285925     112 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Cleared TensorFlow/Keras session
Memory clearing complete
Downloading the file...


Downloading...
From (original): https://drive.google.com/uc?id=1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc
From (redirected): https://drive.google.com/uc?id=1lwC9JLRu4Z4SSQwjNtetAymStPqQeaDc&confirm=t&uuid=b3a2a1ab-3889-4e2e-a2c9-963a26aaa88f
To: /kaggle/tmp/model0.tar.gz
100%|██████████| 10.6G/10.6G [00:55<00:00, 192MB/s]


Download successful! File saved to: /kaggle/tmp/model0.tar.gz
File size: 10092.92 MB
Processing directory: /kaggle/tmp
Extracting: /kaggle/tmp/model0.tar.gz
Deleted compressed file: /kaggle/tmp/model0.tar.gz
Total .tar.gz files processed: 1


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
from torch.nn import functional as F

def entropy_loss(batch_logits):
    """
    Memory-efficient entropy loss without in-place operations that can break computation graphs
    """
    log_probs = F.log_softmax(batch_logits, dim=-1)

    probs = log_probs.exp()

    entropy = -(probs * log_probs).sum(dim=-1)

    return entropy.mean()

In [8]:
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset

def load_prompts(tokenizer, args):
    # Load Alpaca dataset
    dataset = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])

    # Tokenization function
    def tokenize_function(example):
        encodings = tokenizer(
            example["instruction"],
            padding="max_length",
            truncation=True,
            max_length=args["max_length"],
        )
        return {
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"]
        }

    # Tokenize all examples
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
        batch_size=1000
    )

    # Convert to torch tensors
    input_ids = torch.tensor(tokenized_dataset["input_ids"])
    attention_mask = torch.tensor(tokenized_dataset["attention_mask"])

    del dataset, tokenized_dataset
    gc.collect()

    # Create PyTorch Dataset
    torch_dataset = TensorDataset(input_ids, attention_mask)

    # Wrap in DataLoader
    return DataLoader(
        torch_dataset,
        batch_size=args["batch_size"],
        shuffle=False
    )

In [9]:
args = {
    "data_dir": "kaggle/working/data",
    "max_length": 512,
    "batch_size": 1
}
dataloader = load_prompts(tokenizer, args)

In [None]:
from torch import nn

def find_embedding_layer(model):
    """
    Attempts to find the embedding layer in an arbitrary PyTorch model.
    
    Parameters:
    - model: A PyTorch model (nn.Module) or Hugging Face model.
    
    Returns:
    - embedding_layer: The nn.Embedding layer (or None if not found).
    - path: The attribute path to the embedding layer (e.g., 'embeddings.word_embeddings').
    """
    # Check if the model has a get_input_embeddings method (common in Hugging Face)
    if hasattr(model, 'get_input_embeddings'):
        emb = model.get_input_embeddings()
        if isinstance(emb, nn.Embedding):
            return emb, 'get_input_embeddings()'
    
    # Iterate through all named modules to find an embedding layer
    for name, module in model.named_modules():
        if isinstance(module, nn.Embedding):
            return module, name
    
    # If no embedding layer is found, return None
    return None, None

def freeze_except_embeddings(model, emb_layers):
    """
    Freezes all model parameters except the weights of specified embedding layers.
    
    Parameters:
    - model: PyTorch model (nn.Module).
    - emb_layers: Single nn.Embedding layer or list of nn.Embedding layers to keep unfrozen.
    
    Returns:
    - None
    """
    # Convert single embedding layer to list for generality
    if isinstance(emb_layers, nn.Embedding):
        emb_layers = [emb_layers]
    
    # Validate that emb_layers are part of the model
    model_params = set(model.parameters())
    for emb_layer in emb_layers:
        if not isinstance(emb_layer, nn.Embedding):
            raise ValueError(f"Expected nn.Embedding, got {type(emb_layer)}")
        if emb_layer.weight not in model_params:
            raise ValueError("Embedding layer weight is not part of the model's parameters")
    
    # Get set of embedding weights to keep unfrozen
    emb_weights = set(emb_layer.weight for emb_layer in emb_layers)
    
    # Freeze parameters and clear gradients
    for name, param in model.named_parameters():
        if param in emb_weights:
            param.requires_grad = True  # Ensure embedding weights are trainable
        else:
            param.requires_grad = False
            param.grad = None  # Clear gradients to save memory
    
    # Verify embedding layers remain unfrozen
    for emb_layer in emb_layers:
        assert emb_layer.weight.requires_grad, f"Embedding layer {emb_layer} was unexpectedly frozen"

In [11]:
def token_to_word(token_id, tokenizer):
    """Convert a token ID to its corresponding word using the tokenizer."""
    return tokenizer.decode([token_id]).strip()

def word_to_token(word, tokenizer):
    """Convert a word to its corresponding token ID using the tokenizer."""
    return tokenizer.encode(word, add_special_tokens=False)[0]

In [17]:
def compute_loss(model, input_ids, attention_mask, loss_fn, n_tokens=10):
    """
    Autoregressive generation with gradient tracking only on the final step.
    Returns total loss over n future tokens.
    """
    total_loss = 0.0
    input_clone = input_ids.detach()
    attention_clone = attention_mask.detach()
    # Only track gradients for this step; no need for previous steps' graphs
    input_ids = input_ids.detach()
    attention_mask = attention_mask.detach()
    
    for _ in range(n_tokens):
        print(f"Processing {_} step")
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, -1, :]  # (B, V)

        # Compute step loss
        step_loss = loss_fn(logits)
        total_loss = total_loss + step_loss

        # Generate next token without gradient tracking
        with torch.no_grad():
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_token], dim=1)
            attention_mask = torch.cat(
                [attention_mask, torch.ones_like(next_token)], dim=1
            )

        # Free unused tensors
        del outputs, logits, step_loss, next_token
        torch.cuda.empty_cache()

    input_ids = input_clone
    attention_mask = attention_clone
    return total_loss


In [18]:
import torch
from tqdm import tqdm

def find_best_flip(model, batch, loss_fn, max_length_tokens, batch_id, topk=1):
    """
    For each example in the batch, find the (position i, replacement vocab index v_hat)
    that minimize g_i · (v_hat - v_i)  where g_i = dJ/dx_i, v_i is current embedding.

    Returns a list (one per batch element) of dicts with best positions / vocab indices / scores.
    """
    model.eval()
    device = next(model.parameters()).device
    emb_layer, _ = find_embedding_layer(model)
    print("found embedding layer")

    freeze_except_embeddings(model, emb_layer)
    print("froze all but embedding layer")


    input_ids = batch[0].to(device)       # (B, L)
    attention_mask = batch[1].to(device)   # (B, L)

    with torch.no_grad():
        vocab_embeds = emb_layer.weight.detach()                # (V, E)


    # 1) Get embeddings for inputs and ensure they retain grad
    embeddings = emb_layer(input_ids)                     # (B, L, E)
    embeddings.retain_grad()

    # 2) Compute loss
    print("gonna compute loss")
    loss = compute_loss(model, input_ids, attention_mask, loss_fn, max_length_tokens)
    # 3) Backprop to get grads wrt embeddings
    model.zero_grad()
    loss.backward()

    grads = embeddings.grad.detach()   # (B, L, E)
    embeds_det = embeddings.detach()   # (B, L, E)

    results = []
    B, L, _ = grads.shape

    # Process one sample at a time
    for b in tqdm(range(B), desc="Searching through batch samples"):
        grads_b = grads[b]         # (L, E)
        embeds_b = embeds_det[b]   # (L, E)

        # Compute per-position scalar s_i = g_i · v_i  (shape (L,))
        # element-wise dot across embedding dim
        s_i = (grads_b * embeds_b).sum(dim=1)   # (L,)

        # Compute dot between vocab_embeds (V, E) and grads_b (L, E) -> (V, L)
        # We want scores_{v,i} = vocab_embeds @ grads_b[i] - s_i[i]
        # Efficient way: (V, E) @ (E, L) = (V, L)
        # grads_b.T is (E, L)
        scores_VxL = torch.matmul(vocab_embeds, grads_b.T)   # (V, L)

        # subtract s_i from each column i
        scores_VxL = scores_VxL - s_i.unsqueeze(0)          # broadcast (V, L)

        # Now find minimal value and its indices
        # Flatten to search global min across V*L
        flat = scores_VxL.view(-1)              # (V*L,)
        if topk == 1:
            min_val, min_idx = torch.min(flat, dim=0)
            v_idx = (min_idx // L).item()       # vocab index
            pos_i = (min_idx % L).item()        # position index
            results.append({
                "best_position": pos_i,
                "best_vocab_index": v_idx,
                "min_score": min_val.item(),
                "sample_id": b,
                "batch_id": batch_id
            })
        else:
            # optionally return top-k pairs
            vals, idxs = torch.topk(flat, k=topk, largest=False)
            pairs = []
            for val, idx in zip(vals.tolist(), idxs.tolist()):
                v_idx = idx // L
                pos_i = idx % L
                pairs.append({"position": int(pos_i), "vocab_index": int(v_idx), "score": float(val)})
            results.append({"topk": pairs})
        del grads_b, embeds_b, s_i, scores_VxL, flat
        torch.cuda.empty_cache()

     # 🔹 Free remaining references
    del grads, embeds_det, embeddings, vocab_embeds, loss, input_ids, attention_mask
    torch.cuda.empty_cache()
    gc.collect()

    return results

In [20]:
check_memory()

Free GPU Memory: 14.74 GB
Total GPU Memory: 14.74 GB


In [15]:
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader


def get_top_k_by_min_score(dict_list, k):
    # Sort the list of dictionaries by min_score in ascending order
    sorted_list = sorted(dict_list, key=lambda x: x['min_score'])
    # Return the first k elements
    return sorted_list[:k]



def apply_topk_results_to_inputs(original_dataloader, top_results, device="cpu"):
    
    """
    Apply flips from top_results to the dataset behind a DataLoader.
    Works safely on multi-GPU (device_map="auto") setups.
    """

    all_input_ids = []
    all_attention_masks = []

    # Collect everything on CPU (safe and memory-stable)
    for batch in original_dataloader:
        input_ids, attention_mask = batch
        all_input_ids.append(input_ids.cpu())
        all_attention_masks.append(attention_mask.cpu())

    # Combine into full tensors
    all_input_ids = torch.cat(all_input_ids, dim=0)
    all_attention_masks = torch.cat(all_attention_masks, dim=0)
    torch.cuda.empty_cache()

    # Clone for modified copy
    new_inputs = all_input_ids.clone()

    # Apply token flips (CPU-side, no GPU load)
    for res in top_results:
        sample_idx = res["sample_id"]
        pos_i = res["best_position"]
        new_token = res["best_vocab_index"]
        new_inputs[sample_idx, pos_i] = new_token

    # Combine datasets
    combined_inputs = torch.cat([all_input_ids, new_inputs], dim=0)
    combined_masks = torch.cat([all_attention_masks, all_attention_masks], dim=0)

    # Cleanup intermediate large tensors to free memory
    del all_input_ids, all_attention_masks, new_inputs
    torch.cuda.empty_cache()

    # Build dataset and DataLoader
    new_dataset = TensorDataset(combined_inputs, combined_masks)
    new_dataloader = DataLoader(
        new_dataset,
        batch_size=original_dataloader.batch_size,
        shuffle=True,
        pin_memory=(device != "cpu"),
    )

    return new_dataloader

In [21]:
results = []
batch_id = 0
rounds = 10
for round in range(rounds):
    print(f"=== Round {round+1}/{rounds} ===")
    for batch in tqdm(dataloader, desc="Processing batches"):
        input_ids, attention_mask = batch
        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device)

        results.extend(find_best_flip(
            model=model,
            batch=batch,
            loss_fn=entropy_loss,
            max_length_tokens=1,
            batch_id=batch_id,
            topk=1
        ))
        batch_id += 1
    # === Use top-k results to update dataloader ===
    results = get_top_k_by_min_score(results, k=500)
    dataloader = apply_topk_results_to_inputs(dataloader, results)

=== Round 1/10 ===


Processing batches:   0%|          | 0/52002 [00:00<?, ?it/s]

found embedding layer
froze all but embedding layer
gonna compute loss
Processing 0 step


Processing batches:   0%|          | 0/52002 [00:17<?, ?it/s]


KeyboardInterrupt: 

In [27]:
clear_memory(keep_vars=['model', 'tokenizer'], verbose=True)

Starting memory clearing process...
Ran Python garbage collection
Cleared CUDA cache
Current CUDA memory allocated: 14186.88 MB
Current CUDA memory cached: 14928.00 MB
Cleared TensorFlow/Keras session
Memory clearing complete
