In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel

# -----------------------------
# Models
# -----------------------------
model_name = "Qwen/Qwen2.5-14B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00,  1.22it/s]


In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
device = 'cuda'
hidden_size = model.config.hidden_size
scale = 2.0

model.eval()
model.gradient_checkpointing_enable()
for p in model.parameters():
    p.requires_grad = False

In [3]:
import pandas as pd
from tqdm.auto import tqdm 

# -----------------------------
# Dataset
# -----------------------------
class LatentVectorDataset(Dataset):
    def __init__(self, df):
        self.prompts = df['full_prompt'].tolist()
        self.targets = df['original_completion'].tolist()

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx], self.targets[idx]

df = pd.read_csv("scenarios_cleaned.csv")
dataset = LatentVectorDataset(df)
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

In [4]:
v = torch.randn(hidden_size, device=device, dtype=torch.bfloat16, requires_grad=True)
optimizer = torch.optim.Adam([v], lr=1e-2)

# Hook
handles = []
def unregister_all_hooks(handles):
    for h in handles: h.remove()
    handles.clear()

def random_hook(module, input, output):
    hidden = output[0]
    v_scaled = (v / v.norm(p=2)) * scale
    return (hidden + v_scaled.to(hidden.device),) + output[1:]

for layer in model.model.layers:
    handles.append(layer.register_forward_hook(random_hook))

# -----------------------------
# Trainable vector
# -----------------------------
v = torch.randn(hidden_size, device=device, dtype=torch.bfloat16, requires_grad=True)
optimizer = torch.optim.Adam([v], lr=1e-2)

# Hook
handles = []
def unregister_all_hooks(handles):
    for h in handles: h.remove()
    handles.clear()

def random_hook(module, input, output):
    hidden = output[0]
    # Normalization moved to after optimizer step to prevent vanishing gradients
    v_scaled = v * scale
    return (hidden + v_scaled.to(hidden.device),) + output[1:]

# -----------------------------
# Training loop with tqdm
# -----------------------------
max_new_tokens = 50
eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
epochs = 10

torch.cuda.empty_cache()

for epoch in range(epochs):
    # --- TQDM WRAPPER: Wrap train_loader with tqdm for a live progress bar ---
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    
    for prompt, target in progress_bar:
        prompt_ids = tokenizer(prompt[0], return_tensors="pt").input_ids.to(device)
        
        optimizer.zero_grad()

        gen_ids = prompt_ids.clone()
        kl_loss_seq = []

        for step in range(max_new_tokens):
            unregister_all_hooks(handles)
            with torch.no_grad():
                logits_base = model(gen_ids, use_cache=False).logits[:, -1, :]
            
            for layer in model.model.layers:
                handles.append(layer.register_forward_hook(random_hook))

            logits_hook = model(gen_ids, use_cache=False).logits[:, -1, :]

            log_probs_base = F.log_softmax(logits_base, dim=-1)
            kl = F.kl_div(F.log_softmax(logits_hook, dim=-1), log_probs_base, reduction='batchmean', log_target=True)
            kl_loss_seq.append(kl)

            with torch.no_grad():
                next_token = torch.multinomial(F.softmax(logits_base, dim=-1), num_samples=1)
                gen_ids = torch.cat([gen_ids, next_token], dim=1)

            if next_token.item() == eos_id:
                break

        if not kl_loss_seq: continue

        avg_kl = torch.stack(kl_loss_seq).mean()
        
        loss = -avg_kl
        loss.backward()
        
        optimizer.step()

        # Normalize v *after* the optimizer step
        with torch.no_grad():
            v.copy_((v / v.norm(p=2)))
            
        # --- TQDM UPDATE: Show the latest KL divergence in the progress bar ---
        progress_bar.set_postfix(kl_div=f"{avg_kl.item():.4f}")

unregister_all_hooks(handles)

Epoch 1/10:   2%|▏         | 8/384 [00:44<34:45,  5.55s/it, kl_div=25.5000]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 79.22 GiB of which 12.06 MiB is free. Process 78354 has 17.65 GiB memory in use. Process 1317044 has 61.54 GiB memory in use. Of the allocated memory 59.70 GiB is allocated by PyTorch, and 1.11 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)