In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt
/kaggle/input/1-var-dataset/1_var_test.json
/kaggle/input/1-var-dataset/1_var_val.json
/kaggle/input/1-var-dataset/1_var_train.json
/kaggle/input/xye_9var/pytorch/default/1/XYE_9Var.pt
/kaggle/input/symbolic_diffusion_initial/pytorch/default/1/symbolic_diffusion_model.pth


In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import glob
import json
from torch.utils.data import Dataset
import re
import numpy as np
import tqdm
import random
from math import exp, sin, cos


def generateDataStrEq(
    eq, n_points=2, n_vars=3, decimals=4, supportPoints=None, min_x=0, max_x=3
):
    X = []
    Y = []
    # TODO: Need to make this faster
    for p in range(n_points):
        if supportPoints is None:
            if type(min_x) == list:
                x = []
                for _ in range(n_vars):
                    idx = np.random.randint(len(min_x))
                    x += list(
                        np.round(np.random.uniform(min_x[idx], max_x[idx], 1), decimals)
                    )
            else:
                x = list(np.round(np.random.uniform(min_x, max_x, n_vars), decimals))
            assert (
                len(x) != 0
            ), "For some reason, we didn't generate the points correctly!"
        else:
            x = supportPoints[p]

        tmpEq = eq + ""
        for nVID in range(n_vars):
            tmpEq = tmpEq.replace("x{}".format(nVID + 1), str(x[nVID]))
        y = float(np.round(eval(tmpEq), decimals))
        X.append(x)
        Y.append(y)
    return X, Y


# def processDataFiles(files):
#     text = ""
#     for f in tqdm(files):
#         with open(f, 'r') as h:
#             lines = h.read() # don't worry we won't run out of file handles
#             if lines[-1]==-1:
#                 lines = lines[:-1]
#             #text += lines #json.loads(line)
#             text = ''.join([lines,text])
#     return text


def processDataFiles(files):
    text = ""
    for f in files:
        with open(f, "r") as h:
            lines = h.read()  # don't worry we won't run out of file handles
            if lines[-1] == -1:
                lines = lines[:-1]
            # text += lines #json.loads(line)
            text = "".join([lines, text])
    return text


def tokenize_equation(eq):
    token_spec = [
        (r'\*\*'),                # exponentiation
        (r'exp'),                 # exp function
        (r'[+\-*/=()]'),          # operators and parentheses
        (r'sin'),                 # sin function
        (r'cos'),                 # cos function
        (r'log'),                 # log function
        (r'x\d+'),                # variables like x1, x23, etc.
        (r'C'),                   # constants placeholder
        (r'-?\d+\.\d+'),          # decimal numbers
        (r'-?\d+'),               # integers
        (r'_'),                   # padding token
    ]
    token_regex = '|'.join(f'({pattern})' for pattern in token_spec)
    matches = re.finditer(token_regex, eq)
    return [match.group(0) for match in matches]


class CharDataset(Dataset):
    def __init__(
        self,
        data,
        block_size,
        tokens,
        numVars,
        numYs,
        numPoints,
        target="Skeleton",
        addVars=False,
        const_range=[-0.4, 0.4],
        xRange=[-3.0, 3.0],
        decimals=4,
        augment=False,
    ):

        data_size, vocab_size = len(data), len(tokens)
        print("data has %d examples, %d unique." % (data_size, vocab_size))

        self.stoi = {tok: i for i, tok in enumerate(tokens)}
        self.itos = {i: tok for i, tok in enumerate(tokens)}


        self.numVars = numVars
        self.numYs = numYs
        self.numPoints = numPoints

        # padding token
        self.paddingToken = "_"
        self.paddingID = self.stoi["_"]  # or another ID not already used
        self.stoi[self.paddingToken] = self.paddingID
        self.itos[self.paddingID] = self.paddingToken

        self.threshold = [-1000, 1000]

        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data  # it should be a list of examples
        self.target = target
        self.addVars = addVars

        self.const_range = const_range
        self.xRange = xRange
        self.decimals = decimals
        self.augment = augment

    def __len__(self):
        return len(self.data) - 1

    def __getitem__(self, idx):
        # grab an example from the data
        chunk = self.data[idx]  # sequence of tokens including x, y, eq, etc.

        try:
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary
        except Exception as e:
            print("Couldn't convert to json: {} \n error is: {}".format(chunk, e))
            # try the previous example
            idx = idx - 1
            idx = idx if idx >= 0 else 0
            chunk = self.data[idx]
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary

        # find the number of variables in the equation
        printInfoCondition = random.random() < 0.0000001
        eq = chunk[self.target]
        if printInfoCondition:
            print(f"\nEquation: {eq}")
        vars = re.finditer("x[\d]+", eq)
        numVars = 0
        for v in vars:
            v = v.group(0).strip("x")
            v = eval(v)
            v = int(v)
            if v > numVars:
                numVars = v

        if self.target == "Skeleton" and self.augment:
            threshold = 5000
            # randomly generate the constants
            cleanEqn = ""
            for chr in eq:
                if chr == "C":
                    # genereate a new random number
                    chr = "{}".format(
                        np.random.uniform(self.const_range[0], self.const_range[1])
                    )
                cleanEqn += chr

            # update the points
            nPoints = np.random.randint(
                *self.numPoints
            )  # if supportPoints is None else len(supportPoints)
            try:
                if printInfoCondition:
                    print("Org:", chunk["X"], chunk["Y"])

                X, y = generateDataStrEq(
                    cleanEqn,
                    n_points=nPoints,
                    n_vars=self.numVars,
                    decimals=self.decimals,
                    min_x=self.xRange[0],
                    max_x=self.xRange[1],
                )

                # replace out of threshold with maximum numbers
                y = [e if abs(e) < threshold else np.sign(e) * threshold for e in y]

                # check if there is nan/inf/very large numbers in the y
                conditions = (
                    (np.isnan(y).any() or np.isinf(y).any())
                    or len(y) == 0
                    or (abs(min(y)) > threshold or abs(max(y)) > threshold)
                )
                if not conditions:
                    chunk["X"], chunk["Y"] = X, y

                if printInfoCondition:
                    print("Evd:", chunk["X"], chunk["Y"])
            except Exception as e:
                # for different reason this might happend including but not limited to division by zero
                print(
                    "".join(
                        [
                            f"We just used the original equation and support points because of {e}. ",
                            f"The equation is {eq}, and we update the equation to {cleanEqn}",
                        ]
                    )
                )

        # encode every character in the equation to an integer
        # < is SOS, > is EOS
        if self.addVars:
            dix = [self.stoi[s] for s in "<" + str(numVars) + ":" + eq + ">"]
        else:
            eq_tokens = tokenize_equation(eq)
            if self.addVars:
                token_seq = ["<", str(numVars), ":", *eq_tokens, ">"]
            else:
                token_seq = ["<", *eq_tokens, ">"]
            dix = [self.stoi[tok] for tok in token_seq]

        inputs = dix[:-1]
        outputs = dix[1:]

        # add the padding to the equations
        paddingSize = max(self.block_size - len(inputs), 0)
        paddingList = [self.paddingID] * paddingSize
        inputs += paddingList
        outputs += paddingList

        # make sure it is not more than what should be
        inputs = inputs[: self.block_size]
        outputs = outputs[: self.block_size]

        points = torch.zeros(self.numVars + self.numYs, self.numPoints - 1)
        for idx, xy in enumerate(zip(chunk["X"], chunk["Y"])):

            if not isinstance(xy[0], list) or not isinstance(
                xy[1], (list, float, np.float64)
            ):
                print(f"Unexpected types: {type(xy[0])}, {type(xy[1])}")
                continue  # Skip if types are incorrect

            # don't let to exceed the maximum number of points
            if idx >= self.numPoints - 1:
                break

            x = xy[0]
            x = x + [0] * (max(self.numVars - len(x), 0))  # padding

            y = [xy[1]] if type(xy[1]) == float or type(xy[1]) == np.float64 else xy[1]

            y = y + [0] * (max(self.numYs - len(y), 0))  # padding
            p = x + y  # because it is only one point
            p = torch.tensor(p)
            # replace nan and inf
            p = torch.nan_to_num(
                p,
                nan=self.threshold[1],
                posinf=self.threshold[1],
                neginf=self.threshold[0],
            )
            
            points[:, idx] = p

        points = torch.nan_to_num(
            points,
            nan=self.threshold[1],
            posinf=self.threshold[1],
            neginf=self.threshold[0],
        )

        inputs = torch.tensor(inputs, dtype=torch.long)
        outputs = torch.tensor(outputs, dtype=torch.long)
        numVars = torch.tensor(numVars, dtype=torch.long)
        return inputs, outputs, points, numVars


