# Login wandb project

In [None]:
#!pip install -U transformers torch

# Mount google drive if using colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [8]:
# Add this import at the top of your model loading cell (cell 6)

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.utils as utils  # Add this line
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from huggingface_hub import login
from tqdm import tqdm


class PreferenceDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j]
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [None]:
!pip install transformers==4.40.0

In [None]:
login("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")

# Use a smaller model to avoid memory issues
model_name = "microsoft/DialoGPT-medium"  # Much smaller than LLaMA-3-8B
# Alternative: "distilbert-base-uncased" or "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Load model with more memory-efficient settings
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",  # Let it decide device placement
    low_cpu_mem_usage=True,  # More memory efficient loading
    trust_remote_code=True
)

model.train()

# Training configuration
config = {
    "model": model_name,
    "optimizer": "AdamW",
    "lr": 5e-6,
    "tau": 0.05,
    "batch_size": 1,  # Reduced batch size
    "epochs": 3,
}

# === Helpers ===
def safe_ipo_loss(logp_win, logp_lose, tau=0.05):
    """Safe IPO loss with NaN detection and clipping"""
    # Ensure inputs are valid
    if torch.isnan(logp_win).any() or torch.isnan(logp_lose).any():
        print(f"NaN detected in logp: win={logp_win.item()}, lose={logp_lose.item()}")
        return torch.tensor(0.0, device=logp_win.device, requires_grad=True)
    
    # Clip extreme values
    logp_win = torch.clamp(logp_win, min=-10, max=10)
    logp_lose = torch.clamp(logp_lose, min=-10, max=10)
    
    # Compute IPO loss with safety checks
    diff = logp_win - logp_lose - 0.5 / tau
    loss = (diff ** 2).mean()
    
    # Check for NaN in loss
    if torch.isnan(loss):
        print(f"NaN in loss computation: diff={diff.item()}")
        return torch.tensor(0.0, device=logp_win.device, requires_grad=True)
    
    return loss

