In [None]:
import torch
import json
import os
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

def load_model(model_filepath: str, torch_dtype:torch.dtype=torch.float16):
    """Load a model given a specific model_path.

    Args:
        model_filepath: str - Path to where the model is stored

    Returns:
        model, dict, str - Torch model + dictionary representation of the model + model class name
    """

    conf_filepath = os.path.join(model_filepath, 'reduced-config.json')
    logging.info("Loading config file from: {}".format(conf_filepath))
    with open(conf_filepath, 'r') as fh:
        round_config = json.load(fh)

    logging.info("Loading model from filepath: {}".format(model_filepath))
    # https://huggingface.co/docs/transformers/installation#offline-mode
    if round_config['use_lora']:
        base_model_filepath = os.path.join(model_filepath, 'base-model')
        logging.info("loading the base model (before LORA) from {}".format(base_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(base_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(round_config['model_architecture'], trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("loading the LORA adapter onto the base model from {}".format(fine_tuned_model_filepath))
        model.load_adapter(fine_tuned_model_filepath)
    else:
        fine_tuned_model_filepath = os.path.join(model_filepath, 'fine-tuned-model')
        logging.info("Loading full fine tune checkpoint into cpu from {}".format(fine_tuned_model_filepath))
        model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, device_map = "auto", trust_remote_code=True, torch_dtype=torch_dtype, local_files_only=True)
        # model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_filepath, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch_dtype)

    model.eval()

    tokenizer_filepath = os.path.join(model_filepath, 'tokenizer')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_filepath)

    return model, tokenizer


def _two_gpu_max_memory(headroom_gb=2):
    """
    Reserve headroom so HF sharding MUST split across both 16GB T4s.
    """
    if not torch.cuda.is_available():
        return None
    n = torch.cuda.device_count()
    cap = f"{16 - headroom_gb}GiB"  # e.g., "14GiB"
    return {i: cap for i in range(n)}

def _common_from_pretrained_kwargs():
    """
    Settings that reduce both CPU and GPU peak memory and use a lean attention impl.
    """
    kw = dict(
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.float16,     # T4 â†’ FP16
        low_cpu_mem_usage=True,        # streaming load
        offload_state_dict=True,       # avoid CPU spikes
        attn_implementation="sdpa",    # available by default on Kaggle
    )
    mm = _two_gpu_max_memory(headroom_gb=2)
    if mm and torch.cuda.device_count() > 1:
        kw["device_map"] = "auto"
        kw["max_memory"] = mm
        # Optional if host RAM is tight:
        # kw["offload_folder"] = "/kaggle/working/offload"
    else:
        kw["device_map"] = {"": 0}
    return kw

def load_model_and_tokenizer(model_dir: str, merge_lora: bool = True):
    """
    Robust loader for full fine-tunes or LoRA adapters stored under `model_dir`.
    Expects:
      - reduced-config.json with {"use_lora": <bool>, ...}
      - For LoRA: base-model/, fine-tuned-model/
      - For full FT: fine-tuned-model/
      - tokenizer/ with tokenizer files
    Returns: (model, tokenizer)
    """
    conf_path = os.path.join(model_dir, "reduced-config.json")
    logging.info(f"Loading config: {conf_path}")
    with open(conf_path, "r") as fh:
        cfg = json.load(fh)

    kw = _common_from_pretrained_kwargs()

    if cfg.get("use_lora", False):
        base_dir = os.path.join(model_dir, "base-model")
        lora_dir = os.path.join(model_dir, "fine-tuned-model")

        logging.info(f"Loading base model: {base_dir}")
        model = AutoModelForCausalLM.from_pretrained(base_dir, **kw)
        logging.info(f"Attaching LoRA adapter: {lora_dir}")
        # If PeftModel is missing, use .load_adapter if available
        try:
            model = PeftModel.from_pretrained(model, lora_dir, is_trainable=False)  # type: ignore
        except Exception:
            model.load_adapter(lora_dir)

    else:
        ft_dir = os.path.join(model_dir, "fine-tuned-model")
        logging.info(f"Loading full fine-tuned model: {ft_dir}")
        model = AutoModelForCausalLM.from_pretrained(ft_dir, **kw)

    # Tokenizer hygiene
    tok_dir = os.path.join(model_dir, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tok_dir, use_fast=True, local_files_only=True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  # better for causal LMs with dynamic padding

    # Runtime memory knobs for your gradient-based rollout
    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False  # reduce KV/activation memory during your search

    # Optional: quick sanity check of sharding
    try:
        print(getattr(model, "hf_device_map", "no device map"))
    except Exception:
        pass

    return model, tokenizer

model, tokenizer = load_model_and_tokenizer(
    model_dir="/kaggle/input/trojai-rev2-00000001/id-00000001"
)

In [None]:
def get_emb_layer(model):
    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False
    return model.get_input_embeddings()

emb_layer = get_emb_layer(model)

In [None]:
def project_suffix_to_tokens_and_diagnostics(
    suffix_z,
    emb_layer,
    tokenizer,
):
    """
    suffix_z: (Ls, E) - optimized continuous suffix embeddings
    emb_layer: model.get_input_embeddings()
    """
    with torch.no_grad():
        dev = emb_layer.weight.device
        E = emb_layer.weight        # (V, E)
        V, d = E.shape

        # Move suffix to same device
        z = suffix_z.to(dev)        # (Ls, E)

        # ---- Fix dtype mismatch: work in float32 for stability ----
        E_f = E.float()             # (V, E) fp32
        z_f = z.float()             # (Ls, E) fp32

        # Normalize for cosine similarity
        E_norm = F.normalize(E_f, dim=-1)        # (V, E)
        z_norm = F.normalize(z_f, dim=-1)        # (Ls, E)

        # Cosine similarity: (V, E) @ (E, Ls) -> (V, Ls)
        cos_sim = torch.matmul(E_norm, z_norm.T)  # (V, Ls)

        # For each suffix position, get best matching token
        best_token_ids = cos_sim.argmax(dim=0)    # (Ls,)

        # Diagnostics: L2 distances between z[i] and E[best_token_ids[i]]
        nearest_embs = E_f[best_token_ids]        # (Ls, E) fp32
        l2_dists = (z_f - nearest_embs).norm(dim=-1)  # (Ls,)

        print("L2 distance between optimized embeddings and nearest token embeddings:")
        print(f"  min:  {l2_dists.min().item():.6f}")
        print(f"  max:  {l2_dists.max().item():.6f}")
        print(f"  mean: {l2_dists.mean().item():.6f}")

        best_cos = cos_sim.max(dim=0).values     # (Ls,)
        print("Cosine similarity of optimized embeddings to nearest tokens:")
        print(f"  min:  {best_cos.min().item():.6f}")
        print(f"  max:  {best_cos.max().item():.6f}")
        print(f"  mean: {best_cos.mean().item():.6f}")

        suffix_token_ids = best_token_ids.cpu()
        suffix_tokens = tokenizer.convert_ids_to_tokens(suffix_token_ids.tolist())
        suffix_text = tokenizer.decode(
            suffix_token_ids.tolist(),
            skip_special_tokens=False
        )

        print("\nProjected discrete suffix token IDs:", suffix_token_ids.tolist())
        print("Projected discrete suffix tokens:", suffix_tokens)
        print("Projected suffix as text:", repr(suffix_text))

        return suffix_token_ids


In [None]:
def print_suffix(suffix_z, emb_layer, tokenizer):
    """
    Print projected suffix tokens and diagnostics.
    """
    suffix_token_ids = project_suffix_to_tokens_and_diagnostics(
        suffix_z,
        emb_layer,
        tokenizer,
    )
    print("Projected Suffix Token IDs:", suffix_token_ids.tolist())
    suffix_tokens = tokenizer.convert_ids_to_tokens(suffix_token_ids.tolist())
    print("Projected Suffix Tokens:", suffix_tokens)
    suffix_text = tokenizer.decode(
        suffix_token_ids.tolist(),
        skip_special_tokens=False
    )
    print("Projected Suffix Text:", repr(suffix_text))

In [None]:
def read_suffix_pt(filepath: str) -> torch.Tensor:
    """
    Read suffix embeddings from a .pt file.
    """
    suffix_z = torch.load(filepath)
    return suffix_z

In [None]:
for i in range(5):
    suffix_z = read_suffix_pt(f"/kaggle/working/hotflip/rounds/suffix_r{i}_e2.pt")
    print(f"Read suffix{i} successfully.")
    print_suffix(suffix_z, emb_layer, tokenizer)