In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/1-var-dataset/1_var_test.json
/kaggle/input/1-var-dataset/1_var_val.json
/kaggle/input/1-var-dataset/1_var_train.json
/kaggle/input/symbolic_diffusion_initial/pytorch/default/1/symbolic_diffusion_model.pth
/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt
/kaggle/input/xye_9var/pytorch/default/1/XYE_9Var.pt


In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import glob
import json
from torch.utils.data import Dataset
import re
import numpy as np
import tqdm
import random
from math import exp, sin, cos

def generateDataStrEq(eq, n_points=2, n_vars=3,
                      decimals=4, supportPoints=None, 
                      min_x=0, max_x=3):
    X = []
    Y= []
    # TODO: Need to make this faster
    for p in range(n_points):
        if supportPoints is None:
            if type(min_x) == list:
                x = []
                for _ in range(n_vars):
                    idx = np.random.randint(len(min_x))
                    x += list(np.round(np.random.uniform(min_x[idx], max_x[idx], 1), decimals))
            else:
                x = list(np.round(np.random.uniform(min_x, max_x, n_vars), decimals))
            assert len(x)!=0, "For some reason, we didn't generate the points correctly!"
        else:
            x = supportPoints[p]

        tmpEq = eq + ''
        for nVID in range(n_vars):
            tmpEq = tmpEq.replace('x{}'.format(nVID+1), str(x[nVID]))
        y = float(np.round(eval(tmpEq), decimals))
        X.append(x)
        Y.append(y)
    return X, Y


# def processDataFiles(files):
#     text = ""
#     for f in tqdm(files):
#         with open(f, 'r') as h: 
#             lines = h.read() # don't worry we won't run out of file handles
#             if lines[-1]==-1:
#                 lines = lines[:-1]
#             #text += lines #json.loads(line)    
#             text = ''.join([lines,text])    
#     return text

def processDataFiles(files):
    text = ""
    for f in files:
        with open(f, 'r') as h: 
            lines = h.read() # don't worry we won't run out of file handles
            if lines[-1]==-1:
                lines = lines[:-1]
            #text += lines #json.loads(line)    
            text = ''.join([lines,text])    
    return text

