In [None]:
!pip install -i https://pypi.org/simple/ bitsandbytes --quiet
!pip install torch
!pip install transformers --quiet
!pip install datasets --quiet
!pip install accelerate --quiet
!pip install fairscale --quiet
!pip install peft --quiet
!pip install tqdm --quiet
!pip install matplotlib


In [None]:
from tqdm import tqdm
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM 
import pandas as pd
import random

cache_dir = None



In [6]:
import json

def load_jsonl_data(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data_point = json.loads(line)
            data.append(data_point)
    return data

file_path = 'phase1_train.jsonl'
data = load_jsonl_data(file_path)

# # Example: Print the first data point to check
# print(json.dumps(data[0], indent=4))

print(json.dumps(data[7], indent=4))

{
    "labeler": "d8aa7923-b970-45e1-9734-e4a7f6c4a7db",
    "timestamp": "2022-07-17T17:03:28.219211",
    "generation": null,
    "is_quality_control_question": false,
    "is_initial_screening_question": false,
    "question": {
        "problem": "Find the integer $n,$ $-90 < n < 90,$ such that $\\tan n^\\circ = \\tan 312^\\circ.$",
        "ground_truth_answer": "-48"
    },
    "label": {
        "steps": [
            {
                "completions": [
                    {
                        "text": "So we have that $\\tan n^\\circ = \\tan 312^\\circ$.",
                        "rating": 0,
                        "flagged": false
                    },
                    {
                        "text": "So I guess we want to find an integer n such that $\\tan n = \\tan 312$.",
                        "rating": 0,
                        "flagged": false
                    },
                    {
                        "text": "So we're looking for an integer $n$ tha

In [7]:
from itertools import product

def get_all_paths_with_ratings(datapoint, limit=50):
    question = datapoint['question']['problem']

    steps_completions = [step['completions'] for step in datapoint['label']['steps']]

    all_paths_combinations = list(product(*steps_completions))
    random.shuffle(all_paths_combinations)
    if len(all_paths_combinations) > limit:
        all_paths_combinations = all_paths_combinations[:limit]

    # print('len', len(all_paths_combinations))
    all_paths = []
    
    for path in all_paths_combinations:
        # Append ' [SIGNAL] \n ' to the end of each step's text, including the question for the first step
        path_texts = []
        path_texts.append(f"{question} \n {path[0]['text']} [SIGNAL] \n ")
        if path[0]['rating'] != -1:
            for completion in path[1:]:
                text = f"{completion['text']} [SIGNAL] \n "
                if completion['rating'] == -1:
                    path_texts.append(text)
                    break
                
                path_texts.append(text)
        concatenated_path_text = ''.join(path_texts)

        path_ratings = []
        for completion in path:
            rating = completion['rating']
            if rating == -1:
                path_ratings.append(rating)
                break
            path_ratings.append(rating)        
        path_dict = [concatenated_path_text, path_ratings]
        if path_dict not in all_paths:
            all_paths.append(path_dict)

    sampled_paths = all_paths if len(all_paths) <= limit else random.sample(all_paths, limit)

    return sampled_paths

example_datapoint = get_all_paths_with_ratings(data[6])


print(len(example_datapoint))
print(example_datapoint)


6
[['One day Max says to Liz, "Out of the 25 people taking either English or French, you and I are the only two taking both.\'\' Liz, being mathematically inclined, responds by pointing out that there are exactly twice as many people in the English class as there are in the French class. How many people are taking English but not French? \n I think we should start by letting the number of people taking French be a variable, say $f$. [SIGNAL] \n And then the number of people taking English is $2f$ because there are twice as many people in the English class as there are in the French class. [SIGNAL] \n That\'s right. Now we can use the fact that there are 25 people taking either English or French. [SIGNAL] \n So the total number of people taking either English or French is $f + 2f = 25$. [SIGNAL] \n ', [1, 1, 1, -1]], ['One day Max says to Liz, "Out of the 25 people taking either English or French, you and I are the only two taking both.\'\' Liz, being mathematically inclined, responds b

In [8]:

def is_datapoint_within_char_limit(datapoint, char_limit=3000):
    if datapoint['label']['finish_reason'] == 'give_up':
        return False

    total_chars = sum(len(completion['text']) for step in datapoint['label']['steps'] for completion in step['completions'])
    # print(total_chars)
    return total_chars < char_limit

paths = []

for example in data:
    if is_datapoint_within_char_limit(example):
        paths.extend(get_all_paths_with_ratings(example))


In [None]:
print(len(paths))

In [11]:

from torch.nn.utils.rnn import pad_sequence
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW

import torch.distributed as dist

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MathProblemDataset(Dataset):
    def __init__(self, data, tokenizer, max_token_len=1024, special_token="[SIGNAL]"):
        self.data = data
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len
        self.special_token = special_token

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        datapoint = self.data[idx]
        text = datapoint[0]
        labels = datapoint[1] 

        labels = [0 if x is None else x for x in labels]
        labels_tensor = torch.tensor(labels, dtype=torch.float)

        # Calculate how much padding is needed for labels to match max_token_len
        padding_length = 30 - labels_tensor.size(0)
    
        padded_labels = F.pad(labels_tensor, (0, padding_length), 'constant', float('nan'))
        text_prepared = text + " " 

        encoding = self.tokenizer(text_prepared, max_length=self.max_token_len, padding='max_length', truncation=True, return_tensors="pt")

        return {
          'input_ids': encoding['input_ids'].flatten(),
          'labels': padded_labels
        }


def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]

    input_ids = torch.stack(input_ids, dim=0)
    labels = pad_sequence(labels, batch_first=True, padding_value=-2.0)  # Use -1.0 or any appropriate padding value

    return {'input_ids': input_ids, 'labels': labels}


class PRM(nn.Module):
    def __init__(self, pre_trained_model = None, tokenizer=None):
        super(PRM, self).__init__()
        self.pre_trained_model = pre_trained_model
        self.tokenizer = tokenizer
        # self.linear = nn.Linear(50296, 1).cuda()
        self.sigmoid = nn.Sigmoid().cuda()
    
    def forward(self, data, attention_mask=None):
        outputs = self.pre_trained_model(data, output_hidden_states=True, attention_mask=attention_mask)
        logits = outputs.logits
        signal_token_id = self.tokenizer.convert_tokens_to_ids('[SIGNAL]')
        batch_size, seq_len, vocab_size = logits.shape

        # Getting both batch indices and position indices
        batch_indices, pos_indices = (data == signal_token_id).nonzero(as_tuple=True)
        
        # Initialize a container to hold the positions for each batch
        batched_positions = [[] for _ in range(batch_size)]
    
        # Populate the container with positions for each batch
        for batch_idx, pos_idx in zip(batch_indices, pos_indices):
            batched_positions[batch_idx.item()].append(pos_idx.item())
    
        # Convert lists to tensors
        batched_positions = [torch.tensor(pos_list, device=data.device) if pos_list else torch.tensor([], device=data.device) 
                             for pos_list in batched_positions]
        signal_logits = []
        pos_token, neu_token, neg_token = '[POS]', '[NEU]', '[NEG]'

        # Extract logits for signal positions in each batch
        for i in range(batch_size):
            # If there are no positions for this batch (empty tensor), continue to the next batch
            if len(batched_positions[i]) == 0:
                # Optionally handle empty positions case, e.g., with zeros or skip
                continue
            
            # Index into the logits tensor for the current batch and positions
            current_batch_logits = logits[i, batched_positions[i], :]
            current_batch_logits = current_batch_logits[:,[self.tokenizer.convert_tokens_to_ids(pos_token), self.tokenizer.convert_tokens_to_ids(neu_token), self.tokenizer.convert_tokens_to_ids(neg_token)]]
            # Append to the list
            signal_logits.append(current_batch_logits)

        # Assuming signal_logits_list is a list of tensors
        signal_logits = [self.sigmoid(logits) for logits in signal_logits]

        max_signals = max(logits.shape[0] for logits in signal_logits if logits.numel() > 0)

        # Step 2: Pad each tensor to have 'max_signals' rows
        padded_signal_logits = []
        for logits in signal_logits:
            pad_size = max_signals - logits.shape[0]
            if pad_size > 0:
                # Pad the tensor along the first dimension (number of signals)
                padded_tensor = F.pad(logits, (0, 0, 0, pad_size), "constant", value=-2)  # Pad bottom rows with zeros
            else:
                padded_tensor = logits
            padded_signal_logits.append(padded_tensor)
    
        # Step 3: Stack the padded tensors
        if padded_signal_logits:
            stacked_signal_logits = torch.stack(padded_signal_logits, dim=0)
        else:
            # Handle the case where all tensors might be empty or signal_logits is an empty list
            stacked_signal_logits = torch.zeros(0, max_signals, 3, device=data.device)  # Adjust the 3 if different features
        
        return stacked_signal_logits


In [12]:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP

def init_distributed_mode():
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(torch.distributed.get_rank())
    
def training_function():
    init_distributed_mode()

    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b", cache_dir=cache_dir)

    # Define the new special token
    special_token = "[SIGNAL]"
    special_tokens_dict = {'additional_special_tokens': [special_token]}
    pos_token, neu_token, neg_token = '[POS]', '[NEU]', '[NEG]'

    # Add the special token to the tokenizer
    tokenizer.add_special_tokens(special_tokens_dict)
    tokenizer.add_tokens([pos_token, neu_token, neg_token])
    
    # Load the model
    base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/llemma_7b", cache_dir=cache_dir)

    # It's important to resize the token embeddings in the model
    # This adjusts the model to account for the new token(s)
    base_model.resize_token_embeddings(len(tokenizer))
    
    # Assign the pad token if it's not already set
    tokenizer.pad_token = tokenizer.eos_token
        
    dataset = MathProblemDataset(paths, tokenizer)

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)  # Adjust batch_size as needed, only use shuffle=False when debugging

    prm_model = PRM(base_model, tokenizer)

    prm_model=prm_model.cuda()
    prm_model = FSDP(prm_model)

    optimizer = AdamW(prm_model.parameters(), lr=1e-6)
    criterion = nn.BCELoss()

    label_mappings = torch.tensor([
        [0, 1, 0],  # Index 0: maps to label 0
        [1, 0, 0],  # Index 1: maps to label 1
        [0, 0, 1],  # Index 2: maps to label -1
        [0, 0, 0]   # Index 3: maps to label -2 (padding)
    ], device=device)

    num_epochs = 2
    
    for epoch in range(num_epochs):  # Example: 1 epoch, adjust as necessary
        num_batches = 0  # Keep track of the number of batches
        data_loader_with_progress = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        # Assuming indices 0 and 1 are for normal labels, 2 for -1, and 3 for -2

        for batch in data_loader_with_progress:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            # print('labels', labels)

            # Mapping -2 and -1 to their new indices
            labels = torch.where(labels == -2, 3, labels)  # Maps -2 to index 3
            labels = torch.where(labels == -1, 2, labels)  # Maps -1 to index 2

            labels = labels.long() 

            labels_mapped = label_mappings[labels]
            mask = labels != 3  

            # Directly use model's forward pass
            predictions = prm_model(input_ids)
            valid_predictions = predictions[mask]
            valid_labels = labels_mapped[mask].float()

            # Calculate loss only on valid data
            loss = criterion(valid_predictions, valid_labels)

            loss.backward()
            optimizer.step()
            num_batches += 1
            data_loader_with_progress.set_description(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

    torch.distributed.barrier()

    # Only the main process saves the model
    if torch.distributed.get_rank() == 0:
        # Remove FSDP wrapper to get the underlying model
        full_state_dict = prm_model.state_dict()

        # Save the state dict
        checkpoint_path = 'model_checkpoint.pth'
        torch.save(full_state_dict, checkpoint_path)
        print(f"Model checkpoint saved at {checkpoint_path}")

    # Return the checkpoint path (optional)
    return checkpoint_path if torch.distributed.get_rank() == 0 else None

In [None]:
from accelerate import notebook_launcher

# set num_processes to the number of GPUs you have
checkpoint_path = notebook_launcher(training_function, num_processes=2)

Launching training on 2 GPUs.


  return self.fget.__get__(instance, owner)()
  return self.fget.__get__(instance, owner)()
