<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/LLM_REASONING_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers -q
!pip install datasets -q
!pip install torch -q
!pip install gymnasium -q

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.optim import AdamW
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import numpy as np

In [None]:
# --- 1. Pure Supervised Fine-tuning (Creating an initial reasoning model) ---
#print("\n--- 1. Pure Supervised Fine-tuning ---")
sft_model_name = "gpt2"
sft_tokenizer = AutoTokenizer.from_pretrained(sft_model_name)
sft_model = AutoModelForCausalLM.from_pretrained(sft_model_name)

# In the "Pure Supervised Fine-tuning" section:
sft_tokenizer = AutoTokenizer.from_pretrained(sft_model_name)
if sft_tokenizer.pad_token is None:
    sft_tokenizer.pad_token = sft_tokenizer.eos_token

class ReasoningDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        encoding = self.tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.max_length)
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze()
        }

sft_train_data = [
    {"text": "Question: What is the capital of Canada? Answer: Ottawa"},
    {"text": "Question: What is the largest city in Canada? Answer: Toronto"},
    {"text": "Question: What is the official language of Quebec? Answer: French"}
]


sft_train_dataset = ReasoningDataset(sft_train_data, sft_tokenizer)
sft_train_loader = torch.utils.data.DataLoader(sft_train_dataset, batch_size=2)

sft_optimizer = AdamW(sft_model.parameters(), lr=5e-5)
sft_epochs = 5

In [4]:
print("\n--- 1. Pure Supervised Fine-tuning ---")
print('\n')
print("Creating the SFT model...")

# Training loop
for epoch in range(sft_epochs):
    sft_model.train()
    for batch in sft_train_loader:
        input_ids = batch["input_ids"].to(sft_model.device)
        attention_mask = batch["attention_mask"].to(sft_model.device)
        labels = batch["labels"].to(sft_model.device)
        sft_optimizer.zero_grad()
        outputs = sft_model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        sft_optimizer.step()
    print(f"SFT Epoch {epoch+1} Loss: {loss.item()}")

print("\n")
print("Saving the SFT model...")
sft_model.save_pretrained("./sft_reasoning_model")
sft_tokenizer.save_pretrained("./sft_reasoning_model")


--- 1. Pure Supervised Fine-tuning ---


Creating the SFT model...


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


SFT Epoch 1 Loss: 6.823558807373047
SFT Epoch 2 Loss: 1.2183752059936523
SFT Epoch 3 Loss: 0.7251948118209839
SFT Epoch 4 Loss: 0.7156901955604553
SFT Epoch 5 Loss: 0.7075764536857605


Saving the SFT model...


('./sft_reasoning_model/tokenizer_config.json',
 './sft_reasoning_model/special_tokens_map.json',
 './sft_reasoning_model/vocab.json',
 './sft_reasoning_model/merges.txt',
 './sft_reasoning_model/added_tokens.json',
 './sft_reasoning_model/tokenizer.json')

In [6]:
# --- 2. Reinforcement Learning Fine-tuning (Using the SFT model as initialization) ---
print("\n--- 2. Reinforcement Learning Fine-tuning ---")

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import torch.nn.functional as F

from torch.optim.lr_scheduler import StepLR
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split

class FlightWaypointDataset(Dataset):
    def __init__(self, input_texts, num_waypoints, tokenizer):
        self.input_texts = input_texts
        self.num_waypoints = num_waypoints
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        num_waypoint = self.num_waypoints[idx]

        encoding = self.tokenizer(input_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": torch.tensor(num_waypoint, dtype=torch.float32)
        }

!rm -rf *.pth
loss_fn = torch.nn.MSELoss()
rl_epochs = 300
batch_size = 32
lr=1e-5

#sft_model_name = "gpt2"
#sft_tokenizer = AutoTokenizer.from_pretrained(sft_model_name)
#if sft_tokenizer.pad_token is None:
#    sft_tokenizer.pad_token = sft_tokenizer.eos_token


# Early stopping initialization
best_val_loss = np.inf
patience = 3
epochs_without_improvement = 0

# 1. Load and Prepare the Dataset
dataset = load_dataset("frankmorales2020/flight_plan_waypoints")

# Split the dataset into train and validation sets
train_size = int(0.8 * len(dataset["train"]))
val_size = len(dataset["train"]) - train_size
train_dataset, val_dataset = random_split(dataset["train"], [train_size, val_size])

