In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/config.txt
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Val/0_2_0_13062021_184140.json
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Test/0_2_0_13062021_184319.json
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Train/4_2_0_15062021_123153.json
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Train/2_2_0_15062021_123153.json
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Train/7_2_0_15062021_123153.json
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Train/5_2_0_15062021_123153.json
/kaggle/input/2-var-dataset/2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points/Train/0_2_0_15062021_123153.json
/kaggle/input/2-v

In [2]:
import torch
import json
from torch.utils.data import Dataset
import re
import numpy as np
import random
from sympy import sympify, Symbol, sin, cos, log, exp


def generateDataStrEq(
    eq, n_points=2, n_vars=3, decimals=4, supportPoints=None, min_x=0, max_x=3
):
    X = []
    Y = []
    # TODO: Need to make this faster
    for p in range(n_points):
        if supportPoints is None:
            if type(min_x) == list:
                x = []
                for _ in range(n_vars):
                    idx = np.random.randint(len(min_x))
                    x += list(
                        np.round(np.random.uniform(min_x[idx], max_x[idx], 1), decimals)
                    )
            else:
                x = list(np.round(np.random.uniform(min_x, max_x, n_vars), decimals))
            assert (
                len(x) != 0
            ), "For some reason, we didn't generate the points correctly!"
        else:
            x = supportPoints[p]

        tmpEq = eq + ""
        for nVID in range(n_vars):
            tmpEq = tmpEq.replace("x{}".format(nVID + 1), str(x[nVID]))
        y = float(np.round(eval(tmpEq), decimals))
        X.append(x)
        Y.append(y)
    return X, Y


# def processDataFiles(files):
#     text = ""
#     for f in tqdm(files):
#         with open(f, 'r') as h:
#             lines = h.read() # don't worry we won't run out of file handles
#             if lines[-1]==-1:
#                 lines = lines[:-1]
#             #text += lines #json.loads(line)
#             text = ''.join([lines,text])
#     return text


def processDataFiles(files):
    text = ""
    for f in files:
        with open(f, "r") as h:
            lines = h.read()  # don't worry we won't run out of file handles
            if lines[-1] == -1:
                lines = lines[:-1]
            # text += lines #json.loads(line)
            text = "".join([lines, text])
    return text


def tokenize_equation(eq):
    token_spec = [
        (r"\*\*"),  # exponentiation
        (r"exp"),  # exp function
        (r"[+\-*/=()]"),  # operators and parentheses
        (r"sin"),  # sin function
        (r"cos"),  # cos function
        (r"log"),  # log function
        (r"x\d+"),  # variables like x1, x23, etc.
        (r"C"),  # constants placeholder
        (r"-?\d+\.\d+"),  # decimal numbers
        (r"-?\d+"),  # integers
        (r"_"),  # padding token
    ]
    token_regex = "|".join(f"({pattern})" for pattern in token_spec)
    matches = re.finditer(token_regex, eq)
    return [match.group(0) for match in matches]


class CharDataset(Dataset):
    def __init__(
        self,
        data,
        block_size,
        tokens,
        numVars,
        numYs,
        numPoints,
        target="Skeleton",
        addVars=False,
        const_range=[-0.4, 0.4],
        xRange=[-3.0, 3.0],
        decimals=4,
        augment=False,
    ):

        data_size, vocab_size = len(data), len(tokens)
        print("data has %d examples, %d unique." % (data_size, vocab_size))

        self.stoi = {tok: i for i, tok in enumerate(tokens)}
        self.itos = {i: tok for i, tok in enumerate(tokens)}

        self.numVars = numVars
        self.numYs = numYs
        self.numPoints = numPoints

        # padding token
        self.paddingToken = "_"
        self.paddingID = self.stoi["_"]  # or another ID not already used
        self.stoi[self.paddingToken] = self.paddingID
        self.itos[self.paddingID] = self.paddingToken

        self.threshold = [-1000, 1000]

        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data  # it should be a list of examples
        self.target = target
        self.addVars = addVars

        self.const_range = const_range
        self.xRange = xRange
        self.decimals = decimals
        self.augment = augment

    def __len__(self):
        return len(self.data) - 1

    def __getitem__(self, idx):
        # grab an example from the data
        chunk = self.data[idx]  # sequence of tokens including x, y, eq, etc.

        try:
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary
        except Exception as e:
            print("Couldn't convert to json: {} \n error is: {}".format(chunk, e))
            # try the previous example
            idx = idx - 1
            idx = idx if idx >= 0 else 0
            chunk = self.data[idx]
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary

        # find the number of variables in the equation
        printInfoCondition = random.random() < 0.0000001
        eq = chunk[self.target]
        if printInfoCondition:
            print(f"\nEquation: {eq}")
        vars = re.finditer("x[\d]+", eq)
        numVars = 0
        for v in vars:
            v = v.group(0).strip("x")
            v = eval(v)
            v = int(v)
            if v > numVars:
                numVars = v

        if self.target == "Skeleton" and self.augment:
            threshold = 5000
            # randomly generate the constants
            cleanEqn = ""
            for chr in eq:
                if chr == "C":
                    # genereate a new random number
                    chr = "{}".format(
                        np.random.uniform(self.const_range[0], self.const_range[1])
                    )
                cleanEqn += chr

            # update the points
            nPoints = np.random.randint(
                *self.numPoints
            )  # if supportPoints is None else len(supportPoints)
            try:
                if printInfoCondition:
                    print("Org:", chunk["X"], chunk["Y"])

                X, y = generateDataStrEq(
                    cleanEqn,
                    n_points=nPoints,
                    n_vars=self.numVars,
                    decimals=self.decimals,
                    min_x=self.xRange[0],
                    max_x=self.xRange[1],
                )

                # replace out of threshold with maximum numbers
                y = [e if abs(e) < threshold else np.sign(e) * threshold for e in y]

                # check if there is nan/inf/very large numbers in the y
                conditions = (
                    (np.isnan(y).any() or np.isinf(y).any())
                    or len(y) == 0
                    or (abs(min(y)) > threshold or abs(max(y)) > threshold)
                )
                if not conditions:
                    chunk["X"], chunk["Y"] = X, y

                if printInfoCondition:
                    print("Evd:", chunk["X"], chunk["Y"])
            except Exception as e:
                # for different reason this might happend including but not limited to division by zero
                print(
                    "".join(
                        [
                            f"We just used the original equation and support points because of {e}. ",
                            f"The equation is {eq}, and we update the equation to {cleanEqn}",
                        ]
                    )
                )

        # encode every character in the equation to an integer
        # < is SOS, > is EOS
        if self.addVars:
            dix = [self.stoi[s] for s in "<" + str(numVars) + ":" + eq + ">"]
        else:
            eq_tokens = tokenize_equation(eq)
            if self.addVars:
                token_seq = ["<", str(numVars), ":", *eq_tokens, ">"]
            else:
                token_seq = ["<", *eq_tokens, ">"]
            dix = [self.stoi[tok] for tok in token_seq]

        inputs = dix[:-1]
        outputs = dix[1:]

        # add the padding to the equations
        paddingSize = max(self.block_size - len(inputs), 0)
        paddingList = [self.paddingID] * paddingSize
        inputs += paddingList
        outputs += paddingList

        # make sure it is not more than what should be
        inputs = inputs[: self.block_size]
        outputs = outputs[: self.block_size]

        points = torch.zeros(self.numVars + self.numYs, self.numPoints - 1)
        for idx, xy in enumerate(zip(chunk["X"], chunk["Y"])):

            if not isinstance(xy[0], list) or not isinstance(
                xy[1], (list, float, np.float64)
            ):
                print(f"Unexpected types: {type(xy[0])}, {type(xy[1])}")
                continue  # Skip if types are incorrect

            # don't let to exceed the maximum number of points
            if idx >= self.numPoints - 1:
                break

            x = xy[0]
            x = x + [0] * (max(self.numVars - len(x), 0))  # padding

            y = [xy[1]] if type(xy[1]) == float or type(xy[1]) == np.float64 else xy[1]

            y = y + [0] * (max(self.numYs - len(y), 0))  # padding
            p = x + y  # because it is only one point
            p = torch.tensor(p)
            # replace nan and inf
            p = torch.nan_to_num(
                p,
                nan=self.threshold[1],
                posinf=self.threshold[1],
                neginf=self.threshold[0],
            )

            points[:, idx] = p

        points = torch.nan_to_num(
            points,
            nan=self.threshold[1],
            posinf=self.threshold[1],
            neginf=self.threshold[0],
        )

        inputs = torch.tensor(inputs, dtype=torch.long)
        outputs = torch.tensor(outputs, dtype=torch.long)
        numVars = torch.tensor(numVars, dtype=torch.long)
        return inputs, outputs, points, numVars


