https://chatgpt.com/share/67f9255c-a1ec-8012-be33-0f10696efe98

In [None]:
# Revised training script
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import json
import copy
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import sys

project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from scripts.function_composition import HierarchicalCompositionalModel

# Revised load_dataset function that loads intermediate steps and operations.
def load_dataset(path):
    with open(path, 'r') as f:
        data = json.load(f)
    inputs = [sample['input'] for sample in data]
    final_outputs = [sample['output'] for sample in data]
    # Use intermediate_steps excluding the initial input (i.e. the output after each op)
    intermediate_steps = [sample['intermediate_steps'][1:] for sample in data]
    operations = [sample['operations'] for sample in data]
    op_lengths = [len(sample['operations']) for sample in data]
    max_len_in = max(len(x) for x in inputs)
    padded_inputs = [x + [0]*(max_len_in - len(x)) for x in inputs]
    max_len_out = max(len(x) for x in final_outputs)
    padded_final_outputs = [x + [0]*(max_len_out - len(x)) for x in final_outputs]
    # Pad intermediate steps: each sample is a list of lists (each inner list has length equal to input size)
    max_steps = max(len(steps) for steps in intermediate_steps)
    padded_intermediate = []
    for steps in intermediate_steps:
        step_len = len(steps[0])
        if len(steps) < max_steps:
            steps = steps + [[0]*step_len]*(max_steps - len(steps))
        padded_intermediate.append(steps)
    # Convert operations to indices
    op_to_index = {'sort': 0, 'reverse': 1, 'add': 2, 'subtract': 3, 'multiply': 4, 'divide': 5}
    max_ops = max(op_lengths)
    padded_ops = []
    for op_list in operations:
        indices = [op_to_index[op] for op in op_list]
        if len(indices) < max_ops:
            indices = indices + [-1]*(max_ops - len(indices))
        padded_ops.append(indices)
    inputs_tensor = torch.tensor(padded_inputs, dtype=torch.float32)
    final_outputs_tensor = torch.tensor(padded_final_outputs, dtype=torch.float32)
    intermediate_tensor = torch.tensor(padded_intermediate, dtype=torch.float32)  # [N, max_steps, input_size]
    operations_tensor = torch.tensor(padded_ops, dtype=torch.long)  # [N, max_ops]
    op_lengths_tensor = torch.tensor(op_lengths, dtype=torch.long)
    return inputs_tensor, final_outputs_tensor, intermediate_tensor, operations_tensor, op_lengths_tensor

train_x, train_y, train_intermediates, train_ops, train_op_lengths = load_dataset('../datasets/function_composition/train_dataset.json')
val_x, val_y, val_intermediates, val_ops, val_op_lengths = load_dataset('../datasets/function_composition/validation_dataset.json')
test_x, test_y, test_intermediates, test_ops, test_op_lengths = load_dataset('../datasets/function_composition/test_dataset.json')

train_loader = DataLoader(TensorDataset(train_x, train_y, train_intermediates, train_ops, train_op_lengths), batch_size=8, shuffle=True)
val_loader = DataLoader(TensorDataset(val_x, val_y, val_intermediates, val_ops, val_op_lengths), batch_size=8)
test_loader = DataLoader(TensorDataset(test_x, test_y, test_intermediates, test_ops, test_op_lengths), batch_size=8)

