In [1]:
#!pip install -U transformers torch wandb
!wandb login a796e1b90cee67247e930fff3ef9e6001f4db5c2

from google.colab import drive
drive.mount('/content/drive')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!


In [3]:
from torch.utils.data import Dataset, DataLoader
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
import wandb
from huggingface_hub import login
from tqdm import tqdm

class PreferenceDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j]
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [4]:
!pip install transformers==4.40.0



In [None]:
login("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")

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

model.train()

# === WandB Init ===
wandb.init(project="TA-LeReT", name="IPO-finetune-LLaMA3", config={
    "model": model_name,
    "optimizer": "AdamW",
    "lr": 5e-6,
    "tau": 0.05,
    "batch_size": 2,
    "epochs": 3,
})

# === Helpers ===
def ipo_loss(logp_win, logp_lose, tau=0.05):
    return ((logp_win - logp_lose - 0.5 / tau) ** 2).mean()

def compute_logp(prompt, completion):
    full_input = prompt + completion
    encoded = tokenizer(full_input, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
    
    # Get the device of the model's first parameter
    device = next(model.parameters()).device
    
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    # Get prompt length
    prompt_len = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.shape[-1]

    labels = input_ids.clone()
    labels[:, :prompt_len] = -100  # mask out prompt

    # Updated autocast usage - no more deprecation warning
    with torch.amp.autocast('cuda'):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    return -outputs.loss

def clear_cuda_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# === Training ===
dataset_path = '/content/preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  # Reduced batch size to 1

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    for batch_idx, batch in enumerate(pbar):
        questions = batch["question"]
        preferred = batch["preferred"]
        dispreferred = batch["dispreferred"]

        logp_w_list, logp_l_list = [], []

        for q, w, l in zip(questions, preferred, dispreferred):
            try:
                prompt = f"Generate a search query for the following question:\n{q}\nQuery:"
                logp_w = compute_logp(prompt, " " + w)
                logp_l = compute_logp(prompt, " " + l)
                logp_w_list.append(logp_w)
                logp_l_list.append(logp_l)
            except RuntimeError as e:
                print(f"Error in batch {batch_idx}: {e}")
                clear_cuda_cache()
                continue

        if logp_w_list and logp_l_list:
            logp_w_batch = torch.stack(logp_w_list)
            logp_l_batch = torch.stack(logp_l_list)
            loss = ipo_loss(logp_w_batch, logp_l_batch, tau)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            
            # Clear intermediate tensors and cache
            del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
            
        clear_cuda_cache()

        avg_loss = total_loss / max(1, batch_idx + 1)
        pbar.set_postfix({"loss": avg_loss})
        wandb.log({"loss": avg_loss, "epoch": epoch + 1})

    print(f"[Epoch {epoch + 1}] Average Loss: {total_loss / len(dataloader):.4f}")
    clear_cuda_cache()

wandb.finish()

Epoch 1:   0%|          | 0/35037 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [15]:
# prompt: free the cuda gpu memory aggresivley, atomic way, nuclear it : "OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 76.88 MiB is free. Process 72946 has 39.47 GiB memory in use. Of the allocated memory 38.61 GiB is allocated by PyTorch, and 371.07 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
# "

# Function to clear CUDA cache
def clear_cuda_cache():
  if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache cleared.")

# Call this function after each batch or at the end of each epoch
# Example usage within your training loop:
# ... (inside the for batch loop) ...
#         optimizer.step()
#         clear_cuda_cache() # Call after optimizer step
# ... (outside the for batch loop) ...
# print(f"[Epoch {epoch + 1}] Average Loss: {total_loss / len(dataloader):.4f}")
# clear_cuda_cache() # Call at the end of the epoch

# You can also call it at the beginning of the loop or whenever you suspect memory issues.

# Alternatively, you can also try deleting tensors that are no longer needed.
# For example, inside the batch loop:
# del logp_w_list, logp_l_list, logp_w_batch, logp_l_batch, loss
# clear_cuda_cache()

# Add this function definition somewhere before the training loop.
