In [None]:
# ================== 2×GPU threaded pipeline (notebook-friendly) ==================
import os, json, torch, threading, queue, gc
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---- You already have these; included for completeness ----
def entropy_loss(batch_logits):
    import torch.nn.functional as F
    log_probs = F.log_softmax(batch_logits, dim=-1)
    probs = log_probs.exp()
    entropy = -(probs * log_probs).sum(dim=-1)
    return entropy.mean()

from torch.cuda.amp import autocast as cuda_autocast

@torch.no_grad()
def _one_step(model, embeddings, attention_mask, amp_dtype=torch.float16):
    with cuda_autocast(dtype=amp_dtype):
        out = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        logits = out.logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
    return probs

def compute_loss(
    model, emb_layer, embeddings, attention_mask, loss_fn,
    n_tokens=10, amp_dtype=torch.float16, track_last_only=True
):
    B, L, E = embeddings.shape
    dev = embeddings.device
    for _ in range(max(0, n_tokens - 1)):
        probs = _one_step(model, embeddings, attention_mask, amp_dtype)
        w = emb_layer.weight.to(dev)
        probs = probs.to(w.dtype)
        next_embeds = probs @ w
        embeddings = torch.cat([embeddings, next_embeds.unsqueeze(1)], dim=1)
        attention_mask = torch.cat(
            [attention_mask, torch.ones((B, 1), dtype=attention_mask.dtype, device=dev)], dim=1
        )

    with cuda_autocast(dtype=amp_dtype):
        out = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        logits = out.logits[:, -1, :]
        loss = loss_fn(logits) if track_last_only else 0.0
    return loss