# Relative Mean Square Error
def relativeErr(y, yHat, info=False, eps=1e-5):
    yHat = np.reshape(yHat, [1, -1])[0]
    y = np.reshape(y, [1, -1])[0]
    if len(y) > 0 and len(y) == len(yHat):
        err = ((yHat - y)) ** 2 / np.linalg.norm(y + eps)
        if info:
            for _ in range(5):
                i = np.random.randint(len(y))
                print("yPR,yTrue:{},{}, Err:{}".format(yHat[i], y[i], err[i]))
    else:
        err = 100

    return np.mean(err)


def lossFunc(constants, eq, X, Y, eps=1e-5):
    err = 0
    eq = eq.replace("C", "{}").format(*constants)

    for x, y in zip(X, Y):
        eqTemp = eq + ""
        if type(x) == np.float32:
            x = [x]
        for i, e in enumerate(x):
            # make sure e is not a tensor
            if type(e) == torch.Tensor:
                e = e.item()
            eqTemp = eqTemp.replace("x{}".format(i + 1), str(e))
        try:
            yHat = eval(eqTemp)
        except:
            print("Exception has been occured! EQ: {}, OR: {}".format(eqTemp, eq))
            continue
            yHat = 100
        try:
            # handle overflow
            err += relativeErr(y, yHat)  # (y-yHat)**2
        except:
            print(
                "Exception has been occured! EQ: {}, OR: {}, y:{}-yHat:{}".format(
                    eqTemp, eq, y, yHat
                )
            )
            continue
            err += 10

    err /= len(Y)
    return err


In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from tqdm import tqdm

# from SymbolicGPT: https://github.com/mojivalipour/symbolicgpt/blob/master/models.py
class PointNetConfig:
    """base PointNet config"""

    def __init__(
        self,
        embeddingSize,
        numberofPoints,
        numberofVars,
        numberofYs,
        method="GPT",
        varibleEmbedding="NOT_VAR",
        **kwargs,
    ):
        self.embeddingSize = embeddingSize
        self.numberofPoints = numberofPoints  # number of points
        self.numberofVars = numberofVars  # input dimension (Xs)
        self.numberofYs = numberofYs  # output dimension (Ys)
        self.method = method
        self.varibleEmbedding = varibleEmbedding

        for k, v in kwargs.items():
            setattr(self, k, v)


class tNet(nn.Module):
    """
    The PointNet structure in the orginal PointNet paper:
    PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation by Qi et. al. 2017
    """

    def __init__(self, config):
        super(tNet, self).__init__()

        self.activation_func = F.relu
        self.num_units = config.embeddingSize

        self.conv1 = nn.Conv1d(
            config.numberofVars + config.numberofYs, self.num_units, 1
        )
        self.conv2 = nn.Conv1d(self.num_units, 2 * self.num_units, 1)
        self.conv3 = nn.Conv1d(2 * self.num_units, 4 * self.num_units, 1)
        self.fc1 = nn.Linear(4 * self.num_units, 2 * self.num_units)
        self.fc2 = nn.Linear(2 * self.num_units, self.num_units)

        # self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(config.numberofVars + config.numberofYs)
        # self.input_layer_norm = nn.LayerNorm(config.numberofPoints)

        self.bn1 = nn.BatchNorm1d(self.num_units)
        self.bn2 = nn.BatchNorm1d(2 * self.num_units)
        self.bn3 = nn.BatchNorm1d(4 * self.num_units)
        self.bn4 = nn.BatchNorm1d(2 * self.num_units)
        self.bn5 = nn.BatchNorm1d(self.num_units)

    def forward(self, x):
        """
        :param x: [batch, #features, #points]
        :return:
            logit: [batch, embedding_size]
        """
        x = self.input_batch_norm(x)
        x = self.activation_func(self.bn1(self.conv1(x)))
        x = self.activation_func(self.bn2(self.conv2(x)))
        x = self.activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4 * self.num_units

        x = self.activation_func(self.bn4(self.fc1(x)))
        x = self.activation_func(self.bn5(self.fc2(x)))
        # x = self.fc2(x)

        return x


class NoisePredictionTransformer(nn.Module):
    def __init__(self, n_embd, max_seq_len, n_layer=6, n_head=8, max_timesteps=1000):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, n_embd))
        self.time_emb = nn.Embedding(max_timesteps, n_embd)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)

    def forward(self, x_t, t, condition):
        B, L, _ = x_t.shape
        pos_emb = self.pos_emb[:, :L, :]
        time_emb = self.time_emb(t).unsqueeze(1)
        condition = condition.unsqueeze(1)
        x = x_t + pos_emb + time_emb + condition
        return self.encoder(x)  # Predicts x_start