# Create data loaders
# Access the underlying data using the original dataset and the indices from the subset
train_dataset_processed = FlightWaypointDataset(
    [dataset["train"][i]["input"] for i in train_dataset.indices],
    [dataset["train"][i]["label"] for i in train_dataset.indices],
    sft_tokenizer
)
train_loader = DataLoader(train_dataset_processed, batch_size=batch_size, shuffle=True)

val_dataset_processed = FlightWaypointDataset(
    [dataset["train"][i]["input"] for i in val_dataset.indices],
    [dataset["train"][i]["label"] for i in val_dataset.indices],
    sft_tokenizer
)
val_loader = DataLoader(val_dataset_processed, batch_size=batch_size, shuffle=False)  # No need to shuffle validation data

print(f"Train dataset size: {len(train_dataset_processed)}")
print(f"Validation dataset size: {len(val_dataset_processed)}")


# 2. Load the Pre-trained Model (Corrected)
sft_model = AutoModelForCausalLM.from_pretrained("./sft_reasoning_model")  # Load from sft_reasoning_model directory
sft_tokenizer = AutoTokenizer.from_pretrained("./sft_reasoning_model")
if sft_tokenizer.pad_token is None:
    sft_tokenizer.pad_token = sft_tokenizer.eos_token

# 3. Define the RLAgent (Waypoint Predictor) - CORRECTED
class RLAgent(torch.nn.Module):
    def __init__(self, sft_model, sft_tokenizer):
        super().__init__()
        self.sft_model = sft_model
        self.sft_tokenizer = sft_tokenizer
        self.waypoint_predictor = torch.nn.Linear(sft_model.config.hidden_size, 1)

        # Unfreeze language model parameters (important for fine-tuning)
        for param in self.sft_model.parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask):
        #print('\n')
        #print("Input IDs shape:", input_ids.shape)
        #print("Attention Mask shape:", attention_mask.shape)
        #print('\n')

        outputs = self.sft_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # Averaging the last 4 hidden states (or experiment with other numbers)
        hidden_states = torch.stack(outputs.hidden_states[-4:]).mean(0)[:, 0, :]
        #hidden_states.retain_grad()

        # Debugging: Print hidden states shape and grad
        #print("Hidden states shape:", hidden_states.shape)
        #print("Hidden states grad:", hidden_states.grad)

        predicted_numberofwaypoints = self.waypoint_predictor(hidden_states)
        #predicted_numberofwaypoints.retain_grad()  # Add this line


        # Debugging: Print prediction shape and grad
        #print('\n')
        #print("Predicted waypoints shape:", predicted_numberofwaypoints.shape)
        #print("Predicted waypoints grad:", predicted_numberofwaypoints.grad)


        return predicted_numberofwaypoints


    def predict_numberofwaypoints(self, input_text):
        encoding = self.sft_tokenizer(input_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.sft_model.device)
        attention_mask = encoding["attention_mask"].to(self.sft_model.device)
        with torch.no_grad():
            predicted_numberofwaypoints = self.forward(input_ids, attention_mask)
        return predicted_numberofwaypoints.item()

# 4. Set up Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rl_model = RLAgent(sft_model, sft_tokenizer).to(device)
rl_optimizer = AdamW(rl_model.parameters(), lr=lr)

# Ensure waypoint_predictor parameters are trainable
for name, param in rl_model.named_parameters():
    #print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")
    if "waypoint_predictor" in name:
        param.requires_grad = True

scheduler = StepLR(rl_optimizer, step_size=2, gamma=0.1)  # Reduce LR by 0.1 every 2 epochs

print('\n')
print(f"Total Epoch: {rl_epochs}")
print('\n')


