# Login wandb project

In [None]:
#!pip install -U transformers torch

# Mount google drive if using colab

In [1]:
from google.colab import drive
drive.mount('/content/drive')


ModuleNotFoundError: No module named 'google.colab'

In [2]:
# Add this import at the top of your model loading cell (cell 6)

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.utils as utils  # Add this line
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from huggingface_hub import login
from tqdm import tqdm


class PreferenceDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j]
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [None]:
!pip install transformers==4.40.0

In [3]:
login("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")

# Use a smaller model to avoid memory issues
model_name = "microsoft/DialoGPT-medium"  # Much smaller than LLaMA-3-8B
# Alternative: "distilbert-base-uncased" or "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Load model with more memory-efficient settings
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",  # Let it decide device placement
    low_cpu_mem_usage=True,  # More memory efficient loading
    trust_remote_code=True
)

model.train()

# Training configuration
config = {
    "model": model_name,
    "optimizer": "AdamW",
    "lr": 5e-6,
    "tau": 0.05,
    "batch_size": 1,  # Reduced batch size
    "epochs": 3,
}

# === Helpers ===
def safe_ipo_loss(logp_win, logp_lose, tau=0.05):
    """Safe IPO loss with NaN detection and clipping"""
    # Ensure inputs are valid
    if torch.isnan(logp_win).any() or torch.isnan(logp_lose).any():
        print(f"NaN detected in logp: win={logp_win.item()}, lose={logp_lose.item()}")
        return torch.tensor(0.0, device=logp_win.device, requires_grad=True)
    
    # Clip extreme values
    logp_win = torch.clamp(logp_win, min=-10, max=10)
    logp_lose = torch.clamp(logp_lose, min=-10, max=10)
    
    # Compute IPO loss with safety checks
    diff = logp_win - logp_lose - 0.5 / tau
    loss = (diff ** 2).mean()
    
    # Check for NaN in loss
    if torch.isnan(loss):
        print(f"NaN in loss computation: diff={diff.item()}")
        return torch.tensor(0.0, device=logp_win.device, requires_grad=True)
    
    return loss

