# 3. Diffusion Model Training 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import random
import numpy as np
import json
import math

# Define a dummy PAD token (you can set this to a proper value based on your vocab)
PAD_TOKEN = "<PAD>"

# Set the random seed (for replicability)
seed = 20777980
random.seed(seed)
np.random.seed(seed)

def determine_max_seq_len(data, max_length='max_length'):
    """Calculate the max sequence length dynamically if 'max_length' is used as argument."""
    if max_length == 'max_length':
        MAX_LENGTH = max(len(dp["tokens"]) for dp in data)
    else:
        MAX_LENGTH = max_length
    return MAX_LENGTH

def load_dataset(file_path):
    """Load the dataset from a JSON file."""
    with open(file_path, 'r') as file:
        dataset = [json.loads(line) for line in file]
    return dataset

def load_JSON(filename):
    """Load a JSON file."""
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

# Define a Noise Schedule using a Cosine schedule (commonly used in DDPMs)
class CosineNoiseSchedule:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        self.alphas = torch.cos(torch.linspace(0, math.pi / 2, timesteps))**2
        self.betas = 1.0 - self.alphas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)  # Cumulative product of alphas

    def get_alpha(self, t):
        return self.alphas[t]

    def get_beta(self, t):
        return self.betas[t]

    def get_variance(self, t):
        return self.get_beta(t) * (1 - self.get_alpha(t))
    
    def get_alpha_bar(self, t):
        return self.alpha_bar[t]

# Define the Diffusion Model (symbolic regression version)
class DiffusionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=4, num_timesteps=1000):
        super(DiffusionModel, self).__init__()
        self.num_timesteps = num_timesteps
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer = nn.Transformer(d_model=embedding_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers)
        self.fc_out = nn.Linear(embedding_dim, vocab_size)  # For token prediction (symbolic generation)
    
    def forward(self, tokens):
        embedded_tokens = self.embedding(tokens)
        embedded_tokens = embedded_tokens.transpose(0, 1)  # Change to (seq_len, batch_size, embedding_dim)
        transformer_output = self.transformer(embedded_tokens, embedded_tokens)
        logits = self.fc_out(transformer_output)
        return logits
    
    def add_noise(self, token_ids, t, schedule):
        """Add noise to the token sequence during the forward diffusion process."""
        noisy_token_ids = token_ids.clone()
        noise_level = schedule.get_variance(t)

        # Apply noise as per variance (e.g., Gaussian noise)
        noise = torch.normal(mean=0, std=noise_level, size=noisy_token_ids.shape)
        noisy_token_ids = noisy_token_ids + noise.long()  # Convert to integer token IDs

        return noisy_token_ids

    def reverse_diffusion(self, noisy_input, schedule):
        """Apply reverse diffusion to denoise the sequence."""
        x_t = noisy_input
        for t in reversed(range(self.num_timesteps)):
            predicted_noise = self.forward(x_t)  # Predict noise at timestep t
            alpha_t = schedule.get_alpha(t)
            beta_t = schedule.get_beta(t)
            alpha_bar_t = schedule.get_alpha_bar(t)
            # Reverse the noise process
            x_t = (x_t - beta_t * predicted_noise) / alpha_bar_t
        return x_t

# Loss function for training: Predict the noise at each timestep
def denoising_loss(predicted_noise, noisy_input, clean_input):
    """Loss function for symbolic regression."""
    return nn.MSELoss()(predicted_noise, clean_input - noisy_input)