input_size = train_x.size(1)
hidden_size = 64
# Use the maximum number of intermediate steps as the sequence length.
max_sequence_length = train_intermediates.size(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HierarchicalCompositionalModel(input_size, hidden_size, max_sequence_length).to(device)

# Loss functions: MSE for intermediate outputs and CrossEntropy for operations.
mse_loss_fn = nn.MSELoss(reduction='none')
ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

optimizer = optim.Adam(model.parameters(), lr=0.001)

best_model = None
best_val_loss = float('inf')
num_epochs = 50
train_losses = []
val_losses = []
pbar = tqdm(range(num_epochs), desc="Epoch 0")
for epoch in pbar:
    model.train()
    total_loss = 0
    for inputs, final_targets, intermediate_targets, operations_targets, op_lengths in train_loader:
        inputs = inputs.to(device)
        final_targets = final_targets.to(device)
        intermediate_targets = intermediate_targets.to(device)
        operations_targets = operations_targets.to(device)
        op_lengths = op_lengths.to(device)
        optimizer.zero_grad()
        outputs, module_logits = model(inputs)  # outputs: [B, seq_len, input_size]
                                               # module_logits: [B, seq_len, num_modules]
        B, seq_len, _ = outputs.size()
        # Create mask for valid intermediate steps (for each sample, valid steps: [0, op_length))
        step_range = torch.arange(seq_len, device=device).unsqueeze(0)  # [1, seq_len]
        mask = (step_range < op_lengths.unsqueeze(1)).float()  # [B, seq_len]
        # Intermediate loss: compute MSE between predicted intermediate outputs and ground truth
        diff = mse_loss_fn(outputs, intermediate_targets)  # [B, seq_len, input_size]
        diff = diff.mean(dim=2)  # [B, seq_len]
        intermediate_loss = (diff * mask).sum() / mask.sum()
        # Operations loss: supervise the controller's module predictions
        # Use only the valid steps (number of operations = max_ops for each sample)
        logits_valid = module_logits[:, :operations_targets.size(1), :]  # [B, max_ops, num_modules]
        operations_loss = ce_loss_fn(logits_valid.reshape(-1, logits_valid.size(-1)),
                                       operations_targets.reshape(-1))
        loss = intermediate_loss + operations_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for inputs, final_targets, intermediate_targets, operations_targets, op_lengths in val_loader:
            inputs = inputs.to(device)
            final_targets = final_targets.to(device)
            intermediate_targets = intermediate_targets.to(device)
            operations_targets = operations_targets.to(device)
            op_lengths = op_lengths.to(device)
            outputs, module_logits = model(inputs)
            B, seq_len, _ = outputs.size()
            step_range = torch.arange(seq_len, device=device).unsqueeze(0)
            mask = (step_range < op_lengths.unsqueeze(1)).float()
            diff = mse_loss_fn(outputs, intermediate_targets)
            diff = diff.mean(dim=2)
            intermediate_loss = (diff * mask).sum() / mask.sum()
            logits_valid = module_logits[:, :operations_targets.size(1), :]
            operations_loss = ce_loss_fn(logits_valid.reshape(-1, logits_valid.size(-1)),
                                           operations_targets.reshape(-1))
            total_val_loss += (intermediate_loss + operations_loss).item()
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    pbar.set_description(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model = copy.deepcopy(model.state_dict())
model.load_state_dict(best_model)
model.eval()
total_test_loss = 0
with torch.no_grad():
    for inputs, final_targets, intermediate_targets, operations_targets, op_lengths in test_loader:
        inputs = inputs.to(device)
        final_targets = final_targets.to(device)
        intermediate_targets = intermediate_targets.to(device)
        operations_targets = operations_targets.to(device)
        op_lengths = op_lengths.to(device)
        outputs, module_logits = model(inputs)
        B, seq_len, _ = outputs.size()
        step_range = torch.arange(seq_len, device=device).unsqueeze(0)
        mask = (step_range < op_lengths.unsqueeze(1)).float()
        diff = mse_loss_fn(outputs, intermediate_targets)
        diff = diff.mean(dim=2)
        intermediate_loss = (diff * mask).sum() / mask.sum()
        logits_valid = module_logits[:, :operations_targets.size(1), :]
        operations_loss = ce_loss_fn(logits_valid.reshape(-1, logits_valid.size(-1)),
                                       operations_targets.reshape(-1))
        total_test_loss += (intermediate_loss + operations_loss).item()
    avg_test_loss = total_test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

os.makedirs('./plots', exist_ok=True)
plt.figure()
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('./plots/fc_model.png')
plt.show()
os.makedirs('./models', exist_ok=True)
torch.save(best_model, './models/best_fc_model.pth')
print("Best model saved.")