class CharDataset(Dataset):
    def __init__(self, data, block_size, chars, 
                 numVars, numYs, numPoints, target='EQ', 
                 addVars=False, const_range=[-0.4, 0.4],
                 xRange=[-3.0,3.0], decimals=4, augment=False):

        data_size, vocab_size = len(data), len(chars)
        print('data has %d examples, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }

        self.numVars = numVars
        self.numYs = numYs
        self.numPoints = numPoints
        
        # padding token
        self.paddingToken = '_'
        self.paddingID = self.stoi[self.paddingToken]
        self.stoi[self.paddingToken] = self.paddingID
        self.itos[self.paddingID] = self.paddingToken
        self.threshold = [-1000,1000]
        
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data # it should be a list of examples
        self.target = target
        self.addVars = addVars

        self.const_range = const_range
        self.xRange = xRange
        self.decimals = decimals
        self.augment = augment
    
    def __len__(self):
        return len(self.data)-1

    def __getitem__(self, idx):
        # grab an example from the data
        chunk = self.data[idx] # sequence of tokens including x, y, eq, etc.
        
        try:
            chunk = json.loads(chunk) # convert the sequence tokens to a dictionary
        except Exception as e:
            print("Couldn't convert to json: {} \n error is: {}".format(chunk, e))
            # try the previous example
            idx = idx - 1 
            idx = idx if idx>=0 else 0
            chunk = self.data[idx]
            chunk = json.loads(chunk) # convert the sequence tokens to a dictionary
            
        # find the number of variables in the equation
        printInfoCondition = random.random() < 0.0000001
        eq = chunk[self.target]
        if printInfoCondition:
            print(f'\nEquation: {eq}')
        vars = re.finditer('x[\d]+',eq) 
        numVars = 0
        for v in vars:
            v = v.group(0).strip('x')
            v = eval(v)
            v = int(v)
            if v > numVars:
                numVars = v

        if self.target == 'Skeleton' and self.augment:
            threshold = 5000
            # randomly generate the constants
            cleanEqn = ''
            for chr in eq:
                if chr == 'C':
                    # genereate a new random number
                    chr = '{}'.format(np.random.uniform(self.const_range[0], self.const_range[1]))
                cleanEqn += chr

            # update the points
            nPoints = np.random.randint(*self.numPoints) #if supportPoints is None else len(supportPoints)
            try:
                if printInfoCondition:
                    print('Org:',chunk['X'], chunk['Y'])

                X, y = generateDataStrEq(cleanEqn, n_points=nPoints, n_vars=self.numVars,
                                         decimals=self.decimals, min_x=self.xRange[0], 
                                         max_x=self.xRange[1])

                # replace out of threshold with maximum numbers
                y = [e if abs(e)<threshold else np.sign(e) * threshold for e in y]

                # check if there is nan/inf/very large numbers in the y
                conditions = (np.isnan(y).any() or np.isinf(y).any()) or len(y) == 0 or (abs(min(y)) > threshold or abs(max(y)) > threshold)
                if not conditions:
                    chunk['X'], chunk['Y'] = X, y

                if printInfoCondition:
                    print('Evd:',chunk['X'], chunk['Y'])
            except Exception as e: 
                # for different reason this might happend including but not limited to division by zero
                print("".join([
                    f"We just used the original equation and support points because of {e}. ",
                    f"The equation is {eq}, and we update the equation to {cleanEqn}",
                ]))
 
        # encode every character in the equation to an integer
        # < is SOS, > is EOS
        if self.addVars:
            dix = [self.stoi[s] for s in '<'+str(numVars)+':'+eq+'>']
        else:
            dix = [self.stoi[s] for s in '<'+eq+'>']
        inputs = dix[:-1]
        outputs = dix[1:]
        
        # add the padding to the equations
        paddingSize = max(self.block_size-len(inputs),0)
        paddingList = [self.paddingID]*paddingSize
        inputs += paddingList
        outputs += paddingList
        
        # make sure it is not more than what should be
        inputs = inputs[:self.block_size]
        outputs = outputs[:self.block_size]
        
        # extract points from the input sequence
        # maxX = max(chunk['X'])
        # maxY = max(chunk['Y'])
        # minX = min(chunk['X'])
        # minY = min(chunk['Y'])
        points = torch.zeros(self.numVars+self.numYs, self.numPoints-1)
        for idx, xy in enumerate(zip(chunk['X'], chunk['Y'])):

            if not isinstance(xy[0], list) or not isinstance(xy[1], (list, float, np.float64)):
                print(f"Unexpected types: {type(xy[0])}, {type(xy[1])}")
                continue  # Skip if types are incorrect

            # don't let to exceed the maximum number of points
            if idx >= self.numPoints-1:
                break
            
            x = xy[0]
            #x = [(e-minX[eID])/(maxX[eID]-minX[eID]+eps) for eID, e in enumerate(x)] # normalize x
            x = x + [0]*(max(self.numVars-len(x),0)) # padding

            y = [xy[1]] if type(xy[1])==float or type(xy[1])==np.float64 else xy[1]

            #y = [(e-minY)/(maxY-minY+eps) for e in y]
            y = y + [0]*(max(self.numYs-len(y),0)) # padding
            p = x+y # because it is only one point 
            p = torch.tensor(p)
            #replace nan and inf
            p = torch.nan_to_num(p, nan=self.threshold[1], 
                                 posinf=self.threshold[1], 
                                 neginf=self.threshold[0])
            # p[p>self.threshold[1]] = self.threshold[1] # clip the upper bound
            # p[p<self.threshold[0]] = self.threshold[0] # clip the lower bound
            points[:,idx] = p

        # Normalize points between zero and one # DxN
        # minP = points.min(dim=1, keepdim=True)[0]
        # maxP = points.max(dim=1, keepdim=True)[0]
        # points -= minP
        # points /= (maxP-minP+eps)
        # if printInfoCondition:
        #     print(f'Points: {points}')

        # points -= points.mean()
        # points /= points.std()
        points = torch.nan_to_num(points, nan=self.threshold[1],
                                 posinf=self.threshold[1],
                                 neginf=self.threshold[0])

        # if printInfoCondition:
        #     print(f'Points: {points}')
        #points += torch.normal(0, 0.05, size=points.shape) # add a guassian noise
        
        inputs = torch.tensor(inputs, dtype=torch.long)
        outputs = torch.tensor(outputs, dtype=torch.long)
        numVars = torch.tensor(numVars, dtype=torch.long)
        return inputs, outputs, points, numVars

# Relative Mean Square Error
def relativeErr(y, yHat, info=False, eps=1e-5):
    yHat = np.reshape(yHat, [1, -1])[0]
    y = np.reshape(y, [1, -1])[0]
    if len(y) > 0 and len(y)==len(yHat):
        err = ( (yHat - y) )** 2 / np.linalg.norm(y+eps)
        if info:
            for _ in range(5):
                i = np.random.randint(len(y))
                print('yPR,yTrue:{},{}, Err:{}'.format(yHat[i],y[i],err[i]))
    else:
        err = 100

    return np.mean(err)

def lossFunc(constants, eq, X, Y, eps=1e-5):
    err = 0
    eq = eq.replace('C','{}').format(*constants)

    for x,y in zip(X,Y):
        eqTemp = eq + ''
        if type(x) == np.float32:
            x = [x]
        for i,e in enumerate(x):
            # make sure e is not a tensor
            if type(e) == torch.Tensor:
                e = e.item()
            eqTemp = eqTemp.replace('x{}'.format(i+1), str(e))
        try:
            yHat = eval(eqTemp)
        except:
            print('Exception has been occured! EQ: {}, OR: {}'.format(eqTemp, eq))
            continue
            yHat = 100
        try:
            # handle overflow
            err += relativeErr(y, yHat) #(y-yHat)**2
        except:
            print('Exception has been occured! EQ: {}, OR: {}, y:{}-yHat:{}'.format(eqTemp, eq, y, yHat))
            continue
            err += 10
        
    err /= len(Y)
    return err

In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple


# from SymbolicGPT: https://github.com/mojivalipour/symbolicgpt/blob/master/models.py
class PointNetConfig:
    """base PointNet config"""

    def __init__(
        self,
        embeddingSize,
        numberofPoints,
        numberofVars,
        numberofYs,
        method="GPT",
        varibleEmbedding="NOT_VAR",
        **kwargs,
    ):
        self.embeddingSize = embeddingSize
        self.numberofPoints = numberofPoints  # number of points
        self.numberofVars = numberofVars  # input dimension (Xs)
        self.numberofYs = numberofYs  # output dimension (Ys)
        self.method = method
        self.varibleEmbedding = varibleEmbedding

        for k, v in kwargs.items():
            setattr(self, k, v)


class tNet(nn.Module):
    """
    The PointNet structure in the orginal PointNet paper:
    PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation by Qi et. al. 2017
    """

    def __init__(self, config):
        super(tNet, self).__init__()

        self.activation_func = F.relu
        self.num_units = config.embeddingSize

        self.conv1 = nn.Conv1d(
            config.numberofVars + config.numberofYs, self.num_units, 1
        )
        self.conv2 = nn.Conv1d(self.num_units, 2 * self.num_units, 1)
        self.conv3 = nn.Conv1d(2 * self.num_units, 4 * self.num_units, 1)
        self.fc1 = nn.Linear(4 * self.num_units, 2 * self.num_units)
        self.fc2 = nn.Linear(2 * self.num_units, self.num_units)

        # self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(config.numberofVars + config.numberofYs)
        # self.input_layer_norm = nn.LayerNorm(config.numberofPoints)

        self.bn1 = nn.BatchNorm1d(self.num_units)
        self.bn2 = nn.BatchNorm1d(2 * self.num_units)
        self.bn3 = nn.BatchNorm1d(4 * self.num_units)
        self.bn4 = nn.BatchNorm1d(2 * self.num_units)
        self.bn5 = nn.BatchNorm1d(self.num_units)

    def forward(self, x):
        """
        :param x: [batch, #features, #points]
        :return:
            logit: [batch, embedding_size]
        """
        x = self.input_batch_norm(x)
        x = self.activation_func(self.bn1(self.conv1(x)))
        x = self.activation_func(self.bn2(self.conv2(x)))
        x = self.activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4 * self.num_units

        x = self.activation_func(self.bn4(self.fc1(x)))
        x = self.activation_func(self.bn5(self.fc2(x)))
        # x = self.fc2(x)

        return x


class NoisePredictionTransformer(nn.Module):
    """Predicts continuous noise in the embedding space for diffusion-based symbolic regression."""

    def __init__(
        self,
        vocab_size: int,
        max_seq_len: int,
        padding_idx: int = 0,
        n_layer: int = 6,
        n_head: int = 8,
        n_embd: int = 512,
        max_timesteps: int = 1000,
    ):
        super().__init__()
        # self.tok_emb = nn.Embedding(
        #     vocab_size, n_embd, padding_idx=torch.tensor(padding_idx, dtype=torch.long)
        # )
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, n_embd))
        self.time_emb = nn.Embedding(max_timesteps, n_embd)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
        self.cross_attention = nn.MultiheadAttention(n_embd, n_head, batch_first=True)

    def forward(
        self,
        x_t: torch.Tensor,
        t: torch.Tensor,
        condition: torch.Tensor,
    ) -> torch.Tensor:
        """Predicts noise from noisy token indices, timestep, and condition.

        Args:
            x_t: [B, L] - Noisy expression at time t (token indices)
            t: [B] - Timestep
            condition: [B, n_embd] - Combined T-Net and variable embeddings

        Returns:
            noise_pred: [B, L, n_embd] - Predicted noise in embedding space
        """
        # print(x_t.shape)
        _, L, _ = x_t.shape

        # print(f"x_t min: {x_t.min().item()}, max: {x_t.max().item()}")
        # print(self.tok_emb.weight.shape[0])  # Number of embeddings

        # tok_emb = self.tok_emb(x_t.long())
        pos_emb = self.pos_emb[:, :L, :]
        time_emb = self.time_emb(t).unsqueeze(1)
        x = x_t + pos_emb + time_emb 

        condition = condition.unsqueeze(1)  
        attn_output, _ = self.cross_attention(x, condition, condition)  
        x = x + attn_output  # Residual connection

        noise_pred = self.encoder(x)
        return noise_pred