max_steps = (len(train_dataset_processed) // int(batch_size)) * int(rl_epochs)
# Correct for the dropped batch if it's incomplete

#print(f"Total Steps before correction: {max_steps}")
#print('\n')

# Correct for the dropped batch if it's incomplete
if len(train_dataset_processed) % batch_size != 0:
    max_steps -= (rl_epochs) - ((len(train_dataset_processed) % batch_size) / batch_size) * rl_epochs


print(f"Total Steps: {max_steps}")


def custom_loss(predicted_numberofwaypoints, num_waypoints_target, reward_weight=0.1):
    original_loss = loss_fn(predicted_numberofwaypoints.squeeze(), num_waypoints_target)
    reward_loss = -reward_weight * rewards.mean()  # Negative to maximize reward
    total_loss = original_loss + reward_loss
    return total_loss



# 5. Training Loop
for epoch in range(rl_epochs):
    rl_model.train()
    total_loss = 0
    total_reward = 0  # Track total reward
    scheduler.step()  # Update learning rate

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        num_waypoints_target = batch["label"].to(device)

        predicted_numberofwaypoints = rl_model(input_ids, attention_mask)


        # Calculate Reward Score - Example using inverse difference (Corrected)
        rewards = 1 / (1 + torch.abs(predicted_numberofwaypoints.squeeze() - num_waypoints_target))
        reward = rewards.mean().item()  # Calculate average reward


        # Loss function can still be used for parameter updates
        loss = loss_fn(predicted_numberofwaypoints.squeeze(), num_waypoints_target)

        #loss = custom_loss(predicted_numberofwaypoints, num_waypoints_target, reward_weight=0.1)  # Adjust reward_weight as needed

        # Debugging: Print loss and gradients

        #print("Loss:", loss.item())
        #print("Loss grad:", loss.grad)
        #print('\n')


        #for name, param in rl_model.named_parameters():
        #    if param.grad is not None:
        #        print(f"Parameter: {name}, Grad Norm: {param.grad.norm()}")
        #    else:
        #        print(f"Parameter: {name}, Grad: None")

        # Alternatively, you could use a policy gradient approach
        # and update parameters based on the reward directly

        rl_optimizer.zero_grad()
        loss.backward()
        rl_optimizer.step()

        total_loss += loss.item()
        total_reward += reward  # Accumulate reward

    avg_loss = total_loss / len(train_loader)
    avg_reward = total_reward / len(train_loader)  # Calculate average reward

    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}, Average Reward: {avg_reward:.4f}")


    # Validation loop (after each epoch)
    rl_model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            num_waypoints_target = batch["label"].to(device)

            predicted_numberofwaypoints = rl_model(input_ids, attention_mask)
            loss = loss_fn(predicted_numberofwaypoints.squeeze(), num_waypoints_target)  # Use original loss for validation
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0
        # Save the best model (optional)
        torch.save(rl_model.state_dict(), "fine_tuned_rl_best_model.pth")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping triggered!")
            break  # Exit the training loop


# 6. Save the Fine-tuned Model
torch.save(rl_model.state_dict(), "./fine_tuned_rl_agent.pth")
print("\nFine-tuned RL agent model saved to ./fine_tuned_rl_agent.pth")

#rl_model.waypoint_predictor.weight.requires_grad = True
#rl_model.waypoint_predictor.bias.requires_grad = True


--- 2. Reinforcement Learning Fine-tuning ---
Train dataset size: 1600
Validation dataset size: 400


Total Epoch: 300


Total Steps: 15000
Epoch 1, Average Loss: 11.2898, Average Reward: 0.3986
Validation Loss: 3.9722
Epoch 2, Average Loss: 4.6988, Average Reward: 0.4646
Validation Loss: 3.9719
Epoch 3, Average Loss: 4.3585, Average Reward: 0.4599
Validation Loss: 3.8952
Epoch 4, Average Loss: 4.6743, Average Reward: 0.4472
Validation Loss: 3.9076
Epoch 5, Average Loss: 4.6194, Average Reward: 0.4565
Validation Loss: 3.9230
Epoch 6, Average Loss: 4.7148, Average Reward: 0.4531
Validation Loss: 3.9234
Early stopping triggered!

Fine-tuned RL agent model saved to ./fine_tuned_rl_agent.pth


In [25]:
# Recreate the model architecture (assuming RLAgent class is defined)
sft_model = AutoModelForCausalLM.from_pretrained("./sft_reasoning_model")
sft_tokenizer = AutoTokenizer.from_pretrained("./sft_reasoning_model")
rl_model = RLAgent(sft_model, sft_tokenizer)

# Load the state dictionary
state_dict = torch.load("./fine_tuned_rl_best_model.pth")

# Load weights into the model
rl_model.load_state_dict(state_dict)


# Define your flight plan text:
data=val_dataset[359]
data['input']
data['label']
flight_plan_text = data['input']
actual_numberofwaypoint=data['label']

