# 3. Diffusion Model Training 

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import random
import numpy as np
import json
import math

# Set the random seed for replicability
seed = 20777980
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

def determine_max_seq_len(data, max_length='max_length'):
    """Calculate the max sequence length dynamically if 'max_length' is used as an argument."""
    if max_length == 'max_length':
        MAX_LENGTH = max(len(dp["tokens"]) for dp in data)
    else:
        MAX_LENGTH = max_length
    return MAX_LENGTH

def setup_device():
    """Set up the device for training."""
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def save_model(model, filepath):
    """Save the model's state dictionary."""
    torch.save(model.state_dict(), filepath)
    return model

def load_model(model, filepath):
    """Load a saved model state dictionary."""
    model.load_state_dict(torch.load(filepath))
    model.eval()
    device = setup_device()
    model.to(device)
    return model, device

def load_dataset(file_path):
    """Load the dataset from a JSON file."""
    with open(file_path, 'r') as file:
        dataset = [json.loads(line) for line in file]
    return dataset

def save_JSON(data, filename):
    """Save data to a JSON file."""
    with open(filename, 'w') as f:
        json.dump(data, f)
    return

def load_JSON(filename):
    """Load a JSON file."""
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

class CosineNoiseSchedule:
    def __init__(self, timesteps=1000, epsilon=1e-6, device=None):
        self.timesteps = timesteps
        self.epsilon = epsilon
        self.device = device
        
        # Create alphas using a cosine schedule
        self.alphas = torch.cos(torch.linspace(0, math.pi / 2, timesteps, device=device)) ** 2
        self.betas = 1.0 - self.alphas
        self.alpha_bar = torch.maximum(torch.cumprod(self.alphas, dim=0), torch.tensor(self.epsilon, device=self.alphas.device))

    def get_alpha(self, t):
        return self.alphas[t]

    def get_beta(self, t):
        return self.betas[t]

    def get_variance(self, t):
        return self.get_beta(t) * (1 - self.get_alpha(t))

    def get_alpha_bar(self, t):
        return self.alpha_bar[t]

class SymbolicRegressionDataset(Dataset):
    def __init__(self, data, vocab, max_seq_len, noise_schedule):
        self.data = data
        self.vocab = vocab  # Add vocab here
        self.max_seq_len = max_seq_len
        self.noise_schedule = noise_schedule

    def get_input_embeddings(self, tokens):
        embeddings = torch.stack([torch.tensor(token) for token in tokens])
        padded_embeddings = nn.functional.pad(embeddings, (0, self.max_seq_len - embeddings.size(0)))
        padded_embeddings = padded_embeddings.transpose(0,1)
        return padded_embeddings

    def add_noise(self, token_ids, t, schedule):
        noisy_embeddings = token_ids.clone()
        noise_level = torch.sqrt(schedule.get_variance(t))  # Use the schedule's variance
        noise = torch.normal(mean=0.0, std=noise_level, size=noisy_embeddings.shape).to(noisy_embeddings.device)
        noisy_token_embeddings = noisy_embeddings + noise
        return torch.clamp(noisy_token_embeddings, min=0.0, max=1.0)

    def __getitem__(self, idx):
        data_point = self.data[idx]
        tokens = data_point['tokens']
        current_data = data_point['data']
        # Map symbols to embeddings using vocab
        x = torch.tensor(current_data['x'], dtype=torch.float32)
        y = torch.tensor(current_data['y'], dtype=torch.float32)
        mask = torch.tensor(current_data['mask'], dtype=torch.float32)
        token_embeddings = self.get_input_embeddings(tokens)
        
        t = random.randint(0, self.noise_schedule.timesteps - 1)
        noisy_token_embeddings = self.add_noise(token_embeddings, t, self.noise_schedule)
        
        noisy_x = self.add_noise(x, t, self.noise_schedule)  # Fix: using the dataset's add_noise method
        noisy_y = self.add_noise(y, t, self.noise_schedule)  # Fix: using the dataset's add_noise method
        
        return token_embeddings, noisy_token_embeddings, noisy_x, noisy_y, t, mask

    def __len__(self):
        return len(self.data)

class DiffusionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, num_timesteps, max_seq_len=5000, pretrained_embeddings=None):
        super(DiffusionModel, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_timesteps = num_timesteps

        if pretrained_embeddings is not None:
            pretrained_embeddings = torch.tensor(list(pretrained_embeddings.values()), dtype=torch.float32)
            if pretrained_embeddings.size(1) != embedding_dim:
                raise ValueError(
                    f"Pretrained embeddings size {pretrained_embeddings.size(1)} does not match the required embedding_dim {embedding_dim}."
                )
            self.embedding = nn.Parameter(pretrained_embeddings)
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.projection = nn.Linear(embedding_dim, hidden_dim) if embedding_dim != hidden_dim else nn.Identity()
        self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)

        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=False
        )
        
        self.fc_out = nn.Linear(hidden_dim, embedding_dim)

    def forward(self, embeddings):
        batch_size, hidden_dim, seq_len = embeddings.shape
        if self.embedding_dim != self.hidden_dim:
            embeddings = self.projection(embeddings)
        embeddings = embeddings.transpose(1, 2)
        embeddings = self.layer_norm(embeddings)
        embeddings = embeddings.transpose(0, 1)
        embeddings = self.transformer(embeddings, embeddings)
        embeddings = embeddings.transpose(0, 1)
        logits = self.fc_out(embeddings)
        logits = logits.transpose(1, 2)
        return logits

    def add_noise(self, token_embeddings, t, schedule):
        noise_level = torch.sqrt(schedule.get_variance(t))
        noise = torch.normal(mean=0, std=noise_level, size=token_embeddings.shape).to(token_embeddings.device)
        return token_embeddings + noise
    
    def reverse_diffusion(self, noisy_input, schedule):
        x_t = noisy_input
        device = x_t.device
        for t in reversed(range(self.num_timesteps)):
            predicted_noise = self.forward(x_t)
            alpha_t = schedule.get_alpha(t)
            beta_t = schedule.get_beta(t)
            mean_x_prev = (x_t - beta_t * predicted_noise) / torch.sqrt(alpha_t)
            if t > 0:
                std_dev = torch.sqrt(beta_t)
                noise = torch.randn_like(x_t, device=device) * std_dev
                x_t = mean_x_prev + noise
            else:
                x_t = mean_x_prev
            x_t = torch.clamp(x_t, min=-1.0, max=1.0)
        return x_t

def denoising_loss(predicted_embeddings, clean_embeddings):
    return nn.MSELoss()(predicted_embeddings, clean_embeddings)

def train_diffusion_model(model, train_loader, val_loader, num_epochs=10, patience_num_epochs=3):
    device = setup_device()
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=2,factor=0.5)
    schedule = CosineNoiseSchedule(timesteps=1000, device=device)

    best_val_loss = float('inf')
    num_epochs_without_improvement = 0
    early_stopping = False
    performance_metrics = {"epoch_list": [], "train_loss_list": [], "val_loss_list": []}
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for token_embeddings, noisy_token_embeddings, x, y, t, mask in train_loader:
            token_embeddings, noisy_token_embeddings, x, y, mask = token_embeddings.to(device), noisy_token_embeddings.to(device), x.to(device), y.to(device), mask.to(device)
            mask = mask.to(device) if mask is not None else None
            optimizer.zero_grad()
            predicted_embeddings = model(noisy_token_embeddings)
            loss = denoising_loss(predicted_embeddings, token_embeddings)
            loss.backward()
            optimizer.step()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
        
        train_loss = total_loss/len(train_loader)
        performance_metrics['train_loss_list'].append(train_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for token_embeddings, noisy_token_embeddings, x, y, t, mask in val_loader:
                token_embeddings, noisy_token_embeddings, x, y, mask = token_embeddings.to(device), noisy_token_embeddings.to(device), x.to(device), y.to(device), mask.to(device)
                mask = mask.to(device) if mask is not None else None

                # Forward pass
                predicted_embeddings = model(noisy_token_embeddings)

                # Compute loss
                loss = denoising_loss(predicted_embeddings, token_embeddings)
                val_loss += loss.item()
        
        val_loss = val_loss / len(val_loader)
        performance_metrics['val_loss_list'].append(val_loss)
        performance_metrics['epoch_list'].append(epoch + 1)

        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, "best_diffusion_model.pt")
            num_epochs_without_improvement = 0
        else:
            num_epochs_without_improvement += 1

        if num_epochs_without_improvement >= patience_num_epochs:
            print(f"Training stopped early at epoch {epoch + 1}. Best validation loss: {best_val_loss}")
            early_stopping = True
            break
        
    if early_stopping == False:
        save_model(model, "best_diffusion_model.pt")

    return model, performance_metrics