# ---- Model/Tokenizer loader pinned to a single GPU ----
def load_model_single_device(model_dir: str, device_id: int):
    try:
        from peft import AutoPeftModelForCausalLM, PeftModel
        HAS_PEFT = True
    except Exception:
        HAS_PEFT = False

    kw = dict(
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        offload_state_dict=True,
        attn_implementation="sdpa",
        device_map={"": device_id},  # single GPU
    )

    with open(os.path.join(model_dir, "reduced-config.json"), "r") as fh:
        cfg = json.load(fh)

    if cfg.get("use_lora", False) and HAS_PEFT:
        lora_dir = os.path.join(model_dir, "fine-tuned-model")
        model = AutoPeftModelForCausalLM.from_pretrained(lora_dir, **kw)
        if hasattr(model, "merge_and_unload"):
            model = model.merge_and_unload()
    elif cfg.get("use_lora", False):
        base_dir = os.path.join(model_dir, "base-model")
        model = AutoModelForCausalLM.from_pretrained(base_dir, **kw)
        lora_dir = os.path.join(model_dir, "fine-tuned-model")
        try:
            from peft import PeftModel
            model = PeftModel.from_pretrained(model, lora_dir, is_trainable=False)
        except Exception:
            model.load_adapter(lora_dir)
    else:
        ft_dir = os.path.join(model_dir, "fine-tuned-model")
        model = AutoModelForCausalLM.from_pretrained(ft_dir, **kw)

    tok_dir = os.path.join(model_dir, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tok_dir, use_fast=True, local_files_only=True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    model.eval()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False
    return model, tokenizer

# ---- Data subset loader with dynamic padding ----
def build_subset_loader(tokenizer, args, indices):
    ds = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])
    sub = Subset(ds, indices)

    def collate(batch):
        texts = [ex["instruction"] for ex in batch]
        enc = tokenizer(texts, padding=True, truncation=True,
                        max_length=args["max_length"], return_tensors="pt")
        return enc["input_ids"], enc["attention_mask"]

    import os
    num_workers = max(2, os.cpu_count() // 2)
    return DataLoader(sub,
                      batch_size=args["batch_size"],
                      shuffle=False,
                      pin_memory=True,
                      num_workers=num_workers,
                      persistent_workers=True,
                      collate_fn=collate)

# ---- Your helpers (as you defined earlier) ----
from torch import nn

def find_embedding_layer(model):
    if hasattr(model, 'get_input_embeddings'):
        emb = model.get_input_embeddings()
        if isinstance(emb, nn.Embedding):
            return emb, 'get_input_embeddings()'
    for name, module in model.named_modules():
        if isinstance(module, nn.Embedding):
            return module, name
    return None, None

def freeze_except_embeddings(model, emb_layers):
    if isinstance(emb_layers, nn.Embedding):
        emb_layers = [emb_layers]
    model_params = set(model.parameters())
    for emb_layer in emb_layers:
        if not isinstance(emb_layer, nn.Embedding):
            raise ValueError(f"Expected nn.Embedding, got {type(emb_layer)}")
        if emb_layer.weight not in model_params:
            raise ValueError("Embedding layer weight not in model params")
    emb_weights = set(emb_layer.weight for emb_layer in emb_layers)
    for name, p in model.named_parameters():
        if p in emb_weights:
            p.requires_grad = True
        else:
            p.requires_grad = False
            p.grad = None

# ---- Batched vocab search (fast) ----
def batched_best_scores(emb_layer, grads, embeds_det, attention_mask, topk, vocab_chunk, device_id):
    """
    grads:      (B,L,E)
    embeds_det: (B,L,E)
    attention_mask: (B,L)
    returns: list of per-sample dicts (like your original)
    """
    B, L, E = grads.shape
    V = emb_layer.weight.size(0)
    dev = device_id

    s_i = (grads * embeds_det).sum(dim=2)       # (B,L)
    mask_b1l = (attention_mask == 0).unsqueeze(1)  # (B,1,L) padding mask

    results = []
    if topk == 1:
        best_vals = torch.full((B,), float("inf"), device=dev, dtype=grads.dtype)
        best_idx  = torch.full((B,), -1, device=dev, dtype=torch.long)
    else:
        vals_keep = None
        idx_keep  = None

    offset = 0
    for start in range(0, V, vocab_chunk):
        end = min(start + vocab_chunk, V)
        vocab_slice = emb_layer.weight[start:end].to(dev, non_blocking=True)  # (vchunk,E)
        scores = torch.einsum("ve,ble->bvl", vocab_slice, grads)  # (B,vchunk,L)
        scores = scores - s_i.unsqueeze(1)                        # (B,1,L) broadcast
        scores = scores.masked_fill(mask_b1l, float("inf"))
        flat = scores.reshape(B, -1)

        if topk == 1:
            chunk_vals, chunk_idx = torch.min(flat, dim=1)
            upd = chunk_vals < best_vals
            best_vals = torch.where(upd, chunk_vals, best_vals)
            best_idx  = torch.where(upd, chunk_idx + offset, best_idx)
        else:
            k_here = min(topk, flat.size(1))
            chunk_vals, chunk_idx = torch.topk(flat, k=k_here, largest=False, dim=1)
            chunk_idx = chunk_idx + offset
            if vals_keep is None:
                vals_keep, idx_keep = chunk_vals, chunk_idx
            else:
                vals_keep = torch.cat([vals_keep, chunk_vals], dim=1)
                idx_keep  = torch.cat([idx_keep,  chunk_idx],  dim=1)
                k_sel = min(topk, vals_keep.size(1))
                sel_vals, sel_pos = torch.topk(vals_keep, k=k_sel, largest=False, dim=1)
                batch_ids = torch.arange(B, device=dev).unsqueeze(1).expand_as(sel_pos)
                idx_keep = idx_keep[batch_ids, sel_pos]
                vals_keep = sel_vals

        del vocab_slice, scores, flat
        torch.cuda.empty_cache()
        offset += (end - start) * L

    if topk == 1:
        best_v = (best_idx // L).tolist()
        best_p = (best_idx %  L).tolist()
        best_vl= best_vals.tolist()
        for b in range(B):
            results.append({
                "best_position": int(best_p[b]),
                "best_vocab_index": int(best_v[b]),
                "min_score": float(best_vl[b]),
                "sample_id": int(b)
            })
    else:
        v_idx = (idx_keep // L).tolist()
        pos_i = (idx_keep %  L).tolist()
        vals  =  vals_keep.tolist()
        for b in range(B):
            pairs = [{"position": int(pos_i[b][j]),
                      "vocab_index": int(v_idx[b][j]),
                      "score": float(vals[b][j])}
                     for j in range(len(v_idx[b]))]
            results.append({"topk": pairs, "sample_id": int(b)})
    return results

# ---- Thread worker (no __main__ needed) ----
def worker_thread(device_id, model_dir, args, indices, vocab_chunk, topk, n_tokens, out_queue):
    try:
        torch.cuda.set_device(device_id)
        model, tokenizer = load_model_single_device(model_dir, device_id)
        loader = build_subset_loader(tokenizer, args, indices)
        emb_layer, _ = find_embedding_layer(model)

        local_results = []
        for batch in loader:
            input_ids = batch[0].to(device_id, non_blocking=True)
            attention = batch[1].to(device_id, non_blocking=True)

            freeze_except_embeddings(model, emb_layer)
            embeddings = emb_layer(input_ids).to(emb_layer.weight.dtype)
            embeddings.retain_grad()

            loss = compute_loss(model, emb_layer, embeddings, attention,
                                loss_fn=entropy_loss, n_tokens=n_tokens,
                                amp_dtype=torch.float16, track_last_only=True)
            model.zero_grad(set_to_none=True)
            loss.backward()

            grads = embeddings.grad.detach()
            embeds_det = embeddings.detach()

            batch_results = batched_best_scores(
                emb_layer, grads, embeds_det, attention,
                topk=topk, vocab_chunk=vocab_chunk, device_id=device_id
            )
            # tag with device for debugging
            for r in batch_results:
                r["device"] = int(device_id)
            local_results.extend(batch_results)

            del grads, embeds_det, embeddings, loss, input_ids, attention
            torch.cuda.empty_cache(); gc.collect()

        out_queue.put(local_results)
    except Exception as e:
        out_queue.put(("ERROR", int(device_id), str(e)))

# ---- Notebook-friendly 2×GPU runner ----
def run_two_gpu_threads(model_dir, args, vocab_chunk=8192, topk=1, n_tokens=5):
    # Determine dataset size once
    tmp = load_dataset("tatsu-lab/alpaca", split="train", cache_dir=args["data_dir"])
    N = len(tmp)
    idx0 = list(range(0, N, 2))
    idx1 = list(range(1, N, 2))
    del tmp

    q = queue.Queue()
    t0 = threading.Thread(target=worker_thread,
                          args=(0, model_dir, args, idx0, vocab_chunk, topk, n_tokens, q),
                          daemon=True)
    t1 = threading.Thread(target=worker_thread,
                          args=(1, model_dir, args, idx1, vocab_chunk, topk, n_tokens, q),
                          daemon=True)

    t0.start(); t1.start()
    # collect from both threads
    results = []
    for _ in range(2):
        msg = q.get()  # blocking
        if isinstance(msg, tuple) and msg[0] == "ERROR":
            _, dev, err = msg
            raise RuntimeError(f"Thread on cuda:{dev} failed: {err}")
        else:
            results.extend(msg)

    t0.join(); t1.join()
    return results
# ==============================================================================

In [2]:
args = {
    "data_dir": "kaggle/working/data",
    "max_length": 512,
    "batch_size": 2,
}

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128")
model_dir = "/kaggle/tmp/id-00000000"

results = run_two_gpu_threads(
    model_dir=model_dir,
    args=args,
    vocab_chunk=8192,   # tune per VRAM (4096–16384)
    n_tokens=10,
    topk=5
)

print(f"Collected {len(results)} tweaks from both GPUs.")


2025-10-26 15:31:33.141668: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761492693.319248     141 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761492693.374381     141 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


RuntimeError: Thread on cuda:1 failed: [Errno 2] No such file or directory: '/kaggle/tmp/id-00000000/reduced-config.json'