# Symbolic Diffusion with Hybrid Loss
class SymbolicGaussianDiffusion(nn.Module):
    def __init__(
        self,
        tnet_config,
        vocab_size,
        max_seq_len,
        padding_idx: int = 0,
        max_num_vars: int = 9,
        n_layer=6,
        n_head=8,
        n_embd=512,
        timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        ce_weight=1.0,  # Weight for CE loss relative to MSE
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.padding_idx = padding_idx
        self.n_embd = n_embd
        self.timesteps = timesteps
        self.ce_weight = ce_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Embedding layers
        self.tok_emb = nn.Embedding(vocab_size, n_embd, padding_idx=self.padding_idx)
        self.vars_emb = nn.Embedding(max_num_vars, n_embd)

        # Decoder
        self.decoder = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.weight = self.tok_emb.weight

        # Models
        self.tnet = tNet(tnet_config)
        self.model = NoisePredictionTransformer(
            n_embd, max_seq_len, n_layer, n_head, timesteps
        )

        # Noise schedule
        self.register_buffer("beta", torch.linspace(beta_start, beta_end, timesteps))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

    def q_sample(self, x_start, t, noise=None):
        noise = torch.randn_like(x_start)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1)

        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
        return x_t

    def p_mean_variance(self, x, t, t_next, condition):
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        alpha_bar_t_next = self.alpha_bar[t_next]
        beta_t = self.beta[t]

        x_start_pred = self.model(x, t.long(), condition)

        coeff1 = torch.sqrt(alpha_bar_t_next) * beta_t / (1 - alpha_bar_t)
        coeff2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_next) / (1 - alpha_bar_t)
        mean = coeff1 * x_start_pred + coeff2 * x
        variance = (1 - alpha_bar_t_next) / (1 - alpha_bar_t) * beta_t
        return mean, variance

    @torch.no_grad()
    def p_sample(self, x, t, t_next, condition):
        mean, variance = self.p_mean_variance(x, t, t_next, condition)
        if torch.all(t_next == 0):
            return mean
        noise = torch.randn_like(x)
        return mean + torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, points, variables, batch_size=16):
        condition = self.tnet(points) + self.vars_emb(variables)
        shape = (batch_size, self.max_seq_len, self.n_embd)
        x = torch.randn(shape, device=self.device)
        steps = torch.arange(self.timesteps, 0, -1, device=self.device)

        for i in tqdm(
            range(self.timesteps), desc="sampling loop", total=self.timesteps
        ):
            t = steps[i]
            t_next = (
                steps[i + 1]
                if i + 1 < self.timesteps
                else torch.tensor(0, device=self.device)
            )
            x = self.p_sample(x, t, t_next, condition)

        # Map embeddings to token indices via decoder
        logits = self.decoder(x)  # [B, L, vocab_size]
        token_indices = torch.argmax(logits, dim=-1)  # [B, L]
        return token_indices

    def p_losses(self, x_start, points, tokens, variables, t, noise=None):
        """Hybrid loss: MSE on embeddings + CE on tokens."""
        noise = torch.randn_like(x_start) if noise is None else noise
        x_t = self.q_sample(x_start, t, noise)
        condition = self.tnet(points) + self.vars_emb(variables)
        x_start_pred = self.model(x_t, t.long(), condition)

        # MSE loss on embeddings
        mse_loss = F.mse_loss(x_start_pred, x_start)

        # CE loss on tokens
        logits = self.decoder(x_start_pred)  # [B, L, vocab_size]
        ce_loss = F.cross_entropy(
            logits.view(-1, self.vocab_size),  # [B*L, vocab_size]
            tokens.view(-1),  # [B*L]
            ignore_index=self.padding_idx,  # Assuming padding_idx=0
            reduction="mean",
        )

        # Combine losses
        total_loss = mse_loss + self.ce_weight * ce_loss
        return total_loss, mse_loss, ce_loss

    def forward(self, points, tokens, variables, t):
        token_emb = self.tok_emb(tokens)
        total_loss, mse_loss, ce_loss = self.p_losses(
            token_emb, points, tokens, variables, t
        )
        return total_loss, mse_loss, ce_loss


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm


def train_epoch(
    model: SymbolicGaussianDiffusion,  # Changed type hint
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:  # Now returns all loss components
    model.train()
    total_train_loss = 0
    total_mse_loss = 0
    total_ce_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)

        optimizer.zero_grad()

        # Model now returns (total_loss, mse_loss, ce_loss)
        total_loss, mse_loss, ce_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}, mse: {mse_loss}, ce: {ce_loss}")

        total_loss.backward()
        optimizer.step()

        # Accumulate all losses
        total_train_loss += total_loss.item()
        total_mse_loss += mse_loss.item()
        total_ce_loss += ce_loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_mse_loss = total_mse_loss / len(train_loader)
    avg_ce_loss = total_ce_loss / len(train_loader)
    return avg_train_loss, avg_mse_loss, avg_ce_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,  # Changed type hint
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:  # Now returns all loss components
    model.eval()
    total_val_loss = 0
    total_mse_loss = 0
    total_ce_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)

            # Model returns (total_loss, mse_loss, ce_loss)
            total_loss, mse_loss, ce_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()
            total_mse_loss += mse_loss.item()
            total_ce_loss += ce_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    avg_mse_loss = total_mse_loss / len(val_loader)
    avg_ce_loss = total_ce_loss / len(val_loader)
    return avg_val_loss, avg_mse_loss, avg_ce_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,  # Changed type hint
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        # Train with epoch progress
        avg_train_loss, avg_mse_loss, avg_ce_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        # Validate with epoch progress
        avg_val_loss, val_mse_loss, val_ce_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        # Print detailed epoch summary
        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f} (MSE: {avg_mse_loss:.4f}, CE: {avg_ce_loss:.4f})"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f} (MSE: {val_mse_loss:.4f}, CE: {val_ce_loss:.4f})"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, "best_model.pth")
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)


In [5]:
# setting hyperparameters
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 1
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
train_path = "/kaggle/input/1-var-dataset/1_var_train.json"
val_path = "/kaggle/input/1-var-dataset/1_var_val.json"

In [7]:
import numpy as np
import glob
import random

files = glob.glob(train_path)
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 498795 examples, 27 unique.
id:50277
outputs:C*x1+C>__________________________
variables:1


In [8]:
files = glob.glob(val_path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 972 examples, 27 unique.
tensor(-2.9973) tensor(2.8438)
id:485
outputs:C>______________________________
variables:0


In [9]:
weights = torch.load("/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt", map_location=torch.device(device))

#vars_emb_weights = weights['vars_emb.weight']
tok_emb_weights = weights['tok_emb.weight']

  weights = torch.load("/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt", map_location=torch.device(device))


In [10]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,  
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
    ce_weight=1.0  
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate
)

Epoch 1/5:   3%|▎         | 250/7794 [00:36<17:36,  7.14it/s]

Batch 250/7794:
total_loss: 2.803213119506836, mse: 1.3523979187011719, ce: 1.4508150815963745


Epoch 1/5:   6%|▋         | 500/7794 [01:12<17:30,  6.94it/s]

Batch 500/7794:
total_loss: 2.1594676971435547, mse: 1.3281495571136475, ce: 0.831318199634552


Epoch 1/5:  10%|▉         | 750/7794 [01:49<17:48,  6.59it/s]

Batch 750/7794:
total_loss: 2.3239758014678955, mse: 1.3251063823699951, ce: 0.9988694190979004


Epoch 1/5:  13%|█▎        | 1000/7794 [02:28<17:55,  6.31it/s]

Batch 1000/7794:
total_loss: 2.246151924133301, mse: 1.236047625541687, ce: 1.0101042985916138


Epoch 1/5:  16%|█▌        | 1250/7794 [03:06<16:47,  6.50it/s]

Batch 1250/7794:
total_loss: 1.8546398878097534, mse: 1.1486310958862305, ce: 0.706008791923523


Epoch 1/5:  19%|█▉        | 1500/7794 [03:45<16:21,  6.41it/s]

Batch 1500/7794:
total_loss: 1.851926565170288, mse: 1.0315630435943604, ce: 0.820363461971283


Epoch 1/5:  22%|██▏       | 1750/7794 [04:24<15:35,  6.46it/s]

Batch 1750/7794:
total_loss: 1.7515376806259155, mse: 0.9856534004211426, ce: 0.765884280204773


Epoch 1/5:  26%|██▌       | 2000/7794 [05:03<15:03,  6.41it/s]

Batch 2000/7794:
total_loss: 1.878409504890442, mse: 0.9916896820068359, ce: 0.886719822883606


Epoch 1/5:  29%|██▉       | 2250/7794 [05:42<14:17,  6.46it/s]