# Relative Mean Square Error
def relativeErr(y, yHat, info=False, eps=1e-5):
    yHat = np.reshape(yHat, [1, -1])[0]
    y = np.reshape(y, [1, -1])[0]
    if len(y) > 0 and len(y) == len(yHat):
        err = ((yHat - y)) ** 2 / np.linalg.norm(y + eps)
        if info:
            for _ in range(5):
                i = np.random.randint(len(y))
                # print("yPR,yTrue:{},{}, Err:{}".format(yHat[i], y[i], err[i]))
    else:
        err = 100

    return np.mean(err)


def lossFunc(constants, eq, X, Y, eps=1e-5):
    err = 0
    eq = eq.replace("C", "{}").format(*constants)

    for x, y in zip(X, Y):
        eqTemp = eq + ""
        if type(x) == np.float32:
            x = [x]
        for i, e in enumerate(x):
            # make sure e is not a tensor
            if type(e) == torch.Tensor:
                e = e.item()
            eqTemp = eqTemp.replace("x{}".format(i + 1), str(e))
        try:
            yHat = eval(eqTemp)
        except:
            # print("Exception has been occured! EQ: {}, OR: {}".format(eqTemp, eq))
            yHat = 100
        try:
            # handle overflow
            err += relativeErr(y, yHat)  # (y-yHat)**2
        except:
            # print(
            #    "Exception has been occured! EQ: {}, OR: {}, y:{}-yHat:{}".format(
            #        eqTemp, eq, y, yHat
            #    )
            # )
            err += 10

    err /= len(Y)
    return err


def get_skeleton(tokens, train_dataset: CharDataset):
    skeleton = "".join([train_dataset.itos[int(idx)] for idx in tokens])
    skeleton = skeleton.strip(train_dataset.paddingToken).split(">")
    skeleton = skeleton[0] if len(skeleton[0]) >= 1 else skeleton[1]
    skeleton = skeleton.strip("<").strip(">")
    skeleton = skeleton.replace("Ce", "C*e")

    return skeleton


def validate_predictions(
    skeletons, variables=["x1", "x2", "x3", "x4", "x5"], constant_symbol="C"
):
    # Define allowed symbols
    local_dict = {var: Symbol(var) for var in variables}
    local_dict[constant_symbol] = Symbol(constant_symbol)

    # Add allowed functions from SymPy
    local_dict.update(
        {
            "sin": sin,
            "cos": cos,
            "log": log,
            "exp": exp,
        }
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    penalties = torch.full(
        (len(skeletons),), fill_value=0, dtype=torch.float32, device=device
    )
    for i, skeleton in enumerate(skeletons):

        if not skeleton or skeleton.isspace():
            penalties[i] = 1.0
            continue
        try:
            expr = sympify(
                skeleton, locals=local_dict, evaluate=False, convert_xor=True
            )
        except Exception as e:
            penalties[i] = 1.0

    return penalties


In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math


# from SymbolicGPT: https://github.com/mojivalipour/symbolicgpt/blob/master/models.py
class PointNetConfig:
    """base PointNet config"""

    def __init__(
        self,
        embeddingSize,
        numberofPoints,
        numberofVars,
        numberofYs,
        **kwargs,
    ):
        self.embeddingSize = embeddingSize
        self.numberofPoints = numberofPoints  # number of points
        self.numberofVars = numberofVars  # input dimension (Xs)
        self.numberofYs = numberofYs  # output dimension (Ys)

        for k, v in kwargs.items():
            setattr(self, k, v)


class tNet(nn.Module):
    """
    The PointNet structure in the orginal PointNet paper:
    PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation by Qi et. al. 2017
    """

    def __init__(self, config):
        super(tNet, self).__init__()

        self.activation_func = F.relu
        self.num_units = config.embeddingSize

        self.conv1 = nn.Conv1d(
            config.numberofVars + config.numberofYs, self.num_units, 1
        )
        self.conv2 = nn.Conv1d(self.num_units, 2 * self.num_units, 1)
        self.conv3 = nn.Conv1d(2 * self.num_units, 4 * self.num_units, 1)
        self.fc1 = nn.Linear(4 * self.num_units, 2 * self.num_units)
        self.fc2 = nn.Linear(2 * self.num_units, self.num_units)

        # self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(config.numberofVars + config.numberofYs)
        # self.input_layer_norm = nn.LayerNorm(config.numberofPoints)

        self.bn1 = nn.BatchNorm1d(self.num_units)
        self.bn2 = nn.BatchNorm1d(2 * self.num_units)
        self.bn3 = nn.BatchNorm1d(4 * self.num_units)
        self.bn4 = nn.BatchNorm1d(2 * self.num_units)
        self.bn5 = nn.BatchNorm1d(self.num_units)

    def forward(self, x):
        """
        :param x: [batch, #features, #points]
        :return:
            logit: [batch, embedding_size]
        """
        x = self.input_batch_norm(x)
        x = self.activation_func(self.bn1(self.conv1(x)))
        x = self.activation_func(self.bn2(self.conv2(x)))
        x = self.activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4 * self.num_units

        x = self.activation_func(self.bn4(self.fc1(x)))
        x = self.activation_func(self.bn5(self.fc2(x)))
        # x = self.fc2(x)

        return x


# from https://github.com/juho-lee/set_transformer/blob/master/modules.py
class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, "ln0", None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, "ln1", None) is None else self.ln1(O)
        return O


class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)


class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)


# from https://github.com/juho-lee/set_transformer/blob/master/models.py
class SetTransformer(nn.Module):
    def __init__(
        self,
        dim_input,
        num_outputs,
        dim_output,
        num_inds=32,
        dim_hidden=128,
        num_heads=4,
        ln=False,
    ):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
            ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
        )
        self.dec = nn.Sequential(
            PMA(dim_hidden, num_heads, num_outputs, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            nn.Linear(dim_hidden, dim_output),
        )

    def forward(self, X):
        return self.dec(self.enc(X)).squeeze(1)


class NoisePredictionTransformer(nn.Module):
    def __init__(self, n_embd, max_seq_len, n_layer=6, n_head=8, max_timesteps=1000):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, n_embd))
        self.time_emb = nn.Embedding(max_timesteps, n_embd)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)

    def forward(self, x_t, t, condition):
        _, L, _ = x_t.shape
        pos_emb = self.pos_emb[:, :L, :]  # [1, L, n_embd]
        time_emb = self.time_emb(t)
        if time_emb.dim() == 1:  # Scalar t case, [n_embd]
            time_emb = time_emb.unsqueeze(0)  # [1, n_embd]
        time_emb = time_emb.unsqueeze(1)  # [1, 1, n_embd]
        condition = condition.unsqueeze(1)  # [B, 1, n_embd]

        x = x_t + pos_emb + time_emb + condition
        return self.encoder(x)