# Example of loading and preparing the dataset
dataset = load_dataset('Data/preprocessed_data_with_embeddings.json')
vocab = load_JSON("Data/vocab_embeddings.json")  # Vocabulary is a dictionary of continuous embeddings

MAX_LENGTH = determine_max_seq_len(dataset)  # Determine the max length dynamically

schedule = CosineNoiseSchedule(timesteps=1000,device=setup_device())

# First, perform the split on the raw dataset
train_size = int(0.7*len(dataset))  # 70% for training
val_size = int(0.15*len(dataset))  # 15% for validation
test_size = len(dataset) - train_size - val_size  # 15% for testing

# Perform random split
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_dataset_SR = SymbolicRegressionDataset(train_dataset, vocab, MAX_LENGTH, schedule)
val_dataset_SR = SymbolicRegressionDataset(val_dataset, vocab, MAX_LENGTH, schedule)
test_dataset_SR = SymbolicRegressionDataset(test_dataset, vocab, MAX_LENGTH, schedule)

# Create DataLoader objects for each subset
train_loader = DataLoader(train_dataset_SR, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset_SR, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset_SR, batch_size=16, shuffle=True)

# Initialize the model
num_heads = 4  # Number of attention heads, ensure this is a divisor of embedding_dim
embedding_dim = 100  # The embedding dimension is 100 as per your problem
hidden_dim = embedding_dim  # Hidden dimension stays the same for simplicity

# Ensure embedding_dim is divisible by num_heads
if embedding_dim % num_heads != 0:
    raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads}).")

model = DiffusionModel(
    vocab_size=len(vocab),
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    num_layers=4,
    num_heads=4,
    num_timesteps=1000,
    pretrained_embeddings=vocab
)

model,performance_metrics_DICT = train_diffusion_model(model,train_loader,val_loader,num_epochs=10,patience_num_epochs=3)

### Plotting Results

In [36]:
# import matplotlib.pyplot as plt

# #Visualize the train and validation loss
# def plot_train_valid(model_name,performance_metrics_DICT):
#     plt.figure();
#     plt.plot(performance_metrics_DICT['epoch_list'], performance_metrics_DICT['train_loss_list'], label=f'Train Loss', color='blue', linestyle='--', marker='o');
#     plt.plot(performance_metrics_DICT['epoch_list'], performance_metrics_DICT['val_loss_list'], label=f'Validation Loss', color='green', linestyle='-', marker='x');
#     plt.title(f'{model_name} Training and Validation Loss');
#     plt.xlabel('Epochs');
#     plt.ylabel('Loss');
#     plt.legend();
#     plt.grid();
#     plt.xlim(0,max(performance_metrics_DICT['epoch_list'])+1);
#     return

In [37]:
# model_name = 'Diffusion Model'

# plot_train_valid(model_name,performance_metrics_DICT)

### Prediction

In [38]:
# def decode_embeddings_to_tokens(embeddings, vocab):
#     vocab_embeddings = torch.stack([torch.tensor(embed) for embed in vocab.values()])
    
#     decoded_tokens = []
#     for embedding in embeddings:
#         # Compute the distance between the embedding and all vocab embeddings
#         distances = torch.norm(embedding - vocab_embeddings, dim=1)
#         closest_token_idx = torch.argmin(distances).item()
#         closest_token = list(vocab.keys())[closest_token_idx]
#         decoded_tokens.append(closest_token)
    
#     return decoded_tokens