Batch 2250/7794:
total_loss: 1.5370471477508545, mse: 0.872600793838501, ce: 0.6644462943077087


Epoch 1/5:  32%|███▏      | 2500/7794 [06:20<13:40,  6.45it/s]

Batch 2500/7794:
total_loss: 1.1254217624664307, mse: 0.7271144390106201, ce: 0.39830729365348816


Epoch 1/5:  35%|███▌      | 2750/7794 [06:59<13:02,  6.45it/s]

Batch 2750/7794:
total_loss: 1.3049204349517822, mse: 0.7360402345657349, ce: 0.5688801407814026


Epoch 1/5:  38%|███▊      | 3000/7794 [07:38<12:17,  6.50it/s]

Batch 3000/7794:
total_loss: 1.1256673336029053, mse: 0.6306405067443848, ce: 0.4950268864631653


Epoch 1/5:  42%|████▏     | 3250/7794 [08:17<11:43,  6.46it/s]

Batch 3250/7794:
total_loss: 1.2224040031433105, mse: 0.5853101015090942, ce: 0.6370939016342163


Epoch 1/5:  45%|████▍     | 3500/7794 [08:55<11:05,  6.45it/s]

Batch 3500/7794:
total_loss: 1.3936150074005127, mse: 0.6489464044570923, ce: 0.7446686029434204


Epoch 1/5:  48%|████▊     | 3750/7794 [09:34<10:28,  6.43it/s]

Batch 3750/7794:
total_loss: 0.7958780527114868, mse: 0.4690704047679901, ce: 0.3268076479434967


Epoch 1/5:  51%|█████▏    | 4000/7794 [10:13<09:47,  6.46it/s]

Batch 4000/7794:
total_loss: 1.256026268005371, mse: 0.5406988263130188, ce: 0.7153274416923523


Epoch 1/5:  55%|█████▍    | 4250/7794 [10:52<09:07,  6.47it/s]

Batch 4250/7794:
total_loss: 1.1390266418457031, mse: 0.5264507532119751, ce: 0.612575888633728


Epoch 1/5:  58%|█████▊    | 4500/7794 [11:31<08:34,  6.41it/s]

Batch 4500/7794:
total_loss: 0.9479575157165527, mse: 0.4382385313510895, ce: 0.5097190141677856


Epoch 1/5:  61%|██████    | 4750/7794 [12:09<07:51,  6.45it/s]

Batch 4750/7794:
total_loss: 1.1318581104278564, mse: 0.45336461067199707, ce: 0.6784934401512146


Epoch 1/5:  64%|██████▍   | 5000/7794 [12:48<07:12,  6.46it/s]

Batch 5000/7794:
total_loss: 1.2643734216690063, mse: 0.47545960545539856, ce: 0.7889137864112854


Epoch 1/5:  67%|██████▋   | 5250/7794 [13:27<06:34,  6.45it/s]

Batch 5250/7794:
total_loss: 0.8296408653259277, mse: 0.401164174079895, ce: 0.4284766614437103


Epoch 1/5:  71%|███████   | 5500/7794 [14:06<05:54,  6.46it/s]

Batch 5500/7794:
total_loss: 0.9335244297981262, mse: 0.3839479088783264, ce: 0.5495765209197998


Epoch 1/5:  74%|███████▍  | 5750/7794 [14:45<05:16,  6.46it/s]

Batch 5750/7794:
total_loss: 0.9255407452583313, mse: 0.3734057545661926, ce: 0.5521349906921387


Epoch 1/5:  77%|███████▋  | 6000/7794 [15:23<04:37,  6.47it/s]

Batch 6000/7794:
total_loss: 0.7544429302215576, mse: 0.31095415353775024, ce: 0.443488746881485


Epoch 1/5:  80%|████████  | 6250/7794 [16:02<03:58,  6.46it/s]

Batch 6250/7794:
total_loss: 0.9050230979919434, mse: 0.3247774839401245, ce: 0.5802456140518188


Epoch 1/5:  83%|████████▎ | 6500/7794 [16:41<03:20,  6.44it/s]

Batch 6500/7794:
total_loss: 0.8645598292350769, mse: 0.3059195876121521, ce: 0.5586402416229248


Epoch 1/5:  85%|████████▌ | 6642/7794 [17:03<02:58,  6.44it/s]


Equation: C*x1**2+C*x1+C


Epoch 1/5:  87%|████████▋ | 6750/7794 [17:20<02:42,  6.42it/s]

Batch 6750/7794:
total_loss: 0.8003764748573303, mse: 0.3235405683517456, ce: 0.4768359065055847


Epoch 1/5:  90%|████████▉ | 7000/7794 [17:59<02:03,  6.43it/s]

Batch 7000/7794:
total_loss: 0.6529006958007812, mse: 0.28027912974357605, ce: 0.3726215660572052


Epoch 1/5:  93%|█████████▎| 7250/7794 [18:37<01:24,  6.44it/s]

Batch 7250/7794:
total_loss: 0.8600479960441589, mse: 0.32982516288757324, ce: 0.5302228331565857


Epoch 1/5:  96%|█████████▌| 7500/7794 [19:16<00:45,  6.46it/s]

Batch 7500/7794:
total_loss: 0.752748966217041, mse: 0.2932053208351135, ce: 0.4595436751842499


Epoch 1/5:  99%|█████████▉| 7750/7794 [19:55<00:06,  6.44it/s]

Batch 7750/7794:
total_loss: 0.881679356098175, mse: 0.31184613704681396, ce: 0.5698332190513611


Epoch 1/5: 100%|██████████| 7794/7794 [20:02<00:00,  6.48it/s]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.78it/s]


Epoch Summary:
Train Total Loss: 1.3768 (MSE: 0.6623, CE: 0.7145)
Val Total Loss: 0.7366 (MSE: 0.2480, CE: 0.4885)
Learning Rate: 0.000100
New best model saved with val loss: 0.7366
--------------------------------------------------



Epoch 2/5:   3%|▎         | 250/7794 [00:38<19:30,  6.44it/s]

Batch 250/7794:
total_loss: 0.7915871143341064, mse: 0.24429020285606384, ce: 0.547296941280365


Epoch 2/5:   6%|▋         | 500/7794 [01:17<18:49,  6.46it/s]

Batch 500/7794:
total_loss: 0.8267897367477417, mse: 0.25402727723121643, ce: 0.5727624297142029


Epoch 2/5:  10%|▉         | 750/7794 [01:56<18:11,  6.45it/s]

Batch 750/7794:
total_loss: 0.7057859897613525, mse: 0.21454575657844543, ce: 0.4912402331829071


Epoch 2/5:  13%|█▎        | 1000/7794 [02:35<17:43,  6.39it/s]

Batch 1000/7794:
total_loss: 0.6985313296318054, mse: 0.2410198599100113, ce: 0.4575114846229553


Epoch 2/5:  16%|█▌        | 1250/7794 [03:14<16:54,  6.45it/s]

Batch 1250/7794:
total_loss: 0.6319164037704468, mse: 0.2041534185409546, ce: 0.4277629554271698


Epoch 2/5:  19%|█▉        | 1500/7794 [03:53<16:24,  6.39it/s]

Batch 1500/7794:
total_loss: 0.6805254220962524, mse: 0.21353968977928162, ce: 0.46698570251464844


Epoch 2/5:  22%|██▏       | 1750/7794 [04:32<15:42,  6.41it/s]