# influenced by https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
class SymbolicGaussianDiffusion(nn.Module):
    def __init__(
        self,
        tnet_config: PointNetConfig,
        vocab_size,
        max_seq_len,
        padding_idx: int = 0,
        max_num_vars: int = 9,
        n_layer=6,
        n_head=8,
        n_embd=512,
        timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        set_transformer=True,
        ce_weight=1.0,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.padding_idx = padding_idx
        self.n_embd = n_embd
        self.timesteps = timesteps
        self.set_transformer = set_transformer
        self.ce_weight = ce_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.tok_emb = nn.Embedding(vocab_size, n_embd, padding_idx=self.padding_idx)
        self.vars_emb = nn.Embedding(max_num_vars, n_embd)

        self.decoder = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.weight = self.tok_emb.weight

        if set_transformer:
            dim_input = tnet_config.numberofVars + tnet_config.numberofYs
            self.tnet = SetTransformer(
                dim_input=dim_input,
                num_outputs=1,
                dim_output=tnet_config.embeddingSize,
                num_inds=tnet_config.numberofPoints,
                num_heads=4,
            )
        else:
            self.tnet = tNet(tnet_config)

        self.model = NoisePredictionTransformer(
            n_embd, max_seq_len, n_layer, n_head, timesteps
        )

        # Noise schedule
        self.register_buffer("beta", torch.linspace(beta_start, beta_end, timesteps))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

    def q_sample(self, x_start, t, noise=None):
        noise = torch.randn_like(x_start)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1)

        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
        return x_t

    def p_mean_variance(self, x, t, t_next, condition):
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        alpha_bar_t_next = self.alpha_bar[t_next]
        beta_t = self.beta[t]

        x_start_pred = self.model(x, t.long(), condition)

        coeff1 = torch.sqrt(alpha_bar_t_next) * beta_t / (1 - alpha_bar_t)
        coeff2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_next) / (1 - alpha_bar_t)
        mean = coeff1 * x_start_pred + coeff2 * x
        variance = (1 - alpha_bar_t_next) / (1 - alpha_bar_t) * beta_t
        return mean, variance

    @torch.no_grad()
    def p_sample(self, x, t, t_next, condition):
        mean, variance = self.p_mean_variance(x, t, t_next, condition)
        if torch.all(t_next == 0):
            return mean
        noise = torch.randn_like(x)
        return mean + torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, points, variables, train_dataset, batch_size=16, ddim_step=1):
        if self.set_transformer:
            points = points.transpose(1, 2)

        condition = self.tnet(points) + self.vars_emb(variables)
        shape = (batch_size, self.max_seq_len, self.n_embd)
        x = torch.randn(shape, device=self.device)
        steps = torch.arange(self.timesteps - 1, -1, -1, device=self.device)

        for i in range(0, self.timesteps, ddim_step):
            t = steps[i]
            t_next = (
                steps[i + ddim_step]
                if i + ddim_step < self.timesteps
                else torch.tensor(0, device=self.device)
            )
            x = self.p_sample(x, t, t_next, condition)

            # Print prediction every 250 steps
            # if (i + 1) % 250 == 0:
            #    logits = self.decoder(x)  # [B, L, vocab_size]
            #    token_indices = torch.argmax(logits, dim=-1)  # [B, L]
            #    for j in range(batch_size):
            #       token_indices_j = token_indices[j]  # [L]
            #        predicted_skeleton = get_predicted_skeleton(
            #            token_indices_j, train_dataset
            #        )
            #        tqdm.write(f" sample {j}: predicted_skeleton: {predicted_skeleton}")

        logits = self.decoder(x)  # [B, L, vocab_size]
        token_indices = torch.argmax(logits, dim=-1)  # [B, L]
        predicted_skeletons = []
        for j in range(batch_size):
            token_indices_j = token_indices[j]  # [L]
            predicted_skeleton = get_skeleton(token_indices_j, train_dataset)
            predicted_skeletons.append(predicted_skeleton)
        return predicted_skeletons

    def get_expr_validation_pen(self, logits, train_dataset: CharDataset):
        B = logits.shape[0]
        pred_tokens = torch.argmax(logits, dim=-1)  # [B, L]

        pred_skeletons = [get_skeleton(pred_tokens[i], train_dataset) for i in range(B)]

        penalties = validate_predictions(pred_skeletons)

        return penalties

    def p_losses(
        self,
        x_start,
        points,
        tokens,
        variables,
        t,
        train_dataset: CharDataset,
    ):
        noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise)

        if self.set_transformer:
            points = points.transpose(1, 2)

        condition = self.tnet(points) + self.vars_emb(variables)

        x_start_pred = self.model(x_t, t.long(), condition)

        logits = self.decoder(x_start_pred)  # [B, L, vocab_size]

        expr_validation_pen = self.get_expr_validation_pen(logits, train_dataset).mean()

        ce_loss = F.cross_entropy(
            logits.view(-1, self.vocab_size),  # [B*L, vocab_size]
            tokens.view(-1),  # [B*L]
            ignore_index=self.padding_idx,
            reduction="mean",
        )

        ce_loss = ce_loss * self.ce_weight

        total_loss = ce_loss + expr_validation_pen * ce_loss

        return total_loss, expr_validation_pen, ce_loss

    def forward(self, points, tokens, variables, t, train_dataset: CharDataset):
        token_emb = self.tok_emb(tokens)
        total_loss, expr_validation_pen, ce_loss = self.p_losses(
            token_emb, points, tokens, variables, t, train_dataset
        )
        return total_loss, expr_validation_pen, ce_loss


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple

