<a href="https://colab.research.google.com/github/carolynw898/STAT946Proj/blob/main/DiffuSym.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from models import SymbolicDiffusion, PointNetConfig
from utils import load_dataset

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm

def train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPU(s)" if num_gpus > 0 else "Using CPU")
    
    if num_gpus > 1:
        model = nn.DataParallel(model)
    model = model.to(device)
    
    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,  
        num_workers=2
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=2
    )
    
    for epoch in range(num_epochs):
        model.train()
        total_train_loss, total_train_mse, total_train_ce = 0, 0, 0
        
        for _, tokens, points, variables in tqdm.tqdm(train_loader, total=len(train_loader)):
            points = points.to(device)
            tokens = tokens.to(device)
            variables = variables.to(device)
            
            optimizer.zero_grad()
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            y_pred, noise_pred, noise = model(points, tokens, variables, t)
            
            # Handle DataParallel wrapped model
            actual_model = model.module if isinstance(model, nn.DataParallel) else model
            loss, mse, ce = actual_model.loss_fn(noise_pred, noise, y_pred, tokens, t)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            total_train_mse += mse.item()
            total_train_ce += ce.item()
        
        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_mse = total_train_mse / len(train_loader)
        avg_train_ce = total_train_ce / len(train_loader)
        
        model.eval()
        total_val_loss, total_val_mse, total_val_ce = 0, 0, 0
        
        with torch.no_grad():
            for _, tokens, points, variables in val_loader:
                points = points.to(device)
                tokens = tokens.to(device)
                variables = variables.to(device)
                t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
                
                y_pred, noise_pred, noise = model(points, tokens, variables, t)
                actual_model = model.module if isinstance(model, nn.DataParallel) else model
                loss, mse, ce = actual_model.loss_fn(noise_pred, noise, y_pred, tokens, t)
                
                total_val_loss += loss.item()
                total_val_mse += mse.item()
                total_val_ce += ce.item()
        
        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_mse = total_val_mse / len(val_loader)
        avg_val_ce = total_val_ce / len(val_loader)
        
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} | MSE: {avg_train_mse:.4f} | CE: {avg_train_ce:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f} | MSE: {avg_val_mse:.4f} | CE: {avg_val_ce:.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        if (epoch + 1) % save_every == 0:
            state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(state_dict, f"model_epoch_{epoch+1}.pth")
            print(f"Saved checkpoint at epoch {epoch+1}")

In [None]:
n_embd = 512             
timesteps = 1000         
batch_size = 64
learning_rate = 1e-4
num_epochs = 10
blockSize = 32
numVars = 1
numYs = 1
numPoints = 250
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_dataset = load_dataset("/kaggle/input/1-var-dataset/1_var_train.json", blockSize, numVars=numVars, 
                numYs=numYs, numPoints=numPoints, addVars=addVars,
                const_range=const_range, xRange=trainRange, decimals=decimals, augment=False)

val_dataset = load_dataset("/kaggle/input/1-var-dataset/1_var_val.json", blockSize, numVars=numVars, 
                numYs=numYs, numPoints=numPoints, addVars=addVars,
                const_range=const_range, xRange=trainRange, decimals=decimals, augment=False)

pconfig = PointNetConfig(
        embeddingSize=n_embd,
        numberofPoints=numPoints,
        numberofVars=numVars,
        numberofYs=numYs,
    )
    
model = SymbolicDiffusion(
        pconfig=pconfig,
        vocab_size=train_dataset.vocab_size,
        max_seq_len=blockSize,
        padding_idx=train_dataset.paddingID,
        max_num_vars=9,
        n_layer=4,
        n_head=4,
        n_embd=n_embd,
        timesteps=timesteps,
        beta_start=0.0001,
        beta_end=0.02,
    )

train_single_gpu(
        model,
        train_dataset,
        val_dataset,
        num_epochs=num_epochs,
        save_every=2,
        batch_size=batch_size,
        timesteps=timesteps,
        learning_rate=learning_rate
    )