def safe_compute_logp(prompt, completion):
    """Safe log probability computation with error handling"""
    try:
        full_input = prompt + completion
        
        # Shorter sequences to avoid issues
        encoded = tokenizer(full_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
        
        device = next(model.parameters()).device
        input_ids = encoded.input_ids.to(device)
        attention_mask = encoded.attention_mask.to(device)

        # Get prompt length
        prompt_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
        prompt_len = prompt_encoded.input_ids.shape[-1]

        labels = input_ids.clone()
        labels[:, :prompt_len] = -100  # mask out prompt

        # Use full precision for stability
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            
        logp = -outputs.loss
        
        # Safety checks
        if torch.isnan(logp) or torch.isinf(logp):
            print(f"Invalid logp: {logp.item()}")
            return torch.tensor(-1.0, device=device, requires_grad=True)
            
        return logp.detach().requires_grad_(True)
        
    except Exception as e:
        print(f"Error in compute_logp: {e}")
        return torch.tensor(-1.0, device=next(model.parameters()).device, requires_grad=True)

        
print(f"Model {model_name} loaded successfully!")
print(f"Model is on device: {next(model.parameters()).device}")
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [10]:
def clear_cuda_cache():
    """Clear CUDA cache to free up memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

In [None]:
dataset_path = 'preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 1

def compute_logp(prompt, completion):
    """Basic log probability computation"""
    return safe_compute_logp(prompt, completion)

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """Basic IPO loss"""
    return safe_ipo_loss(logp_win, logp_lose, tau)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(pbar):
        questions = batch["question"]
        preferred = batch["preferred"] 
        dispreferred = batch["dispreferred"]

        logp_w_list = []
        logp_l_list = []

        for q, w, l in zip(questions, preferred, dispreferred):
            prompt = f"Query: "
            
            logp_w = compute_logp(prompt, w.strip())
            logp_l = compute_logp(prompt, l.strip())
            
            logp_w_list.append(logp_w)
            logp_l_list.append(logp_l)

        logp_w_batch = torch.stack(logp_w_list)
        logp_l_batch = torch.stack(logp_l_list)
        
        loss = ipo_loss(logp_w_batch, logp_l_batch, tau)
        
        optimizer.zero_grad()
        loss.backward()
        utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        # Update progress bar with current loss
        avg_loss = total_loss / (batch_idx + 1)
        pbar.set_postfix({"loss": f"{avg_loss:.4f}"})

    avg_loss = total_loss / len(dataloader)


In [None]:
# === TEST TRAINING - Stop after 1 iteration ===
dataset_path = 'preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 1  # Only 1 epoch for testing

print(f"Model is on device: {next(model.parameters()).device}")
print(f"Starting test run - will stop after 1 iteration")
print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Define missing functions
def compute_logp(prompt, completion):
    """Basic log probability computation"""
    return safe_compute_logp(prompt, completion)

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """Basic IPO loss"""
    return safe_ipo_loss(logp_win, logp_lose, tau)
i = 0
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(dataloader, desc=f"Test Epoch {epoch+1}")
    for batch_idx, batch in enumerate(pbar):
        print(f"Processing batch {batch_idx}")
        
        questions = batch["question"]
        preferred = batch["preferred"]
        dispreferred = batch["dispreferred"]
        
        print(f"Question: {questions[0][:100]}...")
        print(f"Preferred: {preferred[0]}")
        print(f"Dispreferred: {dispreferred[0]}")

        logp_w_list, logp_l_list = [], []

        for q, w, l in zip(questions, preferred, dispreferred):
            try:
                prompt = f"Generate a search query for the following question:\n{q}\nQuery:"
                print(f"Computing logp for preferred: {w}")
                logp_w = compute_logp(prompt, " " + w)
                print(f"Preferred logp: {logp_w.item():.4f}")
                
                print(f"Computing logp for dispreferred: {l}")
                logp_l = compute_logp(prompt, " " + l)
                print(f"Dispreferred logp: {logp_l.item():.4f}")
                
                logp_w_list.append(logp_w)
                logp_l_list.append(logp_l)
                print(f"Batch {batch_idx} completed successfully")
            except RuntimeError as e:
                print(f"Error in batch {batch_idx}: {e}")
                clear_cuda_cache()
                continue

        if logp_w_list and logp_l_list:
            logp_w_batch = torch.stack(logp_w_list)
            logp_l_batch = torch.stack(logp_l_list)
            loss = ipo_loss(logp_w_batch, logp_l_batch, tau)
            
            print(f"IPO Loss: {loss.item():.4f}")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            print(f"Loss computed and backprop completed: {loss.item():.4f}")
            
            # Clear intermediate tensors and cache
            del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
            
        clear_cuda_cache()
        print(f"GPU memory after batch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

        avg_loss = total_loss / max(1, batch_idx + 1)
        pbar.set_postfix({"loss": avg_loss})
        print(f"Test loss: {avg_loss:.6f}")
        
        # STOP AFTER FIRST ITERATION FOR TESTING
        print("TEST COMPLETE - Stopping after 1 iteration")
        if i == 10:
            break
        i += 1

    print(f"[Test Epoch {epoch + 1}] Average Loss: {total_loss / max(1, 1):.4f}")
    clear_cuda_cache()

print("Test run completed successfully!")
print(f"Final GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [5]:
# prompt: free the cuda gpu memory aggresivley, atomic way, nuclear it : "OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 76.88 MiB is free. Process 72946 has 39.47 GiB memory in use. Of the allocated memory 38.61 GiB is allocated by PyTorch, and 371.07 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
# "

# Function to clear CUDA cache
def clear_cuda_cache():
  if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache cleared.")

# Call this function after each batch or at the end of each epoch
# Example usage within your training loop:
# ... (inside the for batch loop) ...
#         optimizer.step()
#         clear_cuda_cache() # Call after optimizer step
# ... (outside the for batch loop) ...
# print(f"[Epoch {epoch + 1}] Average Loss: {total_loss / len(dataloader):.4f}")
# clear_cuda_cache() # Call at the end of the epoch

# You can also call it at the beginning of the loop or whenever you suspect memory issues.

# Alternatively, you can also try deleting tensors that are no longer needed.
# For example, inside the batch loop:
# del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
# clear_cuda_cache()

# Add this function definition somewhere before the training loop.


In [None]:
# Nuclear GPU memory clearing - run this first
import torch
import gc
import os

def nuclear_gpu_clear():
    """Aggressively clear GPU memory"""
    if torch.cuda.is_available():
        # Clear all cached memory
        torch.cuda.empty_cache()
        
        # Force garbage collection
        gc.collect()
        
        # Reset memory stats
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        
        # Try to clear any remaining allocations
        try:
            torch.cuda.synchronize()
        except:
            pass
            
        print(f"GPU memory cleared. Free: {torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()} bytes")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory} bytes")
        print(f"Currently allocated: {torch.cuda.memory_allocated()} bytes")

# Set environment variable for memory fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear GPU memory now
nuclear_gpu_clear()

In [14]:
# Test compute_logp function with various examples
import torch

def test_compute_logp_examples():
    """Test compute_logp with various prompt-completion pairs"""
    
    # Test examples with different scenarios
    test_cases = [
        {
            "name": "Simple Query Generation",
            "prompt": "Query: ",
            "completions": [
                "what is machine learning",
                "machine learning definition",
                "ML basics",
                "artificial intelligence",
                "random text here"
            ]
        },
        {
            "name": "Question-Answer Format",
            "prompt": "Question: What is the capital of France?\nAnswer:",
            "completions": [
                " Paris",
                " London", 
                " Berlin",
                " The capital is Paris",
                " I don't know"
            ]
        },
        {
            "name": "Search Query Context",
            "prompt": "Generate a search query for: How does photosynthesis work?\nQuery:",
            "completions": [
                " photosynthesis process",
                " how photosynthesis works",
                " plant biology photosynthesis",
                " chlorophyll sunlight conversion",
                " unrelated topic"
            ]
        },
        {
            "name": "Short vs Long Completions",
            "prompt": "Query: ",
            "completions": [
                "AI",
                "artificial intelligence",
                "artificial intelligence and machine learning overview",
                "comprehensive guide to artificial intelligence and machine learning technologies",
                ""  # empty completion
            ]
        }
    ]
    
    print("=" * 80)
    print("TESTING compute_logp FUNCTION")
    print("=" * 80)
    
    for test_case in test_cases:
        print(f"\n📋 Test Case: {test_case['name']}")
        print(f"Prompt: '{test_case['prompt']}'")
        print("-" * 60)
        
        results = []
        
        for i, completion in enumerate(test_case['completions']):
            try:
                # Compute log probability
                logp = compute_logp(test_case['prompt'], completion)
                logp_value = logp.item() if hasattr(logp, 'item') else float(logp)
                
                # Store result
                results.append({
                    'completion': completion,
                    'logp': logp_value,
                    'prob': torch.exp(logp).item() if hasattr(logp, 'item') else float(torch.exp(torch.tensor(logp_value)))
                })
                
                print(f"{i+1}. Completion: '{completion}'")
                print(f"   Log Prob: {logp_value:.4f}")
                print(f"   Probability: {results[-1]['prob']:.6f}")
                print()
                
            except Exception as e:
                print(f"{i+1}. Completion: '{completion}' - ERROR: {e}")
                print()
        
        # Sort by log probability (highest first)
        results.sort(key=lambda x: x['logp'], reverse=True)
        
        print("🏆 RANKING (Best to Worst):")
        for rank, result in enumerate(results, 1):
            print(f"{rank}. '{result['completion']}' (logp: {result['logp']:.4f})")
        
        print("\n" + "="*60)

def test_prompt_variations():
    """Test how different prompts affect the same completion"""
    
    completion = "machine learning algorithms"
    
    prompts = [
        "Query: ",
        "Search: ",
        "Find information about ",
        "Generate a query for ",
        "What should I search for regarding ",
        ""  # no prompt
    ]
    
    print("\n" + "=" * 80)
    print("TESTING PROMPT VARIATIONS")
    print(f"Fixed completion: '{completion}'")
    print("=" * 80)
    
    results = []
    
    for prompt in prompts:
        try:
            logp = compute_logp(prompt, completion)
            logp_value = logp.item() if hasattr(logp, 'item') else float(logp)
            
            results.append({
                'prompt': prompt,
                'logp': logp_value
            })
            
            print(f"Prompt: '{prompt}'")
            print(f"Log Prob: {logp_value:.4f}")
            print()
            
        except Exception as e:
            print(f"Prompt: '{prompt}' - ERROR: {e}")
            print()
    
    # Sort by log probability
    results.sort(key=lambda x: x['logp'], reverse=True)
    
    print("🏆 BEST PROMPTS:")
    for rank, result in enumerate(results, 1):
        print(f"{rank}. '{result['prompt']}' (logp: {result['logp']:.4f})")

def test_length_effect():
    """Test how completion length affects log probability"""
    
    prompt = "Query: "
    base_text = "machine learning"
    
    completions = [
        base_text,
        base_text + " algorithms",
        base_text + " algorithms and neural networks",
        base_text + " algorithms and neural networks for data science",
        base_text + " algorithms and neural networks for data science applications"
    ]
    
    print("\n" + "=" * 80)
    print("TESTING LENGTH EFFECT")
    print("=" * 80)
    
    for completion in completions:
        try:
            logp = compute_logp(prompt, completion)
            logp_value = logp.item() if hasattr(logp, 'item') else float(logp)
            
            print(f"Length: {len(completion)} chars")
            print(f"Text: '{completion}'")
            print(f"Log Prob: {logp_value:.4f}")
            print(f"Avg Log Prob per char: {logp_value/len(completion):.4f}")
            print("-" * 40)
            
        except Exception as e:
            print(f"Text: '{completion}' - ERROR: {e}")
            print("-" * 40)

# Run all tests
if __name__ == "__main__":
    test_compute_logp_examples()
    test_prompt_variations()
    test_length_effect()
    
    print("\n✅ Testing completed!")

TESTING compute_logp FUNCTION

📋 Test Case: Simple Query Generation
Prompt: 'Query: '
------------------------------------------------------------
Invalid logp: nan
1. Completion: 'what is machine learning'
   Log Prob: -1.0000
   Probability: 0.367879

Invalid logp: nan
2. Completion: 'machine learning definition'
   Log Prob: -1.0000
   Probability: 0.367879

Invalid logp: nan
3. Completion: 'ML basics'
   Log Prob: -1.0000
   Probability: 0.367879

Invalid logp: nan
4. Completion: 'artificial intelligence'
   Log Prob: -1.0000
   Probability: 0.367879

Invalid logp: nan
5. Completion: 'random text here'
   Log Prob: -1.0000
   Probability: 0.367879

🏆 RANKING (Best to Worst):
1. 'what is machine learning' (logp: -1.0000)
2. 'machine learning definition' (logp: -1.0000)
3. 'ML basics' (logp: -1.0000)
4. 'artificial intelligence' (logp: -1.0000)
5. 'random text here' (logp: -1.0000)


📋 Test Case: Question-Answer Format
Prompt: 'Question: What is the capital of France?
Answer:'
------

In [15]:
# Debug compute_logp function - run this to understand what's happening
import torch

def debug_compute_logp(prompt, completion):
    """Debug version of compute_logp with detailed logging"""
    print(f"\n🔍 DEBUGGING compute_logp")
    print(f"Prompt: '{prompt}'")
    print(f"Completion: '{completion}'")
    
    full_input = prompt + completion
    print(f"Full input: '{full_input}'")
    
    # Tokenize and show tokens
    encoded = tokenizer(full_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
    print(f"Full input tokens: {encoded.input_ids}")
    print(f"Full input decoded: {tokenizer.decode(encoded.input_ids[0])}")
    
    device = next(model.parameters()).device
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Get prompt length
    prompt_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    prompt_len = prompt_encoded.input_ids.shape[-1]
    print(f"Prompt tokens: {prompt_encoded.input_ids}")
    print(f"Prompt length: {prompt_len}")
    print(f"Prompt decoded: {tokenizer.decode(prompt_encoded.input_ids[0])}")
    
    # Show what tokens we're computing loss on
    labels = input_ids.clone()
    labels[:, :prompt_len] = -100
    print(f"Labels (masked): {labels}")
    
    # Show which tokens will be used for loss computation
    completion_tokens = input_ids[:, prompt_len:]
    print(f"Completion tokens (for loss): {completion_tokens}")
    print(f"Completion decoded: {tokenizer.decode(completion_tokens[0])}")
    
    # Model forward pass with debugging
    print(f"Model device: {device}")
    print(f"Input shape: {input_ids.shape}")
    
    try:
        with torch.no_grad():  # Don't track gradients for debugging
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            print(f"Raw loss: {outputs.loss}")
            print(f"Loss device: {outputs.loss.device}")
            print(f"Loss dtype: {outputs.loss.dtype}")
            print(f"Is loss NaN: {torch.isnan(outputs.loss)}")
            print(f"Is loss infinite: {torch.isinf(outputs.loss)}")
            
            if hasattr(outputs, 'logits'):
                print(f"Logits shape: {outputs.logits.shape}")
                print(f"Logits range: {outputs.logits.min():.4f} to {outputs.logits.max():.4f}")
                
            logp = -outputs.loss
            print(f"Log probability: {logp}")
            
            return logp
            
    except Exception as e:
        print(f"❌ Error in model forward pass: {e}")
        import traceback
        traceback.print_exc()
        return None

def test_simple_cases():
    """Test very simple cases to isolate the issue"""
    print("=" * 80)
    print("TESTING SIMPLE CASES")
    print("=" * 80)
    
    test_cases = [
        ("", "hello"),  # No prompt
        ("Q:", "A"),    # Very short
        ("Query:", "search"),  # Simple case
    ]
    
    for prompt, completion in test_cases:
        print("\n" + "-" * 50)
        result = debug_compute_logp(prompt, completion)
        if result is not None:
            print(f"✅ Success! Log prob: {result.item():.4f}")
        else:
            print("❌ Failed!")

def check_model_state():
    """Check if model is in correct state"""
    print("=" * 80)
    print("CHECKING MODEL STATE")
    print("=" * 80)
    
    print(f"Model type: {type(model)}")
    print(f"Model device: {next(model.parameters()).device}")
    print(f"Model dtype: {next(model.parameters()).dtype}")
    print(f"Model training mode: {model.training}")
    print(f"Model vocab size: {model.config.vocab_size}")
    print(f"Tokenizer vocab size: {len(tokenizer)}")
    
    # Test simple generation
    try:
        test_input = "Hello"
        inputs = tokenizer(test_input, return_tensors="pt").to(next(model.parameters()).device)
        with torch.no_grad():
            outputs = model(**inputs)
            print(f"✅ Model forward pass works")
            print(f"Output logits shape: {outputs.logits.shape}")
            print(f"Output logits range: {outputs.logits.min():.4f} to {outputs.logits.max():.4f}")
    except Exception as e:
        print(f"❌ Model forward pass failed: {e}")

# Run diagnostics
check_model_state()
test_simple_cases()

CHECKING MODEL STATE
Model type: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>
Model device: cuda:0
Model dtype: torch.float16
Model training mode: True
Model vocab size: 50257
Tokenizer vocab size: 50257
✅ Model forward pass works
Output logits shape: torch.Size([1, 1, 50257])
Output logits range: nan to nan
TESTING SIMPLE CASES

--------------------------------------------------

🔍 DEBUGGING compute_logp
Prompt: ''
Completion: 'hello'
Full input: 'hello'
Full input tokens: tensor([[31373]])
Full input decoded: hello
Prompt tokens: tensor([], size=(1, 0))
Prompt length: 0
Prompt decoded: 
Labels (masked): tensor([[31373]], device='cuda:0')
Completion tokens (for loss): tensor([[31373]], device='cuda:0')
Completion decoded: hello
Model device: cuda:0
Input shape: torch.Size([1, 1])
Raw loss: nan
Loss device: cuda:0
Loss dtype: torch.float32
Is loss NaN: True
Is loss infinite: False
Logits shape: torch.Size([1, 1, 50257])
Logits range: nan to nan
Log probability: nan


In [17]:
# Fix the NaN logits issue - run this cell to reload the model properly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

# Clear any existing model from memory
if 'model' in globals():
    del model


print("🔄 Reloading model with proper configuration...")

# Use GPT2 instead of DialoGPT for better stability
model_name = "gpt2"  # More stable than DialoGPT-medium

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Load model with FLOAT32 precision to avoid NaN issues
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,  # Use float32 instead of float16
    device_map="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

model.train()

print(f"✅ Model {model_name} loaded successfully!")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model dtype: {next(model.parameters()).dtype}")

# Test the model to ensure it works
def test_model_sanity():
    """Quick test to ensure model produces valid outputs"""
    test_input = "Hello world"
    inputs = tokenizer(test_input, return_tensors="pt").to(next(model.parameters()).device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        
    print(f"Test logits shape: {logits.shape}")
    print(f"Test logits range: {logits.min():.4f} to {logits.max():.4f}")
    print(f"Any NaN in logits: {torch.isnan(logits).any()}")
    print(f"Any Inf in logits: {torch.isinf(logits).any()}")
    
    if not torch.isnan(logits).any() and not torch.isinf(logits).any():
        print("✅ Model is working correctly!")
        return True
    else:
        print("❌ Model still has issues!")
        return False

test_model_sanity()

🔄 Reloading model with proper configuration...


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

✅ Model gpt2 loaded successfully!
Model device: cuda:0
Model dtype: torch.float32
Test logits shape: torch.Size([1, 2, 50257])
Test logits range: -114.9173 to -75.9695
Any NaN in logits: False
Any Inf in logits: False
✅ Model is working correctly!


True

In [18]:
# Now test the fixed compute_logp function
def fixed_compute_logp(prompt, completion):
    """Fixed version of compute_logp"""
    full_input = prompt + completion
    encoded = tokenizer(full_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    device = next(model.parameters()).device
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    # Get prompt length more carefully
    prompt_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    prompt_len = prompt_encoded.input_ids.shape[-1]

    # Create labels for loss computation
    labels = input_ids.clone()
    labels[:, :prompt_len] = -100  # Mask prompt tokens

    # Forward pass
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    
    # Check for valid loss
    if torch.isnan(outputs.loss) or torch.isinf(outputs.loss):
        print(f"⚠️ Invalid loss detected: {outputs.loss}")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    return -outputs.loss

# Test the fixed function
def test_fixed_compute_logp():
    print("=" * 80)
    print("TESTING FIXED compute_logp FUNCTION")
    print("=" * 80)
    
    test_cases = [
        ("", "hello"),
        ("Query: ", "machine learning"),
        ("Question: What is AI?\nAnswer: ", "Artificial Intelligence"),
        ("Search: ", "python tutorial"),
    ]
    
    for prompt, completion in test_cases:
        print(f"\nPrompt: '{prompt}'")
        print(f"Completion: '{completion}'")
        
        try:
            logp = fixed_compute_logp(prompt, completion)
            print(f"Log probability: {logp.item():.4f}")
            print(f"Probability: {torch.exp(logp).item():.6f}")
            print("✅ Success!")
        except Exception as e:
            print(f"❌ Error: {e}")
        
        print("-" * 40)

test_fixed_compute_logp()

TESTING FIXED compute_logp FUNCTION

Prompt: ''
Completion: 'hello'
⚠️ Invalid loss detected: nan
Log probability: 0.0000
Probability: 1.000000
✅ Success!
----------------------------------------

Prompt: 'Query: '
Completion: 'machine learning'
Log probability: -6.5192
Probability: 0.001475
✅ Success!
----------------------------------------

Prompt: 'Question: What is AI?
Answer: '
Completion: 'Artificial Intelligence'
Log probability: -0.5831
Probability: 0.558155
✅ Success!
----------------------------------------

Prompt: 'Search: '
Completion: 'python tutorial'
Log probability: -9.3671
Probability: 0.000085
✅ Success!
----------------------------------------


In [19]:
# Updated training code with fixed compute_logp function

dataset_path = 'preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 1

def compute_logp(prompt, completion):
    """Fixed version of compute_logp with NaN handling"""
    full_input = prompt + completion
    encoded = tokenizer(full_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    device = next(model.parameters()).device
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    prompt_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    prompt_len = prompt_encoded.input_ids.shape[-1]

    labels = input_ids.clone()
    labels[:, :prompt_len] = -100

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    
    # Check for valid loss
    if torch.isnan(outputs.loss) or torch.isinf(outputs.loss):
        print(f"⚠️ Invalid loss detected for: '{completion[:50]}...'")
        return torch.tensor(-10.0, device=device, requires_grad=True)  # Use -10 instead of 0
    
    return -outputs.loss

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """IPO loss with additional safety checks"""
    # Clamp values to prevent extreme differences
    logp_win = torch.clamp(logp_win, min=-15, max=5)
    logp_lose = torch.clamp(logp_lose, min=-15, max=5)
    
    diff = logp_win - logp_lose - 0.5 / tau
    loss = (diff ** 2).mean()
    
    if torch.isnan(loss):
        print(f"⚠️ NaN in IPO loss computation")
        return torch.tensor(1.0, device=logp_win.device, requires_grad=True)
    
    return loss

print("🚀 Starting IPO training with fixed compute_logp...")
print(f"Dataset size: {len(dataset)}")
print(f"Model device: {next(model.parameters()).device}")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    valid_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(pbar):
        questions = batch["question"]
        preferred = batch["preferred"] 
        dispreferred = batch["dispreferred"]

        logp_w_list = []
        logp_l_list = []

        for q, w, l in zip(questions, preferred, dispreferred):
            # Use a more informative prompt
            prompt = f"Generate a search query for: {q[:100]}...\nQuery: "
            
            try:
                logp_w = compute_logp(prompt, w.strip())
                logp_l = compute_logp(prompt, l.strip())
                
                logp_w_list.append(logp_w)
                logp_l_list.append(logp_l)
                
                # Optional: Print first few examples to see what's happening
                if batch_idx < 3:
                    print(f"\nBatch {batch_idx}:")
                    print(f"Preferred: '{w.strip()}' → logp: {logp_w.item():.4f}")
                    print(f"Dispreferred: '{l.strip()}' → logp: {logp_l.item():.4f}")
                    
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue

        if logp_w_list and logp_l_list:
            logp_w_batch = torch.stack(logp_w_list)
            logp_l_batch = torch.stack(logp_l_list)
            
            loss = ipo_loss(logp_w_batch, logp_l_batch, tau)
            
            optimizer.zero_grad()
            loss.backward()
            utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            valid_batches += 1
            
            # Update progress bar with current loss
            avg_loss = total_loss / valid_batches
            pbar.set_postfix({
                "loss": f"{avg_loss:.4f}",
                "batch_loss": f"{loss.item():.4f}",
                "valid_batches": valid_batches
            })
            
            # Clear memory
            del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
            
        # Periodic memory cleanup
        if batch_idx % 10 == 0:
            clear_cuda_cache()

    final_avg_loss = total_loss / max(valid_batches, 1)
    print(f"\n[Epoch {epoch + 1}] Average Loss: {final_avg_loss:.4f}")
    print(f"Valid batches processed: {valid_batches}/{len(dataloader)}")

print("✅ Training completed!")

🚀 Starting IPO training with fixed compute_logp...
Dataset size: 70073
Model device: cuda:0


Epoch 1/1:   0%|          | 1/70073 [00:00<2:07:22,  9.17it/s, loss=101.9563, batch_loss=107.8358, valid_batches=2]


Batch 0:
Preferred: 'Query:location Oksan Station East Asian country' → logp: -3.2360
Dispreferred: '(中國 /中国 ) in its native language.
Matsu Daily () is a newspaper owned by the government of the Lienchiang County, Fujian Province, Republic of China, an East Asian country which is commonly known by its largest island Taiwan.

Generate a search query for the following question:
Oksan Station is located in which East Asian country? 
Query:Oksan Station East Asian country

Context' → logp: -3.4341

Batch 1:
Preferred: 'Examples:
Question:Henry of Almain was the song of one of the wealthiest men in Europe who joined what?
Query:Henry of Almain father wealthiest man Europe joined what

Context:
 The song was written by group members Van Stephenson and Dave Robbins, along with Desmond Child.
 The song was written by group members Dave Robbins and Henry Paul, along with Lee Thomas Miller.
 Urban plays guitar on the Dixie Chicks' rendition.
"Some Days You Gotta Dance" is a song written by Tro

Epoch 1/1:   3%|▎         | 2232/70073 [02:24<1:12:58, 15.49it/s, loss=96.9433, batch_loss=258.4277, valid_batches=2232]


KeyboardInterrupt: 