<a href="https://colab.research.google.com/github/carolynw898/STAT946Proj/blob/main/DiffuSym.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from models import SymbolicDiffusion, PointNetConfig, tNet  
import numpy as np

vocab_size = 50          
max_seq_len = 32         
n_embd = 512             
timesteps = 1000         
batch_size = 32          
learning_rate = 1e-4     
num_epochs = 10        

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=250,
    numberofVars=1,
    numberofYs=1,
)
model = SymbolicDiffusion(
    pconfig=pconfig,
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    padding_idx=0, # one of the dataset properties
    max_num_vars=9,
    n_layer=6,
    n_head=8,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
).to(device)

optimizer = Adam(model.parameters(), lr=learning_rate)

def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for _,  tokens, points, variables in loader:
        points = points.to(device)        
        tokens = tokens.to(device)        
        variables = variables.to(device)  

        optimizer.zero_grad()
        y_pred, noise_pred, noise = model(points, tokens, variables)
        t = torch.randint(0, model.timesteps, (tokens.shape[0],), device=device)
        loss = model.loss_fn(noise_pred, noise, y_pred, tokens, t)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    return avg_loss

def validate_epoch(model, loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for _,  tokens, points, variables in loader:
            points = points.to(device)
            tokens = tokens.to(device)
            variables = variables.to(device)

            y_pred, noise_pred, noise = model(points, tokens, variables)
            t = torch.randint(0, model.timesteps, (tokens.shape[0],), device=device)
            loss = model.loss_fn(noise_pred, noise, y_pred, tokens, t)
            total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    return avg_loss