# Dataset for symbolic regression
class SymbolicRegressionDataset(torch.utils.data.Dataset):
    def __init__(self, data, vocab, max_seq_len, noise_schedule):
        self.data = data
        self.vocab = vocab
        self.max_seq_len = max_seq_len
        self.noise_schedule = noise_schedule

    def add_noise(self, token_ids, t, schedule):
        """Add noise to the token sequence during the forward diffusion process."""
        # Convert token_ids from list to a PyTorch tensor
        token_ids = torch.tensor(token_ids, dtype=torch.long)

        noisy_token_ids = token_ids.clone()  # Now .clone() will work on the tensor
        noise_level = schedule.get_variance(t)  # Get noise level

        if isinstance(noise_level, torch.Tensor):
            noise_level = noise_level.item()
        
        # Apply Gaussian noise
        noise = torch.normal(mean=0, std=noise_level, size=noisy_token_ids.shape)
        noisy_token_ids = noisy_token_ids + noise.long()  # Convert to integer token IDs

        # Ensure the token IDs are within bounds
        vocab_size = len(self.vocab)
        noisy_token_ids = torch.clamp(noisy_token_ids, 0, vocab_size - 1)
        
        return noisy_token_ids

    def __getitem__(self, idx):
        data_point = self.data[idx]
        token_ids = data_point['tokens']
        
        # Convert tokens to their corresponding indices using the vocab
        token_ids = [self.vocab.get(token, self.vocab[PAD_TOKEN]) for token in token_ids]
        
        # Pad the token sequence if necessary
        pad_token_id = self.vocab.get('<PAD>', 5)
        token_ids = token_ids + [pad_token_id] * (self.max_seq_len - len(token_ids))

        # Convert to tensor (this step is now safe as token_ids contains integers)
        token_ids = torch.tensor(token_ids, dtype=torch.long)

        t = torch.randint(0, 1000, (1,))  # Sample a random timestep
        noisy_token_ids = self.add_noise(token_ids, t, self.noise_schedule)

        # Return noisy token IDs and clean token IDs for loss computation
        return noisy_token_ids, token_ids
    
    def __len__(self):
        """Return the length of the dataset."""
        return len(self.data)
    
# Training loop for the diffusion model
def train_diffusion_model(model, train_loader, val_loader, optimizer, schedule, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (noisy_tokens, target_tokens) in enumerate(train_loader):
            optimizer.zero_grad()

            # Randomly pick a timestep for each batch
            t = random.randint(0, model.num_timesteps - 1)

            # Add noise at the forward diffusion process
            noisy_batch = torch.stack([model.add_noise(seq, t, schedule) for seq in noisy_tokens])
            
            # Forward pass through the model (predict denoised expression at timestep t)
            logits = model(noisy_batch)  # Predict logits (denoised tokens)

            # Flatten the logits and target tokens for loss computation
            logits = logits.view(-1, logits.size(-1))  # Flatten the logits (seq_len * batch_size, vocab_size)
            target_tokens = target_tokens.view(-1)  # Flatten the target tokens (seq_len * batch_size)

            # Loss computation using CrossEntropyLoss (you can switch to another loss if needed)
            loss = nn.CrossEntropyLoss()(logits, target_tokens)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

        # Validation step
        validate_diffusion_model(model, val_loader, schedule)
    return

def validate_diffusion_model(model, val_loader, schedule):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for noisy_tokens, target_tokens in val_loader:
            logits = model(noisy_tokens)  # Predict logits (denoised tokens)
            
            # Flatten logits and target tokens for loss computation
            logits = logits.view(-1, logits.size(-1))  # Flatten (seq_len * batch_size, vocab_size)
            target_tokens = target_tokens.view(-1)  # Flatten target tokens (seq_len * batch_size)
            
            # CrossEntropyLoss for token prediction
            loss = nn.CrossEntropyLoss()(logits, target_tokens)
            total_loss += loss.item()

    print(f"Validation Loss: {total_loss / len(val_loader)}")
    model.train()
    return

# Sample generation
def sample(model, schedule, batch_size=16):
    # Start with random noise
    noisy_input = torch.randint(0, len(vocab), (batch_size, model.num_timesteps))
    denoised_output = model.reverse_diffusion(noisy_input, schedule)
    return denoised_output

# Example usage

# Load preprocessed data (assuming `preprocessed_data.json` exists)
dataset = load_dataset('Dataset/preprocessed_data.json')
vocab = load_JSON("Dataset/vocab.json")

MAX_LENGTH = determine_max_seq_len(dataset)  # Determine the max length dynamically

schedule = CosineNoiseSchedule(timesteps=1000)

# Now pass the schedule when initializing SymbolicRegressionDataset
full_dataset = SymbolicRegressionDataset(dataset, vocab, MAX_LENGTH, schedule)

train_size = int(0.7 * len(full_dataset))  # 70% for training
val_size = int(0.15 * len(full_dataset))  # 15% for validation
test_size = len(full_dataset) - train_size - val_size  # 15% for testing

# Perform random split
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# Create DataLoader objects for each subset
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Initialize model, optimizer, and loss function
model = DiffusionModel(vocab_size=len(vocab), embedding_dim=128, hidden_dim=256, num_layers=4, num_timesteps=1000)
optimizer = optim.Adam(model.parameters(), lr=0.001)
schedule = CosineNoiseSchedule(timesteps=1000)

# Train the model
train_diffusion_model(model, train_loader, val_loader, optimizer, schedule, num_epochs=10)

  token_ids = torch.tensor(token_ids, dtype=torch.long)


KeyboardInterrupt: 