def train_epoch(
    model: SymbolicGaussianDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss = 0
    total_ce_loss = 0
    total_expr_pen = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
        optimizer.zero_grad()

        total_loss, expr_validation_pen, ce_loss = model(
            points, tokens, variables, t, train_dataset
        )

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss.item():.4f}")
            print(f"expr_validation_pen: {expr_validation_pen.item():.4f}")
            print(f"ce_loss: {ce_loss.item():.4f}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()
        total_ce_loss += ce_loss.item()
        total_expr_pen += expr_validation_pen.item()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_ce_loss = total_ce_loss / len(train_loader)
    avg_expr_pen = total_expr_pen / len(train_loader)
    return avg_train_loss, avg_ce_loss, avg_expr_pen


def val_epoch(
    model: SymbolicGaussianDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss = 0
    total_ce_loss = 0
    total_expr_pen = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            
            total_loss, expr_validation_pen, ce_loss = model(
                points, tokens, variables, t, train_dataset
            )

            total_val_loss += total_loss.item()
            total_ce_loss += ce_loss.item()
            total_expr_pen += expr_validation_pen.item()

    avg_val_loss = total_val_loss / len(val_loader)
    avg_ce_loss = total_ce_loss / len(val_loader)
    avg_expr_pen = total_expr_pen / len(val_loader)
    return avg_val_loss, avg_ce_loss, avg_expr_pen


def train_single_gpu(
    model: SymbolicGaussianDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
    path=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_ce_loss = float("inf")
    path = path or "model_best_ce.pth"  

    for epoch in range(num_epochs):
        avg_train_loss, avg_train_ce_loss, avg_train_expr_pen = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss, avg_val_ce_loss, avg_val_expr_pen = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_ce_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f}, "
            f"Train CE Loss: {avg_train_ce_loss:.4f}, "
            f"Train Expr Validation Pen: {avg_train_expr_pen:.4f}"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f}, "
            f"Val CE Loss: {avg_val_ce_loss:.4f}, "
            f"Val Expr Validation Pen: {avg_val_expr_pen:.4f}"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        # Save model if validation CE loss improves
        if avg_val_ce_loss < best_val_ce_loss:
            best_val_ce_loss = avg_val_ce_loss
            state_dict = model.state_dict()
            torch.save(state_dict, path)
            print(f"New best model saved with val CE loss: {best_val_ce_loss:.4f}")

        print("-" * 50)

In [5]:
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 2
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
maxNumFiles = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
dataDir = "/kaggle/input/2-var-dataset"
dataFolder = "2Var_RandSupport_FixedLength_-3to3_-5.0to-3.0-3.0to5.0_200Points"

In [7]:
import numpy as np
import glob
import random

path = '{}/{}/Train/*.json'.format(dataDir, dataFolder)
files = glob.glob(path)[:maxNumFiles]
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 499035 examples, 30 unique.
id:345639
outputs:C*sin(C*x2)**3+C>___________________
variables:2


In [8]:
path = '{}/{}/Val/*.json'.format(dataDir,dataFolder)
files = glob.glob(path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 949 examples, 30 unique.
tensor(-2.9860) tensor(7.2492)
id:931
outputs:C*x1*x2+C*x1+C*x2+C>________________
variables:2


In [9]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,  
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate,
    path="2_var_set_transformer_expr_val"
)


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:   3%|▎         | 250/7798 [00:48<23:21,  5.39it/s]

Batch 250/7798:
total_loss: 3.1920
expr_validation_pen: 0.5469
ce_loss: 2.0635



Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or ad

Batch 500/7798:
total_loss: 1.5398
expr_validation_pen: 0.4844
ce_loss: 1.0374



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  10%|▉         | 750/7798 [02:25<22:14,  5.28it/s]

Batch 750/7798:
total_loss: 1.3775
expr_validation_pen: 0.4219
ce_loss: 0.9688



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  13%|█▎        | 1000/7798 [03:14<22:17,  5.08it/s]

Batch 1000/7798:
total_loss: 1.2789
expr_validation_pen: 0.4531
ce_loss: 0.8801



Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  16%|█▌        | 1250/7798 [04:03<20:30,  5.32it/s]

Batch 1250/7798:
total_loss: 1.2107
expr_validation_pen: 0.4375
ce_loss: 0.8422



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  19%|█▉        | 1500/7798 [04:52<20:50,  5.04it/s]

Batch 1500/7798:
total_loss: 1.0315
expr_validation_pen: 0.4375
ce_loss: 0.7175



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

  return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact)