def safe_compute_logp(prompt, completion):
    """Safe log probability computation with error handling"""
    try:
        full_input = prompt + completion
        
        # Shorter sequences to avoid issues
        encoded = tokenizer(full_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
        
        device = next(model.parameters()).device
        input_ids = encoded.input_ids.to(device)
        attention_mask = encoded.attention_mask.to(device)

        # Get prompt length
        prompt_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
        prompt_len = prompt_encoded.input_ids.shape[-1]

        labels = input_ids.clone()
        labels[:, :prompt_len] = -100  # mask out prompt

        # Use full precision for stability
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            
        logp = -outputs.loss
        
        # Safety checks
        if torch.isnan(logp) or torch.isinf(logp):
            print(f"Invalid logp: {logp.item()}")
            return torch.tensor(-1.0, device=device, requires_grad=True)
            
        return logp.detach().requires_grad_(True)
        
    except Exception as e:
        print(f"Error in compute_logp: {e}")
        return torch.tensor(-1.0, device=next(model.parameters()).device, requires_grad=True)

        
print(f"Model {model_name} loaded successfully!")
print(f"Model is on device: {next(model.parameters()).device}")
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

Model microsoft/DialoGPT-medium loaded successfully!
Model is on device: cuda:0
GPU memory allocated: 0.73 GB


In [None]:
def clear_cuda_cache():
    """Clear CUDA cache to free up memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

In [None]:
dataset_path = 'preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    valid_batches = 0

    pbar = tqdm(dataloader, desc=f"SAFE Epoch {epoch+1}")
    for batch_idx, batch in enumerate(pbar):
        try:
            questions = batch["question"]
            preferred = batch["preferred"] 
            dispreferred = batch["dispreferred"]

            logp_w_list, logp_l_list = [], []

            for q, w, l in zip(questions, preferred, dispreferred):
                # Clean the queries (remove extra formatting)
                w_clean = w.strip()
                l_clean = l.strip()
                
                prompt = f"Query: "
                
                print(f"Batch {batch_idx}: Computing logp...")
                logp_w = safe_compute_logp(prompt, w_clean)
                logp_l = safe_compute_logp(prompt, l_clean)
                
                print(f"  Preferred logp: {logp_w.item():.4f}")
                print(f"  Dispreferred logp: {logp_l.item():.4f}")
                
                logp_w_list.append(logp_w)
                logp_l_list.append(logp_l)

            if logp_w_list and logp_l_list:
                logp_w_batch = torch.stack(logp_w_list)
                logp_l_batch = torch.stack(logp_l_list)
                
                loss = safe_ipo_loss(logp_w_batch, logp_l_batch, tau)
                
                print(f"  Computed loss: {loss.item():.6f}")
                
                # Skip if loss is 0 (safety fallback)
                if loss.item() > 0:
                    optimizer.zero_grad()
                    loss.backward()
                    
                    # Gradient clipping for stability
                    utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    
                    optimizer.step()
                    
                    total_loss += loss.item()
                    valid_batches += 1
                    
                    print(f"  Backprop completed successfully")
                else:
                    print(f"  Skipped batch due to zero loss")
                
                # Clean up
                del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss

            clear_cuda_cache()

            if valid_batches > 0:
                avg_loss = total_loss / valid_batches
                pbar.set_postfix({"loss": avg_loss, "valid_batches": valid_batches})
                print(f"Average loss: {avg_loss:.6f}")
            
            # Stop after a few batches for testing
            if batch_idx >= 5:
                print(f"Stopping after {batch_idx+1} batches for testing")
                break
                
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            clear_cuda_cache()
            continue

    if valid_batches > 0:
        final_avg = total_loss / valid_batches
        print(f"[Epoch {epoch + 1}] Valid batches: {valid_batches}, Average Loss: {final_avg:.6f}")
    else:
        print(f"[Epoch {epoch + 1}] No valid batches processed!")
    
    clear_cuda_cache()

print("SAFE training completed!")


In [7]:
# === TEST TRAINING - Stop after 1 iteration ===
dataset_path = 'preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 1  # Only 1 epoch for testing

print(f"Model is on device: {next(model.parameters()).device}")
print(f"Starting test run - will stop after 1 iteration")
print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Define missing functions
def compute_logp(prompt, completion):
    """Basic log probability computation"""
    return safe_compute_logp(prompt, completion)

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """Basic IPO loss"""
    return safe_ipo_loss(logp_win, logp_lose, tau)
i = 0
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(dataloader, desc=f"Test Epoch {epoch+1}")
    for batch_idx, batch in enumerate(pbar):
        print(f"Processing batch {batch_idx}")
        
        questions = batch["question"]
        preferred = batch["preferred"]
        dispreferred = batch["dispreferred"]
        
        print(f"Question: {questions[0][:100]}...")
        print(f"Preferred: {preferred[0]}")
        print(f"Dispreferred: {dispreferred[0]}")

        logp_w_list, logp_l_list = [], []

        for q, w, l in zip(questions, preferred, dispreferred):
            try:
                prompt = f"Generate a search query for the following question:\n{q}\nQuery:"
                print(f"Computing logp for preferred: {w}")
                logp_w = compute_logp(prompt, " " + w)
                print(f"Preferred logp: {logp_w.item():.4f}")
                
                print(f"Computing logp for dispreferred: {l}")
                logp_l = compute_logp(prompt, " " + l)
                print(f"Dispreferred logp: {logp_l.item():.4f}")
                
                logp_w_list.append(logp_w)
                logp_l_list.append(logp_l)
                print(f"Batch {batch_idx} completed successfully")
            except RuntimeError as e:
                print(f"Error in batch {batch_idx}: {e}")
                clear_cuda_cache()
                continue

        if logp_w_list and logp_l_list:
            logp_w_batch = torch.stack(logp_w_list)
            logp_l_batch = torch.stack(logp_l_list)
            loss = ipo_loss(logp_w_batch, logp_l_batch, tau)
            
            print(f"IPO Loss: {loss.item():.4f}")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            print(f"Loss computed and backprop completed: {loss.item():.4f}")
            
            # Clear intermediate tensors and cache
            del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
            
        clear_cuda_cache()
        print(f"GPU memory after batch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

        avg_loss = total_loss / max(1, batch_idx + 1)
        pbar.set_postfix({"loss": avg_loss})
        print(f"Test loss: {avg_loss:.6f}")
        
        # STOP AFTER FIRST ITERATION FOR TESTING
        print("TEST COMPLETE - Stopping after 1 iteration")
        if i == 10:
            break
        i += 1

    print(f"[Test Epoch {epoch + 1}] Average Loss: {total_loss / max(1, 1):.4f}")
    clear_cuda_cache()

print("Test run completed successfully!")
print(f"Final GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

Model is on device: cuda:0
Starting test run - will stop after 1 iteration
GPU memory before training: 0.74 GB


Test Epoch 1:   0%|          | 2/70073 [00:00<1:03:41, 18.33it/s, loss=76]  

Processing batch 0
Question: What is the name of the engagement on the northern coast of France on 19 August 1942 in which the HM...
Preferred: Query: HMS Calpe engagement northern coast France 19
Dispreferred: launch point for a fleet of small boats that attacked a French convoy in the Gironde.
 In 1805 he sighted the French fleet under Admiral Ganteaume attempting to escape and warned the Offshore Squadron, who drove the French back into Brest in a brief engagement.

Generate a search query for the following question:
What is the name of the engagement on the northern coast of France on 19 August 1942 in which the HMS "Calpe" participated? 
Query:engagement northern coast France 1942 HMS
Computing logp for preferred: Query: HMS Calpe engagement northern coast France 19
Preferred logp: -6.4026
Computing logp for dispreferred: launch point for a fleet of small boats that attacked a French convoy in the Gironde.
 In 1805 he sighted the French fleet under Admiral Ganteaume attempting to 

Test Epoch 1:   0%|          | 8/70073 [00:00<49:09, 23.75it/s, loss=93.5]  

Dispreferred logp: -5.6678
Batch 4 completed successfully
IPO Loss: 115.6833
Loss computed and backprop completed: 115.6833
CUDA cache cleared.
GPU memory after batch: 0.74 GB
Test loss: 83.923558
TEST COMPLETE - Stopping after 1 iteration
Processing batch 5
Question: What type of gesture does Donets-Krivoy Rog Soviet Republic and have in common?...
Preferred: Examples:
Question:Comedian Frank Gorshin was a member of The Kopykats, a group of comedians on which television variety show in 1972?
Query:Frank Gorshin The Kopykats television variety show 1972

Generate a search query for the following question:
What type of gesture does Donets-Krivoy Rog Soviet Republic and have in common? Donets-Krivoy Rog Soviet Republic and the People's
Dispreferred: (The answer is a flag.)

Query:Donets-K
Computing logp for preferred: Examples:
Question:Comedian Frank Gorshin was a member of The Kopykats, a group of comedians on which television variety show in 1972?
Query:Frank Gorshin The Kopykats tele

Test Epoch 1:   0%|          | 8/70073 [00:00<49:09, 23.75it/s, loss=92.2]

Dispreferred logp: -6.6928
Batch 9 completed successfully
IPO Loss: 80.6311
Loss computed and backprop completed: 80.6311
CUDA cache cleared.
GPU memory after batch: 0.74 GB
Test loss: 92.183425
TEST COMPLETE - Stopping after 1 iteration
Processing batch 10
Question: Who nominated the justice who wrote the majority opinion in the Supreme Court case Christopher v. Sm...
Preferred: Question:Are Stuart Murdoch and Lee Sung-min both singers?
Query:Are Stuart Murdoch and Lee Sung-min both singers?

Generate a search query for the following question:
Who nominated the justice who wrote the majority opinion in the Supreme Court case Christopher v. SmithKline Beecham Corp? 
Query:Who nominated the justice who wrote the majority opinion
Dispreferred: genus has more species, Worsleya or Gordonia?
Query:Worsleya vs Gordonia number of species

Generate a search query for the following question:
Who nominated the justice who wrote the majority opinion in the Supreme Court case Christopher v. SmithK

Test Epoch 1:   0%|          | 10/70073 [00:00<54:38, 21.37it/s, loss=93.2]

Preferred logp: -5.2667
Computing logp for dispreferred: genus has more species, Worsleya or Gordonia?
Query:Worsleya vs Gordonia number of species

Generate a search query for the following question:
Who nominated the justice who wrote the majority opinion in the Supreme Court case Christopher v. SmithKline Beecham Corp? 
Query:Christopher v. SmithKline Beecham
Dispreferred logp: -5.0829
Batch 10 completed successfully
IPO Loss: 103.7095
Loss computed and backprop completed: 103.7095
CUDA cache cleared.
GPU memory after batch: 0.74 GB
Test loss: 93.231250
TEST COMPLETE - Stopping after 1 iteration
[Test Epoch 1] Average Loss: 1025.5438
CUDA cache cleared.
Test run completed successfully!
Final GPU memory: 0.74 GB





In [5]:
# prompt: free the cuda gpu memory aggresivley, atomic way, nuclear it : "OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 76.88 MiB is free. Process 72946 has 39.47 GiB memory in use. Of the allocated memory 38.61 GiB is allocated by PyTorch, and 371.07 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
# "

# Function to clear CUDA cache
def clear_cuda_cache():
  if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache cleared.")

# Call this function after each batch or at the end of each epoch
# Example usage within your training loop:
# ... (inside the for batch loop) ...
#         optimizer.step()
#         clear_cuda_cache() # Call after optimizer step
# ... (outside the for batch loop) ...
# print(f"[Epoch {epoch + 1}] Average Loss: {total_loss / len(dataloader):.4f}")
# clear_cuda_cache() # Call at the end of the epoch

# You can also call it at the beginning of the loop or whenever you suspect memory issues.

# Alternatively, you can also try deleting tensors that are no longer needed.
# For example, inside the batch loop:
# del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
# clear_cuda_cache()

# Add this function definition somewhere before the training loop.


In [None]:
# Nuclear GPU memory clearing - run this first
import torch
import gc
import os

def nuclear_gpu_clear():
    """Aggressively clear GPU memory"""
    if torch.cuda.is_available():
        # Clear all cached memory
        torch.cuda.empty_cache()
        
        # Force garbage collection
        gc.collect()
        
        # Reset memory stats
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        
        # Try to clear any remaining allocations
        try:
            torch.cuda.synchronize()
        except:
            pass
            
        print(f"GPU memory cleared. Free: {torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()} bytes")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory} bytes")
        print(f"Currently allocated: {torch.cuda.memory_allocated()} bytes")

# Set environment variable for memory fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear GPU memory now
nuclear_gpu_clear()