Batch 1750/7794:
total_loss: 0.7015860080718994, mse: 0.2219967395067215, ce: 0.4795892536640167


Epoch 2/5:  26%|██▌       | 2000/7794 [05:11<15:02,  6.42it/s]

Batch 2000/7794:
total_loss: 0.7235637903213501, mse: 0.20749598741531372, ce: 0.5160678029060364


Epoch 2/5:  29%|██▉       | 2250/7794 [05:49<14:31,  6.36it/s]

Batch 2250/7794:
total_loss: 0.6745002865791321, mse: 0.19819115102291107, ce: 0.4763091504573822


Epoch 2/5:  32%|███▏      | 2500/7794 [06:28<13:48,  6.39it/s]

Batch 2500/7794:
total_loss: 0.7362369298934937, mse: 0.2343657910823822, ce: 0.5018711090087891


Epoch 2/5:  35%|███▌      | 2750/7794 [07:07<13:02,  6.45it/s]

Batch 2750/7794:
total_loss: 0.6201194524765015, mse: 0.1941819190979004, ce: 0.4259375333786011


Epoch 2/5:  38%|███▊      | 3000/7794 [07:46<12:30,  6.39it/s]

Batch 3000/7794:
total_loss: 0.9477267861366272, mse: 0.3179131746292114, ce: 0.6298136115074158


Epoch 2/5:  42%|████▏     | 3250/7794 [08:25<11:49,  6.40it/s]

Batch 3250/7794:
total_loss: 0.670863926410675, mse: 0.20058073103427887, ce: 0.470283180475235


Epoch 2/5:  45%|████▍     | 3500/7794 [09:04<11:14,  6.37it/s]

Batch 3500/7794:
total_loss: 0.5779291391372681, mse: 0.20096677541732788, ce: 0.3769623935222626


Epoch 2/5:  48%|████▊     | 3750/7794 [09:43<10:28,  6.44it/s]

Batch 3750/7794:
total_loss: 0.8059256672859192, mse: 0.251481831073761, ce: 0.5544438362121582


Epoch 2/5:  51%|█████▏    | 4000/7794 [10:22<09:51,  6.42it/s]

Batch 4000/7794:
total_loss: 0.7530562877655029, mse: 0.21827971935272217, ce: 0.5347765684127808


Epoch 2/5:  55%|█████▍    | 4250/7794 [11:01<09:13,  6.40it/s]

Batch 4250/7794:
total_loss: 0.4806978106498718, mse: 0.14509722590446472, ce: 0.3356005847454071


Epoch 2/5:  58%|█████▊    | 4500/7794 [11:40<08:33,  6.41it/s]

Batch 4500/7794:
total_loss: 0.76448655128479, mse: 0.2078508734703064, ce: 0.5566356778144836


Epoch 2/5:  61%|██████    | 4750/7794 [12:19<07:56,  6.39it/s]

Batch 4750/7794:
total_loss: 0.8720264434814453, mse: 0.22907927632331848, ce: 0.6429471373558044


Epoch 2/5:  64%|██████▍   | 5000/7794 [12:58<07:13,  6.44it/s]

Batch 5000/7794:
total_loss: 0.7333993911743164, mse: 0.2308233678340912, ce: 0.5025759935379028


Epoch 2/5:  67%|██████▋   | 5250/7794 [13:37<06:36,  6.41it/s]

Batch 5250/7794:
total_loss: 0.7391877770423889, mse: 0.21513676643371582, ce: 0.5240510106086731


Epoch 2/5:  71%|███████   | 5500/7794 [14:16<05:59,  6.37it/s]

Batch 5500/7794:
total_loss: 0.5564961433410645, mse: 0.20220834016799927, ce: 0.3542878031730652


Epoch 2/5:  74%|███████▍  | 5750/7794 [14:55<05:17,  6.44it/s]

Batch 5750/7794:
total_loss: 0.682887077331543, mse: 0.22393199801445007, ce: 0.4589550495147705


Epoch 2/5:  77%|███████▋  | 6000/7794 [15:34<04:39,  6.41it/s]

Batch 6000/7794:
total_loss: 0.6766022443771362, mse: 0.22873437404632568, ce: 0.44786790013313293


Epoch 2/5:  80%|████████  | 6250/7794 [16:13<04:01,  6.38it/s]

Batch 6250/7794:
total_loss: 0.638864278793335, mse: 0.1991690844297409, ce: 0.43969517946243286


Epoch 2/5:  83%|████████▎ | 6500/7794 [16:52<03:20,  6.44it/s]

Batch 6500/7794:
total_loss: 0.8758271932601929, mse: 0.2177312672138214, ce: 0.6580958962440491


Epoch 2/5:  87%|████████▋ | 6750/7794 [17:31<02:43,  6.40it/s]

Batch 6750/7794:
total_loss: 0.6600271463394165, mse: 0.2205074429512024, ce: 0.4395196735858917


Epoch 2/5:  90%|████████▉ | 7000/7794 [18:11<02:04,  6.37it/s]

Batch 7000/7794:
total_loss: 0.5175879001617432, mse: 0.18286627531051636, ce: 0.3347215950489044


Epoch 2/5:  93%|█████████▎| 7250/7794 [18:50<01:25,  6.39it/s]

Batch 7250/7794:
total_loss: 0.7068728804588318, mse: 0.20909705758094788, ce: 0.4977758228778839


Epoch 2/5:  96%|█████████▌| 7500/7794 [19:29<00:45,  6.40it/s]

Batch 7500/7794:
total_loss: 0.520897626876831, mse: 0.17970585823059082, ce: 0.34119173884391785


Epoch 2/5:  99%|█████████▉| 7750/7794 [20:08<00:06,  6.38it/s]

Batch 7750/7794:
total_loss: 0.5957502722740173, mse: 0.1883668154478073, ce: 0.40738344192504883


Epoch 2/5: 100%|██████████| 7794/7794 [20:15<00:00,  6.41it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 16.10it/s]



Epoch Summary:
Train Total Loss: 0.7019 (MSE: 0.2198, CE: 0.4820)
Val Total Loss: 0.6334 (MSE: 0.1842, CE: 0.4492)
Learning Rate: 0.000100
New best model saved with val loss: 0.6334
--------------------------------------------------


Epoch 3/5:   3%|▎         | 250/7794 [00:39<19:53,  6.32it/s]

Batch 250/7794:
total_loss: 0.8637449741363525, mse: 0.2536913752555847, ce: 0.6100535988807678


Epoch 3/5:   6%|▋         | 500/7794 [01:18<19:04,  6.37it/s]

Batch 500/7794:
total_loss: 0.4435518682003021, mse: 0.17338255047798157, ce: 0.27016931772232056


Epoch 3/5:  10%|▉         | 750/7794 [01:57<18:22,  6.39it/s]

Batch 750/7794:
total_loss: 0.8594338893890381, mse: 0.22772040963172913, ce: 0.6317134499549866


Epoch 3/5:  13%|█▎        | 1000/7794 [02:37<17:47,  6.36it/s]

Batch 1000/7794:
total_loss: 0.8323633670806885, mse: 0.2255403995513916, ce: 0.6068229675292969


Epoch 3/5:  16%|█▌        | 1250/7794 [03:16<17:03,  6.39it/s]

Batch 1250/7794:
total_loss: 0.7218965291976929, mse: 0.21148058772087097, ce: 0.5104159712791443