# def evaluate_diffusion_model(model, test_loader, vocab, schedule, device):
#     model.eval()  # Set the model to evaluation mode
#     total_test_loss = 0.0
#     decoded_formulas = []
#     actual_formulas = []

#     with torch.no_grad():
#         for noisy_embeddings, target_embeddings in test_loader:
#             # Get the predicted denoised embeddings
#             t = random.randint(0, model.num_timesteps - 1)  # Random timestep for diffusion
#             pred_embeddings = model.reverse_diffusion(noisy_embeddings, schedule)

#             # Calculate the loss (MSE between predicted and target embeddings)
#             loss = denoising_loss(pred_embeddings, target_embeddings)
#             total_test_loss += loss.item()

#             # Now, we need to decode the denoised embeddings back to tokens
#             decoded_tokens = decode_embeddings_to_tokens(pred_embeddings, vocab)

#             # Convert the decoded tokens to a formula string
#             predicted_formula = " ".join(decoded_tokens)
#             decoded_formulas.append(predicted_formula)

#             # Assuming target embeddings have a corresponding ground truth formula (you can adjust this part)
#             actual_formula = decode_embeddings_to_tokens(target_embeddings, vocab)
#             actual_formulas.append(" ".join(actual_formula))

#     # Calculate average test loss
#     avg_test_loss = total_test_loss / len(test_loader)
#     return avg_test_loss, decoded_formulas, actual_formulas


In [39]:
# # Example of loading and preparing the dataset
# dataset = load_dataset('Data/preprocessed_data_with_embeddings.json')
# vocab = load_JSON("Data/vocab_embeddings.json")  # Vocabulary is a dictionary of continuous embeddings

# MAX_LENGTH = determine_max_seq_len(dataset)  # Determine the max length dynamically

# schedule = CosineNoiseSchedule(timesteps=1000,device=setup_device())

# # First, perform the split on the raw dataset
# train_size = int(0.7 * len(dataset))  # 70% for training
# val_size = int(0.15 * len(dataset))  # 15% for validation
# test_size = len(dataset) - train_size - val_size  # 15% for testing

# # Perform random split
# train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# # Initialize the SymbolicRegressionDataset with the schedule for each subset
# train_dataset_SR = SymbolicRegressionDataset(train_dataset, vocab, MAX_LENGTH, schedule)
# val_dataset_SR = SymbolicRegressionDataset(val_dataset, vocab, MAX_LENGTH, schedule)
# test_dataset_SR = SymbolicRegressionDataset(test_dataset, vocab, MAX_LENGTH, schedule)

# # Create DataLoader objects for each subset
# train_loader = DataLoader(train_dataset_SR, batch_size=16, shuffle=True)
# val_loader = DataLoader(val_dataset_SR, batch_size=16, shuffle=True)
# test_loader = DataLoader(test_dataset_SR, batch_size=16, shuffle=True)

# # Initialize the model
# num_heads = 4  # Number of attention heads, ensure this is a divisor of embedding_dim
# embedding_dim = 100  # The embedding dimension is 100 as per your problem
# hidden_dim = embedding_dim  # Hidden dimension stays the same for simplicity

# # Ensure embedding_dim is divisible by num_heads
# if embedding_dim % num_heads != 0:
#     raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads}).")

# model = DiffusionModel(
#     vocab_size=len(vocab),
#     embedding_dim=embedding_dim,
#     hidden_dim=hidden_dim,
#     num_layers=6,
#     num_heads=num_heads,
#     num_timesteps=1000,
#     pretrained_embeddings=vocab
# )

# model, device = load_model(model, 'Data/best_diffusion_model.pt')
# model = model.to(device)

In [40]:
# # Example: Evaluate the model on the test set
# test_loss, decoded_formulas, actual_formulas = evaluate_diffusion_model(model, test_loader, vocab, schedule, device)

# # Print out the average test loss
# print(f"Test Loss: {test_loss}")

# # Print out the first few decoded formulas and their corresponding actual formulas
# for predicted, actual in zip(decoded_formulas[:5], actual_formulas[:5]):
#     print(f"Predicted Formula: {predicted}")
#     print(f"Actual Formula: {actual}")
#     print("-" * 50)