class SymbolicDiffusion(nn.Module):
    def __init__(
        self,
        pconfig,
        vocab_size: int,
        max_seq_len: int,
        padding_idx: int = 0,
        max_num_vars: int = 9,
        n_layer: int = 4,
        n_head: int = 4,
        n_embd: int = 512,
        timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        tok_emb_weights: torch.Tensor = None,
        vars_emb_weights: torch.Tensor = None,
        train_decoder: bool = True,
    ):
        super().__init__()
        self.timesteps = timesteps
        self.n_embd = n_embd
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.padding_idx = padding_idx
        self.tok_emb_weights = tok_emb_weights
        self.vars_emb_weights = vars_emb_weights

        # Initialize embedding layers
        self.tok_emb = nn.Embedding(vocab_size, n_embd, padding_idx=padding_idx)
        self.vars_emb = nn.Embedding(max_num_vars, n_embd)

        # Load and freeze (requires_grad = False ensures they won't be updated) weights if provided
        if tok_emb_weights is not None:
            self.tok_emb.weight = nn.Parameter(tok_emb_weights, requires_grad=False)

        if vars_emb_weights is not None:
            self.vars_emb.weight = nn.Parameter(vars_emb_weights, requires_grad=False)

        self.tnet = tNet(pconfig)
        self.transformer = NoisePredictionTransformer(
            vocab_size,
            max_seq_len,
            padding_idx,
            n_layer=n_layer,
            n_head=n_head,
            n_embd=n_embd,
            max_timesteps=timesteps,
        )

        self.train_decoder = train_decoder
        if train_decoder:
            self.decoder = nn.Linear(n_embd, vocab_size)
        else:
            self.decoder = self.round

        self.register_buffer("beta", torch.linspace(beta_start, beta_end, timesteps))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

    def xtoi(self, xt):
        B, L, _ = xt.shape
        emb_table = self.tok_emb.weight  # [vocab_size, n_embd]

        xt_flat = xt.view(B * L, -1)  # [B*L, n_embd]
        # Compute squared L2 distance: ||xt - emb_table||^2
        distances = torch.cdist(xt_flat, emb_table, p=2).pow(2)  # [B*L, vocab_size]
        nearest_indices = torch.argmin(distances, dim=1).view(B, L)  # [B, L]

        return nearest_indices

    def round(self, xt):
        emb_table = self.tok_emb.weight
        idx = self.xtoi(xt)
        return emb_table[idx]

    def q_sample(
        self,
        x0: torch.Tensor,
        t: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Adds Gaussian noise to the input embeddings at time t.

        Args:
            x0: [B, L, n_embd] - Original expression embeddings
            t: [B] - Timestep

        Returns:
            xt: [B, L, n_embd] - Noisy embeddings at time t
            noise: [B, L, n_embd] - Noise added to the original embeddings
        """
        noise = torch.randn_like(x0, dtype=torch.float)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1)
        sqrt_subone_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1)
        xt = sqrt_alpha_bar * x0 + sqrt_subone_alpha_bar * noise
        return xt, noise

    def p_sample(
        self,
        x: torch.Tensor,
        t: int,
        noise_pred: torch.Tensor,
    ) -> torch.Tensor:
        """Denoises the noisy embeddings at time t.

        Args:
            x: [B, L, n_embd] - Noisy embeddings at time t
            t: int - Current timestep (scalar)
            condition: [B, n_embd] - Combined T-Net and variable embeddings
            device: str - Device to use (e.g., "cuda")

        Returns:
            x: [B, L, n_embd] - Denoised embeddings
        """
        B = x.shape[0]
        beta_t = self.beta[t].view(-1, 1, 1)
        alpha_t = self.alpha[t].view(-1, 1, 1)
        alpha_bar_t = self.alpha_bar[t].view(-1, 1, 1)
        mean = (1 / torch.sqrt(alpha_t)) * (
            x - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
        )
        return mean

    def forward(
        self,
        points: torch.Tensor,
        tokens: torch.Tensor,
        variables: torch.Tensor,
        t: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Training forward pass to predict noise added to embeddings.

        Args:
            points: [B, 2, 250] - Dataset points for conditioning
            tokens: [B, L] - Ground truth expression (token indices)
            variables: [B] - Number of variables per sample
            t: [B] - Timestep

        Returns:
            y_pred: [B, L] - Generated expression (token indices)
            noise_pred: [B, L, n_embd] - Predicted noise
            noise: [B, L, n_embd] - Actual noise added
        """
        B = tokens.shape[0]

        condition = self.tnet(points)
        vars_emb = self.vars_emb(variables)
        condition = condition + vars_emb

        token_emb = self.tok_emb(tokens)
        xt, noise = self.q_sample(token_emb, t)
        noise_pred = self.transformer(xt, t, condition)
        y_pred = self.p_sample(xt, t, noise_pred)
        if self.train_decoder:
            y_pred = self.decoder(y_pred)
        return y_pred, noise_pred, noise

    @torch.no_grad()
    def sample(
        self,
        points: torch.Tensor,
        variables: torch.Tensor,
        device: str = "cuda",
    ) -> torch.Tensor:
        """Generates a sample by denoising from random noise.

        Args:
            points: [B, 2, 250] - Dataset points for conditioning
            variables: [B] - Number of variables per sample
            device: str - Device to use (e.g., "cuda")

        Returns:
            xt: [B, L] - Generated expression (token indices)
        """
        condition = self.tnet(points) + self.vars_emb(variables)  # [B, n_embd]
        B = condition.shape[0]

        xt_emb = torch.randn(B, self.max_seq_len, self.n_embd, device=device)
        t_tensor = torch.zeros(B, dtype=torch.long, device=device)  # [B]
        for t in range(self.timesteps - 1, -1, -1):
            t_tensor.fill_(t)  # Update in-place: [B] all set to t
            noise_pred = self.transformer(xt_emb, t_tensor, condition)
            xt_emb = self.p_sample(xt_emb, t, noise_pred)  # t as int for p_sample
            if self.train_decoder:
                xt_logits = self.decoder(xt_emb)
                xt = torch.argmax(xt_logits, dim=-1)
                xt_emb = self.tok_emb(xt)
            else:
                xt_emb = self.round(xt_emb)

        xt = self.xtoi(xt_emb)
        return xt

    def loss_fn(
        self,
        noise_pred: torch.Tensor,  # [B, L, n_embd]
        noise: torch.Tensor,  # [B, L, n_embd]
        pred_logits: torch.Tensor,  # [B, L, vocab_size] or [B, L, n_embd]
        tokens: torch.Tensor,  # [B, L]
        t: torch.Tensor,  # [B]
        verbose: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        B, L = tokens.shape
        device = tokens.device
    
        mse_loss = F.mse_loss(noise_pred, noise)
        ce_weight = 1.0 - (t.float() / self.timesteps)
    
        if self.train_decoder:
            ce_loss = F.cross_entropy(
                pred_logits.view(-1, pred_logits.size(-1)),
                tokens.view(-1),
                reduction="none",
                ignore_index=self.padding_idx,
            ).view(B, L)
            weighted_ce_loss = (ce_weight.unsqueeze(1) * ce_loss).mean()
            rounding_loss = torch.tensor(0.0, device=device)
        else:
            weighted_ce_loss = torch.tensor(0.0, device=device)
            rounding_loss = F.mse_loss(pred_logits, self.tok_emb(tokens))  # predict exact target embeddings

        total_loss = mse_loss + weighted_ce_loss + rounding_loss
    
        if verbose:
            print(f"MSE: {mse_loss.item():.4f}, CE: {weighted_ce_loss.item():.4f}, Rounding: {rounding_loss.item():.4f}")
    
        return total_loss, mse_loss, weighted_ce_loss, rounding_loss


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm

def train_epoch(
    model: SymbolicDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss, total_train_mse, total_train_ce = 0, 0, 0
    
    for _, tokens, points, variables in tqdm.tqdm(train_loader, total=len(train_loader)):
        points, tokens, variables = points.to(device), tokens.to(device), variables.to(device)
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)

        optimizer.zero_grad()
        y_pred, noise_pred, noise = model(points, tokens, variables, t)
        actual_model = model.module if isinstance(model, nn.DataParallel) else model
        
        loss, mse, ce, _ = actual_model.loss_fn(
            noise_pred, noise, y_pred, tokens, t
        )

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        total_train_mse += mse.item()
        total_train_ce += ce.item()
        
    avg_train_loss = total_train_loss / len(train_loader)
    avg_train_mse = total_train_mse / len(train_loader)
    avg_train_ce = total_train_ce / len(train_loader)

    return avg_train_loss, avg_train_mse, avg_train_ce

def val_epoch(
    model: SymbolicDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss, total_val_mse, total_val_ce = 0, 0, 0
    
    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(val_loader, total=len(val_loader)):
            points, tokens, variables = points.to(device), tokens.to(device), variables.to(device)
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            y_pred, noise_pred, noise = model(points, tokens, variables, t)
            actual_model = model.module if isinstance(model, nn.DataParallel) else model

            loss, mse, ce, _ = actual_model.loss_fn(
                noise_pred, noise, y_pred, tokens, t
            )

            total_val_loss += loss.item()
            total_val_mse += mse.item()
            total_val_ce += ce.item()
            
    avg_val_loss = total_val_loss / len(val_loader)
    avg_val_mse = total_val_mse / len(val_loader)
    avg_val_ce = total_val_ce / len(val_loader)

    return avg_val_loss, avg_val_mse, avg_val_ce

def train_single_gpu(
    model: SymbolicDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPU(s)" if num_gpus > 0 else "Using CPU")

    if num_gpus > 1:
        model = nn.DataParallel(model)
    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=4
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=4
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        # Train with epoch progress
        avg_train_loss, avg_train_mse, avg_train_ce = train_epoch(
            model, train_loader, optimizer, train_dataset, timesteps, device, epoch, num_epochs
        )

        # Validate with epoch progress
        avg_val_loss, avg_val_mse, avg_val_ce = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        # Print epoch summary after completion
        print("\nEpoch Summary:")
        print(f"Train Loss: {avg_train_loss:.4f} | MSE: {avg_train_mse:.4f} | CE: {avg_train_ce:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f} | MSE: {avg_val_mse:.4f} | CE: {avg_val_ce:.4f}")
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(state_dict, "best_model.pth")
            print(f"New best model saved with val loss: {best_val_loss:.4f}")
        print("-" * 50)

In [5]:
# setting hyperparameters
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 1
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
train_path = "/kaggle/input/1-var-dataset/1_var_train.json"
val_path = "/kaggle/input/1-var-dataset/1_var_val.json"

In [7]:
import numpy as np
import glob
import random

files = glob.glob(train_path)
text = processDataFiles(files)
chars = sorted(list(set(text))+['_','T','<','>',':']) # extract unique characters from the text before converting the text to a list, # T is for the test data
text = text.split('\n') # convert the raw text to a set of examples
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, chars, numVars=numVars, 
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 498795 examples, 49 unique.
id:93724
outputs:C*log(C*exp(C*x1))+C>___________
variables:1


In [8]:
files = glob.glob(val_path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, chars, numVars=numVars, 
                    numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                    const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 972 examples, 49 unique.
tensor(-2.7263) tensor(76.8960)
id:908
outputs:C*log(C*x1+C)**6+C>_____________
variables:1


In [9]:
weights = torch.load("/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt", map_location=torch.device(device))

#vars_emb_weights = weights['vars_emb.weight']
tok_emb_weights = weights['tok_emb.weight']

  weights = torch.load("/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt", map_location=torch.device(device))


In [10]:
pconfig = PointNetConfig(
        embeddingSize=n_embd,
        numberofPoints=numPoints,
        numberofVars=numVars,
        numberofYs=numYs,
    )
    
model = SymbolicDiffusion(
        pconfig=pconfig,
        vocab_size=train_dataset.vocab_size,
        max_seq_len=blockSize,
        padding_idx=train_dataset.paddingID,
        max_num_vars=9,
        n_layer=8,
        n_head=8,
        n_embd=n_embd,
        timesteps=timesteps,
        beta_start=0.0001,
        beta_end=0.02,
        tok_emb_weights=tok_emb_weights
    )

train_single_gpu(
        model,
        train_dataset,
        val_dataset,
        num_epochs=num_epochs,
        save_every=2,
        batch_size=batch_size,
        timesteps=timesteps,
        learning_rate=learning_rate
    )

Using 2 GPU(s)


100%|██████████| 7794/7794 [18:33<00:00,  7.00it/s]
100%|██████████| 16/16 [00:01<00:00, 12.68it/s]



Epoch Summary:
Train Loss: 0.5462 | MSE: 0.1777 | CE: 0.3685
Val Loss: 0.3217 | MSE: 0.0497 | CE: 0.2720
Learning Rate: 0.000100
New best model saved with val loss: 0.3217
--------------------------------------------------


100%|██████████| 7794/7794 [18:41<00:00,  6.95it/s]
100%|██████████| 16/16 [00:01<00:00, 15.07it/s]



Epoch Summary:
Train Loss: 0.3512 | MSE: 0.0821 | CE: 0.2691
Val Loss: 0.2884 | MSE: 0.0432 | CE: 0.2452
Learning Rate: 0.000100
New best model saved with val loss: 0.2884
--------------------------------------------------


100%|██████████| 7794/7794 [18:40<00:00,  6.95it/s]
100%|██████████| 16/16 [00:01<00:00, 14.95it/s]



Epoch Summary:
Train Loss: 0.3173 | MSE: 0.0682 | CE: 0.2490
Val Loss: 0.2797 | MSE: 0.0400 | CE: 0.2397
Learning Rate: 0.000100
New best model saved with val loss: 0.2797
--------------------------------------------------


100%|██████████| 7794/7794 [18:41<00:00,  6.95it/s]
100%|██████████| 16/16 [00:01<00:00, 14.74it/s]



Epoch Summary:
Train Loss: 0.3020 | MSE: 0.0607 | CE: 0.2413
Val Loss: 0.2673 | MSE: 0.0360 | CE: 0.2313
Learning Rate: 0.000100
New best model saved with val loss: 0.2673
--------------------------------------------------


100%|██████████| 7794/7794 [18:40<00:00,  6.95it/s]
100%|██████████| 16/16 [00:01<00:00, 14.14it/s]



Epoch Summary:
Train Loss: 0.2938 | MSE: 0.0544 | CE: 0.2394
Val Loss: 0.2614 | MSE: 0.0354 | CE: 0.2260
Learning Rate: 0.000100
New best model saved with val loss: 0.2614
--------------------------------------------------