Epoch 3/5:  19%|█▉        | 1500/7794 [03:55<16:33,  6.34it/s]

Batch 1500/7794:
total_loss: 0.52289879322052, mse: 0.1698913276195526, ce: 0.353007435798645


Epoch 3/5:  22%|██▏       | 1750/7794 [04:34<15:45,  6.39it/s]

Batch 1750/7794:
total_loss: 0.5713565349578857, mse: 0.19812965393066406, ce: 0.3732268512248993


Epoch 3/5:  26%|██▌       | 2000/7794 [05:13<14:58,  6.45it/s]

Batch 2000/7794:
total_loss: 0.42746174335479736, mse: 0.18111146986484528, ce: 0.2463502734899521


Epoch 3/5:  29%|██▉       | 2250/7794 [05:52<14:26,  6.40it/s]

Batch 2250/7794:
total_loss: 0.4286259412765503, mse: 0.16185304522514343, ce: 0.26677289605140686


Epoch 3/5:  32%|███▏      | 2500/7794 [06:31<13:47,  6.40it/s]

Batch 2500/7794:
total_loss: 0.8180109858512878, mse: 0.21771275997161865, ce: 0.6002982258796692


Epoch 3/5:  35%|███▌      | 2750/7794 [07:10<13:05,  6.42it/s]

Batch 2750/7794:
total_loss: 0.6408965587615967, mse: 0.22486740350723267, ce: 0.4160291850566864


Epoch 3/5:  38%|███▊      | 3000/7794 [07:49<12:29,  6.40it/s]

Batch 3000/7794:
total_loss: 0.5708451271057129, mse: 0.1749822050333023, ce: 0.3958629369735718


Epoch 3/5:  42%|████▏     | 3250/7794 [08:28<11:52,  6.38it/s]

Batch 3250/7794:
total_loss: 0.6646755933761597, mse: 0.1948835253715515, ce: 0.46979206800460815


Epoch 3/5:  45%|████▍     | 3500/7794 [09:08<11:10,  6.40it/s]

Batch 3500/7794:
total_loss: 0.7698661684989929, mse: 0.2098091095685959, ce: 0.5600570440292358


Epoch 3/5:  48%|████▊     | 3750/7794 [09:47<10:31,  6.40it/s]

Batch 3750/7794:
total_loss: 0.6435375213623047, mse: 0.16964885592460632, ce: 0.47388866543769836


Epoch 3/5:  51%|█████▏    | 4000/7794 [10:26<09:53,  6.39it/s]

Batch 4000/7794:
total_loss: 0.5603810548782349, mse: 0.17005428671836853, ce: 0.39032676815986633


Epoch 3/5:  55%|█████▍    | 4250/7794 [11:05<09:14,  6.39it/s]

Batch 4250/7794:
total_loss: 0.6809452176094055, mse: 0.19509561359882355, ce: 0.48584961891174316


Epoch 3/5:  58%|█████▊    | 4500/7794 [11:44<08:39,  6.33it/s]

Batch 4500/7794:
total_loss: 0.5694869756698608, mse: 0.1671200692653656, ce: 0.4023669362068176


Epoch 3/5:  61%|██████    | 4750/7794 [12:23<07:54,  6.42it/s]

Batch 4750/7794:
total_loss: 0.7044150233268738, mse: 0.22161081433296204, ce: 0.48280420899391174


Epoch 3/5:  64%|██████▍   | 5000/7794 [13:02<07:17,  6.38it/s]

Batch 5000/7794:
total_loss: 0.6527141332626343, mse: 0.19341617822647095, ce: 0.45929792523384094


Epoch 3/5:  67%|██████▋   | 5250/7794 [13:41<06:36,  6.41it/s]

Batch 5250/7794:
total_loss: 0.6529728174209595, mse: 0.20125240087509155, ce: 0.4517204165458679


Epoch 3/5:  71%|███████   | 5500/7794 [14:20<05:58,  6.39it/s]

Batch 5500/7794:
total_loss: 0.6606482267379761, mse: 0.1952683925628662, ce: 0.46537986397743225


Epoch 3/5:  74%|███████▍  | 5750/7794 [14:59<05:19,  6.41it/s]

Batch 5750/7794:
total_loss: 0.5955583453178406, mse: 0.17683716118335724, ce: 0.41872119903564453


Epoch 3/5:  77%|███████▋  | 6000/7794 [15:38<04:40,  6.39it/s]

Batch 6000/7794:
total_loss: 0.606549859046936, mse: 0.22242692112922668, ce: 0.38412296772003174


Epoch 3/5:  80%|████████  | 6250/7794 [16:17<04:01,  6.38it/s]

Batch 6250/7794:
total_loss: 0.7703087329864502, mse: 0.2018507421016693, ce: 0.5684579610824585


Epoch 3/5:  83%|████████▎ | 6500/7794 [16:57<03:21,  6.43it/s]

Batch 6500/7794:
total_loss: 0.6644207835197449, mse: 0.2406044900417328, ce: 0.4238162934780121


Epoch 3/5:  87%|████████▋ | 6750/7794 [17:36<02:43,  6.37it/s]

Batch 6750/7794:
total_loss: 0.6458583474159241, mse: 0.2074071317911148, ce: 0.43845123052597046


Epoch 3/5:  90%|████████▉ | 7000/7794 [18:15<02:03,  6.40it/s]

Batch 7000/7794:
total_loss: 0.5869241952896118, mse: 0.20041459798812866, ce: 0.38650956749916077


Epoch 3/5:  93%|█████████▎| 7250/7794 [18:54<01:24,  6.43it/s]

Batch 7250/7794:
total_loss: 0.6880890130996704, mse: 0.20221000909805298, ce: 0.48587897419929504


Epoch 3/5:  96%|█████████▌| 7500/7794 [19:33<00:45,  6.40it/s]

Batch 7500/7794:
total_loss: 0.580862283706665, mse: 0.1765686571598053, ce: 0.40429365634918213


Epoch 3/5:  99%|█████████▉| 7750/7794 [20:12<00:06,  6.41it/s]

Batch 7750/7794:
total_loss: 0.6021738648414612, mse: 0.19123771786689758, ce: 0.4109361469745636


Epoch 3/5: 100%|██████████| 7794/7794 [20:19<00:00,  6.39it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 16.06it/s]


Epoch Summary:
Train Total Loss: 0.6426 (MSE: 0.1952, CE: 0.4474)
Val Total Loss: 0.6382 (MSE: 0.1839, CE: 0.4543)
Learning Rate: 0.000100
--------------------------------------------------



Epoch 4/5:   3%|▎         | 250/7794 [00:39<19:52,  6.33it/s]

Batch 250/7794:
total_loss: 0.7491709589958191, mse: 0.23252417147159576, ce: 0.5166468024253845


Epoch 4/5:   6%|▋         | 500/7794 [01:18<19:16,  6.31it/s]

Batch 500/7794:
total_loss: 0.7193468809127808, mse: 0.21581682562828064, ce: 0.5035300850868225


Epoch 4/5:  10%|▉         | 750/7794 [01:57<18:29,  6.35it/s]

Batch 750/7794:
total_loss: 0.8206496834754944, mse: 0.2208017259836197, ce: 0.5998479723930359


Epoch 4/5:  13%|█▎        | 1000/7794 [02:37<17:51,  6.34it/s]