# Call predict_numberofwaypoints to get the prediction:
predicted_waypoints = rl_model.predict_numberofwaypoints(flight_plan_text)

# Print the prediction:
print('\n')
print(f"Flight Plan Text: {flight_plan_text}")
print(f"Predicted number of waypoints: {round(predicted_waypoints)}")
print(f"Actual number of waypoints: {actual_numberofwaypoint}")
#delta=round(predicted_waypoints)-actual_numberofwaypoint
#print(f"Delta waypoints number: {delta}")
print('\n')



Flight Plan Text: Calculate the waypoints from PVG to SJU. Departure: 2024-04-17, Aircraft: Boeing 777, Weather: Rainy
Predicted number of waypoints: 6
Actual number of waypoints: 6




In [8]:
class RLAgent(torch.nn.Module):
    def __init__(self, sft_model, sft_tokenizer):
        super().__init__()
        self.sft_model = sft_model
        self.sft_tokenizer = sft_tokenizer
        self.waypoint_predictor = torch.nn.Linear(sft_model.config.hidden_size, 1)

        # Unfreeze language model parameters (important for fine-tuning)
        for param in self.sft_model.parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask):

        outputs = self.sft_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # Averaging the last 4 hidden states (or experiment with other numbers)
        hidden_states = torch.stack(outputs.hidden_states[-4:]).mean(0)[:, 0, :]

        predicted_numberofwaypoints = self.waypoint_predictor(hidden_states)

        return predicted_numberofwaypoints


    def predict_numberofwaypoints(self, input_text):
        encoding = self.sft_tokenizer(input_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.sft_model.device)
        attention_mask = encoding["attention_mask"].to(self.sft_model.device)
        with torch.no_grad():
            predicted_numberofwaypoints = self.forward(input_ids, attention_mask)
        return predicted_numberofwaypoints.item()

In [15]:
#print("\n--- 3. Distillation ---")

#from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import torch.nn.functional as F

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Load Teacher Model
sft_model = AutoModelForCausalLM.from_pretrained("./sft_reasoning_model")
sft_tokenizer = AutoTokenizer.from_pretrained("./sft_reasoning_model")
teacher_model = RLAgent(sft_model, sft_tokenizer).to(device) # assuming 'device' is defined as before
teacher_tokenizer = AutoTokenizer.from_pretrained("./sft_reasoning_model")

# Load teacher model weights
state_dict = torch.load("./fine_tuned_rl_agent.pth", map_location=device)  # Load on the same device
teacher_model.load_state_dict(state_dict)
teacher_model.eval()  # Set to evaluation mode

# 2. Define Student Model
#student_model_name = "bert-base-uncased"
#student_model_name = "distilbert-base-uncased"  # Smaller than bert-base-uncased
#student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
#student_model = torch.nn.Linear(student_tokenizer.vocab_size, teacher_tokenizer.vocab_size).to(device)
#student_model.train()

# 2. Define Student Model
# Use a smaller, pre-trained model as the student model
student_model_name = "distilbert-base-uncased"  # Smaller than bert-base-uncased
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
student_model = AutoModel.from_pretrained(student_model_name).to(device)
student_model.train()


# Add a linear layer on top for prediction
hidden_size = student_model.config.hidden_size
output_size = 1  # For waypoint prediction
prediction_layer = torch.nn.Linear(hidden_size, output_size).to(device)
prediction_layer.train()


# 3. Distillation Data
distill_data = [
    'Input: Calculate the waypoints from MBJ to AMS. Departure: 2024-04-19, Aircraft: Boeing 747, Weather: Stormy',
    'Input: Calculate the waypoints from MDW to SJU. Departure: 2024-12-22, Aircraft: Airbus A320, Weather: Snowy',
    'Input: Calculate the waypoints from DTW to SEA. Departure: 2024-09-11, Aircraft: Airbus A320, Weather: Partly Cloudy'
]

# 4. Distillation Loop
#distill_optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)

distill_optimizer = torch.optim.AdamW(list(student_model.parameters()) + list(prediction_layer.parameters()), lr=5e-5)
distill_epochs = 10