Using non-Expr arguments in Pow is deprecated (in thi

Batch 1750/7798:
total_loss: 1.0067
expr_validation_pen: 0.4062
ce_loss: 0.7159


Epoch 1/5:  26%|██▌       | 2000/7798 [06:30<18:34,  5.20it/s]

Batch 2000/7798:
total_loss: 0.7255
expr_validation_pen: 0.4062
ce_loss: 0.5159



Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  29%|██▉       | 2250/7798 [07:19<17:40,  5.23it/s]

Batch 2250/7798:
total_loss: 1.4331
expr_validation_pen: 0.5000
ce_loss: 0.9554



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  32%|███▏      | 2500/7798 [08:08<17:39,  5.00it/s]

Batch 2500/7798:
total_loss: 0.8692
expr_validation_pen: 0.3750
ce_loss: 0.6321


Epoch 1/5:  35%|███▌      | 2750/7798 [08:56<16:12,  5.19it/s]

Batch 2750/7798:
total_loss: 1.2032
expr_validation_pen: 0.4688
ce_loss: 0.8192


Epoch 1/5:  38%|███▊      | 3000/7798 [09:45<15:28,  5.16it/s]

Batch 3000/7798:
total_loss: 0.5962
expr_validation_pen: 0.3125
ce_loss: 0.4543


Epoch 1/5:  42%|████▏     | 3250/7798 [10:34<14:48,  5.12it/s]

Batch 3250/7798:
total_loss: 0.8352
expr_validation_pen: 0.2500
ce_loss: 0.6681



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  45%|████▍     | 3500/7798 [11:22<13:31,  5.30it/s]

Batch 3500/7798:
total_loss: 0.6277
expr_validation_pen: 0.2656
ce_loss: 0.4959


Epoch 1/5:  48%|████▊     | 3750/7798 [12:12<12:47,  5.27it/s]

Batch 3750/7798:
total_loss: 1.3112
expr_validation_pen: 0.5312
ce_loss: 0.8563



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend 

Batch 4000/7798:
total_loss: 1.1108
expr_validation_pen: 0.4688
ce_loss: 0.7563


Epoch 1/5:  55%|█████▍    | 4250/7798 [13:49<11:45,  5.03it/s]

Batch 4250/7798:
total_loss: 0.5917
expr_validation_pen: 0.3750
ce_loss: 0.4303


Epoch 1/5:  58%|█████▊    | 4500/7798 [14:38<10:47,  5.09it/s]

Batch 4500/7798:
total_loss: 1.0847
expr_validation_pen: 0.4219
ce_loss: 0.7629


Epoch 1/5:  61%|██████    | 4750/7798 [15:27<09:35,  5.30it/s]

Batch 4750/7798:
total_loss: 1.2480
expr_validation_pen: 0.4219
ce_loss: 0.8777


Epoch 1/5:  64%|██████▍   | 5000/7798 [16:15<09:07,  5.11it/s]

Batch 5000/7798:
total_loss: 0.6072
expr_validation_pen: 0.3125
ce_loss: 0.4626



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  67%|██████▋   | 5250/7798 [17:04<08:09,  5.21it/s]

Batch 5250/7798:
total_loss: 0.6253
expr_validation_pen: 0.3281
ce_loss: 0.4708


Epoch 1/5:  71%|███████   | 5500/7798 [17:53<07:31,  5.09it/s]

Batch 5500/7798:
total_loss: 0.6151
expr_validation_pen: 0.3750
ce_loss: 0.4473


Epoch 1/5:  74%|███████▎  | 5750/7798 [18:43<06:45,  5.05it/s]

Batch 5750/7798:
total_loss: 0.7092
expr_validation_pen: 0.4062
ce_loss: 0.5043


Epoch 1/5:  77%|███████▋  | 6000/7798 [19:32<06:01,  4.97it/s]

Batch 6000/7798:
total_loss: 0.7251
expr_validation_pen: 0.3438
ce_loss: 0.5396


Epoch 1/5:  80%|████████  | 6250/7798 [20:21<04:56,  5.23it/s]

Batch 6250/7798:
total_loss: 0.9429
expr_validation_pen: 0.4219
ce_loss: 0.6631



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  83%|████████▎ | 6500/7798 [21:09<04:10,  5.17it/s]

Batch 6500/7798:
total_loss: 0.7568
expr_validation_pen: 0.4219
ce_loss: 0.5322



Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  87%|████████▋ | 6750/7798 [21:58<03:30,  4.97it/s]

Batch 6750/7798:
total_loss: 0.8578
expr_validation_pen: 0.3438
ce_loss: 0.6384


Epoch 1/5:  90%|████████▉ | 7000/7798 [22:47<02:29,  5.33it/s]

Batch 7000/7798:
total_loss: 1.2553
expr_validation_pen: 0.5000
ce_loss: 0.8369


Epoch 1/5:  93%|█████████▎| 7250/7798 [23:37<01:49,  5.01it/s]

Batch 7250/7798:
total_loss: 0.6969
expr_validation_pen: 0.3594
ce_loss: 0.5127



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 1/5:  96%|█████████▌| 7500/7798 [24:26<00:56,  5.28it/s]

Batch 7500/7798:
total_loss: 0.5650
expr_validation_pen: 0.2812
ce_loss: 0.4410


Epoch 1/5:  99%|█████████▉| 7750/7798 [25:15<00:09,  5.01it/s]

Batch 7750/7798:
total_loss: 0.8189
expr_validation_pen: 0.4219
ce_loss: 0.5760


Epoch 1/5: 100%|██████████| 7798/7798 [25:24<00:00,  5.12it/s]
Validating: 100%|██████████| 15/15 [00:03<00:00,  4.69it/s]


Epoch Summary:
Train Total Loss: 1.0722, Train CE Loss: 0.7440, Train Expr Validation Pen: 0.4086
Val Total Loss: 0.6154, Val CE Loss: 0.4523, Val Expr Validation Pen: 0.3481
Learning Rate: 0.000100
New best model saved with val CE loss: 0.4523
--------------------------------------------------




Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:   3%|▎         | 250/7798 [00:50<24:02,  5.23it/s]

Batch 250/7798:
total_loss: 0.6884
expr_validation_pen: 0.3750
ce_loss: 0.5007


Epoch 2/5:   6%|▋         | 500/7798 [01:39<22:56,  5.30it/s]

Batch 500/7798:
total_loss: 0.8773
expr_validation_pen: 0.4062
ce_loss: 0.6238



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  10%|▉         | 750/7798 [02:28<24:26,  4.81it/s]

Batch 750/7798:
total_loss: 0.3840
expr_validation_pen: 0.3125
ce_loss: 0.2926


Epoch 2/5:  13%|█▎        | 1000/7798 [03:17<21:59,  5.15it/s]

Batch 1000/7798:
total_loss: 0.5472
expr_validation_pen: 0.3594
ce_loss: 0.4025



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  16%|█▌        | 1250/7798 [04:06<20:44,  5.26it/s]

Batch 1250/7798:
total_loss: 0.3898
expr_validation_pen: 0.2656
ce_loss: 0.3080


Epoch 2/5:  19%|█▉        | 1500/7798 [04:55<19:40,  5.34it/s]

Batch 1500/7798:
total_loss: 0.4109
expr_validation_pen: 0.2969
ce_loss: 0.3168


Epoch 2/5:  22%|██▏       | 1750/7798 [05:45<19:01,  5.30it/s]

Batch 1750/7798:
total_loss: 0.5917
expr_validation_pen: 0.2500
ce_loss: 0.4733


Epoch 2/5:  26%|██▌       | 2000/7798 [06:34<19:18,  5.01it/s]

Batch 2000/7798:
total_loss: 0.4724
expr_validation_pen: 0.3281
ce_loss: 0.3557


Epoch 2/5:  29%|██▉       | 2250/7798 [07:24<17:36,  5.25it/s]

Batch 2250/7798:
total_loss: 0.7499
expr_validation_pen: 0.3906
ce_loss: 0.5392



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  32%|███▏      | 2500/7798 [08:13<17:16,  5.11it/s]

Batch 2500/7798:
total_loss: 0.5038
expr_validation_pen: 0.3281
ce_loss: 0.3793


Epoch 2/5:  35%|███▌      | 2750/7798 [09:02<15:48,  5.32it/s]

Batch 2750/7798:
total_loss: 0.6242
expr_validation_pen: 0.3594
ce_loss: 0.4592


Epoch 2/5:  38%|███▊      | 3000/7798 [09:51<15:45,  5.08it/s]

Batch 3000/7798:
total_loss: 0.7522
expr_validation_pen: 0.4062
ce_loss: 0.5349


Epoch 2/5:  42%|████▏     | 3250/7798 [10:40<14:57,  5.06it/s]

Batch 3250/7798:
total_loss: 0.6045
expr_validation_pen: 0.3438
ce_loss: 0.4499


Epoch 2/5:  45%|████▍     | 3500/7798 [11:29<13:42,  5.22it/s]

Batch 3500/7798:
total_loss: 0.4327
expr_validation_pen: 0.3594
ce_loss: 0.3183


Epoch 2/5:  48%|████▊     | 3750/7798 [12:19<13:23,  5.04it/s]

Batch 3750/7798:
total_loss: 0.5786
expr_validation_pen: 0.2812
ce_loss: 0.4516


Epoch 2/5:  51%|█████▏    | 4000/7798 [13:08<13:41,  4.63it/s]

Batch 4000/7798:
total_loss: 0.7573
expr_validation_pen: 0.3906
ce_loss: 0.5446


Epoch 2/5:  55%|█████▍    | 4250/7798 [13:57<11:20,  5.22it/s]

Batch 4250/7798:
total_loss: 0.7146
expr_validation_pen: 0.3125
ce_loss: 0.5445



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  58%|█████▊    | 4500/7798 [14:46<12:26,  4.42it/s]

Batch 4500/7798:
total_loss: 1.2559
expr_validation_pen: 0.3906
ce_loss: 0.9031



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend 

Batch 4750/7798:
total_loss: 0.7496
expr_validation_pen: 0.3438
ce_loss: 0.5578


Epoch 2/5:  64%|██████▍   | 5000/7798 [16:24<10:05,  4.62it/s]

Batch 5000/7798:
total_loss: 0.6637
expr_validation_pen: 0.3906
ce_loss: 0.4773



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  67%|██████▋   | 5250/7798 [17:13<07:49,  5.42it/s]

Batch 5250/7798:
total_loss: 0.9029
expr_validation_pen: 0.3438
ce_loss: 0.6720


Epoch 2/5:  71%|███████   | 5500/7798 [18:03<07:41,  4.98it/s]

Batch 5500/7798:
total_loss: 0.7390
expr_validation_pen: 0.2969
ce_loss: 0.5698


Epoch 2/5:  74%|███████▎  | 5750/7798 [18:52<06:39,  5.12it/s]

Batch 5750/7798:
total_loss: 0.3360
expr_validation_pen: 0.2812
ce_loss: 0.2622



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  77%|███████▋  | 6000/7798 [19:41<05:40,  5.28it/s]

Batch 6000/7798:
total_loss: 0.4774
expr_validation_pen: 0.3281
ce_loss: 0.3595


Epoch 2/5:  80%|████████  | 6250/7798 [20:30<05:01,  5.13it/s]

Batch 6250/7798:
total_loss: 0.5814
expr_validation_pen: 0.3594
ce_loss: 0.4277



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  83%|████████▎ | 6500/7798 [21:20<04:18,  5.02it/s]

Batch 6500/7798:
total_loss: 0.7339
expr_validation_pen: 0.3750
ce_loss: 0.5337



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  87%|████████▋ | 6750/7798 [22:09<03:28,  5.02it/s]

Batch 6750/7798:
total_loss: 0.3213
expr_validation_pen: 0.3906
ce_loss: 0.2311



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 2/5:  90%|████████▉ | 7000/7798 [22:58<02:39,  4.99it/s]

Batch 7000/7798:
total_loss: 0.5386
expr_validation_pen: 0.3594
ce_loss: 0.3962


Epoch 2/5:  93%|█████████▎| 7250/7798 [23:47<01:50,  4.98it/s]

Batch 7250/7798:
total_loss: 0.8809
expr_validation_pen: 0.3906
ce_loss: 0.6335


Epoch 2/5:  96%|█████████▌| 7500/7798 [24:37<00:58,  5.05it/s]

Batch 7500/7798:
total_loss: 0.7246
expr_validation_pen: 0.3906
ce_loss: 0.5211


Epoch 2/5:  99%|█████████▉| 7750/7798 [25:26<00:09,  5.04it/s]

Batch 7750/7798:
total_loss: 0.5667
expr_validation_pen: 0.3125
ce_loss: 0.4318


Epoch 2/5: 100%|██████████| 7798/7798 [25:36<00:00,  5.08it/s]
Validating: 100%|██████████| 15/15 [00:03<00:00,  4.62it/s]


Epoch Summary:
Train Total Loss: 0.6102, Train CE Loss: 0.4526, Train Expr Validation Pen: 0.3396
Val Total Loss: 0.4831, Val CE Loss: 0.3811, Val Expr Validation Pen: 0.2617
Learning Rate: 0.000100
New best model saved with val CE loss: 0.3811
--------------------------------------------------



Epoch 3/5:   3%|▎         | 250/7798 [00:50<23:57,  5.25it/s]

Batch 250/7798:
total_loss: 0.4398
expr_validation_pen: 0.2812
ce_loss: 0.3432


Epoch 3/5:   6%|▋         | 500/7798 [01:39<23:29,  5.18it/s]

Batch 500/7798:
total_loss: 0.8143
expr_validation_pen: 0.3906
ce_loss: 0.5856


Epoch 3/5:  10%|▉         | 750/7798 [02:29<22:27,  5.23it/s]

Batch 750/7798:
total_loss: 0.6028
expr_validation_pen: 0.3281
ce_loss: 0.4538



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  13%|█▎        | 1000/7798 [03:18<21:37,  5.24it/s]

Batch 1000/7798:
total_loss: 0.6908
expr_validation_pen: 0.3594
ce_loss: 0.5081


Epoch 3/5:  16%|█▌        | 1250/7798 [04:08<21:43,  5.02it/s]

Batch 1250/7798:
total_loss: 0.3564
expr_validation_pen: 0.1562
ce_loss: 0.3082


Epoch 3/5:  19%|█▉        | 1500/7798 [04:57<23:01,  4.56it/s]

Batch 1500/7798:
total_loss: 0.3213
expr_validation_pen: 0.1562
ce_loss: 0.2779


Epoch 3/5:  22%|██▏       | 1750/7798 [05:47<20:15,  4.97it/s]

Batch 1750/7798:
total_loss: 0.1076
expr_validation_pen: 0.1719
ce_loss: 0.0918



Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  26%|██▌       | 2000/7798 [06:36<19:30,  4.95it/s]

Batch 2000/7798:
total_loss: 0.5981
expr_validation_pen: 0.3125
ce_loss: 0.4557


Epoch 3/5:  29%|██▉       | 2250/7798 [07:25<17:30,  5.28it/s]

Batch 2250/7798:
total_loss: 0.5934
expr_validation_pen: 0.2969
ce_loss: 0.4575


Epoch 3/5:  32%|███▏      | 2500/7798 [08:15<17:18,  5.10it/s]

Batch 2500/7798:
total_loss: 0.7813
expr_validation_pen: 0.3594
ce_loss: 0.5748



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  35%|███▌      | 2750/7798 [09:04<16:35,  5.07it/s]

Batch 2750/7798:
total_loss: 0.4829
expr_validation_pen: 0.2500
ce_loss: 0.3864



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  38%|███▊      | 3000/7798 [09:54<16:07,  4.96it/s]

Batch 3000/7798:
total_loss: 0.3659
expr_validation_pen: 0.2188
ce_loss: 0.3002


Epoch 3/5:  42%|████▏     | 3250/7798 [10:43<15:44,  4.81it/s]

Batch 3250/7798:
total_loss: 0.6973
expr_validation_pen: 0.2812
ce_loss: 0.5442


Epoch 3/5:  45%|████▍     | 3500/7798 [11:33<14:09,  5.06it/s]

Batch 3500/7798:
total_loss: 0.5161
expr_validation_pen: 0.2656
ce_loss: 0.4078



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  48%|████▊     | 3750/7798 [12:23<17:01,  3.96it/s]

Batch 3750/7798:
total_loss: 0.3299
expr_validation_pen: 0.1406
ce_loss: 0.2892


Epoch 3/5:  51%|█████▏    | 4000/7798 [13:12<12:08,  5.21it/s]

Batch 4000/7798:
total_loss: 0.4132
expr_validation_pen: 0.3281
ce_loss: 0.3111



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  55%|█████▍    | 4250/7798 [14:02<11:40,  5.07it/s]

Batch 4250/7798:
total_loss: 0.6392
expr_validation_pen: 0.2969
ce_loss: 0.4929


Epoch 3/5:  58%|█████▊    | 4500/7798 [14:52<10:57,  5.02it/s]

Batch 4500/7798:
total_loss: 0.5814
expr_validation_pen: 0.3750
ce_loss: 0.4229


Epoch 3/5:  61%|██████    | 4750/7798 [15:42<09:39,  5.26it/s]

Batch 4750/7798:
total_loss: 0.5038
expr_validation_pen: 0.3281
ce_loss: 0.3793



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  64%|██████▍   | 5000/7798 [16:31<08:46,  5.31it/s]

Batch 5000/7798:
total_loss: 0.5983
expr_validation_pen: 0.2500
ce_loss: 0.4786


Epoch 3/5:  67%|██████▋   | 5250/7798 [17:21<07:56,  5.35it/s]

Batch 5250/7798:
total_loss: 0.6522
expr_validation_pen: 0.3438
ce_loss: 0.4854



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  71%|███████   | 5500/7798 [18:10<07:33,  5.07it/s]

Batch 5500/7798:
total_loss: 0.3582
expr_validation_pen: 0.2656
ce_loss: 0.2830


Epoch 3/5:  74%|███████▎  | 5750/7798 [19:00<06:35,  5.18it/s]

Batch 5750/7798:
total_loss: 0.2936
expr_validation_pen: 0.1719
ce_loss: 0.2505


Epoch 3/5:  77%|███████▋  | 6000/7798 [19:50<05:58,  5.02it/s]

Batch 6000/7798:
total_loss: 0.6464
expr_validation_pen: 0.2656
ce_loss: 0.5108


Epoch 3/5:  80%|████████  | 6250/7798 [20:40<05:08,  5.02it/s]

Batch 6250/7798:
total_loss: 0.3504
expr_validation_pen: 0.2500
ce_loss: 0.2803


Epoch 3/5:  83%|████████▎ | 6500/7798 [21:30<04:23,  4.92it/s]

Batch 6500/7798:
total_loss: 0.6918
expr_validation_pen: 0.4375
ce_loss: 0.4812



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  87%|████████▋ | 6750/7798 [22:20<03:24,  5.12it/s]

Batch 6750/7798:
total_loss: 0.4753
expr_validation_pen: 0.2969
ce_loss: 0.3665


Epoch 3/5:  90%|████████▉ | 7000/7798 [23:09<02:47,  4.76it/s]

Batch 7000/7798:
total_loss: 0.2626
expr_validation_pen: 0.2031
ce_loss: 0.2182


Epoch 3/5:  93%|█████████▎| 7250/7798 [23:59<01:48,  5.04it/s]

Batch 7250/7798:
total_loss: 0.3872
expr_validation_pen: 0.2656
ce_loss: 0.3060



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 3/5:  96%|█████████▌| 7500/7798 [24:49<01:09,  4.27it/s]

Batch 7500/7798:
total_loss: 0.5428
expr_validation_pen: 0.2031
ce_loss: 0.4512


Epoch 3/5:  99%|█████████▉| 7750/7798 [25:39<00:09,  4.93it/s]

Batch 7750/7798:
total_loss: 0.4618
expr_validation_pen: 0.2656
ce_loss: 0.3648


Epoch 3/5: 100%|██████████| 7798/7798 [25:48<00:00,  5.03it/s]
Validating: 100%|██████████| 15/15 [00:03<00:00,  4.63it/s]



Epoch Summary:
Train Total Loss: 0.4957, Train CE Loss: 0.3830, Train Expr Validation Pen: 0.2860
Val Total Loss: 0.4453, Val CE Loss: 0.3579, Val Expr Validation Pen: 0.2248
Learning Rate: 0.000100
New best model saved with val CE loss: 0.3579
--------------------------------------------------



Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

  p = self.func(*self.as_base_exp())  # in case it's unevaluated
Epoch 4/5:   3%|▎         | 250/7798 [00:51<25:14,  4.98it/s]

Batch 250/7798:
total_loss: 0.7168
expr_validation_pen: 0.2656
ce_loss: 0.5663


Epoch 4/5:   6%|▋         | 500/7798 [01:41<24:17,  5.01it/s]

Batch 500/7798:
total_loss: 0.4659
expr_validation_pen: 0.2812
ce_loss: 0.3637


Epoch 4/5:  10%|▉         | 750/7798 [02:30<23:14,  5.05it/s]

Batch 750/7798:
total_loss: 0.4176
expr_validation_pen: 0.2188
ce_loss: 0.3426



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  13%|█▎        | 1000/7798 [03:20<21:58,  5.15it/s]

Batch 1000/7798:
total_loss: 0.4787
expr_validation_pen: 0.2188
ce_loss: 0.3928



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  16%|█▌        | 1250/7798 [04:10<22:12,  4.91it/s]

Batch 1250/7798:
total_loss: 0.5793
expr_validation_pen: 0.2500
ce_loss: 0.4634


Epoch 4/5:  19%|█▉        | 1500/7798 [05:01<22:05,  4.75it/s]

Batch 1500/7798:
total_loss: 0.5535
expr_validation_pen: 0.2500
ce_loss: 0.4428


Epoch 4/5:  22%|██▏       | 1750/7798 [05:53<22:12,  4.54it/s]

Batch 1750/7798:
total_loss: 0.3700
expr_validation_pen: 0.2031
ce_loss: 0.3075



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  26%|██▌       | 2000/7798 [06:42<18:36,  5.19it/s]

Batch 2000/7798:
total_loss: 0.6055
expr_validation_pen: 0.2656
ce_loss: 0.4784



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  29%|██▉       | 2250/7798 [07:32<19:06,  4.84it/s]

Batch 2250/7798:
total_loss: 0.3005
expr_validation_pen: 0.2344
ce_loss: 0.2435


Epoch 4/5:  32%|███▏      | 2500/7798 [08:22<16:44,  5.27it/s]

Batch 2500/7798:
total_loss: 0.3494
expr_validation_pen: 0.2031
ce_loss: 0.2904


Epoch 4/5:  35%|███▌      | 2750/7798 [09:12<17:01,  4.94it/s]

Batch 2750/7798:
total_loss: 0.2415
expr_validation_pen: 0.1094
ce_loss: 0.2177


Epoch 4/5:  38%|███▊      | 3000/7798 [10:03<15:09,  5.27it/s]

Batch 3000/7798:
total_loss: 0.4310
expr_validation_pen: 0.2344
ce_loss: 0.3491


Epoch 4/5:  42%|████▏     | 3250/7798 [10:53<18:16,  4.15it/s]

Batch 3250/7798:
total_loss: 0.5196
expr_validation_pen: 0.2812
ce_loss: 0.4055


Epoch 4/5:  45%|████▍     | 3500/7798 [11:43<14:28,  4.95it/s]

Batch 3500/7798:
total_loss: 0.4414
expr_validation_pen: 0.2500
ce_loss: 0.3531



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  48%|████▊     | 3750/7798 [12:33<13:27,  5.02it/s]

Batch 3750/7798:
total_loss: 0.4209
expr_validation_pen: 0.2188
ce_loss: 0.3453


Epoch 4/5:  51%|█████▏    | 4000/7798 [13:23<11:35,  5.46it/s]

Batch 4000/7798:
total_loss: 0.3795
expr_validation_pen: 0.2656
ce_loss: 0.2999



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Pow is deprecated (in this case, one of the
arguments is of type 'Tuple').

If you really did intend to construct a power with this base, use the **
operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  55%|█████▍    | 4250/7798 [14:13<11:14,  5.26it/s]

Batch 4250/7798:
total_loss: 0.5130
expr_validation_pen: 0.2500
ce_loss: 0.4104


Epoch 4/5:  58%|█████▊    | 4500/7798 [15:03<10:49,  5.08it/s]

Batch 4500/7798:
total_loss: 0.1700
expr_validation_pen: 0.1719
ce_loss: 0.1451


Epoch 4/5:  61%|██████    | 4750/7798 [15:53<09:38,  5.27it/s]

Batch 4750/7798:
total_loss: 0.4635
expr_validation_pen: 0.2656
ce_loss: 0.3663


Epoch 4/5:  64%|██████▍   | 5000/7798 [16:43<09:52,  4.73it/s]

Batch 5000/7798:
total_loss: 0.3687
expr_validation_pen: 0.1875
ce_loss: 0.3105


Epoch 4/5:  67%|██████▋   | 5250/7798 [17:32<08:05,  5.24it/s]

Batch 5250/7798:
total_loss: 0.4373
expr_validation_pen: 0.2812
ce_loss: 0.3413


Epoch 4/5:  71%|███████   | 5500/7798 [18:23<08:34,  4.47it/s]

Batch 5500/7798:
total_loss: 0.3730
expr_validation_pen: 0.2031
ce_loss: 0.3100


Epoch 4/5:  74%|███████▎  | 5750/7798 [19:13<07:01,  4.86it/s]

Batch 5750/7798:
total_loss: 0.2899
expr_validation_pen: 0.1719
ce_loss: 0.2474


Epoch 4/5:  77%|███████▋  | 6000/7798 [20:03<05:54,  5.08it/s]

Batch 6000/7798:
total_loss: 0.5302
expr_validation_pen: 0.2812
ce_loss: 0.4138


Epoch 4/5:  80%|████████  | 6250/7798 [20:53<05:06,  5.05it/s]

Batch 6250/7798:
total_loss: 0.2587
expr_validation_pen: 0.2031
ce_loss: 0.2150


Epoch 4/5:  83%|████████▎ | 6500/7798 [21:43<04:15,  5.07it/s]

Batch 6500/7798:
total_loss: 0.4245
expr_validation_pen: 0.2188
ce_loss: 0.3483



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  87%|████████▋ | 6750/7798 [22:32<03:24,  5.13it/s]

Batch 6750/7798:
total_loss: 0.3854
expr_validation_pen: 0.2031
ce_loss: 0.3203



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  90%|████████▉ | 7000/7798 [23:23<02:37,  5.06it/s]

Batch 7000/7798:
total_loss: 0.3817
expr_validation_pen: 0.2344
ce_loss: 0.3092


Epoch 4/5:  93%|█████████▎| 7250/7798 [24:13<01:47,  5.08it/s]

Batch 7250/7798:
total_loss: 0.2924
expr_validation_pen: 0.2031
ce_loss: 0.2431


Epoch 4/5:  96%|█████████▌| 7500/7798 [25:03<00:57,  5.15it/s]

Batch 7500/7798:
total_loss: 0.4519
expr_validation_pen: 0.2188
ce_loss: 0.3708



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 4/5:  99%|█████████▉| 7750/7798 [25:53<00:09,  5.15it/s]

Batch 7750/7798:
total_loss: 0.2158
expr_validation_pen: 0.2188
ce_loss: 0.1771


Epoch 4/5: 100%|██████████| 7798/7798 [26:02<00:00,  4.99it/s]
Validating: 100%|██████████| 15/15 [00:03<00:00,  4.14it/s]


Epoch Summary:
Train Total Loss: 0.4199, Train CE Loss: 0.3338, Train Expr Validation Pen: 0.2498
Val Total Loss: 0.3653, Val CE Loss: 0.3004, Val Expr Validation Pen: 0.2009
Learning Rate: 0.000100
New best model saved with val CE loss: 0.3004
--------------------------------------------------




Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:   3%|▎         | 250/7798 [00:50<25:18,  4.97it/s]

Batch 250/7798:
total_loss: 0.2995
expr_validation_pen: 0.2188
ce_loss: 0.2458


Epoch 5/5:   6%|▋         | 500/7798 [01:40<23:18,  5.22it/s]

Batch 500/7798:
total_loss: 0.2968
expr_validation_pen: 0.1562
ce_loss: 0.2567


Epoch 5/5:  10%|▉         | 750/7798 [02:30<23:01,  5.10it/s]

Batch 750/7798:
total_loss: 0.2752
expr_validation_pen: 0.1562
ce_loss: 0.2380


Epoch 5/5:  13%|█▎        | 1000/7798 [03:20<21:40,  5.23it/s]

Batch 1000/7798:
total_loss: 0.5878
expr_validation_pen: 0.2812
ce_loss: 0.4588


Epoch 5/5:  16%|█▌        | 1250/7798 [04:10<21:50,  5.00it/s]

Batch 1250/7798:
total_loss: 0.3355
expr_validation_pen: 0.2031
ce_loss: 0.2788


Epoch 5/5:  19%|█▉        | 1500/7798 [05:01<20:57,  5.01it/s]

Batch 1500/7798:
total_loss: 0.5042
expr_validation_pen: 0.2656
ce_loss: 0.3984


Epoch 5/5:  22%|██▏       | 1750/7798 [05:50<19:48,  5.09it/s]

Batch 1750/7798:
total_loss: 0.4717
expr_validation_pen: 0.3125
ce_loss: 0.3594


Epoch 5/5:  26%|██▌       | 2000/7798 [06:41<19:08,  5.05it/s]

Batch 2000/7798:
total_loss: 0.3844
expr_validation_pen: 0.1875
ce_loss: 0.3237



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:  29%|██▉       | 2250/7798 [07:31<18:47,  4.92it/s]

Batch 2250/7798:
total_loss: 0.5847
expr_validation_pen: 0.3750
ce_loss: 0.4253



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:  32%|███▏      | 2500/7798 [08:21<17:51,  4.94it/s]

Batch 2500/7798:
total_loss: 0.2699
expr_validation_pen: 0.1875
ce_loss: 0.2272


Epoch 5/5:  35%|███▌      | 2750/7798 [09:11<16:21,  5.14it/s]

Batch 2750/7798:
total_loss: 0.1672
expr_validation_pen: 0.1562
ce_loss: 0.1446


Epoch 5/5:  38%|███▊      | 3000/7798 [10:02<16:18,  4.91it/s]

Batch 3000/7798:
total_loss: 0.2309
expr_validation_pen: 0.2812
ce_loss: 0.1803



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:  42%|████▏     | 3250/7798 [10:52<14:02,  5.40it/s]

Batch 3250/7798:
total_loss: 0.4022
expr_validation_pen: 0.2344
ce_loss: 0.3258


Epoch 5/5:  45%|████▍     | 3500/7798 [11:42<14:20,  4.99it/s]

Batch 3500/7798:
total_loss: 0.4963
expr_validation_pen: 0.2656
ce_loss: 0.3921



Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.


Using non-Expr arguments in Mul is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:  48%|████▊     | 3750/7798 [12:32<13:07,  5.14it/s]

Batch 3750/7798:
total_loss: 0.3743
expr_validation_pen: 0.1719
ce_loss: 0.3194


Epoch 5/5:  51%|█████▏    | 4000/7798 [13:22<12:17,  5.15it/s]

Batch 4000/7798:
total_loss: 0.5034
expr_validation_pen: 0.2969
ce_loss: 0.3882



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:  55%|█████▍    | 4250/7798 [14:12<11:44,  5.04it/s]

Batch 4250/7798:
total_loss: 0.2546
expr_validation_pen: 0.1250
ce_loss: 0.2263


Epoch 5/5:  58%|█████▊    | 4500/7798 [15:02<10:03,  5.47it/s]

Batch 4500/7798:
total_loss: 0.1596
expr_validation_pen: 0.1719
ce_loss: 0.1362


Epoch 5/5:  61%|██████    | 4750/7798 [15:52<10:06,  5.02it/s]

Batch 4750/7798:
total_loss: 0.4275
expr_validation_pen: 0.2031
ce_loss: 0.3553


Epoch 5/5:  64%|██████▍   | 5000/7798 [16:43<09:43,  4.79it/s]

Batch 5000/7798:
total_loss: 0.3838
expr_validation_pen: 0.1875
ce_loss: 0.3232


Epoch 5/5:  67%|██████▋   | 5250/7798 [17:32<08:21,  5.08it/s]

Batch 5250/7798:
total_loss: 0.2399
expr_validation_pen: 0.1562
ce_loss: 0.2075


Epoch 5/5:  71%|███████   | 5500/7798 [18:23<07:38,  5.02it/s]

Batch 5500/7798:
total_loss: 0.3289
expr_validation_pen: 0.1719
ce_loss: 0.2807


Epoch 5/5:  74%|███████▎  | 5750/7798 [19:13<06:25,  5.31it/s]

Batch 5750/7798:
total_loss: 0.4600
expr_validation_pen: 0.2031
ce_loss: 0.3824


Epoch 5/5:  77%|███████▋  | 6000/7798 [20:03<06:01,  4.97it/s]

Batch 6000/7798:
total_loss: 0.1993
expr_validation_pen: 0.1406
ce_loss: 0.1747


Epoch 5/5:  80%|████████  | 6250/7798 [20:54<05:01,  5.13it/s]

Batch 6250/7798:
total_loss: 0.2001
expr_validation_pen: 0.1406
ce_loss: 0.1754



Using non-Expr arguments in Add is deprecated (in this case, one of
the arguments has type 'Tuple').

If you really did intend to use a multiplication or addition operation with
this object, use the * or + operator instead.

See https://docs.sympy.org/latest/explanation/active-deprecations.html#non-expr-args-deprecated
for details.

This has been deprecated since SymPy version 1.7. It
will be removed in a future version of SymPy.

Epoch 5/5:  83%|████████▎ | 6500/7798 [21:44<04:26,  4.86it/s]

Batch 6500/7798:
total_loss: 0.4149
expr_validation_pen: 0.2656
ce_loss: 0.3279


Epoch 5/5:  87%|████████▋ | 6750/7798 [22:34<03:27,  5.04it/s]

Batch 6750/7798:
total_loss: 0.1609
expr_validation_pen: 0.1719
ce_loss: 0.1373


Epoch 5/5:  90%|████████▉ | 7000/7798 [23:24<02:34,  5.17it/s]

Batch 7000/7798:
total_loss: 0.3070
expr_validation_pen: 0.2031
ce_loss: 0.2552


Epoch 5/5:  93%|█████████▎| 7250/7798 [24:14<01:50,  4.96it/s]

Batch 7250/7798:
total_loss: 0.1779
expr_validation_pen: 0.1562
ce_loss: 0.1538


Epoch 5/5:  96%|█████████▌| 7500/7798 [25:05<01:01,  4.85it/s]

Batch 7500/7798:
total_loss: 0.3742
expr_validation_pen: 0.2344
ce_loss: 0.3032


Epoch 5/5:  99%|█████████▉| 7750/7798 [25:55<00:11,  4.02it/s]

Batch 7750/7798:
total_loss: 0.4646
expr_validation_pen: 0.2031
ce_loss: 0.3861


Epoch 5/5: 100%|██████████| 7798/7798 [26:04<00:00,  4.98it/s]
Validating: 100%|██████████| 15/15 [00:03<00:00,  4.13it/s]


Epoch Summary:
Train Total Loss: 0.3613, Train CE Loss: 0.2940, Train Expr Validation Pen: 0.2209
Val Total Loss: 0.3230, Val CE Loss: 0.2672, Val Expr Validation Pen: 0.1954
Learning Rate: 0.000100
New best model saved with val CE loss: 0.2672
--------------------------------------------------