Batch 1000/7794:
total_loss: 0.607125997543335, mse: 0.197175532579422, ce: 0.4099504351615906


Epoch 4/5:  16%|█▌        | 1250/7794 [03:16<17:07,  6.37it/s]

Batch 1250/7794:
total_loss: 0.40149378776550293, mse: 0.15884269773960114, ce: 0.242651104927063


Epoch 4/5:  19%|█▉        | 1500/7794 [03:55<16:29,  6.36it/s]

Batch 1500/7794:
total_loss: 0.5946148037910461, mse: 0.17494237422943115, ce: 0.419672429561615


Epoch 4/5:  22%|██▏       | 1750/7794 [04:35<15:52,  6.34it/s]

Batch 1750/7794:
total_loss: 0.5892081260681152, mse: 0.18274009227752686, ce: 0.406468003988266


Epoch 4/5:  26%|██▌       | 2000/7794 [05:14<15:04,  6.40it/s]

Batch 2000/7794:
total_loss: 0.6190183162689209, mse: 0.21263182163238525, ce: 0.40638652443885803


Epoch 4/5:  29%|██▉       | 2250/7794 [05:53<14:27,  6.39it/s]

Batch 2250/7794:
total_loss: 0.5858324766159058, mse: 0.17000101506710052, ce: 0.41583144664764404


Epoch 4/5:  32%|███▏      | 2500/7794 [06:32<13:49,  6.38it/s]

Batch 2500/7794:
total_loss: 0.49271130561828613, mse: 0.15603937208652496, ce: 0.33667194843292236


Epoch 4/5:  35%|███▌      | 2750/7794 [07:11<13:06,  6.41it/s]

Batch 2750/7794:
total_loss: 0.4472140073776245, mse: 0.16588732600212097, ce: 0.28132668137550354


Epoch 4/5:  38%|███▊      | 3000/7794 [07:50<12:30,  6.38it/s]

Batch 3000/7794:
total_loss: 0.6237618327140808, mse: 0.18122366070747375, ce: 0.44253817200660706


Epoch 4/5:  42%|████▏     | 3250/7794 [08:29<11:52,  6.38it/s]

Batch 3250/7794:
total_loss: 0.5501350164413452, mse: 0.15647846460342407, ce: 0.39365655183792114


Epoch 4/5:  45%|████▍     | 3500/7794 [09:08<11:11,  6.40it/s]

Batch 3500/7794:
total_loss: 0.5244737863540649, mse: 0.13274234533309937, ce: 0.3917314410209656


Epoch 4/5:  48%|████▊     | 3750/7794 [09:47<10:35,  6.36it/s]

Batch 3750/7794:
total_loss: 0.5418643951416016, mse: 0.1630270630121231, ce: 0.37883734703063965


Epoch 4/5:  51%|█████▏    | 4000/7794 [10:26<09:56,  6.36it/s]

Batch 4000/7794:
total_loss: 0.6099255084991455, mse: 0.16975447535514832, ce: 0.4401710331439972


Epoch 4/5:  55%|█████▍    | 4250/7794 [11:05<09:18,  6.35it/s]

Batch 4250/7794:
total_loss: 0.4915138781070709, mse: 0.12525048851966858, ce: 0.36626338958740234


Epoch 4/5:  58%|█████▊    | 4500/7794 [11:45<08:37,  6.36it/s]

Batch 4500/7794:
total_loss: 0.5525071620941162, mse: 0.1596391499042511, ce: 0.3928679823875427


Epoch 4/5:  61%|██████    | 4750/7794 [12:24<07:59,  6.35it/s]

Batch 4750/7794:
total_loss: 0.42849376797676086, mse: 0.1532277762889862, ce: 0.27526599168777466


Epoch 4/5:  64%|██████▍   | 5000/7794 [13:03<07:15,  6.41it/s]

Batch 5000/7794:
total_loss: 0.6641747951507568, mse: 0.19052079319953918, ce: 0.47365403175354004


Epoch 4/5:  67%|██████▋   | 5250/7794 [13:42<06:35,  6.44it/s]

Batch 5250/7794:
total_loss: 0.5786541700363159, mse: 0.15516534447669983, ce: 0.4234888255596161


Epoch 4/5:  71%|███████   | 5500/7794 [14:21<05:58,  6.40it/s]

Batch 5500/7794:
total_loss: 0.7393583059310913, mse: 0.17175328731536865, ce: 0.5676050186157227


Epoch 4/5:  74%|███████▍  | 5750/7794 [15:00<05:20,  6.39it/s]

Batch 5750/7794:
total_loss: 0.644080400466919, mse: 0.1747177243232727, ce: 0.46936267614364624


Epoch 4/5:  77%|███████▋  | 6000/7794 [15:39<04:44,  6.31it/s]

Batch 6000/7794:
total_loss: 0.7704215049743652, mse: 0.19552025198936462, ce: 0.574901282787323


Epoch 4/5:  80%|████████  | 6250/7794 [16:19<04:03,  6.35it/s]

Batch 6250/7794:
total_loss: 0.5285043120384216, mse: 0.1468035727739334, ce: 0.3817007541656494


Epoch 4/5:  83%|████████▎ | 6500/7794 [16:58<03:23,  6.35it/s]

Batch 6500/7794:
total_loss: 0.6759126782417297, mse: 0.16318632662296295, ce: 0.512726366519928


Epoch 4/5:  87%|████████▋ | 6750/7794 [17:38<02:44,  6.34it/s]

Batch 6750/7794:
total_loss: 0.4960393011569977, mse: 0.1361275315284729, ce: 0.3599117696285248


Epoch 4/5:  90%|████████▉ | 7000/7794 [18:17<02:05,  6.33it/s]

Batch 7000/7794:
total_loss: 0.5229132175445557, mse: 0.15684863924980164, ce: 0.3660646080970764


Epoch 4/5:  93%|█████████▎| 7250/7794 [18:56<01:25,  6.37it/s]

Batch 7250/7794:
total_loss: 0.623106062412262, mse: 0.1626323163509369, ce: 0.4604737460613251


Epoch 4/5:  96%|█████████▌| 7500/7794 [19:36<00:46,  6.39it/s]

Batch 7500/7794:
total_loss: 0.5820409059524536, mse: 0.15329857170581818, ce: 0.42874231934547424


Epoch 4/5:  99%|█████████▉| 7750/7794 [20:15<00:06,  6.40it/s]

Batch 7750/7794:
total_loss: 0.39981788396835327, mse: 0.13094404339790344, ce: 0.26887384057044983


Epoch 4/5: 100%|██████████| 7794/7794 [20:22<00:00,  6.38it/s]
Validating: 100%|██████████| 16/16 [00:01<00:00, 15.83it/s]



Epoch Summary:
Train Total Loss: 0.5945 (MSE: 0.1698, CE: 0.4247)
Val Total Loss: 0.4794 (MSE: 0.1299, CE: 0.3496)
Learning Rate: 0.000100
New best model saved with val loss: 0.4794
--------------------------------------------------


Epoch 5/5:   3%|▎         | 250/7794 [00:39<19:43,  6.37it/s]

Batch 250/7794:
total_loss: 0.7291012406349182, mse: 0.17252248525619507, ce: 0.5565787553787231


Epoch 5/5:   6%|▋         | 500/7794 [01:18<19:03,  6.38it/s]