In [16]:
print("\n--- 3. Distillation ---")
for epoch in range(distill_epochs):
    total_loss = 0
    for text in distill_data:
        # Teacher prediction
        predicted_waypoints = teacher_model.predict_numberofwaypoints(text)

        # Target sequence for distillation (not used for MSE loss)
        # teacher_output_text = f"Predicted Waypoints: {predicted_waypoints}"
        # teacher_target = teacher_tokenizer(teacher_output_text, return_tensors="pt").input_ids.squeeze().to(device)

        # Student prediction - Update to use the pre-trained student model and prediction layer
        student_input = student_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
        student_outputs = student_model(**student_input)
        student_logits = prediction_layer(student_outputs.last_hidden_state[:, 0, :])

        # Reshape student_logits to match the target shape (1D tensor with 1 element)
        student_logits = student_logits.view(-1)

        # Calculate and apply loss (use MSE loss for waypoint prediction)
        loss = F.mse_loss(student_logits, torch.tensor([predicted_waypoints], device=device))
        distill_optimizer.zero_grad()
        loss.backward()
        distill_optimizer.step()
        total_loss += loss.item()

    print(f"Distillation Epoch {epoch+1} Loss: {total_loss / len(distill_data)}")


--- 3. Distillation ---
Distillation Epoch 1 Loss: 20.516501744588215
Distillation Epoch 2 Loss: 4.883150577545166
Distillation Epoch 3 Loss: 0.5096774883568287
Distillation Epoch 4 Loss: 0.1139904862890641
Distillation Epoch 5 Loss: 0.10458541909853618
Distillation Epoch 6 Loss: 0.03858725090685766
Distillation Epoch 7 Loss: 0.03599905284045235
Distillation Epoch 8 Loss: 0.013790201162919402
Distillation Epoch 9 Loss: 0.04426688700914383
Distillation Epoch 10 Loss: 0.021623721268648904


In [20]:
# --- 4. Inference-time Scaling (Using the distilled model in a simple agent) ---
print("\n--- 4. Inference-time Scaling with Distilled Model ---")

# Assuming student_model, student_tokenizer, prediction_layer are defined from the Distillation section

class SimpleFlightAgent:
    def __init__(self, reasoner, prediction_layer, tokenizer, knowledge_base):
        self.reasoner = reasoner  # The student model (DistilBERT in this case)
        self.prediction_layer = prediction_layer  # The linear layer for prediction
        self.tokenizer = tokenizer  # The student tokenizer
        self.knowledge = knowledge_base

    def plan_flight(self, flight_details):
        """
        Generates a flight plan (number of waypoints) using the distilled model.

        Args:
            flight_details (str): A string containing flight details
                                  in the format:
                                  "Input: Calculate the waypoints from [origin] to [destination].
                                  Departure: [date], Aircraft: [aircraft_type], Weather: [weather]"

        Returns:
            str: The predicted number of waypoints.
        """
        # Tokenize the flight details prompt
        student_input = self.tokenizer(flight_details, return_tensors="pt", padding=True, truncation=True, max_length=128).to(self.reasoner.device)

        # Get the prediction from the student model and prediction layer
        with torch.no_grad():  # No need to calculate gradients during inference
            student_outputs = self.reasoner(**student_input)
            predicted_waypoints = self.prediction_layer(student_outputs.last_hidden_state[:, 0, :])
            predicted_waypoints = round(predicted_waypoints.item())

        return f"Predicted number of waypoints: {predicted_waypoints}"


        #return f"Predicted number of waypoints: {predicted_waypoints.item()}"

# Create an instance of the agent
agent_knowledge = {"flight_info": "Direct flights are sometimes available between major cities."}
flight_agent = SimpleFlightAgent(student_model, prediction_layer, student_tokenizer, agent_knowledge)  # Pass the student model and prediction layer


# Example usage
data=val_dataset[359]
data['input']
data['label']
flight_plan_text = data['input']
actual_numberofwaypoint=data['label']

# Define your flight plan text:
flight_plan_text = data['input']
actual_numberofwaypoint=data['label']
flight_details = f"Input: {flight_plan_text}"
flight_plan = flight_agent.plan_flight(flight_details)

print('\n')
print(f"flight_details: {flight_details}")
print(f"Number of waypoints: {flight_plan}")

print("\n--- End of Full Sequence ---")


--- 4. Inference-time Scaling with Distilled Model ---


flight_details: Input: Calculate the waypoints from PVG to SJU. Departure: 2024-04-17, Aircraft: Boeing 777, Weather: Rainy
Number of waypoints: Predicted number of waypoints: 6

--- End of Full Sequence ---