Batch 500/7794:
total_loss: 0.6222289204597473, mse: 0.17542624473571777, ce: 0.44680267572402954


Epoch 5/5:  10%|▉         | 750/7794 [01:57<18:38,  6.30it/s]

Batch 750/7794:
total_loss: 0.4747333824634552, mse: 0.1394093632698059, ce: 0.3353240191936493


Epoch 5/5:  13%|█▎        | 1000/7794 [02:37<17:49,  6.35it/s]

Batch 1000/7794:
total_loss: 0.5932381749153137, mse: 0.17300330102443695, ce: 0.4202348589897156


Epoch 5/5:  16%|█▌        | 1250/7794 [03:16<17:21,  6.28it/s]

Batch 1250/7794:
total_loss: 0.4903102219104767, mse: 0.13319575786590576, ce: 0.3571144640445709


Epoch 5/5:  19%|█▉        | 1500/7794 [03:56<16:31,  6.35it/s]

Batch 1500/7794:
total_loss: 0.6078946590423584, mse: 0.1678474247455597, ce: 0.4400472044944763


Epoch 5/5:  22%|██▏       | 1750/7794 [04:35<16:00,  6.29it/s]

Batch 1750/7794:
total_loss: 0.49276527762413025, mse: 0.1398104727268219, ce: 0.35295480489730835


Epoch 5/5:  26%|██▌       | 2000/7794 [05:15<15:16,  6.32it/s]

Batch 2000/7794:
total_loss: 0.4763675928115845, mse: 0.1537758708000183, ce: 0.32259172201156616


Epoch 5/5:  29%|██▉       | 2250/7794 [05:54<14:37,  6.32it/s]

Batch 2250/7794:
total_loss: 0.4306081533432007, mse: 0.14019152522087097, ce: 0.2904166281223297


Epoch 5/5:  32%|███▏      | 2500/7794 [06:34<13:54,  6.35it/s]

Batch 2500/7794:
total_loss: 0.5101659297943115, mse: 0.13015982508659363, ce: 0.3800061047077179


Epoch 5/5:  35%|███▌      | 2750/7794 [07:13<13:11,  6.37it/s]

Batch 2750/7794:
total_loss: 0.5315396785736084, mse: 0.13994938135147095, ce: 0.39159026741981506


Epoch 5/5:  38%|███▊      | 3000/7794 [07:53<12:32,  6.37it/s]

Batch 3000/7794:
total_loss: 0.6057891845703125, mse: 0.15494495630264282, ce: 0.4508441984653473


Epoch 5/5:  42%|████▏     | 3250/7794 [08:32<11:49,  6.41it/s]

Batch 3250/7794:
total_loss: 0.7285109758377075, mse: 0.1814500391483307, ce: 0.5470609068870544


Epoch 5/5:  45%|████▍     | 3500/7794 [09:11<11:11,  6.39it/s]

Batch 3500/7794:
total_loss: 0.5719096660614014, mse: 0.1414783000946045, ce: 0.4304313361644745


Epoch 5/5:  48%|████▊     | 3750/7794 [09:50<10:33,  6.38it/s]

Batch 3750/7794:
total_loss: 0.4109147787094116, mse: 0.11870883405208588, ce: 0.29220595955848694


Epoch 5/5:  51%|█████▏    | 4000/7794 [10:29<09:54,  6.38it/s]

Batch 4000/7794:
total_loss: 0.45694321393966675, mse: 0.12049715965986252, ce: 0.33644604682922363


Epoch 5/5:  55%|█████▍    | 4250/7794 [11:08<09:13,  6.41it/s]

Batch 4250/7794:
total_loss: 0.5886490941047668, mse: 0.15834176540374756, ce: 0.4303073287010193


Epoch 5/5:  58%|█████▊    | 4500/7794 [11:48<08:38,  6.36it/s]

Batch 4500/7794:
total_loss: 0.7140775918960571, mse: 0.17279404401779175, ce: 0.5412835478782654


Epoch 5/5:  61%|██████    | 4750/7794 [12:27<07:55,  6.40it/s]

Batch 4750/7794:
total_loss: 0.45568132400512695, mse: 0.11866095662117004, ce: 0.3370203673839569


Epoch 5/5:  64%|██████▍   | 5000/7794 [13:06<07:18,  6.37it/s]

Batch 5000/7794:
total_loss: 0.7432637214660645, mse: 0.17397823929786682, ce: 0.56928551197052


Epoch 5/5:  67%|██████▋   | 5250/7794 [13:45<06:38,  6.39it/s]

Batch 5250/7794:
total_loss: 0.5356984734535217, mse: 0.12627416849136353, ce: 0.4094243049621582


Epoch 5/5:  71%|███████   | 5500/7794 [14:24<05:59,  6.39it/s]

Batch 5500/7794:
total_loss: 0.5065790414810181, mse: 0.14296436309814453, ce: 0.3636147081851959


Epoch 5/5:  74%|███████▍  | 5750/7794 [15:03<05:21,  6.36it/s]

Batch 5750/7794:
total_loss: 0.5288441181182861, mse: 0.1274733692407608, ce: 0.4013707637786865


Epoch 5/5:  77%|███████▋  | 6000/7794 [15:42<04:40,  6.40it/s]

Batch 6000/7794:
total_loss: 0.3331661820411682, mse: 0.10572810471057892, ce: 0.2274380773305893


Epoch 5/5:  80%|████████  | 6250/7794 [16:21<04:02,  6.37it/s]

Batch 6250/7794:
total_loss: 0.6216846704483032, mse: 0.1616356372833252, ce: 0.4600490629673004


Epoch 5/5:  83%|████████▎ | 6500/7794 [17:01<03:21,  6.41it/s]

Batch 6500/7794:
total_loss: 0.5728191137313843, mse: 0.12852302193641663, ce: 0.44429612159729004


Epoch 5/5:  87%|████████▋ | 6750/7794 [17:40<02:43,  6.38it/s]

Batch 6750/7794:
total_loss: 0.6799358129501343, mse: 0.14830438792705536, ce: 0.5316314101219177


Epoch 5/5:  90%|████████▉ | 7000/7794 [18:19<02:05,  6.33it/s]

Batch 7000/7794:
total_loss: 0.49639463424682617, mse: 0.1379700005054474, ce: 0.3584246337413788


Epoch 5/5:  93%|█████████▎| 7250/7794 [18:59<01:25,  6.35it/s]

Batch 7250/7794:
total_loss: 0.6831130385398865, mse: 0.16643311083316803, ce: 0.5166799426078796


Epoch 5/5:  96%|█████████▌| 7500/7794 [19:38<00:46,  6.32it/s]

Batch 7500/7794:
total_loss: 0.4372456669807434, mse: 0.12337194383144379, ce: 0.3138737082481384


Epoch 5/5:  99%|█████████▉| 7750/7794 [20:18<00:06,  6.32it/s]

Batch 7750/7794:
total_loss: 0.34312111139297485, mse: 0.11018188297748566, ce: 0.2329392433166504


Epoch 5/5: 100%|██████████| 7794/7794 [20:25<00:00,  6.36it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 16.07it/s]



Epoch Summary:
Train Total Loss: 0.5494 (MSE: 0.1427, CE: 0.4067)
Val Total Loss: 0.4720 (MSE: 0.1102, CE: 0.3618)
Learning Rate: 0.000100
New best model saved with val loss: 0.4720
--------------------------------------------------
