In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/xye_1var/pytorch/default/1/XYE_1Var.pt
/kaggle/input/symbolic_diffusion_initial/pytorch/default/1/symbolic_diffusion_model.pth
/kaggle/input/1-var-dataset/1_var_test.json
/kaggle/input/1-var-dataset/1_var_val.json
/kaggle/input/1-var-dataset/1_var_train.json
/kaggle/input/xye_9var/pytorch/default/1/XYE_9Var.pt


In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import glob
import json
from torch.utils.data import Dataset
import re
import numpy as np
import tqdm
import random
from math import exp, sin, cos


def generateDataStrEq(
    eq, n_points=2, n_vars=3, decimals=4, supportPoints=None, min_x=0, max_x=3
):
    X = []
    Y = []
    # TODO: Need to make this faster
    for p in range(n_points):
        if supportPoints is None:
            if type(min_x) == list:
                x = []
                for _ in range(n_vars):
                    idx = np.random.randint(len(min_x))
                    x += list(
                        np.round(np.random.uniform(min_x[idx], max_x[idx], 1), decimals)
                    )
            else:
                x = list(np.round(np.random.uniform(min_x, max_x, n_vars), decimals))
            assert (
                len(x) != 0
            ), "For some reason, we didn't generate the points correctly!"
        else:
            x = supportPoints[p]

        tmpEq = eq + ""
        for nVID in range(n_vars):
            tmpEq = tmpEq.replace("x{}".format(nVID + 1), str(x[nVID]))
        y = float(np.round(eval(tmpEq), decimals))
        X.append(x)
        Y.append(y)
    return X, Y


# def processDataFiles(files):
#     text = ""
#     for f in tqdm(files):
#         with open(f, 'r') as h:
#             lines = h.read() # don't worry we won't run out of file handles
#             if lines[-1]==-1:
#                 lines = lines[:-1]
#             #text += lines #json.loads(line)
#             text = ''.join([lines,text])
#     return text


def processDataFiles(files):
    text = ""
    for f in files:
        with open(f, "r") as h:
            lines = h.read()  # don't worry we won't run out of file handles
            if lines[-1] == -1:
                lines = lines[:-1]
            # text += lines #json.loads(line)
            text = "".join([lines, text])
    return text


def tokenize_equation(eq):
    token_spec = [
        (r"\*\*"),  # exponentiation
        (r"exp"),  # exp function
        (r"[+\-*/=()]"),  # operators and parentheses
        (r"sin"),  # sin function
        (r"cos"),  # cos function
        (r"log"),  # log function
        (r"x\d+"),  # variables like x1, x23, etc.
        (r"C"),  # constants placeholder
        (r"-?\d+\.\d+"),  # decimal numbers
        (r"-?\d+"),  # integers
        (r"_"),  # padding token
    ]
    token_regex = "|".join(f"({pattern})" for pattern in token_spec)
    matches = re.finditer(token_regex, eq)
    return [match.group(0) for match in matches]


class CharDataset(Dataset):
    def __init__(
        self,
        data,
        block_size,
        tokens,
        numVars,
        numYs,
        numPoints,
        target="Skeleton",
        addVars=False,
        const_range=[-0.4, 0.4],
        xRange=[-3.0, 3.0],
        decimals=4,
        augment=False,
    ):

        data_size, vocab_size = len(data), len(tokens)
        print("data has %d examples, %d unique." % (data_size, vocab_size))

        self.stoi = {tok: i for i, tok in enumerate(tokens)}
        self.itos = {i: tok for i, tok in enumerate(tokens)}

        self.numVars = numVars
        self.numYs = numYs
        self.numPoints = numPoints

        # padding token
        self.paddingToken = "_"
        self.paddingID = self.stoi["_"]  # or another ID not already used
        self.stoi[self.paddingToken] = self.paddingID
        self.itos[self.paddingID] = self.paddingToken

        self.threshold = [-1000, 1000]

        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data  # it should be a list of examples
        self.target = target
        self.addVars = addVars

        self.const_range = const_range
        self.xRange = xRange
        self.decimals = decimals
        self.augment = augment

    def __len__(self):
        return len(self.data) - 1

    def __getitem__(self, idx):
        # grab an example from the data
        chunk = self.data[idx]  # sequence of tokens including x, y, eq, etc.

        try:
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary
        except Exception as e:
            print("Couldn't convert to json: {} \n error is: {}".format(chunk, e))
            # try the previous example
            idx = idx - 1
            idx = idx if idx >= 0 else 0
            chunk = self.data[idx]
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary

        # find the number of variables in the equation
        printInfoCondition = random.random() < 0.0000001
        eq = chunk[self.target]
        if printInfoCondition:
            print(f"\nEquation: {eq}")
        vars = re.finditer("x[\d]+", eq)
        numVars = 0
        for v in vars:
            v = v.group(0).strip("x")
            v = eval(v)
            v = int(v)
            if v > numVars:
                numVars = v

        if self.target == "Skeleton" and self.augment:
            threshold = 5000
            # randomly generate the constants
            cleanEqn = ""
            for chr in eq:
                if chr == "C":
                    # genereate a new random number
                    chr = "{}".format(
                        np.random.uniform(self.const_range[0], self.const_range[1])
                    )
                cleanEqn += chr

            # update the points
            nPoints = np.random.randint(
                *self.numPoints
            )  # if supportPoints is None else len(supportPoints)
            try:
                if printInfoCondition:
                    print("Org:", chunk["X"], chunk["Y"])

                X, y = generateDataStrEq(
                    cleanEqn,
                    n_points=nPoints,
                    n_vars=self.numVars,
                    decimals=self.decimals,
                    min_x=self.xRange[0],
                    max_x=self.xRange[1],
                )

                # replace out of threshold with maximum numbers
                y = [e if abs(e) < threshold else np.sign(e) * threshold for e in y]

                # check if there is nan/inf/very large numbers in the y
                conditions = (
                    (np.isnan(y).any() or np.isinf(y).any())
                    or len(y) == 0
                    or (abs(min(y)) > threshold or abs(max(y)) > threshold)
                )
                if not conditions:
                    chunk["X"], chunk["Y"] = X, y

                if printInfoCondition:
                    print("Evd:", chunk["X"], chunk["Y"])
            except Exception as e:
                # for different reason this might happend including but not limited to division by zero
                print(
                    "".join(
                        [
                            f"We just used the original equation and support points because of {e}. ",
                            f"The equation is {eq}, and we update the equation to {cleanEqn}",
                        ]
                    )
                )

        # encode every character in the equation to an integer
        # < is SOS, > is EOS
        if self.addVars:
            dix = [self.stoi[s] for s in "<" + str(numVars) + ":" + eq + ">"]
        else:
            eq_tokens = tokenize_equation(eq)
            if self.addVars:
                token_seq = ["<", str(numVars), ":", *eq_tokens, ">"]
            else:
                token_seq = ["<", *eq_tokens, ">"]
            dix = [self.stoi[tok] for tok in token_seq]

        inputs = dix[:-1]
        outputs = dix[1:]

        # add the padding to the equations
        paddingSize = max(self.block_size - len(inputs), 0)
        paddingList = [self.paddingID] * paddingSize
        inputs += paddingList
        outputs += paddingList

        # make sure it is not more than what should be
        inputs = inputs[: self.block_size]
        outputs = outputs[: self.block_size]

        points = torch.zeros(self.numVars + self.numYs, self.numPoints - 1)
        for idx, xy in enumerate(zip(chunk["X"], chunk["Y"])):

            if not isinstance(xy[0], list) or not isinstance(
                xy[1], (list, float, np.float64)
            ):
                print(f"Unexpected types: {type(xy[0])}, {type(xy[1])}")
                continue  # Skip if types are incorrect

            # don't let to exceed the maximum number of points
            if idx >= self.numPoints - 1:
                break

            x = xy[0]
            x = x + [0] * (max(self.numVars - len(x), 0))  # padding

            y = [xy[1]] if type(xy[1]) == float or type(xy[1]) == np.float64 else xy[1]

            y = y + [0] * (max(self.numYs - len(y), 0))  # padding
            p = x + y  # because it is only one point
            p = torch.tensor(p)
            # replace nan and inf
            p = torch.nan_to_num(
                p,
                nan=self.threshold[1],
                posinf=self.threshold[1],
                neginf=self.threshold[0],
            )

            points[:, idx] = p

        points = torch.nan_to_num(
            points,
            nan=self.threshold[1],
            posinf=self.threshold[1],
            neginf=self.threshold[0],
        )

        inputs = torch.tensor(inputs, dtype=torch.long)
        outputs = torch.tensor(outputs, dtype=torch.long)
        numVars = torch.tensor(numVars, dtype=torch.long)
        return inputs, outputs, points, numVars


# Relative Mean Square Error
def relativeErr(y, yHat, info=False, eps=1e-5):
    yHat = np.reshape(yHat, [1, -1])[0]
    y = np.reshape(y, [1, -1])[0]
    if len(y) > 0 and len(y) == len(yHat):
        err = ((yHat - y)) ** 2 / np.linalg.norm(y + eps)
        if info:
            for _ in range(5):
                i = np.random.randint(len(y))
                print("yPR,yTrue:{},{}, Err:{}".format(yHat[i], y[i], err[i]))
    else:
        err = 100

    return np.mean(err)


def lossFunc(constants, eq, X, Y, eps=1e-5):
    err = 0
    eq = eq.replace("C", "{}").format(*constants)

    for x, y in zip(X, Y):
        eqTemp = eq + ""
        if type(x) == np.float32:
            x = [x]
        for i, e in enumerate(x):
            # make sure e is not a tensor
            if type(e) == torch.Tensor:
                e = e.item()
            eqTemp = eqTemp.replace("x{}".format(i + 1), str(e))
        try:
            yHat = eval(eqTemp)
        except:
            # print("Exception has been occured! EQ: {}, OR: {}".format(eqTemp, eq))
            yHat = 100
        try:
            # handle overflow
            err += relativeErr(y, yHat)  # (y-yHat)**2
        except:
            # print(
            #    "Exception has been occured! EQ: {}, OR: {}, y:{}-yHat:{}".format(
            #        eqTemp, eq, y, yHat
            #    )
            # )
            err += 10

    err /= len(Y)
    return err


def get_predicted_skeleton(generated_tokens, train_dataset: CharDataset):
    predicted_tokens = generated_tokens.cpu().numpy()
    predicted = "".join([train_dataset.itos[int(idx)] for idx in predicted_tokens])
    predicted = predicted.strip(train_dataset.paddingToken).split(">")
    predicted = predicted[0] if len(predicted[0]) >= 1 else predicted[1]
    predicted = predicted.strip("<").strip(">")
    predicted = predicted.replace("Ce", "C*e")

    return predicted


In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm

# from SymbolicGPT: https://github.com/mojivalipour/symbolicgpt/blob/master/models.py
class PointNetConfig:
    """base PointNet config"""

    def __init__(
        self,
        embeddingSize,
        numberofPoints,
        numberofVars,
        numberofYs,
        method="GPT",
        varibleEmbedding="NOT_VAR",
        **kwargs,
    ):
        self.embeddingSize = embeddingSize
        self.numberofPoints = numberofPoints  # number of points
        self.numberofVars = numberofVars  # input dimension (Xs)
        self.numberofYs = numberofYs  # output dimension (Ys)
        self.method = method
        self.varibleEmbedding = varibleEmbedding

        for k, v in kwargs.items():
            setattr(self, k, v)


class tNet(nn.Module):
    """
    The PointNet structure in the orginal PointNet paper:
    PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation by Qi et. al. 2017
    """

    def __init__(self, config):
        super(tNet, self).__init__()

        self.activation_func = F.relu
        self.num_units = config.embeddingSize

        self.conv1 = nn.Conv1d(
            config.numberofVars + config.numberofYs, self.num_units, 1
        )
        self.conv2 = nn.Conv1d(self.num_units, 2 * self.num_units, 1)
        self.conv3 = nn.Conv1d(2 * self.num_units, 4 * self.num_units, 1)
        self.fc1 = nn.Linear(4 * self.num_units, 2 * self.num_units)
        self.fc2 = nn.Linear(2 * self.num_units, self.num_units)

        # self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(config.numberofVars + config.numberofYs)
        # self.input_layer_norm = nn.LayerNorm(config.numberofPoints)

        self.bn1 = nn.BatchNorm1d(self.num_units)
        self.bn2 = nn.BatchNorm1d(2 * self.num_units)
        self.bn3 = nn.BatchNorm1d(4 * self.num_units)
        self.bn4 = nn.BatchNorm1d(2 * self.num_units)
        self.bn5 = nn.BatchNorm1d(self.num_units)

    def forward(self, x):
        """
        :param x: [batch, #features, #points]
        :return:
            logit: [batch, embedding_size]
        """
        x = self.input_batch_norm(x)
        x = self.activation_func(self.bn1(self.conv1(x)))
        x = self.activation_func(self.bn2(self.conv2(x)))
        x = self.activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4 * self.num_units

        x = self.activation_func(self.bn4(self.fc1(x)))
        x = self.activation_func(self.bn5(self.fc2(x)))
        # x = self.fc2(x)

        return x


class NoisePredictionTransformer(nn.Module):
    def __init__(self, n_embd, max_seq_len, n_layer=6, n_head=8, max_timesteps=1000):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, n_embd))
        self.time_emb = nn.Embedding(max_timesteps, n_embd)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)

    def forward(self, x_t, t, condition):
        _, L, _ = x_t.shape
        pos_emb = self.pos_emb[:, :L, :]  # [1, L, n_embd]
        time_emb = self.time_emb(t)
        if time_emb.dim() == 1:  # Scalar t case, [n_embd]
            time_emb = time_emb.unsqueeze(0)  # [1, n_embd]
        time_emb = time_emb.unsqueeze(1)  # [1, 1, n_embd]
        condition = condition.unsqueeze(1)  # [B, 1, n_embd]

        x = x_t + pos_emb + time_emb + condition
        return self.encoder(x)


# influenced by https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
class SymbolicGaussianDiffusion(nn.Module):
    def __init__(
        self,
        tnet_config,
        vocab_size,
        max_seq_len,
        padding_idx: int = 0,
        max_num_vars: int = 9,
        n_layer=6,
        n_head=8,
        n_embd=512,
        timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        ce_weight=1.0,  # Weight for CE loss relative to MSE
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.padding_idx = padding_idx
        self.n_embd = n_embd
        self.timesteps = timesteps
        self.ce_weight = ce_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.tok_emb = nn.Embedding(vocab_size, n_embd, padding_idx=self.padding_idx)
        self.vars_emb = nn.Embedding(max_num_vars, n_embd)

        self.decoder = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.weight = self.tok_emb.weight

        self.tnet = tNet(tnet_config)
        self.model = NoisePredictionTransformer(
            n_embd, max_seq_len, n_layer, n_head, timesteps
        )

        # Noise schedule
        self.register_buffer("beta", torch.linspace(beta_start, beta_end, timesteps))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

    def q_sample(self, x_start, t, noise=None):
        noise = torch.randn_like(x_start)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1)

        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
        return x_t

    def p_mean_variance(self, x, t, t_next, condition):
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        alpha_bar_t_next = self.alpha_bar[t_next]
        beta_t = self.beta[t]

        x_start_pred = self.model(x, t.long(), condition)

        coeff1 = torch.sqrt(alpha_bar_t_next) * beta_t / (1 - alpha_bar_t)
        coeff2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_next) / (1 - alpha_bar_t)
        mean = coeff1 * x_start_pred + coeff2 * x
        variance = (1 - alpha_bar_t_next) / (1 - alpha_bar_t) * beta_t
        return mean, variance

    @torch.no_grad()
    def p_sample(self, x, t, t_next, condition):
        mean, variance = self.p_mean_variance(x, t, t_next, condition)
        if torch.all(t_next == 0):
            return mean
        noise = torch.randn_like(x)
        return mean + torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, points, variables, train_dataset, batch_size=16):
        condition = self.tnet(points) + self.vars_emb(variables)
        shape = (batch_size, self.max_seq_len, self.n_embd)
        x = torch.randn(shape, device=self.device)
        steps = torch.arange(self.timesteps - 1, -1, -1, device=self.device)

        for i in tqdm(
            range(self.timesteps), desc="sampling loop", total=self.timesteps
        ):
            t = steps[i]
            t_next = (
                steps[i + 1]
                if i + 1 < self.timesteps
                else torch.tensor(0, device=self.device)
            )
            x = self.p_sample(x, t, t_next, condition)

            # Print prediction every 250 steps
            if (i + 1) % 250 == 0:
                logits = self.decoder(x)  # [B, L, vocab_size]
                token_indices = torch.argmax(logits, dim=-1)  # [B, L]
                for j in range(batch_size):
                    token_indices_j = token_indices[j]  # [L]
                    predicted_skeleton = get_predicted_skeleton(
                        token_indices_j, train_dataset
                    )
                    tqdm.write(f" sample {j}: predicted_skeleton: {predicted_skeleton}")

        logits = self.decoder(x)  # [B, L, vocab_size]
        token_indices = torch.argmax(logits, dim=-1)  # [B, L]
        predicted_skeletons = []
        for j in range(batch_size):
            token_indices_j = token_indices[j]  # [L]
            predicted_skeleton = get_predicted_skeleton(token_indices_j, train_dataset)
            predicted_skeletons.append(predicted_skeleton)
        return predicted_skeletons

    def p_losses(
        self, x_start, points, tokens, variables, t, noise=None, mse: bool = False
    ):
        """Hybrid loss: MSE on embeddings + CE on tokens."""
        noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise)
        condition = self.tnet(points) + self.vars_emb(variables)
        x_start_pred = self.model(x_t, t.long(), condition)

        # MSE loss on embeddings
        if mse:
            mse_loss = F.mse_loss(x_start_pred, x_start)
        else:
            mse_loss = torch.tensor(0.0, device=self.device)

        # CE loss on tokens
        logits = self.decoder(x_start_pred)  # [B, L, vocab_size]
        ce_loss = F.cross_entropy(
            logits.view(-1, self.vocab_size),  # [B*L, vocab_size]
            tokens.view(-1),  # [B*L]
            ignore_index=self.padding_idx,
            reduction="mean",
        )

        total_loss = mse_loss + self.ce_weight * ce_loss
        return total_loss, mse_loss, ce_loss

    def forward(self, points, tokens, variables, t, mse=False):
        token_emb = self.tok_emb(tokens)
        total_loss, mse_loss, ce_loss = self.p_losses(
            token_emb, points, tokens, variables, t, mse=mse
        )
        return total_loss, mse_loss, ce_loss


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple


def train_epoch(
    model: SymbolicGaussianDiffusion,  
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:  
    model.train()
    total_train_loss = 0
    total_mse_loss = 0
    total_ce_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)

        optimizer.zero_grad()

        total_loss, mse_loss, ce_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}, mse: {mse_loss}, ce: {ce_loss}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()
        total_mse_loss += mse_loss.item()
        total_ce_loss += ce_loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_mse_loss = total_mse_loss / len(train_loader)
    avg_ce_loss = total_ce_loss / len(train_loader)
    return avg_train_loss, avg_mse_loss, avg_ce_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,  
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:  
    model.eval()
    total_val_loss = 0
    total_mse_loss = 0
    total_ce_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)

            total_loss, mse_loss, ce_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()
            total_mse_loss += mse_loss.item()
            total_ce_loss += ce_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    avg_mse_loss = total_mse_loss / len(val_loader)
    avg_ce_loss = total_ce_loss / len(val_loader)
    return avg_val_loss, avg_mse_loss, avg_ce_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,  
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        avg_train_loss, avg_mse_loss, avg_ce_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss, val_mse_loss, val_ce_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f} (MSE: {avg_mse_loss:.4f}, CE: {avg_ce_loss:.4f})"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f} (MSE: {val_mse_loss:.4f}, CE: {val_ce_loss:.4f})"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, "best_model.pth")
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)


In [5]:
n_embd = 512
timesteps = 1000
batch_size = 32
learning_rate = 1e-4
num_epochs = 10
blockSize = 32
numVars = 1
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
train_path = "/kaggle/input/1-var-dataset/1_var_train.json"
val_path = "/kaggle/input/1-var-dataset/1_var_val.json"

In [7]:
import numpy as np
import glob
import random

files = glob.glob(train_path)
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 498795 examples, 27 unique.
id:489771
outputs:C*log(C*x1+C)**3+C>_________________
variables:1


In [8]:
files = glob.glob(val_path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 972 examples, 27 unique.
tensor(-35.1828) tensor(37.9909)
id:410
outputs:C*x1**3+C*x1**2+C*x1+C>______________
variables:1


In [9]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,  
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=8,
    n_head=8,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
    ce_weight=1.0  
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate
)

Epoch 1/10:   2%|▏         | 251/15588 [00:28<28:17,  9.03it/s]

Batch 250/15588:
total_loss: 1.8881055116653442, mse: 0.0, ce: 1.8881055116653442


Epoch 1/10:   3%|▎         | 500/15588 [00:56<30:31,  8.24it/s]

Batch 500/15588:
total_loss: 1.4779595136642456, mse: 0.0, ce: 1.4779595136642456


Epoch 1/10:   5%|▍         | 750/15588 [01:27<30:15,  8.17it/s]

Batch 750/15588:
total_loss: 1.4312117099761963, mse: 0.0, ce: 1.4312117099761963


Epoch 1/10:   6%|▋         | 1001/15588 [01:57<29:07,  8.35it/s]

Batch 1000/15588:
total_loss: 0.8003225326538086, mse: 0.0, ce: 0.8003225326538086


Epoch 1/10:   8%|▊         | 1250/15588 [02:27<28:57,  8.25it/s]

Batch 1250/15588:
total_loss: 0.6797471046447754, mse: 0.0, ce: 0.6797471046447754


Epoch 1/10:  10%|▉         | 1500/15588 [02:57<28:26,  8.26it/s]

Batch 1500/15588:
total_loss: 0.5210134983062744, mse: 0.0, ce: 0.5210134983062744


Epoch 1/10:  11%|█         | 1750/15588 [03:28<27:58,  8.24it/s]

Batch 1750/15588:
total_loss: 0.7250174283981323, mse: 0.0, ce: 0.7250174283981323


Epoch 1/10:  13%|█▎        | 2000/15588 [03:58<27:31,  8.23it/s]

Batch 2000/15588:
total_loss: 0.615664541721344, mse: 0.0, ce: 0.615664541721344


Epoch 1/10:  14%|█▍        | 2250/15588 [04:28<26:59,  8.23it/s]

Batch 2250/15588:
total_loss: 0.710407555103302, mse: 0.0, ce: 0.710407555103302


Epoch 1/10:  16%|█▌        | 2500/15588 [04:59<26:25,  8.26it/s]

Batch 2500/15588:
total_loss: 0.7401155233383179, mse: 0.0, ce: 0.7401155233383179


Epoch 1/10:  18%|█▊        | 2750/15588 [05:29<25:58,  8.24it/s]

Batch 2750/15588:
total_loss: 0.5831716060638428, mse: 0.0, ce: 0.5831716060638428


Epoch 1/10:  19%|█▉        | 3000/15588 [05:59<25:26,  8.24it/s]

Batch 3000/15588:
total_loss: 0.7434169054031372, mse: 0.0, ce: 0.7434169054031372


Epoch 1/10:  21%|██        | 3250/15588 [06:30<24:54,  8.26it/s]

Batch 3250/15588:
total_loss: 0.7535939812660217, mse: 0.0, ce: 0.7535939812660217


Epoch 1/10:  22%|██▏       | 3500/15588 [07:00<24:28,  8.23it/s]

Batch 3500/15588:
total_loss: 0.4333100914955139, mse: 0.0, ce: 0.4333100914955139


Epoch 1/10:  24%|██▍       | 3750/15588 [07:30<23:55,  8.25it/s]

Batch 3750/15588:
total_loss: 0.6444059014320374, mse: 0.0, ce: 0.6444059014320374


Epoch 1/10:  26%|██▌       | 4000/15588 [08:01<23:23,  8.26it/s]

Batch 4000/15588:
total_loss: 0.33479148149490356, mse: 0.0, ce: 0.33479148149490356


Epoch 1/10:  27%|██▋       | 4250/15588 [08:31<22:56,  8.24it/s]

Batch 4250/15588:
total_loss: 0.6684640049934387, mse: 0.0, ce: 0.6684640049934387


Epoch 1/10:  29%|██▉       | 4500/15588 [09:01<22:28,  8.22it/s]

Batch 4500/15588:
total_loss: 0.6135002374649048, mse: 0.0, ce: 0.6135002374649048


Epoch 1/10:  30%|███       | 4750/15588 [09:32<21:54,  8.24it/s]

Batch 4750/15588:
total_loss: 0.791426956653595, mse: 0.0, ce: 0.791426956653595


Epoch 1/10:  32%|███▏      | 5000/15588 [10:02<21:24,  8.24it/s]

Batch 5000/15588:
total_loss: 0.3923131823539734, mse: 0.0, ce: 0.3923131823539734


Epoch 1/10:  34%|███▎      | 5250/15588 [10:32<20:54,  8.24it/s]

Batch 5250/15588:
total_loss: 0.7235156893730164, mse: 0.0, ce: 0.7235156893730164


Epoch 1/10:  35%|███▌      | 5500/15588 [11:03<20:24,  8.24it/s]

Batch 5500/15588:
total_loss: 0.5003390908241272, mse: 0.0, ce: 0.5003390908241272


Epoch 1/10:  37%|███▋      | 5750/15588 [11:33<19:55,  8.23it/s]

Batch 5750/15588:
total_loss: 0.5136409401893616, mse: 0.0, ce: 0.5136409401893616


Epoch 1/10:  38%|███▊      | 6000/15588 [12:03<19:21,  8.26it/s]

Batch 6000/15588:
total_loss: 0.42643633484840393, mse: 0.0, ce: 0.42643633484840393


Epoch 1/10:  40%|████      | 6250/15588 [12:34<18:53,  8.24it/s]

Batch 6250/15588:
total_loss: 0.4373394846916199, mse: 0.0, ce: 0.4373394846916199


Epoch 1/10:  42%|████▏     | 6500/15588 [13:04<18:19,  8.26it/s]

Batch 6500/15588:
total_loss: 0.3703329563140869, mse: 0.0, ce: 0.3703329563140869


Epoch 1/10:  43%|████▎     | 6750/15588 [13:34<17:51,  8.24it/s]

Batch 6750/15588:
total_loss: 0.5311222076416016, mse: 0.0, ce: 0.5311222076416016


Epoch 1/10:  45%|████▍     | 7000/15588 [14:05<17:22,  8.23it/s]

Batch 7000/15588:
total_loss: 0.12214284390211105, mse: 0.0, ce: 0.12214284390211105


Epoch 1/10:  47%|████▋     | 7250/15588 [14:35<16:52,  8.23it/s]

Batch 7250/15588:
total_loss: 0.7176676392555237, mse: 0.0, ce: 0.7176676392555237


Epoch 1/10:  48%|████▊     | 7500/15588 [15:05<16:21,  8.24it/s]

Batch 7500/15588:
total_loss: 0.7872928977012634, mse: 0.0, ce: 0.7872928977012634


Epoch 1/10:  50%|████▉     | 7750/15588 [15:36<15:51,  8.24it/s]

Batch 7750/15588:
total_loss: 0.235074982047081, mse: 0.0, ce: 0.235074982047081


Epoch 1/10:  51%|█████▏    | 8000/15588 [16:06<15:20,  8.24it/s]

Batch 8000/15588:
total_loss: 0.4636082649230957, mse: 0.0, ce: 0.4636082649230957


Epoch 1/10:  53%|█████▎    | 8250/15588 [16:36<14:50,  8.24it/s]

Batch 8250/15588:
total_loss: 0.38040465116500854, mse: 0.0, ce: 0.38040465116500854


Epoch 1/10:  55%|█████▍    | 8500/15588 [17:07<14:17,  8.26it/s]

Batch 8500/15588:
total_loss: 0.5828723907470703, mse: 0.0, ce: 0.5828723907470703


Epoch 1/10:  56%|█████▌    | 8750/15588 [17:37<13:51,  8.22it/s]

Batch 8750/15588:
total_loss: 0.666487991809845, mse: 0.0, ce: 0.666487991809845


Epoch 1/10:  58%|█████▊    | 9000/15588 [18:07<13:19,  8.24it/s]

Batch 9000/15588:
total_loss: 0.7623967528343201, mse: 0.0, ce: 0.7623967528343201


Epoch 1/10:  59%|█████▉    | 9250/15588 [18:38<12:49,  8.24it/s]

Batch 9250/15588:
total_loss: 0.46110987663269043, mse: 0.0, ce: 0.46110987663269043


Epoch 1/10:  61%|██████    | 9500/15588 [19:08<12:14,  8.29it/s]

Batch 9500/15588:
total_loss: 0.5870435237884521, mse: 0.0, ce: 0.5870435237884521


Epoch 1/10:  63%|██████▎   | 9751/15588 [19:38<11:46,  8.26it/s]

Batch 9750/15588:
total_loss: 0.9682082533836365, mse: 0.0, ce: 0.9682082533836365


Epoch 1/10:  64%|██████▍   | 10000/15588 [20:09<11:17,  8.25it/s]

Batch 10000/15588:
total_loss: 0.46549612283706665, mse: 0.0, ce: 0.46549612283706665


Epoch 1/10:  66%|██████▌   | 10250/15588 [20:39<10:49,  8.22it/s]

Batch 10250/15588:
total_loss: 0.6029150485992432, mse: 0.0, ce: 0.6029150485992432


Epoch 1/10:  67%|██████▋   | 10501/15588 [21:09<10:14,  8.28it/s]

Batch 10500/15588:
total_loss: 0.285274863243103, mse: 0.0, ce: 0.285274863243103


Epoch 1/10:  69%|██████▉   | 10750/15588 [21:40<09:46,  8.24it/s]

Batch 10750/15588:
total_loss: 0.4207443594932556, mse: 0.0, ce: 0.4207443594932556


Epoch 1/10:  71%|███████   | 11000/15588 [22:10<09:18,  8.22it/s]

Batch 11000/15588:
total_loss: 0.46110185980796814, mse: 0.0, ce: 0.46110185980796814


Epoch 1/10:  72%|███████▏  | 11250/15588 [22:40<08:47,  8.23it/s]

Batch 11250/15588:
total_loss: 0.44224342703819275, mse: 0.0, ce: 0.44224342703819275


Epoch 1/10:  74%|███████▍  | 11500/15588 [23:11<08:16,  8.24it/s]

Batch 11500/15588:
total_loss: 0.39307236671447754, mse: 0.0, ce: 0.39307236671447754


Epoch 1/10:  75%|███████▌  | 11750/15588 [23:41<07:45,  8.25it/s]

Batch 11750/15588:
total_loss: 0.3327305316925049, mse: 0.0, ce: 0.3327305316925049


Epoch 1/10:  77%|███████▋  | 12000/15588 [24:11<07:14,  8.26it/s]

Batch 12000/15588:
total_loss: 0.6633751392364502, mse: 0.0, ce: 0.6633751392364502


Epoch 1/10:  79%|███████▊  | 12250/15588 [24:42<06:45,  8.22it/s]

Batch 12250/15588:
total_loss: 0.36988571286201477, mse: 0.0, ce: 0.36988571286201477


Epoch 1/10:  80%|████████  | 12501/15588 [25:12<06:12,  8.28it/s]

Batch 12500/15588:
total_loss: 0.277704656124115, mse: 0.0, ce: 0.277704656124115


Epoch 1/10:  82%|████████▏ | 12750/15588 [25:42<05:43,  8.25it/s]

Batch 12750/15588:
total_loss: 0.49609825015068054, mse: 0.0, ce: 0.49609825015068054


Epoch 1/10:  83%|████████▎ | 13000/15588 [26:13<05:14,  8.23it/s]

Batch 13000/15588:
total_loss: 0.5507823824882507, mse: 0.0, ce: 0.5507823824882507


Epoch 1/10:  85%|████████▌ | 13250/15588 [26:43<04:43,  8.24it/s]

Batch 13250/15588:
total_loss: 0.41169989109039307, mse: 0.0, ce: 0.41169989109039307


Epoch 1/10:  87%|████████▋ | 13500/15588 [27:13<04:13,  8.24it/s]

Batch 13500/15588:
total_loss: 0.3458543121814728, mse: 0.0, ce: 0.3458543121814728


Epoch 1/10:  88%|████████▊ | 13750/15588 [27:44<03:42,  8.25it/s]

Batch 13750/15588:
total_loss: 0.23436084389686584, mse: 0.0, ce: 0.23436084389686584


Epoch 1/10:  90%|████████▉ | 14000/15588 [28:14<03:13,  8.23it/s]

Batch 14000/15588:
total_loss: 0.45459112524986267, mse: 0.0, ce: 0.45459112524986267


Epoch 1/10:  91%|█████████▏| 14250/15588 [28:44<02:42,  8.24it/s]

Batch 14250/15588:
total_loss: 0.17282480001449585, mse: 0.0, ce: 0.17282480001449585


Epoch 1/10:  93%|█████████▎| 14501/15588 [29:15<02:11,  8.29it/s]

Batch 14500/15588:
total_loss: 0.507618248462677, mse: 0.0, ce: 0.507618248462677


Epoch 1/10:  95%|█████████▍| 14750/15588 [29:45<01:41,  8.26it/s]

Batch 14750/15588:
total_loss: 0.17272180318832397, mse: 0.0, ce: 0.17272180318832397


Epoch 1/10:  96%|█████████▌| 15000/15588 [30:15<01:11,  8.25it/s]

Batch 15000/15588:
total_loss: 0.37516406178474426, mse: 0.0, ce: 0.37516406178474426


Epoch 1/10:  98%|█████████▊| 15250/15588 [30:46<00:40,  8.26it/s]

Batch 15250/15588:
total_loss: 0.37800657749176025, mse: 0.0, ce: 0.37800657749176025


Epoch 1/10:  99%|█████████▉| 15500/15588 [31:16<00:10,  8.26it/s]

Batch 15500/15588:
total_loss: 0.36628425121307373, mse: 0.0, ce: 0.36628425121307373


Epoch 1/10: 100%|██████████| 15588/15588 [31:27<00:00,  8.26it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 21.98it/s]



Epoch Summary:
Train Total Loss: 0.6141 (MSE: 0.0000, CE: 0.6141)
Val Total Loss: 0.3595 (MSE: 0.0000, CE: 0.3595)
Learning Rate: 0.000100
New best model saved with val loss: 0.3595
--------------------------------------------------


Epoch 2/10:   2%|▏         | 250/15588 [00:30<30:56,  8.26it/s]

Batch 250/15588:
total_loss: 0.18784067034721375, mse: 0.0, ce: 0.18784067034721375


Epoch 2/10:   3%|▎         | 500/15588 [01:00<30:29,  8.25it/s]

Batch 500/15588:
total_loss: 0.49641868472099304, mse: 0.0, ce: 0.49641868472099304


Epoch 2/10:   5%|▍         | 750/15588 [01:31<29:58,  8.25it/s]

Batch 750/15588:
total_loss: 0.5023239254951477, mse: 0.0, ce: 0.5023239254951477


Epoch 2/10:   6%|▋         | 1000/15588 [02:01<29:23,  8.27it/s]

Batch 1000/15588:
total_loss: 0.5287664532661438, mse: 0.0, ce: 0.5287664532661438


Epoch 2/10:   8%|▊         | 1250/15588 [02:31<28:59,  8.24it/s]

Batch 1250/15588:
total_loss: 0.28096383810043335, mse: 0.0, ce: 0.28096383810043335


Epoch 2/10:  10%|▉         | 1500/15588 [03:02<28:26,  8.26it/s]

Batch 1500/15588:
total_loss: 0.3294494152069092, mse: 0.0, ce: 0.3294494152069092


Epoch 2/10:  11%|█         | 1750/15588 [03:32<27:59,  8.24it/s]

Batch 1750/15588:
total_loss: 0.7195943593978882, mse: 0.0, ce: 0.7195943593978882


Epoch 2/10:  13%|█▎        | 2000/15588 [04:02<27:29,  8.24it/s]

Batch 2000/15588:
total_loss: 0.200326070189476, mse: 0.0, ce: 0.200326070189476


Epoch 2/10:  14%|█▍        | 2250/15588 [04:33<26:50,  8.28it/s]

Batch 2250/15588:
total_loss: 0.5654963254928589, mse: 0.0, ce: 0.5654963254928589


Epoch 2/10:  16%|█▌        | 2500/15588 [05:03<26:31,  8.22it/s]

Batch 2500/15588:
total_loss: 0.14485780894756317, mse: 0.0, ce: 0.14485780894756317


Epoch 2/10:  18%|█▊        | 2750/15588 [05:33<25:56,  8.25it/s]

Batch 2750/15588:
total_loss: 0.4102831184864044, mse: 0.0, ce: 0.4102831184864044


Epoch 2/10:  19%|█▉        | 3000/15588 [06:04<25:29,  8.23it/s]

Batch 3000/15588:
total_loss: 0.2513987123966217, mse: 0.0, ce: 0.2513987123966217


Epoch 2/10:  21%|██        | 3250/15588 [06:34<24:56,  8.24it/s]

Batch 3250/15588:
total_loss: 0.5226007699966431, mse: 0.0, ce: 0.5226007699966431


Epoch 2/10:  22%|██▏       | 3500/15588 [07:04<24:28,  8.23it/s]

Batch 3500/15588:
total_loss: 0.4349147379398346, mse: 0.0, ce: 0.4349147379398346


Epoch 2/10:  24%|██▍       | 3750/15588 [07:35<23:55,  8.24it/s]

Batch 3750/15588:
total_loss: 0.24654226005077362, mse: 0.0, ce: 0.24654226005077362


Epoch 2/10:  26%|██▌       | 4000/15588 [08:05<23:25,  8.24it/s]

Batch 4000/15588:
total_loss: 0.4468362629413605, mse: 0.0, ce: 0.4468362629413605


Epoch 2/10:  27%|██▋       | 4250/15588 [08:35<22:58,  8.23it/s]

Batch 4250/15588:
total_loss: 0.31832972168922424, mse: 0.0, ce: 0.31832972168922424


Epoch 2/10:  29%|██▉       | 4500/15588 [09:06<22:24,  8.25it/s]

Batch 4500/15588:
total_loss: 0.4022160470485687, mse: 0.0, ce: 0.4022160470485687


Epoch 2/10:  30%|███       | 4750/15588 [09:36<21:52,  8.26it/s]

Batch 4750/15588:
total_loss: 0.2518802285194397, mse: 0.0, ce: 0.2518802285194397


Epoch 2/10:  32%|███▏      | 5000/15588 [10:07<21:23,  8.25it/s]

Batch 5000/15588:
total_loss: 0.14275477826595306, mse: 0.0, ce: 0.14275477826595306


Epoch 2/10:  34%|███▎      | 5250/15588 [10:37<20:52,  8.25it/s]

Batch 5250/15588:
total_loss: 0.6174957752227783, mse: 0.0, ce: 0.6174957752227783


Epoch 2/10:  35%|███▌      | 5500/15588 [11:07<20:22,  8.25it/s]

Batch 5500/15588:
total_loss: 0.333468496799469, mse: 0.0, ce: 0.333468496799469


Epoch 2/10:  37%|███▋      | 5750/15588 [11:38<19:54,  8.23it/s]

Batch 5750/15588:
total_loss: 0.3735477030277252, mse: 0.0, ce: 0.3735477030277252


Epoch 2/10:  38%|███▊      | 6000/15588 [12:08<19:26,  8.22it/s]

Batch 6000/15588:
total_loss: 0.5223907232284546, mse: 0.0, ce: 0.5223907232284546


Epoch 2/10:  40%|████      | 6250/15588 [12:38<18:54,  8.23it/s]

Batch 6250/15588:
total_loss: 0.3376755714416504, mse: 0.0, ce: 0.3376755714416504


Epoch 2/10:  42%|████▏     | 6500/15588 [13:09<18:22,  8.25it/s]

Batch 6500/15588:
total_loss: 0.2403395175933838, mse: 0.0, ce: 0.2403395175933838


Epoch 2/10:  43%|████▎     | 6750/15588 [13:39<17:57,  8.21it/s]

Batch 6750/15588:
total_loss: 0.21357358992099762, mse: 0.0, ce: 0.21357358992099762


Epoch 2/10:  45%|████▍     | 7000/15588 [14:09<17:19,  8.26it/s]

Batch 7000/15588:
total_loss: 0.14615046977996826, mse: 0.0, ce: 0.14615046977996826


Epoch 2/10:  47%|████▋     | 7250/15588 [14:40<16:54,  8.22it/s]

Batch 7250/15588:
total_loss: 0.45495620369911194, mse: 0.0, ce: 0.45495620369911194


Epoch 2/10:  48%|████▊     | 7500/15588 [15:10<16:21,  8.24it/s]

Batch 7500/15588:
total_loss: 0.1839555948972702, mse: 0.0, ce: 0.1839555948972702


Epoch 2/10:  50%|████▉     | 7750/15588 [15:40<15:50,  8.24it/s]

Batch 7750/15588:
total_loss: 0.29593396186828613, mse: 0.0, ce: 0.29593396186828613


Epoch 2/10:  51%|█████▏    | 8000/15588 [16:11<15:21,  8.24it/s]

Batch 8000/15588:
total_loss: 0.554622232913971, mse: 0.0, ce: 0.554622232913971


Epoch 2/10:  53%|█████▎    | 8250/15588 [16:41<14:50,  8.24it/s]

Batch 8250/15588:
total_loss: 0.4679303467273712, mse: 0.0, ce: 0.4679303467273712


Epoch 2/10:  55%|█████▍    | 8500/15588 [17:11<14:21,  8.23it/s]

Batch 8500/15588:
total_loss: 0.1906944066286087, mse: 0.0, ce: 0.1906944066286087


Epoch 2/10:  56%|█████▌    | 8750/15588 [17:42<13:48,  8.26it/s]

Batch 8750/15588:
total_loss: 0.2942860424518585, mse: 0.0, ce: 0.2942860424518585


Epoch 2/10:  58%|█████▊    | 9000/15588 [18:12<13:17,  8.26it/s]

Batch 9000/15588:
total_loss: 0.21570537984371185, mse: 0.0, ce: 0.21570537984371185


Epoch 2/10:  59%|█████▉    | 9250/15588 [18:42<12:49,  8.23it/s]

Batch 9250/15588:
total_loss: 0.16843754053115845, mse: 0.0, ce: 0.16843754053115845


Epoch 2/10:  61%|██████    | 9500/15588 [19:13<12:17,  8.26it/s]

Batch 9500/15588:
total_loss: 0.2714032828807831, mse: 0.0, ce: 0.2714032828807831


Epoch 2/10:  63%|██████▎   | 9750/15588 [19:43<11:47,  8.25it/s]

Batch 9750/15588:
total_loss: 0.2516717314720154, mse: 0.0, ce: 0.2516717314720154


Epoch 2/10:  64%|██████▍   | 10000/15588 [20:14<11:18,  8.23it/s]

Batch 10000/15588:
total_loss: 0.3497743010520935, mse: 0.0, ce: 0.3497743010520935


Epoch 2/10:  66%|██████▌   | 10250/15588 [20:44<10:48,  8.23it/s]

Batch 10250/15588:
total_loss: 0.1540152132511139, mse: 0.0, ce: 0.1540152132511139


Epoch 2/10:  67%|██████▋   | 10500/15588 [21:14<10:17,  8.24it/s]

Batch 10500/15588:
total_loss: 0.2003636658191681, mse: 0.0, ce: 0.2003636658191681


Epoch 2/10:  69%|██████▉   | 10750/15588 [21:45<09:48,  8.23it/s]

Batch 10750/15588:
total_loss: 0.25281113386154175, mse: 0.0, ce: 0.25281113386154175


Epoch 2/10:  71%|███████   | 11000/15588 [22:15<09:15,  8.26it/s]

Batch 11000/15588:
total_loss: 0.2115965336561203, mse: 0.0, ce: 0.2115965336561203


Epoch 2/10:  72%|███████▏  | 11250/15588 [22:45<08:47,  8.22it/s]

Batch 11250/15588:
total_loss: 0.42816317081451416, mse: 0.0, ce: 0.42816317081451416


Epoch 2/10:  74%|███████▍  | 11500/15588 [23:16<08:14,  8.27it/s]

Batch 11500/15588:
total_loss: 0.3133881986141205, mse: 0.0, ce: 0.3133881986141205


Epoch 2/10:  75%|███████▌  | 11750/15588 [23:46<07:45,  8.24it/s]

Batch 11750/15588:
total_loss: 0.1696394979953766, mse: 0.0, ce: 0.1696394979953766


Epoch 2/10:  77%|███████▋  | 12000/15588 [24:16<07:15,  8.25it/s]

Batch 12000/15588:
total_loss: 0.43233245611190796, mse: 0.0, ce: 0.43233245611190796


Epoch 2/10:  79%|███████▊  | 12250/15588 [24:47<06:44,  8.25it/s]

Batch 12250/15588:
total_loss: 0.35807275772094727, mse: 0.0, ce: 0.35807275772094727


Epoch 2/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.24it/s]

Batch 12500/15588:
total_loss: 0.5302973985671997, mse: 0.0, ce: 0.5302973985671997


Epoch 2/10:  82%|████████▏ | 12751/15588 [25:48<05:43,  8.27it/s]

Batch 12750/15588:
total_loss: 0.40723028779029846, mse: 0.0, ce: 0.40723028779029846


Epoch 2/10:  83%|████████▎ | 13000/15588 [26:18<05:14,  8.22it/s]

Batch 13000/15588:
total_loss: 0.15333740413188934, mse: 0.0, ce: 0.15333740413188934


Epoch 2/10:  85%|████████▌ | 13251/15588 [26:48<04:42,  8.26it/s]

Batch 13250/15588:
total_loss: 0.2455703765153885, mse: 0.0, ce: 0.2455703765153885


Epoch 2/10:  87%|████████▋ | 13500/15588 [27:18<04:13,  8.25it/s]

Batch 13500/15588:
total_loss: 0.598786473274231, mse: 0.0, ce: 0.598786473274231


Epoch 2/10:  88%|████████▊ | 13751/15588 [27:49<03:41,  8.30it/s]

Batch 13750/15588:
total_loss: 0.28943464159965515, mse: 0.0, ce: 0.28943464159965515


Epoch 2/10:  90%|████████▉ | 14000/15588 [28:19<03:13,  8.22it/s]

Batch 14000/15588:
total_loss: 0.20093414187431335, mse: 0.0, ce: 0.20093414187431335


Epoch 2/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.24it/s]

Batch 14250/15588:
total_loss: 0.44172343611717224, mse: 0.0, ce: 0.44172343611717224


Epoch 2/10:  93%|█████████▎| 14500/15588 [29:20<02:11,  8.27it/s]

Batch 14500/15588:
total_loss: 0.3153049647808075, mse: 0.0, ce: 0.3153049647808075


Epoch 2/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.24it/s]

Batch 14750/15588:
total_loss: 0.11764249205589294, mse: 0.0, ce: 0.11764249205589294


Epoch 2/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.21it/s]

Batch 15000/15588:
total_loss: 0.1592952460050583, mse: 0.0, ce: 0.1592952460050583


Epoch 2/10:  98%|█████████▊| 15250/15588 [30:51<00:41,  8.24it/s]

Batch 15250/15588:
total_loss: 0.046305131167173386, mse: 0.0, ce: 0.046305131167173386


Epoch 2/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.24it/s]

Batch 15500/15588:
total_loss: 0.2403099089860916, mse: 0.0, ce: 0.2403099089860916


Epoch 2/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.88it/s]



Epoch Summary:
Train Total Loss: 0.3415 (MSE: 0.0000, CE: 0.3415)
Val Total Loss: 0.2765 (MSE: 0.0000, CE: 0.2765)
Learning Rate: 0.000100
New best model saved with val loss: 0.2765
--------------------------------------------------


Epoch 3/10:   2%|▏         | 250/15588 [00:30<30:57,  8.26it/s]

Batch 250/15588:
total_loss: 0.44107016921043396, mse: 0.0, ce: 0.44107016921043396


Epoch 3/10:   3%|▎         | 500/15588 [01:00<30:33,  8.23it/s]

Batch 500/15588:
total_loss: 0.5421457886695862, mse: 0.0, ce: 0.5421457886695862


Epoch 3/10:   5%|▍         | 750/15588 [01:31<29:57,  8.25it/s]

Batch 750/15588:
total_loss: 0.470371812582016, mse: 0.0, ce: 0.470371812582016


Epoch 3/10:   6%|▋         | 1000/15588 [02:01<29:29,  8.24it/s]

Batch 1000/15588:
total_loss: 0.1438809633255005, mse: 0.0, ce: 0.1438809633255005


Epoch 3/10:   8%|▊         | 1250/15588 [02:31<29:04,  8.22it/s]

Batch 1250/15588:
total_loss: 0.278715044260025, mse: 0.0, ce: 0.278715044260025


Epoch 3/10:  10%|▉         | 1500/15588 [03:02<28:30,  8.23it/s]

Batch 1500/15588:
total_loss: 0.5193100571632385, mse: 0.0, ce: 0.5193100571632385


Epoch 3/10:  11%|█         | 1750/15588 [03:32<28:00,  8.24it/s]

Batch 1750/15588:
total_loss: 0.2527184784412384, mse: 0.0, ce: 0.2527184784412384


Epoch 3/10:  13%|█▎        | 2000/15588 [04:02<27:29,  8.24it/s]

Batch 2000/15588:
total_loss: 0.4716529846191406, mse: 0.0, ce: 0.4716529846191406


Epoch 3/10:  14%|█▍        | 2250/15588 [04:33<26:52,  8.27it/s]

Batch 2250/15588:
total_loss: 0.2538526952266693, mse: 0.0, ce: 0.2538526952266693


Epoch 3/10:  16%|█▌        | 2500/15588 [05:03<26:27,  8.24it/s]

Batch 2500/15588:
total_loss: 0.5930019617080688, mse: 0.0, ce: 0.5930019617080688


Epoch 3/10:  18%|█▊        | 2750/15588 [05:33<25:50,  8.28it/s]

Batch 2750/15588:
total_loss: 0.11200559139251709, mse: 0.0, ce: 0.11200559139251709


Epoch 3/10:  19%|█▉        | 3000/15588 [06:04<25:25,  8.25it/s]

Batch 3000/15588:
total_loss: 0.18036173284053802, mse: 0.0, ce: 0.18036173284053802


Epoch 3/10:  21%|██        | 3250/15588 [06:34<24:58,  8.23it/s]

Batch 3250/15588:
total_loss: 0.1292218714952469, mse: 0.0, ce: 0.1292218714952469


Epoch 3/10:  22%|██▏       | 3500/15588 [07:05<24:26,  8.25it/s]

Batch 3500/15588:
total_loss: 0.4242439866065979, mse: 0.0, ce: 0.4242439866065979


Epoch 3/10:  24%|██▍       | 3750/15588 [07:35<23:55,  8.25it/s]

Batch 3750/15588:
total_loss: 0.3950757086277008, mse: 0.0, ce: 0.3950757086277008


Epoch 3/10:  26%|██▌       | 4000/15588 [08:05<23:25,  8.25it/s]

Batch 4000/15588:
total_loss: 0.22378118336200714, mse: 0.0, ce: 0.22378118336200714


Epoch 3/10:  27%|██▋       | 4250/15588 [08:36<22:54,  8.25it/s]

Batch 4250/15588:
total_loss: 0.21109122037887573, mse: 0.0, ce: 0.21109122037887573


Epoch 3/10:  29%|██▉       | 4500/15588 [09:06<22:27,  8.23it/s]

Batch 4500/15588:
total_loss: 0.15503975749015808, mse: 0.0, ce: 0.15503975749015808


Epoch 3/10:  30%|███       | 4751/15588 [09:36<21:53,  8.25it/s]

Batch 4750/15588:
total_loss: 0.5536734461784363, mse: 0.0, ce: 0.5536734461784363


Epoch 3/10:  32%|███▏      | 5000/15588 [10:07<21:23,  8.25it/s]

Batch 5000/15588:
total_loss: 0.549224317073822, mse: 0.0, ce: 0.549224317073822


Epoch 3/10:  34%|███▎      | 5250/15588 [10:37<20:55,  8.24it/s]

Batch 5250/15588:
total_loss: 0.34015730023384094, mse: 0.0, ce: 0.34015730023384094


Epoch 3/10:  35%|███▌      | 5500/15588 [11:07<20:24,  8.24it/s]

Batch 5500/15588:
total_loss: 0.2198163866996765, mse: 0.0, ce: 0.2198163866996765


Epoch 3/10:  37%|███▋      | 5750/15588 [11:38<19:56,  8.23it/s]

Batch 5750/15588:
total_loss: 0.36983633041381836, mse: 0.0, ce: 0.36983633041381836


Epoch 3/10:  38%|███▊      | 6000/15588 [12:08<19:20,  8.26it/s]

Batch 6000/15588:
total_loss: 0.15589681267738342, mse: 0.0, ce: 0.15589681267738342


Epoch 3/10:  40%|████      | 6250/15588 [12:38<18:52,  8.25it/s]

Batch 6250/15588:
total_loss: 0.30454888939857483, mse: 0.0, ce: 0.30454888939857483


Epoch 3/10:  42%|████▏     | 6500/15588 [13:09<18:23,  8.24it/s]

Batch 6500/15588:
total_loss: 0.3414357900619507, mse: 0.0, ce: 0.3414357900619507


Epoch 3/10:  43%|████▎     | 6750/15588 [13:39<17:55,  8.22it/s]

Batch 6750/15588:
total_loss: 0.15388992428779602, mse: 0.0, ce: 0.15388992428779602


Epoch 3/10:  45%|████▍     | 7000/15588 [14:09<17:23,  8.23it/s]

Batch 7000/15588:
total_loss: 0.3057640790939331, mse: 0.0, ce: 0.3057640790939331


Epoch 3/10:  47%|████▋     | 7250/15588 [14:40<16:48,  8.27it/s]

Batch 7250/15588:
total_loss: 0.07837588340044022, mse: 0.0, ce: 0.07837588340044022


Epoch 3/10:  48%|████▊     | 7500/15588 [15:10<16:21,  8.24it/s]

Batch 7500/15588:
total_loss: 0.12536205351352692, mse: 0.0, ce: 0.12536205351352692


Epoch 3/10:  50%|████▉     | 7750/15588 [15:40<15:51,  8.24it/s]

Batch 7750/15588:
total_loss: 0.0497506707906723, mse: 0.0, ce: 0.0497506707906723


Epoch 3/10:  51%|█████▏    | 8000/15588 [16:11<15:19,  8.25it/s]

Batch 8000/15588:
total_loss: 0.3785448670387268, mse: 0.0, ce: 0.3785448670387268


Epoch 3/10:  53%|█████▎    | 8250/15588 [16:41<14:49,  8.25it/s]

Batch 8250/15588:
total_loss: 0.14627331495285034, mse: 0.0, ce: 0.14627331495285034


Epoch 3/10:  55%|█████▍    | 8500/15588 [17:12<14:19,  8.25it/s]

Batch 8500/15588:
total_loss: 0.4477422833442688, mse: 0.0, ce: 0.4477422833442688


Epoch 3/10:  56%|█████▌    | 8750/15588 [17:42<13:48,  8.26it/s]

Batch 8750/15588:
total_loss: 0.1666882336139679, mse: 0.0, ce: 0.1666882336139679


Epoch 3/10:  58%|█████▊    | 9000/15588 [18:12<13:20,  8.23it/s]

Batch 9000/15588:
total_loss: 0.262448787689209, mse: 0.0, ce: 0.262448787689209


Epoch 3/10:  59%|█████▉    | 9250/15588 [18:43<12:46,  8.27it/s]

Batch 9250/15588:
total_loss: 0.20753127336502075, mse: 0.0, ce: 0.20753127336502075


Epoch 3/10:  61%|██████    | 9500/15588 [19:13<12:19,  8.23it/s]

Batch 9500/15588:
total_loss: 0.22645677626132965, mse: 0.0, ce: 0.22645677626132965


Epoch 3/10:  63%|██████▎   | 9750/15588 [19:43<11:47,  8.26it/s]

Batch 9750/15588:
total_loss: 0.24897412955760956, mse: 0.0, ce: 0.24897412955760956


Epoch 3/10:  64%|██████▍   | 10000/15588 [20:14<11:16,  8.25it/s]

Batch 10000/15588:
total_loss: 0.21405445039272308, mse: 0.0, ce: 0.21405445039272308


Epoch 3/10:  66%|██████▌   | 10250/15588 [20:44<10:49,  8.22it/s]

Batch 10250/15588:
total_loss: 0.12688231468200684, mse: 0.0, ce: 0.12688231468200684


Epoch 3/10:  67%|██████▋   | 10500/15588 [21:14<10:15,  8.26it/s]

Batch 10500/15588:
total_loss: 0.4217899441719055, mse: 0.0, ce: 0.4217899441719055


Epoch 3/10:  69%|██████▉   | 10750/15588 [21:45<09:48,  8.23it/s]

Batch 10750/15588:
total_loss: 0.2524591088294983, mse: 0.0, ce: 0.2524591088294983


Epoch 3/10:  71%|███████   | 11000/15588 [22:15<09:14,  8.27it/s]

Batch 11000/15588:
total_loss: 0.2578107714653015, mse: 0.0, ce: 0.2578107714653015


Epoch 3/10:  72%|███████▏  | 11250/15588 [22:45<08:47,  8.22it/s]

Batch 11250/15588:
total_loss: 0.3202519118785858, mse: 0.0, ce: 0.3202519118785858


Epoch 3/10:  74%|███████▍  | 11500/15588 [23:16<08:16,  8.23it/s]

Batch 11500/15588:
total_loss: 0.06559130549430847, mse: 0.0, ce: 0.06559130549430847


Epoch 3/10:  75%|███████▌  | 11750/15588 [23:46<07:44,  8.26it/s]

Batch 11750/15588:
total_loss: 0.37863001227378845, mse: 0.0, ce: 0.37863001227378845


Epoch 3/10:  77%|███████▋  | 12000/15588 [24:16<07:15,  8.24it/s]

Batch 12000/15588:
total_loss: 0.2532024383544922, mse: 0.0, ce: 0.2532024383544922


Epoch 3/10:  79%|███████▊  | 12251/15588 [24:47<06:44,  8.26it/s]

Batch 12250/15588:
total_loss: 0.27205464243888855, mse: 0.0, ce: 0.27205464243888855


Epoch 3/10:  80%|████████  | 12500/15588 [25:17<06:15,  8.22it/s]

Batch 12500/15588:
total_loss: 0.20915110409259796, mse: 0.0, ce: 0.20915110409259796


Epoch 3/10:  82%|████████▏ | 12750/15588 [25:47<05:44,  8.24it/s]

Batch 12750/15588:
total_loss: 0.205319344997406, mse: 0.0, ce: 0.205319344997406


Epoch 3/10:  83%|████████▎ | 13000/15588 [26:18<05:14,  8.24it/s]

Batch 13000/15588:
total_loss: 0.2486499398946762, mse: 0.0, ce: 0.2486499398946762


Epoch 3/10:  85%|████████▌ | 13250/15588 [26:48<04:43,  8.24it/s]

Batch 13250/15588:
total_loss: 0.17746135592460632, mse: 0.0, ce: 0.17746135592460632


Epoch 3/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.25it/s]

Batch 13500/15588:
total_loss: 0.222633495926857, mse: 0.0, ce: 0.222633495926857


Epoch 3/10:  88%|████████▊ | 13751/15588 [27:49<03:42,  8.27it/s]

Batch 13750/15588:
total_loss: 0.197670578956604, mse: 0.0, ce: 0.197670578956604


Epoch 3/10:  90%|████████▉ | 14000/15588 [28:19<03:13,  8.22it/s]

Batch 14000/15588:
total_loss: 0.24738986790180206, mse: 0.0, ce: 0.24738986790180206


Epoch 3/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.24it/s]

Batch 14250/15588:
total_loss: 0.21810182929039001, mse: 0.0, ce: 0.21810182929039001


Epoch 3/10:  93%|█████████▎| 14500/15588 [29:20<02:12,  8.22it/s]

Batch 14500/15588:
total_loss: 0.29398053884506226, mse: 0.0, ce: 0.29398053884506226


Epoch 3/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.23it/s]

Batch 14750/15588:
total_loss: 0.2504832446575165, mse: 0.0, ce: 0.2504832446575165


Epoch 3/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.22it/s]

Batch 15000/15588:
total_loss: 0.2786954343318939, mse: 0.0, ce: 0.2786954343318939


Epoch 3/10:  98%|█████████▊| 15250/15588 [30:51<00:40,  8.26it/s]

Batch 15250/15588:
total_loss: 0.3827971816062927, mse: 0.0, ce: 0.3827971816062927


Epoch 3/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.27it/s]

Batch 15500/15588:
total_loss: 0.1445256918668747, mse: 0.0, ce: 0.1445256918668747


Epoch 3/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.63it/s]



Epoch Summary:
Train Total Loss: 0.2686 (MSE: 0.0000, CE: 0.2686)
Val Total Loss: 0.2018 (MSE: 0.0000, CE: 0.2018)
Learning Rate: 0.000100
New best model saved with val loss: 0.2018
--------------------------------------------------


Epoch 4/10:   2%|▏         | 250/15588 [00:30<30:56,  8.26it/s]

Batch 250/15588:
total_loss: 0.3889603912830353, mse: 0.0, ce: 0.3889603912830353


Epoch 4/10:   3%|▎         | 500/15588 [01:00<30:32,  8.23it/s]

Batch 500/15588:
total_loss: 0.3460238575935364, mse: 0.0, ce: 0.3460238575935364


Epoch 4/10:   5%|▍         | 750/15588 [01:31<30:01,  8.24it/s]

Batch 750/15588:
total_loss: 0.4700982868671417, mse: 0.0, ce: 0.4700982868671417


Epoch 4/10:   6%|▋         | 1000/15588 [02:01<29:31,  8.23it/s]

Batch 1000/15588:
total_loss: 0.2931751310825348, mse: 0.0, ce: 0.2931751310825348


Epoch 4/10:   8%|▊         | 1250/15588 [02:31<28:59,  8.24it/s]

Batch 1250/15588:
total_loss: 0.18928605318069458, mse: 0.0, ce: 0.18928605318069458


Epoch 4/10:  10%|▉         | 1500/15588 [03:02<28:31,  8.23it/s]

Batch 1500/15588:
total_loss: 0.1917988657951355, mse: 0.0, ce: 0.1917988657951355


Epoch 4/10:  11%|█         | 1750/15588 [03:32<28:00,  8.24it/s]

Batch 1750/15588:
total_loss: 0.22296875715255737, mse: 0.0, ce: 0.22296875715255737


Epoch 4/10:  13%|█▎        | 2000/15588 [04:02<27:23,  8.27it/s]

Batch 2000/15588:
total_loss: 0.2468649297952652, mse: 0.0, ce: 0.2468649297952652


Epoch 4/10:  14%|█▍        | 2250/15588 [04:33<26:58,  8.24it/s]

Batch 2250/15588:
total_loss: 0.19842743873596191, mse: 0.0, ce: 0.19842743873596191


Epoch 4/10:  16%|█▌        | 2500/15588 [05:03<26:30,  8.23it/s]

Batch 2500/15588:
total_loss: 0.11994954943656921, mse: 0.0, ce: 0.11994954943656921


Epoch 4/10:  18%|█▊        | 2750/15588 [05:33<25:51,  8.27it/s]

Batch 2750/15588:
total_loss: 0.14405034482479095, mse: 0.0, ce: 0.14405034482479095


Epoch 4/10:  19%|█▉        | 3000/15588 [06:04<25:25,  8.25it/s]

Batch 3000/15588:
total_loss: 0.27227821946144104, mse: 0.0, ce: 0.27227821946144104


Epoch 4/10:  21%|██        | 3250/15588 [06:34<24:53,  8.26it/s]

Batch 3250/15588:
total_loss: 0.3872939646244049, mse: 0.0, ce: 0.3872939646244049


Epoch 4/10:  22%|██▏       | 3500/15588 [07:05<24:29,  8.23it/s]

Batch 3500/15588:
total_loss: 0.2241605669260025, mse: 0.0, ce: 0.2241605669260025


Epoch 4/10:  24%|██▍       | 3750/15588 [07:35<23:54,  8.25it/s]

Batch 3750/15588:
total_loss: 0.03877301141619682, mse: 0.0, ce: 0.03877301141619682


Epoch 4/10:  26%|██▌       | 4000/15588 [08:05<23:24,  8.25it/s]

Batch 4000/15588:
total_loss: 0.23482418060302734, mse: 0.0, ce: 0.23482418060302734


Epoch 4/10:  27%|██▋       | 4250/15588 [08:36<22:55,  8.24it/s]

Batch 4250/15588:
total_loss: 0.4877718687057495, mse: 0.0, ce: 0.4877718687057495


Epoch 4/10:  29%|██▉       | 4500/15588 [09:06<22:23,  8.25it/s]

Batch 4500/15588:
total_loss: 0.1875505894422531, mse: 0.0, ce: 0.1875505894422531


Epoch 4/10:  30%|███       | 4750/15588 [09:36<21:49,  8.28it/s]

Batch 4750/15588:
total_loss: 0.07290147244930267, mse: 0.0, ce: 0.07290147244930267


Epoch 4/10:  32%|███▏      | 5000/15588 [10:07<21:26,  8.23it/s]

Batch 5000/15588:
total_loss: 0.08707255125045776, mse: 0.0, ce: 0.08707255125045776


Epoch 4/10:  34%|███▎      | 5250/15588 [10:37<20:55,  8.24it/s]

Batch 5250/15588:
total_loss: 0.03996926173567772, mse: 0.0, ce: 0.03996926173567772


Epoch 4/10:  35%|███▌      | 5500/15588 [11:07<20:24,  8.24it/s]

Batch 5500/15588:
total_loss: 0.1577468067407608, mse: 0.0, ce: 0.1577468067407608


Epoch 4/10:  37%|███▋      | 5750/15588 [11:38<19:52,  8.25it/s]

Batch 5750/15588:
total_loss: 0.2522791028022766, mse: 0.0, ce: 0.2522791028022766


Epoch 4/10:  38%|███▊      | 6001/15588 [12:08<19:21,  8.26it/s]

Batch 6000/15588:
total_loss: 0.2746586203575134, mse: 0.0, ce: 0.2746586203575134


Epoch 4/10:  40%|████      | 6250/15588 [12:38<18:53,  8.24it/s]

Batch 6250/15588:
total_loss: 0.10927180200815201, mse: 0.0, ce: 0.10927180200815201


Epoch 4/10:  42%|████▏     | 6500/15588 [13:09<18:19,  8.27it/s]

Batch 6500/15588:
total_loss: 0.20106495916843414, mse: 0.0, ce: 0.20106495916843414


Epoch 4/10:  43%|████▎     | 6750/15588 [13:39<17:52,  8.24it/s]

Batch 6750/15588:
total_loss: 0.20816053450107574, mse: 0.0, ce: 0.20816053450107574


Epoch 4/10:  45%|████▍     | 7000/15588 [14:09<17:21,  8.25it/s]

Batch 7000/15588:
total_loss: 0.16743586957454681, mse: 0.0, ce: 0.16743586957454681


Epoch 4/10:  47%|████▋     | 7250/15588 [14:40<16:53,  8.22it/s]

Batch 7250/15588:
total_loss: 0.06563535332679749, mse: 0.0, ce: 0.06563535332679749


Epoch 4/10:  48%|████▊     | 7500/15588 [15:10<16:20,  8.25it/s]

Batch 7500/15588:
total_loss: 0.3258880078792572, mse: 0.0, ce: 0.3258880078792572


Epoch 4/10:  50%|████▉     | 7750/15588 [15:41<15:49,  8.25it/s]

Batch 7750/15588:
total_loss: 0.04789073392748833, mse: 0.0, ce: 0.04789073392748833


Epoch 4/10:  51%|█████▏    | 8000/15588 [16:11<15:18,  8.26it/s]

Batch 8000/15588:
total_loss: 0.12063005566596985, mse: 0.0, ce: 0.12063005566596985


Epoch 4/10:  53%|█████▎    | 8250/15588 [16:41<14:49,  8.25it/s]

Batch 8250/15588:
total_loss: 0.25087037682533264, mse: 0.0, ce: 0.25087037682533264


Epoch 4/10:  55%|█████▍    | 8500/15588 [17:12<14:21,  8.23it/s]

Batch 8500/15588:
total_loss: 0.12312360852956772, mse: 0.0, ce: 0.12312360852956772


Epoch 4/10:  56%|█████▌    | 8750/15588 [17:42<13:51,  8.23it/s]

Batch 8750/15588:
total_loss: 0.07189787924289703, mse: 0.0, ce: 0.07189787924289703


Epoch 4/10:  58%|█████▊    | 9000/15588 [18:12<13:20,  8.23it/s]

Batch 9000/15588:
total_loss: 0.20378342270851135, mse: 0.0, ce: 0.20378342270851135


Epoch 4/10:  59%|█████▉    | 9250/15588 [18:43<12:46,  8.27it/s]

Batch 9250/15588:
total_loss: 0.047598280012607574, mse: 0.0, ce: 0.047598280012607574


Epoch 4/10:  61%|██████    | 9500/15588 [19:13<12:21,  8.21it/s]

Batch 9500/15588:
total_loss: 0.34985265135765076, mse: 0.0, ce: 0.34985265135765076


Epoch 4/10:  63%|██████▎   | 9750/15588 [19:43<11:46,  8.26it/s]

Batch 9750/15588:
total_loss: 0.11294972151517868, mse: 0.0, ce: 0.11294972151517868


Epoch 4/10:  64%|██████▍   | 10000/15588 [20:14<11:18,  8.24it/s]

Batch 10000/15588:
total_loss: 0.408340722322464, mse: 0.0, ce: 0.408340722322464


Epoch 4/10:  66%|██████▌   | 10250/15588 [20:44<10:48,  8.24it/s]

Batch 10250/15588:
total_loss: 0.09324496239423752, mse: 0.0, ce: 0.09324496239423752


Epoch 4/10:  67%|██████▋   | 10500/15588 [21:14<10:17,  8.24it/s]

Batch 10500/15588:
total_loss: 0.028179334476590157, mse: 0.0, ce: 0.028179334476590157


Epoch 4/10:  69%|██████▉   | 10750/15588 [21:45<09:45,  8.27it/s]

Batch 10750/15588:
total_loss: 0.23758737742900848, mse: 0.0, ce: 0.23758737742900848


Epoch 4/10:  71%|███████   | 11000/15588 [22:15<09:16,  8.25it/s]

Batch 11000/15588:
total_loss: 0.09659792482852936, mse: 0.0, ce: 0.09659792482852936


Epoch 4/10:  72%|███████▏  | 11250/15588 [22:45<08:47,  8.22it/s]

Batch 11250/15588:
total_loss: 0.15948650240898132, mse: 0.0, ce: 0.15948650240898132


Epoch 4/10:  74%|███████▍  | 11500/15588 [23:16<08:15,  8.25it/s]

Batch 11500/15588:
total_loss: 0.4617890417575836, mse: 0.0, ce: 0.4617890417575836


Epoch 4/10:  75%|███████▌  | 11750/15588 [23:46<07:46,  8.22it/s]

Batch 11750/15588:
total_loss: 0.11941453069448471, mse: 0.0, ce: 0.11941453069448471


Epoch 4/10:  77%|███████▋  | 12000/15588 [24:17<07:13,  8.27it/s]

Batch 12000/15588:
total_loss: 0.19055132567882538, mse: 0.0, ce: 0.19055132567882538


Epoch 4/10:  79%|███████▊  | 12251/15588 [24:47<06:44,  8.25it/s]

Batch 12250/15588:
total_loss: 0.3543267548084259, mse: 0.0, ce: 0.3543267548084259


Epoch 4/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.24it/s]

Batch 12500/15588:
total_loss: 0.0505891777575016, mse: 0.0, ce: 0.0505891777575016


Epoch 4/10:  82%|████████▏ | 12750/15588 [25:48<05:44,  8.25it/s]

Batch 12750/15588:
total_loss: 0.2392888069152832, mse: 0.0, ce: 0.2392888069152832


Epoch 4/10:  83%|████████▎ | 13000/15588 [26:18<05:13,  8.25it/s]

Batch 13000/15588:
total_loss: 0.2343469113111496, mse: 0.0, ce: 0.2343469113111496


Epoch 4/10:  85%|████████▌ | 13250/15588 [26:48<04:42,  8.27it/s]

Batch 13250/15588:
total_loss: 0.18414267897605896, mse: 0.0, ce: 0.18414267897605896


Epoch 4/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.25it/s]

Batch 13500/15588:
total_loss: 0.01973840221762657, mse: 0.0, ce: 0.01973840221762657


Epoch 4/10:  88%|████████▊ | 13750/15588 [27:49<03:43,  8.23it/s]

Batch 13750/15588:
total_loss: 0.3015168011188507, mse: 0.0, ce: 0.3015168011188507


Epoch 4/10:  90%|████████▉ | 14000/15588 [28:19<03:12,  8.24it/s]

Batch 14000/15588:
total_loss: 0.19427722692489624, mse: 0.0, ce: 0.19427722692489624


Epoch 4/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.25it/s]

Batch 14250/15588:
total_loss: 0.08341950923204422, mse: 0.0, ce: 0.08341950923204422


Epoch 4/10:  93%|█████████▎| 14500/15588 [29:20<02:12,  8.23it/s]

Batch 14500/15588:
total_loss: 0.45834410190582275, mse: 0.0, ce: 0.45834410190582275


Epoch 4/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.24it/s]

Batch 14750/15588:
total_loss: 0.5176121592521667, mse: 0.0, ce: 0.5176121592521667


Epoch 4/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.23it/s]

Batch 15000/15588:
total_loss: 0.11532868444919586, mse: 0.0, ce: 0.11532868444919586


Epoch 4/10:  98%|█████████▊| 15250/15588 [30:51<00:41,  8.24it/s]

Batch 15250/15588:
total_loss: 0.24642740190029144, mse: 0.0, ce: 0.24642740190029144


Epoch 4/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.24it/s]

Batch 15500/15588:
total_loss: 0.2542201578617096, mse: 0.0, ce: 0.2542201578617096


Epoch 4/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.62it/s]



Epoch Summary:
Train Total Loss: 0.2216 (MSE: 0.0000, CE: 0.2216)
Val Total Loss: 0.1997 (MSE: 0.0000, CE: 0.1997)
Learning Rate: 0.000100
New best model saved with val loss: 0.1997
--------------------------------------------------


Epoch 5/10:   2%|▏         | 250/15588 [00:30<30:55,  8.27it/s]

Batch 250/15588:
total_loss: 0.20123280584812164, mse: 0.0, ce: 0.20123280584812164


Epoch 5/10:   3%|▎         | 500/15588 [01:00<30:31,  8.24it/s]

Batch 500/15588:
total_loss: 0.1671476811170578, mse: 0.0, ce: 0.1671476811170578


Epoch 5/10:   5%|▍         | 750/15588 [01:31<29:55,  8.26it/s]

Batch 750/15588:
total_loss: 0.1958959400653839, mse: 0.0, ce: 0.1958959400653839


Epoch 5/10:   6%|▋         | 1000/15588 [02:01<29:30,  8.24it/s]

Batch 1000/15588:
total_loss: 0.24457257986068726, mse: 0.0, ce: 0.24457257986068726


Epoch 5/10:   8%|▊         | 1250/15588 [02:31<29:04,  8.22it/s]

Batch 1250/15588:
total_loss: 0.07026302069425583, mse: 0.0, ce: 0.07026302069425583


Epoch 5/10:  10%|▉         | 1500/15588 [03:02<28:32,  8.23it/s]

Batch 1500/15588:
total_loss: 0.04322227090597153, mse: 0.0, ce: 0.04322227090597153


Epoch 5/10:  11%|█         | 1750/15588 [03:32<27:55,  8.26it/s]

Batch 1750/15588:
total_loss: 0.4476970136165619, mse: 0.0, ce: 0.4476970136165619


Epoch 5/10:  13%|█▎        | 2000/15588 [04:02<27:26,  8.25it/s]

Batch 2000/15588:
total_loss: 0.14291530847549438, mse: 0.0, ce: 0.14291530847549438


Epoch 5/10:  14%|█▍        | 2251/15588 [04:33<26:57,  8.24it/s]

Batch 2250/15588:
total_loss: 0.16892245411872864, mse: 0.0, ce: 0.16892245411872864


Epoch 5/10:  16%|█▌        | 2500/15588 [05:03<26:26,  8.25it/s]

Batch 2500/15588:
total_loss: 0.15362532436847687, mse: 0.0, ce: 0.15362532436847687


Epoch 5/10:  18%|█▊        | 2750/15588 [05:33<25:57,  8.24it/s]

Batch 2750/15588:
total_loss: 0.2556341886520386, mse: 0.0, ce: 0.2556341886520386


Epoch 5/10:  19%|█▉        | 3001/15588 [06:04<25:18,  8.29it/s]

Batch 3000/15588:
total_loss: 0.25719285011291504, mse: 0.0, ce: 0.25719285011291504


Epoch 5/10:  21%|██        | 3250/15588 [06:34<24:59,  8.23it/s]

Batch 3250/15588:
total_loss: 0.36239907145500183, mse: 0.0, ce: 0.36239907145500183


Epoch 5/10:  22%|██▏       | 3500/15588 [07:05<24:27,  8.24it/s]

Batch 3500/15588:
total_loss: 0.25277623534202576, mse: 0.0, ce: 0.25277623534202576


Epoch 5/10:  24%|██▍       | 3750/15588 [07:35<23:56,  8.24it/s]

Batch 3750/15588:
total_loss: 0.24690216779708862, mse: 0.0, ce: 0.24690216779708862


Epoch 5/10:  26%|██▌       | 4000/15588 [08:05<23:26,  8.24it/s]

Batch 4000/15588:
total_loss: 0.11566781997680664, mse: 0.0, ce: 0.11566781997680664


Epoch 5/10:  27%|██▋       | 4250/15588 [08:36<22:50,  8.27it/s]

Batch 4250/15588:
total_loss: 0.13184168934822083, mse: 0.0, ce: 0.13184168934822083


Epoch 5/10:  29%|██▉       | 4500/15588 [09:06<22:29,  8.22it/s]

Batch 4500/15588:
total_loss: 0.08172334730625153, mse: 0.0, ce: 0.08172334730625153


Epoch 5/10:  30%|███       | 4750/15588 [09:36<21:48,  8.28it/s]

Batch 4750/15588:
total_loss: 0.19501271843910217, mse: 0.0, ce: 0.19501271843910217


Epoch 5/10:  32%|███▏      | 5000/15588 [10:07<21:27,  8.22it/s]

Batch 5000/15588:
total_loss: 0.0371040403842926, mse: 0.0, ce: 0.0371040403842926


Epoch 5/10:  34%|███▎      | 5250/15588 [10:37<20:51,  8.26it/s]

Batch 5250/15588:
total_loss: 0.3036743104457855, mse: 0.0, ce: 0.3036743104457855


Epoch 5/10:  35%|███▌      | 5500/15588 [11:07<20:23,  8.24it/s]

Batch 5500/15588:
total_loss: 0.29324236512184143, mse: 0.0, ce: 0.29324236512184143


Epoch 5/10:  37%|███▋      | 5750/15588 [11:38<19:49,  8.27it/s]

Batch 5750/15588:
total_loss: 0.3525503873825073, mse: 0.0, ce: 0.3525503873825073


Epoch 5/10:  38%|███▊      | 6000/15588 [12:08<19:22,  8.24it/s]

Batch 6000/15588:
total_loss: 0.09730493277311325, mse: 0.0, ce: 0.09730493277311325


Epoch 5/10:  40%|████      | 6250/15588 [12:38<18:53,  8.24it/s]

Batch 6250/15588:
total_loss: 0.19467078149318695, mse: 0.0, ce: 0.19467078149318695


Epoch 5/10:  42%|████▏     | 6500/15588 [13:09<18:18,  8.27it/s]

Batch 6500/15588:
total_loss: 0.31765031814575195, mse: 0.0, ce: 0.31765031814575195


Epoch 5/10:  43%|████▎     | 6750/15588 [13:39<17:54,  8.22it/s]

Batch 6750/15588:
total_loss: 0.3614760637283325, mse: 0.0, ce: 0.3614760637283325


Epoch 5/10:  45%|████▍     | 7000/15588 [14:09<17:21,  8.24it/s]

Batch 7000/15588:
total_loss: 0.16372711956501007, mse: 0.0, ce: 0.16372711956501007


Epoch 5/10:  47%|████▋     | 7250/15588 [14:40<16:51,  8.24it/s]

Batch 7250/15588:
total_loss: 0.1316092163324356, mse: 0.0, ce: 0.1316092163324356


Epoch 5/10:  48%|████▊     | 7500/15588 [15:10<16:20,  8.25it/s]

Batch 7500/15588:
total_loss: 0.07760056853294373, mse: 0.0, ce: 0.07760056853294373


Epoch 5/10:  50%|████▉     | 7750/15588 [15:40<15:52,  8.23it/s]

Batch 7750/15588:
total_loss: 0.23633967339992523, mse: 0.0, ce: 0.23633967339992523


Epoch 5/10:  51%|█████▏    | 8000/15588 [16:11<15:19,  8.25it/s]

Batch 8000/15588:
total_loss: 0.2289106696844101, mse: 0.0, ce: 0.2289106696844101


Epoch 5/10:  53%|█████▎    | 8250/15588 [16:41<14:50,  8.24it/s]

Batch 8250/15588:
total_loss: 0.2210722714662552, mse: 0.0, ce: 0.2210722714662552


Epoch 5/10:  55%|█████▍    | 8500/15588 [17:12<14:20,  8.23it/s]

Batch 8500/15588:
total_loss: 0.29010239243507385, mse: 0.0, ce: 0.29010239243507385


Epoch 5/10:  56%|█████▌    | 8750/15588 [17:42<13:48,  8.26it/s]

Batch 8750/15588:
total_loss: 0.2302921563386917, mse: 0.0, ce: 0.2302921563386917


Epoch 5/10:  58%|█████▊    | 9000/15588 [18:12<13:19,  8.24it/s]

Batch 9000/15588:
total_loss: 0.38258126378059387, mse: 0.0, ce: 0.38258126378059387


Epoch 5/10:  59%|█████▉    | 9250/15588 [18:43<12:48,  8.25it/s]

Batch 9250/15588:
total_loss: 0.14231209456920624, mse: 0.0, ce: 0.14231209456920624


Epoch 5/10:  61%|██████    | 9500/15588 [19:13<12:18,  8.24it/s]

Batch 9500/15588:
total_loss: 0.2189856916666031, mse: 0.0, ce: 0.2189856916666031


Epoch 5/10:  63%|██████▎   | 9750/15588 [19:43<11:48,  8.24it/s]

Batch 9750/15588:
total_loss: 0.10095132887363434, mse: 0.0, ce: 0.10095132887363434


Epoch 5/10:  64%|██████▍   | 10000/15588 [20:14<11:17,  8.25it/s]

Batch 10000/15588:
total_loss: 0.13643504679203033, mse: 0.0, ce: 0.13643504679203033


Epoch 5/10:  66%|██████▌   | 10250/15588 [20:44<10:47,  8.25it/s]

Batch 10250/15588:
total_loss: 0.09427856653928757, mse: 0.0, ce: 0.09427856653928757


Epoch 5/10:  67%|██████▋   | 10500/15588 [21:14<10:19,  8.21it/s]

Batch 10500/15588:
total_loss: 0.14369472861289978, mse: 0.0, ce: 0.14369472861289978


Epoch 5/10:  69%|██████▉   | 10750/15588 [21:45<09:46,  8.25it/s]

Batch 10750/15588:
total_loss: 0.1988755315542221, mse: 0.0, ce: 0.1988755315542221


Epoch 5/10:  71%|███████   | 11000/15588 [22:15<09:16,  8.24it/s]

Batch 11000/15588:
total_loss: 0.06145906820893288, mse: 0.0, ce: 0.06145906820893288


Epoch 5/10:  72%|███████▏  | 11250/15588 [22:45<08:45,  8.26it/s]

Batch 11250/15588:
total_loss: 0.25224658846855164, mse: 0.0, ce: 0.25224658846855164


Epoch 5/10:  74%|███████▍  | 11500/15588 [23:16<08:16,  8.24it/s]

Batch 11500/15588:
total_loss: 0.09349138289690018, mse: 0.0, ce: 0.09349138289690018


Epoch 5/10:  75%|███████▌  | 11750/15588 [23:46<07:45,  8.24it/s]

Batch 11750/15588:
total_loss: 0.08913377672433853, mse: 0.0, ce: 0.08913377672433853


Epoch 5/10:  77%|███████▋  | 12000/15588 [24:16<07:14,  8.25it/s]

Batch 12000/15588:
total_loss: 0.07758984714746475, mse: 0.0, ce: 0.07758984714746475


Epoch 5/10:  79%|███████▊  | 12250/15588 [24:47<06:44,  8.24it/s]

Batch 12250/15588:
total_loss: 0.23482690751552582, mse: 0.0, ce: 0.23482690751552582


Epoch 5/10:  80%|████████  | 12500/15588 [25:17<06:15,  8.21it/s]

Batch 12500/15588:
total_loss: 0.2546263635158539, mse: 0.0, ce: 0.2546263635158539


Epoch 5/10:  82%|████████▏ | 12750/15588 [25:48<05:44,  8.24it/s]

Batch 12750/15588:
total_loss: 0.2933141589164734, mse: 0.0, ce: 0.2933141589164734


Epoch 5/10:  83%|████████▎ | 13000/15588 [26:18<05:13,  8.26it/s]

Batch 13000/15588:
total_loss: 0.06405265629291534, mse: 0.0, ce: 0.06405265629291534


Epoch 5/10:  85%|████████▌ | 13250/15588 [26:48<04:44,  8.22it/s]

Batch 13250/15588:
total_loss: 0.22650659084320068, mse: 0.0, ce: 0.22650659084320068


Epoch 5/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.24it/s]

Batch 13500/15588:
total_loss: 0.13694417476654053, mse: 0.0, ce: 0.13694417476654053


Epoch 5/10:  88%|████████▊ | 13750/15588 [27:49<03:43,  8.21it/s]

Batch 13750/15588:
total_loss: 0.11670918762683868, mse: 0.0, ce: 0.11670918762683868


Epoch 5/10:  90%|████████▉ | 14000/15588 [28:19<03:12,  8.24it/s]

Batch 14000/15588:
total_loss: 0.14379246532917023, mse: 0.0, ce: 0.14379246532917023


Epoch 5/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.24it/s]

Batch 14250/15588:
total_loss: 0.07978928089141846, mse: 0.0, ce: 0.07978928089141846


Epoch 5/10:  93%|█████████▎| 14500/15588 [29:20<02:11,  8.25it/s]

Batch 14500/15588:
total_loss: 0.17728523910045624, mse: 0.0, ce: 0.17728523910045624


Epoch 5/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.25it/s]

Batch 14750/15588:
total_loss: 0.12449350953102112, mse: 0.0, ce: 0.12449350953102112


Epoch 5/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.22it/s]

Batch 15000/15588:
total_loss: 0.05578715726733208, mse: 0.0, ce: 0.05578715726733208


Epoch 5/10:  98%|█████████▊| 15250/15588 [30:51<00:41,  8.24it/s]

Batch 15250/15588:
total_loss: 0.2997813820838928, mse: 0.0, ce: 0.2997813820838928


Epoch 5/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.23it/s]

Batch 15500/15588:
total_loss: 0.15891395509243011, mse: 0.0, ce: 0.15891395509243011


Epoch 5/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.75it/s]



Epoch Summary:
Train Total Loss: 0.1847 (MSE: 0.0000, CE: 0.1847)
Val Total Loss: 0.1551 (MSE: 0.0000, CE: 0.1551)
Learning Rate: 0.000100
New best model saved with val loss: 0.1551
--------------------------------------------------


Epoch 6/10:   2%|▏         | 250/15588 [00:30<31:01,  8.24it/s]

Batch 250/15588:
total_loss: 0.09874825924634933, mse: 0.0, ce: 0.09874825924634933


Epoch 6/10:   3%|▎         | 500/15588 [01:00<30:32,  8.23it/s]

Batch 500/15588:
total_loss: 0.17904706299304962, mse: 0.0, ce: 0.17904706299304962


Epoch 6/10:   5%|▍         | 750/15588 [01:31<29:58,  8.25it/s]

Batch 750/15588:
total_loss: 0.10651156306266785, mse: 0.0, ce: 0.10651156306266785


Epoch 6/10:   6%|▋         | 1000/15588 [02:01<29:29,  8.25it/s]

Batch 1000/15588:
total_loss: 0.13684919476509094, mse: 0.0, ce: 0.13684919476509094


Epoch 6/10:   8%|▊         | 1250/15588 [02:31<29:00,  8.24it/s]

Batch 1250/15588:
total_loss: 0.07599452883005142, mse: 0.0, ce: 0.07599452883005142


Epoch 6/10:  10%|▉         | 1500/15588 [03:02<28:30,  8.24it/s]

Batch 1500/15588:
total_loss: 0.11385690420866013, mse: 0.0, ce: 0.11385690420866013


Epoch 6/10:  11%|█         | 1750/15588 [03:32<27:59,  8.24it/s]

Batch 1750/15588:
total_loss: 0.24309511482715607, mse: 0.0, ce: 0.24309511482715607


Epoch 6/10:  13%|█▎        | 2000/15588 [04:02<27:24,  8.27it/s]

Batch 2000/15588:
total_loss: 0.02286950871348381, mse: 0.0, ce: 0.02286950871348381


Epoch 6/10:  14%|█▍        | 2250/15588 [04:33<27:05,  8.21it/s]

Batch 2250/15588:
total_loss: 0.09122760593891144, mse: 0.0, ce: 0.09122760593891144


Epoch 6/10:  16%|█▌        | 2500/15588 [05:03<26:26,  8.25it/s]

Batch 2500/15588:
total_loss: 0.2976413667201996, mse: 0.0, ce: 0.2976413667201996


Epoch 6/10:  18%|█▊        | 2750/15588 [05:33<26:00,  8.23it/s]

Batch 2750/15588:
total_loss: 0.2967345416545868, mse: 0.0, ce: 0.2967345416545868


Epoch 6/10:  19%|█▉        | 3000/15588 [06:04<25:27,  8.24it/s]

Batch 3000/15588:
total_loss: 0.09353595227003098, mse: 0.0, ce: 0.09353595227003098


Epoch 6/10:  21%|██        | 3250/15588 [06:34<25:00,  8.23it/s]

Batch 3250/15588:
total_loss: 0.13353954255580902, mse: 0.0, ce: 0.13353954255580902


Epoch 6/10:  22%|██▏       | 3500/15588 [07:05<24:26,  8.24it/s]

Batch 3500/15588:
total_loss: 0.1731114685535431, mse: 0.0, ce: 0.1731114685535431


Epoch 6/10:  24%|██▍       | 3750/15588 [07:35<23:52,  8.26it/s]

Batch 3750/15588:
total_loss: 0.22143664956092834, mse: 0.0, ce: 0.22143664956092834


Epoch 6/10:  26%|██▌       | 4000/15588 [08:05<23:25,  8.24it/s]

Batch 4000/15588:
total_loss: 0.21830567717552185, mse: 0.0, ce: 0.21830567717552185


Epoch 6/10:  27%|██▋       | 4250/15588 [08:36<22:55,  8.24it/s]

Batch 4250/15588:
total_loss: 0.2050786167383194, mse: 0.0, ce: 0.2050786167383194


Epoch 6/10:  29%|██▉       | 4500/15588 [09:06<22:27,  8.23it/s]

Batch 4500/15588:
total_loss: 0.09812697023153305, mse: 0.0, ce: 0.09812697023153305


Epoch 6/10:  30%|███       | 4750/15588 [09:36<21:54,  8.25it/s]

Batch 4750/15588:
total_loss: 0.13499361276626587, mse: 0.0, ce: 0.13499361276626587


Epoch 6/10:  32%|███▏      | 5000/15588 [10:07<21:30,  8.20it/s]

Batch 5000/15588:
total_loss: 0.132181316614151, mse: 0.0, ce: 0.132181316614151


Epoch 6/10:  34%|███▎      | 5250/15588 [10:37<20:54,  8.24it/s]

Batch 5250/15588:
total_loss: 0.3049050569534302, mse: 0.0, ce: 0.3049050569534302


Epoch 6/10:  35%|███▌      | 5500/15588 [11:07<20:27,  8.22it/s]

Batch 5500/15588:
total_loss: 0.17935508489608765, mse: 0.0, ce: 0.17935508489608765


Epoch 6/10:  37%|███▋      | 5750/15588 [11:38<19:53,  8.24it/s]

Batch 5750/15588:
total_loss: 0.15789149701595306, mse: 0.0, ce: 0.15789149701595306


Epoch 6/10:  38%|███▊      | 6000/15588 [12:08<19:23,  8.24it/s]

Batch 6000/15588:
total_loss: 0.299396276473999, mse: 0.0, ce: 0.299396276473999


Epoch 6/10:  40%|████      | 6250/15588 [12:38<18:55,  8.23it/s]

Batch 6250/15588:
total_loss: 0.26859527826309204, mse: 0.0, ce: 0.26859527826309204


Epoch 6/10:  42%|████▏     | 6500/15588 [13:09<18:22,  8.24it/s]

Batch 6500/15588:
total_loss: 0.06730251759290695, mse: 0.0, ce: 0.06730251759290695


Epoch 6/10:  43%|████▎     | 6751/15588 [13:39<17:45,  8.30it/s]

Batch 6750/15588:
total_loss: 0.049404699355363846, mse: 0.0, ce: 0.049404699355363846


Epoch 6/10:  45%|████▍     | 7000/15588 [14:09<17:18,  8.27it/s]

Batch 7000/15588:
total_loss: 0.045266043394804, mse: 0.0, ce: 0.045266043394804


Epoch 6/10:  47%|████▋     | 7250/15588 [14:40<16:54,  8.22it/s]

Batch 7250/15588:
total_loss: 0.0486094169318676, mse: 0.0, ce: 0.0486094169318676


Epoch 6/10:  48%|████▊     | 7500/15588 [15:10<16:20,  8.25it/s]

Batch 7500/15588:
total_loss: 0.06450741738080978, mse: 0.0, ce: 0.06450741738080978


Epoch 6/10:  50%|████▉     | 7750/15588 [15:40<15:53,  8.22it/s]

Batch 7750/15588:
total_loss: 0.13104815781116486, mse: 0.0, ce: 0.13104815781116486


Epoch 6/10:  51%|█████▏    | 8000/15588 [16:11<15:19,  8.25it/s]

Batch 8000/15588:
total_loss: 0.09191123396158218, mse: 0.0, ce: 0.09191123396158218


Epoch 6/10:  53%|█████▎    | 8250/15588 [16:41<14:49,  8.25it/s]

Batch 8250/15588:
total_loss: 0.3290449380874634, mse: 0.0, ce: 0.3290449380874634


Epoch 6/10:  55%|█████▍    | 8500/15588 [17:12<14:19,  8.25it/s]

Batch 8500/15588:
total_loss: 0.0773661732673645, mse: 0.0, ce: 0.0773661732673645


Epoch 6/10:  56%|█████▌    | 8750/15588 [17:42<13:51,  8.23it/s]

Batch 8750/15588:
total_loss: 0.2009262889623642, mse: 0.0, ce: 0.2009262889623642


Epoch 6/10:  58%|█████▊    | 9000/15588 [18:12<13:19,  8.24it/s]

Batch 9000/15588:
total_loss: 0.1875857412815094, mse: 0.0, ce: 0.1875857412815094


Epoch 6/10:  59%|█████▉    | 9250/15588 [18:43<12:49,  8.24it/s]

Batch 9250/15588:
total_loss: 0.19020068645477295, mse: 0.0, ce: 0.19020068645477295


Epoch 6/10:  61%|██████    | 9500/15588 [19:13<12:20,  8.22it/s]

Batch 9500/15588:
total_loss: 0.12187269330024719, mse: 0.0, ce: 0.12187269330024719


Epoch 6/10:  63%|██████▎   | 9750/15588 [19:43<11:47,  8.25it/s]

Batch 9750/15588:
total_loss: 0.15200099349021912, mse: 0.0, ce: 0.15200099349021912


Epoch 6/10:  64%|██████▍   | 10000/15588 [20:14<11:17,  8.25it/s]

Batch 10000/15588:
total_loss: 0.2034289836883545, mse: 0.0, ce: 0.2034289836883545


Epoch 6/10:  66%|██████▌   | 10250/15588 [20:44<10:45,  8.26it/s]

Batch 10250/15588:
total_loss: 0.10442487895488739, mse: 0.0, ce: 0.10442487895488739


Epoch 6/10:  67%|██████▋   | 10500/15588 [21:14<10:19,  8.21it/s]

Batch 10500/15588:
total_loss: 0.2205609828233719, mse: 0.0, ce: 0.2205609828233719


Epoch 6/10:  69%|██████▉   | 10750/15588 [21:45<09:45,  8.26it/s]

Batch 10750/15588:
total_loss: 0.2441788613796234, mse: 0.0, ce: 0.2441788613796234


Epoch 6/10:  71%|███████   | 11000/15588 [22:15<09:15,  8.26it/s]

Batch 11000/15588:
total_loss: 0.2994532883167267, mse: 0.0, ce: 0.2994532883167267


Epoch 6/10:  72%|███████▏  | 11250/15588 [22:45<08:47,  8.23it/s]

Batch 11250/15588:
total_loss: 0.1510723978281021, mse: 0.0, ce: 0.1510723978281021


Epoch 6/10:  74%|███████▍  | 11500/15588 [23:16<08:14,  8.26it/s]

Batch 11500/15588:
total_loss: 0.14043587446212769, mse: 0.0, ce: 0.14043587446212769


Epoch 6/10:  75%|███████▌  | 11750/15588 [23:46<07:46,  8.23it/s]

Batch 11750/15588:
total_loss: 0.12142268568277359, mse: 0.0, ce: 0.12142268568277359


Epoch 6/10:  77%|███████▋  | 12000/15588 [24:16<07:15,  8.25it/s]

Batch 12000/15588:
total_loss: 0.11692774295806885, mse: 0.0, ce: 0.11692774295806885


Epoch 6/10:  79%|███████▊  | 12250/15588 [24:47<06:45,  8.23it/s]

Batch 12250/15588:
total_loss: 0.08585529774427414, mse: 0.0, ce: 0.08585529774427414


Epoch 6/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.25it/s]

Batch 12500/15588:
total_loss: 0.05440047010779381, mse: 0.0, ce: 0.05440047010779381


Epoch 6/10:  82%|████████▏ | 12750/15588 [25:47<05:44,  8.23it/s]

Batch 12750/15588:
total_loss: 0.13763219118118286, mse: 0.0, ce: 0.13763219118118286


Epoch 6/10:  83%|████████▎ | 13001/15588 [26:18<05:11,  8.30it/s]

Batch 13000/15588:
total_loss: 0.16456103324890137, mse: 0.0, ce: 0.16456103324890137


Epoch 6/10:  85%|████████▌ | 13250/15588 [26:48<04:44,  8.22it/s]

Batch 13250/15588:
total_loss: 0.05938032269477844, mse: 0.0, ce: 0.05938032269477844


Epoch 6/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.23it/s]

Batch 13500/15588:
total_loss: 0.07642203569412231, mse: 0.0, ce: 0.07642203569412231


Epoch 6/10:  88%|████████▊ | 13751/15588 [27:49<03:41,  8.27it/s]

Batch 13750/15588:
total_loss: 0.09913943707942963, mse: 0.0, ce: 0.09913943707942963


Epoch 6/10:  90%|████████▉ | 14001/15588 [28:19<03:11,  8.28it/s]

Batch 14000/15588:
total_loss: 0.18682190775871277, mse: 0.0, ce: 0.18682190775871277


Epoch 6/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.23it/s]

Batch 14250/15588:
total_loss: 0.2829393446445465, mse: 0.0, ce: 0.2829393446445465


Epoch 6/10:  93%|█████████▎| 14500/15588 [29:20<02:12,  8.23it/s]

Batch 14500/15588:
total_loss: 0.3116958737373352, mse: 0.0, ce: 0.3116958737373352


Epoch 6/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.23it/s]

Batch 14750/15588:
total_loss: 0.2790102958679199, mse: 0.0, ce: 0.2790102958679199


Epoch 6/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.23it/s]

Batch 15000/15588:
total_loss: 0.0853685513138771, mse: 0.0, ce: 0.0853685513138771


Epoch 6/10:  98%|█████████▊| 15250/15588 [30:51<00:40,  8.25it/s]

Batch 15250/15588:
total_loss: 0.07803672552108765, mse: 0.0, ce: 0.07803672552108765


Epoch 6/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.22it/s]

Batch 15500/15588:
total_loss: 0.03559144586324692, mse: 0.0, ce: 0.03559144586324692


Epoch 6/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 24.01it/s]



Epoch Summary:
Train Total Loss: 0.1612 (MSE: 0.0000, CE: 0.1612)
Val Total Loss: 0.1499 (MSE: 0.0000, CE: 0.1499)
Learning Rate: 0.000100
New best model saved with val loss: 0.1499
--------------------------------------------------


Epoch 7/10:   2%|▏         | 250/15588 [00:30<31:04,  8.23it/s]

Batch 250/15588:
total_loss: 0.13912150263786316, mse: 0.0, ce: 0.13912150263786316


Epoch 7/10:   3%|▎         | 500/15588 [01:00<30:29,  8.25it/s]

Batch 500/15588:
total_loss: 0.21339312195777893, mse: 0.0, ce: 0.21339312195777893


Epoch 7/10:   5%|▍         | 750/15588 [01:31<29:58,  8.25it/s]

Batch 750/15588:
total_loss: 0.2637976109981537, mse: 0.0, ce: 0.2637976109981537


Epoch 7/10:   6%|▋         | 1000/15588 [02:01<29:32,  8.23it/s]

Batch 1000/15588:
total_loss: 0.05732693895697594, mse: 0.0, ce: 0.05732693895697594


Epoch 7/10:   8%|▊         | 1250/15588 [02:31<29:05,  8.22it/s]

Batch 1250/15588:
total_loss: 0.24883583188056946, mse: 0.0, ce: 0.24883583188056946


Epoch 7/10:  10%|▉         | 1500/15588 [03:02<28:25,  8.26it/s]

Batch 1500/15588:
total_loss: 0.06387753784656525, mse: 0.0, ce: 0.06387753784656525


Epoch 7/10:  11%|█         | 1750/15588 [03:32<28:04,  8.22it/s]

Batch 1750/15588:
total_loss: 0.04967360571026802, mse: 0.0, ce: 0.04967360571026802


Epoch 7/10:  13%|█▎        | 2000/15588 [04:02<27:28,  8.24it/s]

Batch 2000/15588:
total_loss: 0.11452826857566833, mse: 0.0, ce: 0.11452826857566833


Epoch 7/10:  14%|█▍        | 2250/15588 [04:33<26:58,  8.24it/s]

Batch 2250/15588:
total_loss: 0.07390515506267548, mse: 0.0, ce: 0.07390515506267548


Epoch 7/10:  16%|█▌        | 2500/15588 [05:03<26:28,  8.24it/s]

Batch 2500/15588:
total_loss: 0.10235855728387833, mse: 0.0, ce: 0.10235855728387833


Epoch 7/10:  18%|█▊        | 2750/15588 [05:33<25:58,  8.24it/s]

Batch 2750/15588:
total_loss: 0.4392741620540619, mse: 0.0, ce: 0.4392741620540619


Epoch 7/10:  19%|█▉        | 3000/15588 [06:04<25:26,  8.25it/s]

Batch 3000/15588:
total_loss: 0.20658008754253387, mse: 0.0, ce: 0.20658008754253387


Epoch 7/10:  21%|██        | 3250/15588 [06:34<25:00,  8.22it/s]

Batch 3250/15588:
total_loss: 0.07376160472631454, mse: 0.0, ce: 0.07376160472631454


Epoch 7/10:  22%|██▏       | 3500/15588 [07:04<24:21,  8.27it/s]

Batch 3500/15588:
total_loss: 0.3088436424732208, mse: 0.0, ce: 0.3088436424732208


Epoch 7/10:  24%|██▍       | 3750/15588 [07:35<24:00,  8.22it/s]

Batch 3750/15588:
total_loss: 0.2748408913612366, mse: 0.0, ce: 0.2748408913612366


Epoch 7/10:  26%|██▌       | 4000/15588 [08:05<23:30,  8.22it/s]

Batch 4000/15588:
total_loss: 0.0800398662686348, mse: 0.0, ce: 0.0800398662686348


Epoch 7/10:  27%|██▋       | 4250/15588 [08:36<22:52,  8.26it/s]

Batch 4250/15588:
total_loss: 0.15360985696315765, mse: 0.0, ce: 0.15360985696315765


Epoch 7/10:  29%|██▉       | 4500/15588 [09:06<22:26,  8.24it/s]

Batch 4500/15588:
total_loss: 0.12067419290542603, mse: 0.0, ce: 0.12067419290542603


Epoch 7/10:  30%|███       | 4750/15588 [09:36<21:55,  8.24it/s]

Batch 4750/15588:
total_loss: 0.2003580778837204, mse: 0.0, ce: 0.2003580778837204


Epoch 7/10:  32%|███▏      | 5000/15588 [10:07<21:23,  8.25it/s]

Batch 5000/15588:
total_loss: 0.2688574492931366, mse: 0.0, ce: 0.2688574492931366


Epoch 7/10:  34%|███▎      | 5250/15588 [10:37<20:55,  8.23it/s]

Batch 5250/15588:
total_loss: 0.15283353626728058, mse: 0.0, ce: 0.15283353626728058


Epoch 7/10:  35%|███▌      | 5500/15588 [11:07<20:23,  8.24it/s]

Batch 5500/15588:
total_loss: 0.20162761211395264, mse: 0.0, ce: 0.20162761211395264


Epoch 7/10:  37%|███▋      | 5750/15588 [11:38<19:56,  8.22it/s]

Batch 5750/15588:
total_loss: 0.24747037887573242, mse: 0.0, ce: 0.24747037887573242


Epoch 7/10:  38%|███▊      | 6000/15588 [12:08<19:23,  8.24it/s]

Batch 6000/15588:
total_loss: 0.13708658516407013, mse: 0.0, ce: 0.13708658516407013


Epoch 7/10:  40%|████      | 6250/15588 [12:38<18:51,  8.25it/s]

Batch 6250/15588:
total_loss: 0.06532704830169678, mse: 0.0, ce: 0.06532704830169678


Epoch 7/10:  42%|████▏     | 6500/15588 [13:09<18:18,  8.27it/s]

Batch 6500/15588:
total_loss: 0.3585783839225769, mse: 0.0, ce: 0.3585783839225769


Epoch 7/10:  43%|████▎     | 6750/15588 [13:39<17:47,  8.28it/s]

Batch 6750/15588:
total_loss: 0.13600938022136688, mse: 0.0, ce: 0.13600938022136688


Epoch 7/10:  45%|████▍     | 7000/15588 [14:09<17:22,  8.24it/s]

Batch 7000/15588:
total_loss: 0.14080050587654114, mse: 0.0, ce: 0.14080050587654114


Epoch 7/10:  47%|████▋     | 7250/15588 [14:40<16:53,  8.22it/s]

Batch 7250/15588:
total_loss: 0.0437176376581192, mse: 0.0, ce: 0.0437176376581192


Epoch 7/10:  48%|████▊     | 7500/15588 [15:10<16:22,  8.24it/s]

Batch 7500/15588:
total_loss: 0.12836097180843353, mse: 0.0, ce: 0.12836097180843353


Epoch 7/10:  50%|████▉     | 7750/15588 [15:41<15:52,  8.23it/s]

Batch 7750/15588:
total_loss: 0.15953408181667328, mse: 0.0, ce: 0.15953408181667328


Epoch 7/10:  51%|█████▏    | 8000/15588 [16:11<15:22,  8.22it/s]

Batch 8000/15588:
total_loss: 0.10101885348558426, mse: 0.0, ce: 0.10101885348558426


Epoch 7/10:  53%|█████▎    | 8250/15588 [16:41<14:52,  8.22it/s]

Batch 8250/15588:
total_loss: 0.288495272397995, mse: 0.0, ce: 0.288495272397995


Epoch 7/10:  55%|█████▍    | 8500/15588 [17:12<14:20,  8.24it/s]

Batch 8500/15588:
total_loss: 0.1390305906534195, mse: 0.0, ce: 0.1390305906534195


Epoch 7/10:  56%|█████▌    | 8750/15588 [17:42<13:49,  8.24it/s]

Batch 8750/15588:
total_loss: 0.12732850015163422, mse: 0.0, ce: 0.12732850015163422


Epoch 7/10:  58%|█████▊    | 9000/15588 [18:12<13:17,  8.26it/s]

Batch 9000/15588:
total_loss: 0.19725549221038818, mse: 0.0, ce: 0.19725549221038818


Epoch 7/10:  59%|█████▉    | 9250/15588 [18:43<12:48,  8.25it/s]

Batch 9250/15588:
total_loss: 0.08230698853731155, mse: 0.0, ce: 0.08230698853731155


Epoch 7/10:  61%|██████    | 9500/15588 [19:13<12:19,  8.23it/s]

Batch 9500/15588:
total_loss: 0.0532044842839241, mse: 0.0, ce: 0.0532044842839241


Epoch 7/10:  63%|██████▎   | 9750/15588 [19:43<11:47,  8.25it/s]

Batch 9750/15588:
total_loss: 0.20594026148319244, mse: 0.0, ce: 0.20594026148319244


Epoch 7/10:  64%|██████▍   | 10000/15588 [20:14<11:20,  8.21it/s]

Batch 10000/15588:
total_loss: 0.280170738697052, mse: 0.0, ce: 0.280170738697052


Epoch 7/10:  66%|██████▌   | 10250/15588 [20:44<10:48,  8.23it/s]

Batch 10250/15588:
total_loss: 0.18405012786388397, mse: 0.0, ce: 0.18405012786388397


Epoch 7/10:  67%|██████▋   | 10500/15588 [21:14<10:19,  8.22it/s]

Batch 10500/15588:
total_loss: 0.1404106169939041, mse: 0.0, ce: 0.1404106169939041


Epoch 7/10:  69%|██████▉   | 10750/15588 [21:45<09:45,  8.26it/s]

Batch 10750/15588:
total_loss: 0.20245787501335144, mse: 0.0, ce: 0.20245787501335144


Epoch 7/10:  71%|███████   | 11000/15588 [22:15<09:16,  8.24it/s]

Batch 11000/15588:
total_loss: 0.25143203139305115, mse: 0.0, ce: 0.25143203139305115


Epoch 7/10:  72%|███████▏  | 11250/15588 [22:45<08:46,  8.24it/s]

Batch 11250/15588:
total_loss: 0.04246583580970764, mse: 0.0, ce: 0.04246583580970764


Epoch 7/10:  74%|███████▍  | 11500/15588 [23:16<08:15,  8.25it/s]

Batch 11500/15588:
total_loss: 0.09223375469446182, mse: 0.0, ce: 0.09223375469446182


Epoch 7/10:  75%|███████▌  | 11750/15588 [23:46<07:44,  8.26it/s]

Batch 11750/15588:
total_loss: 0.20669668912887573, mse: 0.0, ce: 0.20669668912887573


Epoch 7/10:  77%|███████▋  | 12000/15588 [24:17<07:14,  8.25it/s]

Batch 12000/15588:
total_loss: 0.040816403925418854, mse: 0.0, ce: 0.040816403925418854


Epoch 7/10:  79%|███████▊  | 12250/15588 [24:47<06:46,  8.21it/s]

Batch 12250/15588:
total_loss: 0.23671740293502808, mse: 0.0, ce: 0.23671740293502808


Epoch 7/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.26it/s]

Batch 12500/15588:
total_loss: 0.33660000562667847, mse: 0.0, ce: 0.33660000562667847


Epoch 7/10:  82%|████████▏ | 12750/15588 [25:48<05:44,  8.23it/s]

Batch 12750/15588:
total_loss: 0.27951690554618835, mse: 0.0, ce: 0.27951690554618835


Epoch 7/10:  83%|████████▎ | 13000/15588 [26:18<05:13,  8.25it/s]

Batch 13000/15588:
total_loss: 0.19035346806049347, mse: 0.0, ce: 0.19035346806049347


Epoch 7/10:  85%|████████▌ | 13250/15588 [26:48<04:44,  8.21it/s]

Batch 13250/15588:
total_loss: 0.07151027768850327, mse: 0.0, ce: 0.07151027768850327


Epoch 7/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.25it/s]

Batch 13500/15588:
total_loss: 0.11522039771080017, mse: 0.0, ce: 0.11522039771080017


Epoch 7/10:  88%|████████▊ | 13750/15588 [27:49<03:42,  8.25it/s]

Batch 13750/15588:
total_loss: 0.14732354879379272, mse: 0.0, ce: 0.14732354879379272


Epoch 7/10:  90%|████████▉ | 14000/15588 [28:19<03:12,  8.25it/s]

Batch 14000/15588:
total_loss: 0.10303767025470734, mse: 0.0, ce: 0.10303767025470734


Epoch 7/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.24it/s]

Batch 14250/15588:
total_loss: 0.11843124777078629, mse: 0.0, ce: 0.11843124777078629


Epoch 7/10:  93%|█████████▎| 14500/15588 [29:20<02:11,  8.27it/s]

Batch 14500/15588:
total_loss: 0.16730670630931854, mse: 0.0, ce: 0.16730670630931854


Epoch 7/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.25it/s]

Batch 14750/15588:
total_loss: 0.10223925113677979, mse: 0.0, ce: 0.10223925113677979


Epoch 7/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.22it/s]

Batch 15000/15588:
total_loss: 0.14153045415878296, mse: 0.0, ce: 0.14153045415878296


Epoch 7/10:  98%|█████████▊| 15250/15588 [30:51<00:40,  8.26it/s]

Batch 15250/15588:
total_loss: 0.043373819440603256, mse: 0.0, ce: 0.043373819440603256


Epoch 7/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.22it/s]

Batch 15500/15588:
total_loss: 0.05334024876356125, mse: 0.0, ce: 0.05334024876356125


Epoch 7/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.53it/s]



Epoch Summary:
Train Total Loss: 0.1409 (MSE: 0.0000, CE: 0.1409)
Val Total Loss: 0.1301 (MSE: 0.0000, CE: 0.1301)
Learning Rate: 0.000100
New best model saved with val loss: 0.1301
--------------------------------------------------


Epoch 8/10:   2%|▏         | 250/15588 [00:30<31:00,  8.24it/s]

Batch 250/15588:
total_loss: 0.10616699606180191, mse: 0.0, ce: 0.10616699606180191


Epoch 8/10:   3%|▎         | 500/15588 [01:00<30:26,  8.26it/s]

Batch 500/15588:
total_loss: 0.07173410803079605, mse: 0.0, ce: 0.07173410803079605


Epoch 8/10:   5%|▍         | 750/15588 [01:31<29:56,  8.26it/s]

Batch 750/15588:
total_loss: 0.16073660552501678, mse: 0.0, ce: 0.16073660552501678


Epoch 8/10:   6%|▋         | 1000/15588 [02:01<29:30,  8.24it/s]

Batch 1000/15588:
total_loss: 0.22041632235050201, mse: 0.0, ce: 0.22041632235050201


Epoch 8/10:   8%|▊         | 1250/15588 [02:31<28:54,  8.27it/s]

Batch 1250/15588:
total_loss: 0.2117094248533249, mse: 0.0, ce: 0.2117094248533249


Epoch 8/10:  10%|▉         | 1500/15588 [03:02<28:27,  8.25it/s]

Batch 1500/15588:
total_loss: 0.1165105476975441, mse: 0.0, ce: 0.1165105476975441


Epoch 8/10:  11%|█         | 1750/15588 [03:32<27:58,  8.25it/s]

Batch 1750/15588:
total_loss: 0.23790428042411804, mse: 0.0, ce: 0.23790428042411804


Epoch 8/10:  13%|█▎        | 2000/15588 [04:02<27:31,  8.23it/s]

Batch 2000/15588:
total_loss: 0.13584616780281067, mse: 0.0, ce: 0.13584616780281067


Epoch 8/10:  14%|█▍        | 2250/15588 [04:33<26:56,  8.25it/s]

Batch 2250/15588:
total_loss: 0.07103794813156128, mse: 0.0, ce: 0.07103794813156128


Epoch 8/10:  16%|█▌        | 2500/15588 [05:03<26:28,  8.24it/s]

Batch 2500/15588:
total_loss: 0.1712118536233902, mse: 0.0, ce: 0.1712118536233902


Epoch 8/10:  18%|█▊        | 2750/15588 [05:33<25:58,  8.24it/s]

Batch 2750/15588:
total_loss: 0.2527581751346588, mse: 0.0, ce: 0.2527581751346588


Epoch 8/10:  19%|█▉        | 3000/15588 [06:04<25:21,  8.27it/s]

Batch 3000/15588:
total_loss: 0.08628404140472412, mse: 0.0, ce: 0.08628404140472412


Epoch 8/10:  21%|██        | 3250/15588 [06:34<24:59,  8.23it/s]

Batch 3250/15588:
total_loss: 0.18868856132030487, mse: 0.0, ce: 0.18868856132030487


Epoch 8/10:  22%|██▏       | 3500/15588 [07:04<24:25,  8.25it/s]

Batch 3500/15588:
total_loss: 0.07336528599262238, mse: 0.0, ce: 0.07336528599262238


Epoch 8/10:  24%|██▍       | 3750/15588 [07:35<23:57,  8.23it/s]

Batch 3750/15588:
total_loss: 0.020006239414215088, mse: 0.0, ce: 0.020006239414215088


Epoch 8/10:  26%|██▌       | 4000/15588 [08:05<23:20,  8.27it/s]

Batch 4000/15588:
total_loss: 0.19124653935432434, mse: 0.0, ce: 0.19124653935432434


Epoch 8/10:  27%|██▋       | 4250/15588 [08:36<23:01,  8.21it/s]

Batch 4250/15588:
total_loss: 0.17008456587791443, mse: 0.0, ce: 0.17008456587791443


Epoch 8/10:  29%|██▉       | 4500/15588 [09:06<22:21,  8.27it/s]

Batch 4500/15588:
total_loss: 0.07797999680042267, mse: 0.0, ce: 0.07797999680042267


Epoch 8/10:  30%|███       | 4750/15588 [09:36<21:56,  8.23it/s]

Batch 4750/15588:
total_loss: 0.22340725362300873, mse: 0.0, ce: 0.22340725362300873


Epoch 8/10:  32%|███▏      | 5000/15588 [10:07<21:24,  8.25it/s]

Batch 5000/15588:
total_loss: 0.23756073415279388, mse: 0.0, ce: 0.23756073415279388


Epoch 8/10:  34%|███▎      | 5250/15588 [10:37<20:52,  8.26it/s]

Batch 5250/15588:
total_loss: 0.07748397439718246, mse: 0.0, ce: 0.07748397439718246


Epoch 8/10:  35%|███▌      | 5500/15588 [11:07<20:26,  8.23it/s]

Batch 5500/15588:
total_loss: 0.15975110232830048, mse: 0.0, ce: 0.15975110232830048


Epoch 8/10:  37%|███▋      | 5750/15588 [11:38<19:54,  8.23it/s]

Batch 5750/15588:
total_loss: 0.042520590126514435, mse: 0.0, ce: 0.042520590126514435


Epoch 8/10:  38%|███▊      | 6000/15588 [12:08<19:23,  8.24it/s]

Batch 6000/15588:
total_loss: 0.17316409945487976, mse: 0.0, ce: 0.17316409945487976


Epoch 8/10:  40%|████      | 6250/15588 [12:38<18:50,  8.26it/s]

Batch 6250/15588:
total_loss: 0.015447590500116348, mse: 0.0, ce: 0.015447590500116348


Epoch 8/10:  42%|████▏     | 6500/15588 [13:09<18:24,  8.23it/s]

Batch 6500/15588:
total_loss: 0.05608300119638443, mse: 0.0, ce: 0.05608300119638443


Epoch 8/10:  43%|████▎     | 6750/15588 [13:39<17:47,  8.28it/s]

Batch 6750/15588:
total_loss: 0.046667955815792084, mse: 0.0, ce: 0.046667955815792084


Epoch 8/10:  45%|████▍     | 7000/15588 [14:09<17:20,  8.25it/s]

Batch 7000/15588:
total_loss: 0.04646437242627144, mse: 0.0, ce: 0.04646437242627144


Epoch 8/10:  47%|████▋     | 7250/15588 [14:40<16:47,  8.27it/s]

Batch 7250/15588:
total_loss: 0.13764911890029907, mse: 0.0, ce: 0.13764911890029907


Epoch 8/10:  48%|████▊     | 7500/15588 [15:10<16:19,  8.26it/s]

Batch 7500/15588:
total_loss: 0.08663894981145859, mse: 0.0, ce: 0.08663894981145859


Epoch 8/10:  50%|████▉     | 7750/15588 [15:41<15:50,  8.24it/s]

Batch 7750/15588:
total_loss: 0.15753261744976044, mse: 0.0, ce: 0.15753261744976044


Epoch 8/10:  51%|█████▏    | 8000/15588 [16:11<15:20,  8.25it/s]

Batch 8000/15588:
total_loss: 0.19884777069091797, mse: 0.0, ce: 0.19884777069091797


Epoch 8/10:  53%|█████▎    | 8250/15588 [16:41<14:53,  8.21it/s]

Batch 8250/15588:
total_loss: 0.1483343094587326, mse: 0.0, ce: 0.1483343094587326


Epoch 8/10:  55%|█████▍    | 8500/15588 [17:12<14:19,  8.25it/s]

Batch 8500/15588:
total_loss: 0.09888564050197601, mse: 0.0, ce: 0.09888564050197601


Epoch 8/10:  56%|█████▌    | 8750/15588 [17:42<13:53,  8.21it/s]

Batch 8750/15588:
total_loss: 0.18773473799228668, mse: 0.0, ce: 0.18773473799228668


Epoch 8/10:  58%|█████▊    | 9000/15588 [18:12<13:20,  8.23it/s]

Batch 9000/15588:
total_loss: 0.10374171286821365, mse: 0.0, ce: 0.10374171286821365


Epoch 8/10:  59%|█████▉    | 9250/15588 [18:43<12:48,  8.24it/s]

Batch 9250/15588:
total_loss: 0.22051702439785004, mse: 0.0, ce: 0.22051702439785004


Epoch 8/10:  61%|██████    | 9500/15588 [19:13<12:18,  8.24it/s]

Batch 9500/15588:
total_loss: 0.053524795919656754, mse: 0.0, ce: 0.053524795919656754


Epoch 8/10:  63%|██████▎   | 9750/15588 [19:43<11:48,  8.24it/s]

Batch 9750/15588:
total_loss: 0.06097733974456787, mse: 0.0, ce: 0.06097733974456787


Epoch 8/10:  64%|██████▍   | 10000/15588 [20:14<11:18,  8.23it/s]

Batch 10000/15588:
total_loss: 0.0457480251789093, mse: 0.0, ce: 0.0457480251789093


Epoch 8/10:  66%|██████▌   | 10250/15588 [20:44<10:46,  8.26it/s]

Batch 10250/15588:
total_loss: 0.17108194530010223, mse: 0.0, ce: 0.17108194530010223


Epoch 8/10:  67%|██████▋   | 10500/15588 [21:14<10:17,  8.23it/s]

Batch 10500/15588:
total_loss: 0.02239607274532318, mse: 0.0, ce: 0.02239607274532318


Epoch 8/10:  69%|██████▉   | 10750/15588 [21:45<09:47,  8.23it/s]

Batch 10750/15588:
total_loss: 0.07124931365251541, mse: 0.0, ce: 0.07124931365251541


Epoch 8/10:  71%|███████   | 11000/15588 [22:15<09:17,  8.23it/s]

Batch 11000/15588:
total_loss: 0.005366070196032524, mse: 0.0, ce: 0.005366070196032524


Epoch 8/10:  72%|███████▏  | 11250/15588 [22:45<08:46,  8.24it/s]

Batch 11250/15588:
total_loss: 0.1300736516714096, mse: 0.0, ce: 0.1300736516714096


Epoch 8/10:  74%|███████▍  | 11500/15588 [23:16<08:16,  8.23it/s]

Batch 11500/15588:
total_loss: 0.049248404800891876, mse: 0.0, ce: 0.049248404800891876


Epoch 8/10:  75%|███████▌  | 11750/15588 [23:46<07:46,  8.23it/s]

Batch 11750/15588:
total_loss: 0.03403165563941002, mse: 0.0, ce: 0.03403165563941002


Epoch 8/10:  77%|███████▋  | 12000/15588 [24:16<07:15,  8.23it/s]

Batch 12000/15588:
total_loss: 0.15983228385448456, mse: 0.0, ce: 0.15983228385448456


Epoch 8/10:  79%|███████▊  | 12250/15588 [24:47<06:45,  8.23it/s]

Batch 12250/15588:
total_loss: 0.03314942121505737, mse: 0.0, ce: 0.03314942121505737


Epoch 8/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.25it/s]

Batch 12500/15588:
total_loss: 0.32252785563468933, mse: 0.0, ce: 0.32252785563468933


Epoch 8/10:  82%|████████▏ | 12750/15588 [25:48<05:44,  8.24it/s]

Batch 12750/15588:
total_loss: 0.10410337150096893, mse: 0.0, ce: 0.10410337150096893


Epoch 8/10:  83%|████████▎ | 13000/15588 [26:18<05:14,  8.24it/s]

Batch 13000/15588:
total_loss: 0.236379012465477, mse: 0.0, ce: 0.236379012465477


Epoch 8/10:  85%|████████▌ | 13250/15588 [26:48<04:43,  8.24it/s]

Batch 13250/15588:
total_loss: 0.09494666010141373, mse: 0.0, ce: 0.09494666010141373


Epoch 8/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.24it/s]

Batch 13500/15588:
total_loss: 0.3348500430583954, mse: 0.0, ce: 0.3348500430583954


Epoch 8/10:  88%|████████▊ | 13750/15588 [27:49<03:43,  8.23it/s]

Batch 13750/15588:
total_loss: 0.11109260469675064, mse: 0.0, ce: 0.11109260469675064


Epoch 8/10:  90%|████████▉ | 14000/15588 [28:19<03:12,  8.24it/s]

Batch 14000/15588:
total_loss: 0.015565470792353153, mse: 0.0, ce: 0.015565470792353153


Epoch 8/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.24it/s]

Batch 14250/15588:
total_loss: 0.035620611160993576, mse: 0.0, ce: 0.035620611160993576


Epoch 8/10:  93%|█████████▎| 14500/15588 [29:20<02:11,  8.25it/s]

Batch 14500/15588:
total_loss: 0.030191969126462936, mse: 0.0, ce: 0.030191969126462936


Epoch 8/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.25it/s]

Batch 14750/15588:
total_loss: 0.05433792993426323, mse: 0.0, ce: 0.05433792993426323


Epoch 8/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.25it/s]

Batch 15000/15588:
total_loss: 0.03989415988326073, mse: 0.0, ce: 0.03989415988326073


Epoch 8/10:  98%|█████████▊| 15250/15588 [30:51<00:40,  8.27it/s]

Batch 15250/15588:
total_loss: 0.09615546464920044, mse: 0.0, ce: 0.09615546464920044


Epoch 8/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.23it/s]

Batch 15500/15588:
total_loss: 0.15520678460597992, mse: 0.0, ce: 0.15520678460597992


Epoch 8/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.05it/s]



Epoch Summary:
Train Total Loss: 0.1248 (MSE: 0.0000, CE: 0.1248)
Val Total Loss: 0.1119 (MSE: 0.0000, CE: 0.1119)
Learning Rate: 0.000100
New best model saved with val loss: 0.1119
--------------------------------------------------


Epoch 9/10:   2%|▏         | 250/15588 [00:30<30:59,  8.25it/s]

Batch 250/15588:
total_loss: 0.058215390890836716, mse: 0.0, ce: 0.058215390890836716


Epoch 9/10:   3%|▎         | 500/15588 [01:00<30:32,  8.23it/s]

Batch 500/15588:
total_loss: 0.05051074177026749, mse: 0.0, ce: 0.05051074177026749


Epoch 9/10:   5%|▍         | 750/15588 [01:31<30:00,  8.24it/s]

Batch 750/15588:
total_loss: 0.21863345801830292, mse: 0.0, ce: 0.21863345801830292


Epoch 9/10:   6%|▋         | 1000/15588 [02:01<29:32,  8.23it/s]

Batch 1000/15588:
total_loss: 0.33024293184280396, mse: 0.0, ce: 0.33024293184280396


Epoch 9/10:   8%|▊         | 1250/15588 [02:31<28:58,  8.25it/s]

Batch 1250/15588:
total_loss: 0.0773557722568512, mse: 0.0, ce: 0.0773557722568512


Epoch 9/10:  10%|▉         | 1500/15588 [03:02<28:34,  8.22it/s]

Batch 1500/15588:
total_loss: 0.10672532021999359, mse: 0.0, ce: 0.10672532021999359


Epoch 9/10:  11%|█         | 1750/15588 [03:32<28:00,  8.24it/s]

Batch 1750/15588:
total_loss: 0.19603119790554047, mse: 0.0, ce: 0.19603119790554047


Epoch 9/10:  13%|█▎        | 2000/15588 [04:02<27:33,  8.22it/s]

Batch 2000/15588:
total_loss: 0.04392052814364433, mse: 0.0, ce: 0.04392052814364433


Epoch 9/10:  14%|█▍        | 2250/15588 [04:33<26:58,  8.24it/s]

Batch 2250/15588:
total_loss: 0.18876543641090393, mse: 0.0, ce: 0.18876543641090393


Epoch 9/10:  16%|█▌        | 2500/15588 [05:03<26:26,  8.25it/s]

Batch 2500/15588:
total_loss: 0.023392245173454285, mse: 0.0, ce: 0.023392245173454285


Epoch 9/10:  18%|█▊        | 2751/15588 [05:34<25:51,  8.27it/s]

Batch 2750/15588:
total_loss: 0.18187089264392853, mse: 0.0, ce: 0.18187089264392853


Epoch 9/10:  19%|█▉        | 3000/15588 [06:04<25:22,  8.27it/s]

Batch 3000/15588:
total_loss: 0.059925246983766556, mse: 0.0, ce: 0.059925246983766556


Epoch 9/10:  21%|██        | 3250/15588 [06:34<25:01,  8.22it/s]

Batch 3250/15588:
total_loss: 0.1390763372182846, mse: 0.0, ce: 0.1390763372182846


Epoch 9/10:  22%|██▏       | 3500/15588 [07:05<24:27,  8.24it/s]

Batch 3500/15588:
total_loss: 0.24084973335266113, mse: 0.0, ce: 0.24084973335266113


Epoch 9/10:  24%|██▍       | 3750/15588 [07:35<23:58,  8.23it/s]

Batch 3750/15588:
total_loss: 0.15034864842891693, mse: 0.0, ce: 0.15034864842891693


Epoch 9/10:  26%|██▌       | 4000/15588 [08:05<23:25,  8.25it/s]

Batch 4000/15588:
total_loss: 0.09488628059625626, mse: 0.0, ce: 0.09488628059625626


Epoch 9/10:  27%|██▋       | 4250/15588 [08:36<22:53,  8.25it/s]

Batch 4250/15588:
total_loss: 0.062033746391534805, mse: 0.0, ce: 0.062033746391534805


Epoch 9/10:  29%|██▉       | 4500/15588 [09:06<22:26,  8.24it/s]

Batch 4500/15588:
total_loss: 0.2477474957704544, mse: 0.0, ce: 0.2477474957704544


Epoch 9/10:  30%|███       | 4750/15588 [09:36<21:55,  8.24it/s]

Batch 4750/15588:
total_loss: 0.10254382342100143, mse: 0.0, ce: 0.10254382342100143


Epoch 9/10:  32%|███▏      | 5000/15588 [10:07<21:26,  8.23it/s]

Batch 5000/15588:
total_loss: 0.20191943645477295, mse: 0.0, ce: 0.20191943645477295


Epoch 9/10:  34%|███▎      | 5250/15588 [10:37<20:52,  8.25it/s]

Batch 5250/15588:
total_loss: 0.15520764887332916, mse: 0.0, ce: 0.15520764887332916


Epoch 9/10:  35%|███▌      | 5500/15588 [11:07<20:25,  8.23it/s]

Batch 5500/15588:
total_loss: 0.015233068726956844, mse: 0.0, ce: 0.015233068726956844


Epoch 9/10:  37%|███▋      | 5750/15588 [11:38<19:53,  8.24it/s]

Batch 5750/15588:
total_loss: 0.07537487149238586, mse: 0.0, ce: 0.07537487149238586


Epoch 9/10:  38%|███▊      | 6000/15588 [12:08<19:27,  8.21it/s]

Batch 6000/15588:
total_loss: 0.012966791167855263, mse: 0.0, ce: 0.012966791167855263


Epoch 9/10:  40%|████      | 6250/15588 [12:38<18:54,  8.23it/s]

Batch 6250/15588:
total_loss: 0.19022947549819946, mse: 0.0, ce: 0.19022947549819946


Epoch 9/10:  42%|████▏     | 6500/15588 [13:09<18:23,  8.24it/s]

Batch 6500/15588:
total_loss: 0.11196618527173996, mse: 0.0, ce: 0.11196618527173996


Epoch 9/10:  43%|████▎     | 6750/15588 [13:39<17:51,  8.25it/s]

Batch 6750/15588:
total_loss: 0.054853279143571854, mse: 0.0, ce: 0.054853279143571854


Epoch 9/10:  45%|████▍     | 7000/15588 [14:09<17:23,  8.23it/s]

Batch 7000/15588:
total_loss: 0.04301941767334938, mse: 0.0, ce: 0.04301941767334938


Epoch 9/10:  47%|████▋     | 7250/15588 [14:40<16:52,  8.23it/s]

Batch 7250/15588:
total_loss: 0.08392602950334549, mse: 0.0, ce: 0.08392602950334549


Epoch 9/10:  48%|████▊     | 7500/15588 [15:10<16:19,  8.26it/s]

Batch 7500/15588:
total_loss: 0.01688063144683838, mse: 0.0, ce: 0.01688063144683838


Epoch 9/10:  50%|████▉     | 7750/15588 [15:41<15:53,  8.22it/s]

Batch 7750/15588:
total_loss: 0.3578774333000183, mse: 0.0, ce: 0.3578774333000183


Epoch 9/10:  51%|█████▏    | 8000/15588 [16:11<15:21,  8.23it/s]

Batch 8000/15588:
total_loss: 0.2047886848449707, mse: 0.0, ce: 0.2047886848449707


Epoch 9/10:  53%|█████▎    | 8250/15588 [16:41<14:52,  8.22it/s]

Batch 8250/15588:
total_loss: 0.261261522769928, mse: 0.0, ce: 0.261261522769928


Epoch 9/10:  55%|█████▍    | 8500/15588 [17:12<14:20,  8.23it/s]

Batch 8500/15588:
total_loss: 0.09394688159227371, mse: 0.0, ce: 0.09394688159227371


Epoch 9/10:  56%|█████▌    | 8750/15588 [17:42<13:49,  8.24it/s]

Batch 8750/15588:
total_loss: 0.04749446362257004, mse: 0.0, ce: 0.04749446362257004


Epoch 9/10:  58%|█████▊    | 9000/15588 [18:12<13:19,  8.24it/s]

Batch 9000/15588:
total_loss: 0.024622159078717232, mse: 0.0, ce: 0.024622159078717232


Epoch 9/10:  59%|█████▉    | 9250/15588 [18:43<12:51,  8.22it/s]

Batch 9250/15588:
total_loss: 0.12764766812324524, mse: 0.0, ce: 0.12764766812324524


Epoch 9/10:  61%|██████    | 9500/15588 [19:13<12:19,  8.23it/s]

Batch 9500/15588:
total_loss: 0.07061564177274704, mse: 0.0, ce: 0.07061564177274704


Epoch 9/10:  63%|██████▎   | 9750/15588 [19:43<11:48,  8.24it/s]

Batch 9750/15588:
total_loss: 0.23168371617794037, mse: 0.0, ce: 0.23168371617794037


Epoch 9/10:  64%|██████▍   | 10000/15588 [20:14<11:16,  8.26it/s]

Batch 10000/15588:
total_loss: 0.1593315452337265, mse: 0.0, ce: 0.1593315452337265


Epoch 9/10:  66%|██████▌   | 10250/15588 [20:44<10:46,  8.26it/s]

Batch 10250/15588:
total_loss: 0.05457799881696701, mse: 0.0, ce: 0.05457799881696701


Epoch 9/10:  67%|██████▋   | 10500/15588 [21:14<10:16,  8.26it/s]

Batch 10500/15588:
total_loss: 0.18084238469600677, mse: 0.0, ce: 0.18084238469600677


Epoch 9/10:  69%|██████▉   | 10750/15588 [21:45<09:47,  8.23it/s]

Batch 10750/15588:
total_loss: 0.378312349319458, mse: 0.0, ce: 0.378312349319458


Epoch 9/10:  71%|███████   | 11000/15588 [22:15<09:17,  8.23it/s]

Batch 11000/15588:
total_loss: 0.09904104471206665, mse: 0.0, ce: 0.09904104471206665


Epoch 9/10:  72%|███████▏  | 11250/15588 [22:45<08:46,  8.24it/s]

Batch 11250/15588:
total_loss: 0.08502790331840515, mse: 0.0, ce: 0.08502790331840515


Epoch 9/10:  74%|███████▍  | 11500/15588 [23:16<08:17,  8.22it/s]

Batch 11500/15588:
total_loss: 0.16557545959949493, mse: 0.0, ce: 0.16557545959949493


Epoch 9/10:  75%|███████▌  | 11750/15588 [23:46<07:45,  8.24it/s]

Batch 11750/15588:
total_loss: 0.10416305810213089, mse: 0.0, ce: 0.10416305810213089


Epoch 9/10:  77%|███████▋  | 12000/15588 [24:16<07:15,  8.23it/s]

Batch 12000/15588:
total_loss: 0.03764395788311958, mse: 0.0, ce: 0.03764395788311958


Epoch 9/10:  79%|███████▊  | 12250/15588 [24:47<06:44,  8.26it/s]

Batch 12250/15588:
total_loss: 0.18610049784183502, mse: 0.0, ce: 0.18610049784183502


Epoch 9/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.24it/s]

Batch 12500/15588:
total_loss: 0.09512792527675629, mse: 0.0, ce: 0.09512792527675629


Epoch 9/10:  82%|████████▏ | 12750/15588 [25:48<05:44,  8.24it/s]

Batch 12750/15588:
total_loss: 0.02527749165892601, mse: 0.0, ce: 0.02527749165892601


Epoch 9/10:  83%|████████▎ | 13000/15588 [26:18<05:14,  8.23it/s]

Batch 13000/15588:
total_loss: 0.12042059004306793, mse: 0.0, ce: 0.12042059004306793


Epoch 9/10:  85%|████████▌ | 13250/15588 [26:48<04:44,  8.23it/s]

Batch 13250/15588:
total_loss: 0.0672399252653122, mse: 0.0, ce: 0.0672399252653122


Epoch 9/10:  87%|████████▋ | 13500/15588 [27:19<04:13,  8.25it/s]

Batch 13500/15588:
total_loss: 0.04489794746041298, mse: 0.0, ce: 0.04489794746041298


Epoch 9/10:  88%|████████▊ | 13750/15588 [27:49<03:43,  8.22it/s]

Batch 13750/15588:
total_loss: 0.07077725976705551, mse: 0.0, ce: 0.07077725976705551


Epoch 9/10:  90%|████████▉ | 14000/15588 [28:19<03:12,  8.24it/s]

Batch 14000/15588:
total_loss: 0.02119268663227558, mse: 0.0, ce: 0.02119268663227558


Epoch 9/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.21it/s]

Batch 14250/15588:
total_loss: 0.23338715732097626, mse: 0.0, ce: 0.23338715732097626


Epoch 9/10:  93%|█████████▎| 14500/15588 [29:20<02:11,  8.26it/s]

Batch 14500/15588:
total_loss: 0.14102743566036224, mse: 0.0, ce: 0.14102743566036224


Epoch 9/10:  95%|█████████▍| 14750/15588 [29:50<01:41,  8.23it/s]

Batch 14750/15588:
total_loss: 0.0208909809589386, mse: 0.0, ce: 0.0208909809589386


Epoch 9/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.23it/s]

Batch 15000/15588:
total_loss: 0.1514698565006256, mse: 0.0, ce: 0.1514698565006256


Epoch 9/10:  98%|█████████▊| 15250/15588 [30:51<00:40,  8.26it/s]

Batch 15250/15588:
total_loss: 0.08399583399295807, mse: 0.0, ce: 0.08399583399295807


Epoch 9/10:  99%|█████████▉| 15500/15588 [31:21<00:10,  8.23it/s]

Batch 15500/15588:
total_loss: 0.05624689161777496, mse: 0.0, ce: 0.05624689161777496


Epoch 9/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.08it/s]



Epoch Summary:
Train Total Loss: 0.1103 (MSE: 0.0000, CE: 0.1103)
Val Total Loss: 0.1079 (MSE: 0.0000, CE: 0.1079)
Learning Rate: 0.000100
New best model saved with val loss: 0.1079
--------------------------------------------------


Epoch 10/10:   2%|▏         | 250/15588 [00:30<31:05,  8.22it/s]

Batch 250/15588:
total_loss: 0.07066758722066879, mse: 0.0, ce: 0.07066758722066879


Epoch 10/10:   3%|▎         | 500/15588 [01:00<30:30,  8.24it/s]

Batch 500/15588:
total_loss: 0.17320716381072998, mse: 0.0, ce: 0.17320716381072998


Epoch 10/10:   5%|▍         | 750/15588 [01:31<30:05,  8.22it/s]

Batch 750/15588:
total_loss: 0.11744408309459686, mse: 0.0, ce: 0.11744408309459686


Epoch 10/10:   6%|▋         | 1000/15588 [02:01<29:35,  8.22it/s]

Batch 1000/15588:
total_loss: 0.02544846385717392, mse: 0.0, ce: 0.02544846385717392


Epoch 10/10:   8%|▊         | 1250/15588 [02:31<28:54,  8.27it/s]

Batch 1250/15588:
total_loss: 0.019029628485441208, mse: 0.0, ce: 0.019029628485441208


Epoch 10/10:  10%|▉         | 1500/15588 [03:02<28:27,  8.25it/s]

Batch 1500/15588:
total_loss: 0.028219999745488167, mse: 0.0, ce: 0.028219999745488167


Epoch 10/10:  11%|█         | 1750/15588 [03:32<28:00,  8.23it/s]

Batch 1750/15588:
total_loss: 0.09160353988409042, mse: 0.0, ce: 0.09160353988409042


Epoch 10/10:  13%|█▎        | 2000/15588 [04:02<27:29,  8.24it/s]

Batch 2000/15588:
total_loss: 0.04552297666668892, mse: 0.0, ce: 0.04552297666668892


Epoch 10/10:  14%|█▍        | 2250/15588 [04:33<26:55,  8.25it/s]

Batch 2250/15588:
total_loss: 0.17320740222930908, mse: 0.0, ce: 0.17320740222930908


Epoch 10/10:  16%|█▌        | 2500/15588 [05:03<26:31,  8.22it/s]

Batch 2500/15588:
total_loss: 0.09104309976100922, mse: 0.0, ce: 0.09104309976100922


Epoch 10/10:  18%|█▊        | 2750/15588 [05:34<25:58,  8.24it/s]

Batch 2750/15588:
total_loss: 0.13549070060253143, mse: 0.0, ce: 0.13549070060253143


Epoch 10/10:  19%|█▉        | 3000/15588 [06:04<25:26,  8.25it/s]

Batch 3000/15588:
total_loss: 0.05381762981414795, mse: 0.0, ce: 0.05381762981414795


Epoch 10/10:  21%|██        | 3250/15588 [06:34<24:53,  8.26it/s]

Batch 3250/15588:
total_loss: 0.03329109773039818, mse: 0.0, ce: 0.03329109773039818


Epoch 10/10:  22%|██▏       | 3500/15588 [07:05<24:29,  8.23it/s]

Batch 3500/15588:
total_loss: 0.036476846784353256, mse: 0.0, ce: 0.036476846784353256


Epoch 10/10:  24%|██▍       | 3750/15588 [07:35<23:59,  8.22it/s]

Batch 3750/15588:
total_loss: 0.08951829373836517, mse: 0.0, ce: 0.08951829373836517


Epoch 10/10:  26%|██▌       | 4000/15588 [08:05<23:21,  8.27it/s]

Batch 4000/15588:
total_loss: 0.011405924335122108, mse: 0.0, ce: 0.011405924335122108


Epoch 10/10:  27%|██▋       | 4250/15588 [08:36<22:56,  8.24it/s]

Batch 4250/15588:
total_loss: 0.09005965292453766, mse: 0.0, ce: 0.09005965292453766


Epoch 10/10:  29%|██▉       | 4500/15588 [09:06<22:27,  8.23it/s]

Batch 4500/15588:
total_loss: 0.083316370844841, mse: 0.0, ce: 0.083316370844841


Epoch 10/10:  30%|███       | 4750/15588 [09:36<21:57,  8.23it/s]

Batch 4750/15588:
total_loss: 0.043776873499155045, mse: 0.0, ce: 0.043776873499155045


Epoch 10/10:  32%|███▏      | 5000/15588 [10:07<21:23,  8.25it/s]

Batch 5000/15588:
total_loss: 0.055209819227457047, mse: 0.0, ce: 0.055209819227457047


Epoch 10/10:  34%|███▎      | 5250/15588 [10:37<20:53,  8.25it/s]

Batch 5250/15588:
total_loss: 0.03326816111803055, mse: 0.0, ce: 0.03326816111803055


Epoch 10/10:  35%|███▌      | 5500/15588 [11:07<20:21,  8.26it/s]

Batch 5500/15588:
total_loss: 0.09923941642045975, mse: 0.0, ce: 0.09923941642045975


Epoch 10/10:  37%|███▋      | 5750/15588 [11:38<19:54,  8.24it/s]

Batch 5750/15588:
total_loss: 0.017390968278050423, mse: 0.0, ce: 0.017390968278050423


Epoch 10/10:  38%|███▊      | 6000/15588 [12:08<19:23,  8.24it/s]

Batch 6000/15588:
total_loss: 0.03628721460700035, mse: 0.0, ce: 0.03628721460700035


Epoch 10/10:  40%|████      | 6250/15588 [12:38<18:51,  8.25it/s]

Batch 6250/15588:
total_loss: 0.13591822981834412, mse: 0.0, ce: 0.13591822981834412


Epoch 10/10:  42%|████▏     | 6500/15588 [13:09<18:24,  8.23it/s]

Batch 6500/15588:
total_loss: 0.20218496024608612, mse: 0.0, ce: 0.20218496024608612


Epoch 10/10:  43%|████▎     | 6750/15588 [13:39<17:49,  8.26it/s]

Batch 6750/15588:
total_loss: 0.06984636187553406, mse: 0.0, ce: 0.06984636187553406


Epoch 10/10:  45%|████▍     | 7000/15588 [14:10<17:21,  8.25it/s]

Batch 7000/15588:
total_loss: 0.183722585439682, mse: 0.0, ce: 0.183722585439682


Epoch 10/10:  47%|████▋     | 7250/15588 [14:40<16:52,  8.24it/s]

Batch 7250/15588:
total_loss: 0.14742617309093475, mse: 0.0, ce: 0.14742617309093475


Epoch 10/10:  48%|████▊     | 7500/15588 [15:10<16:21,  8.24it/s]

Batch 7500/15588:
total_loss: 0.13066820800304413, mse: 0.0, ce: 0.13066820800304413


Epoch 10/10:  50%|████▉     | 7750/15588 [15:41<15:50,  8.24it/s]

Batch 7750/15588:
total_loss: 0.10684754699468613, mse: 0.0, ce: 0.10684754699468613


Epoch 10/10:  51%|█████▏    | 8000/15588 [16:11<15:19,  8.25it/s]

Batch 8000/15588:
total_loss: 0.07688268274068832, mse: 0.0, ce: 0.07688268274068832


Epoch 10/10:  53%|█████▎    | 8250/15588 [16:41<14:51,  8.23it/s]

Batch 8250/15588:
total_loss: 0.07703609019517899, mse: 0.0, ce: 0.07703609019517899


Epoch 10/10:  55%|█████▍    | 8500/15588 [17:12<14:19,  8.25it/s]

Batch 8500/15588:
total_loss: 0.05933923274278641, mse: 0.0, ce: 0.05933923274278641


Epoch 10/10:  56%|█████▌    | 8750/15588 [17:42<13:50,  8.23it/s]

Batch 8750/15588:
total_loss: 0.09704504162073135, mse: 0.0, ce: 0.09704504162073135


Epoch 10/10:  58%|█████▊    | 9000/15588 [18:12<13:23,  8.20it/s]

Batch 9000/15588:
total_loss: 0.041516952216625214, mse: 0.0, ce: 0.041516952216625214


Epoch 10/10:  59%|█████▉    | 9250/15588 [18:43<12:51,  8.22it/s]

Batch 9250/15588:
total_loss: 0.07206816971302032, mse: 0.0, ce: 0.07206816971302032


Epoch 10/10:  61%|██████    | 9500/15588 [19:13<12:16,  8.26it/s]

Batch 9500/15588:
total_loss: 0.10992339998483658, mse: 0.0, ce: 0.10992339998483658


Epoch 10/10:  63%|██████▎   | 9750/15588 [19:43<11:49,  8.22it/s]

Batch 9750/15588:
total_loss: 0.07160274684429169, mse: 0.0, ce: 0.07160274684429169


Epoch 10/10:  64%|██████▍   | 10000/15588 [20:14<11:18,  8.24it/s]

Batch 10000/15588:
total_loss: 0.1950579434633255, mse: 0.0, ce: 0.1950579434633255


Epoch 10/10:  66%|██████▌   | 10250/15588 [20:44<10:46,  8.25it/s]

Batch 10250/15588:
total_loss: 0.11888530850410461, mse: 0.0, ce: 0.11888530850410461


Epoch 10/10:  67%|██████▋   | 10500/15588 [21:15<10:16,  8.25it/s]

Batch 10500/15588:
total_loss: 0.03882870450615883, mse: 0.0, ce: 0.03882870450615883


Epoch 10/10:  69%|██████▉   | 10750/15588 [21:45<09:46,  8.25it/s]

Batch 10750/15588:
total_loss: 0.09363624453544617, mse: 0.0, ce: 0.09363624453544617


Epoch 10/10:  71%|███████   | 11000/15588 [22:15<09:17,  8.23it/s]

Batch 11000/15588:
total_loss: 0.025416439399123192, mse: 0.0, ce: 0.025416439399123192


Epoch 10/10:  72%|███████▏  | 11250/15588 [22:46<08:47,  8.23it/s]

Batch 11250/15588:
total_loss: 0.041370123624801636, mse: 0.0, ce: 0.041370123624801636


Epoch 10/10:  74%|███████▍  | 11500/15588 [23:16<08:16,  8.23it/s]

Batch 11500/15588:
total_loss: 0.18972942233085632, mse: 0.0, ce: 0.18972942233085632


Epoch 10/10:  75%|███████▌  | 11750/15588 [23:46<07:44,  8.26it/s]

Batch 11750/15588:
total_loss: 0.06604891270399094, mse: 0.0, ce: 0.06604891270399094


Epoch 10/10:  77%|███████▋  | 12000/15588 [24:17<07:17,  8.21it/s]

Batch 12000/15588:
total_loss: 0.08164037764072418, mse: 0.0, ce: 0.08164037764072418


Epoch 10/10:  79%|███████▊  | 12251/15588 [24:47<06:43,  8.26it/s]

Batch 12250/15588:
total_loss: 0.06777989119291306, mse: 0.0, ce: 0.06777989119291306


Epoch 10/10:  80%|████████  | 12500/15588 [25:17<06:14,  8.24it/s]

Batch 12500/15588:
total_loss: 0.017375467345118523, mse: 0.0, ce: 0.017375467345118523


Epoch 10/10:  82%|████████▏ | 12750/15588 [25:48<05:44,  8.24it/s]

Batch 12750/15588:
total_loss: 0.24494647979736328, mse: 0.0, ce: 0.24494647979736328


Epoch 10/10:  83%|████████▎ | 13000/15588 [26:18<05:13,  8.26it/s]

Batch 13000/15588:
total_loss: 0.18115928769111633, mse: 0.0, ce: 0.18115928769111633


Epoch 10/10:  85%|████████▌ | 13250/15588 [26:48<04:43,  8.25it/s]

Batch 13250/15588:
total_loss: 0.07186031341552734, mse: 0.0, ce: 0.07186031341552734


Epoch 10/10:  87%|████████▋ | 13500/15588 [27:19<04:12,  8.25it/s]

Batch 13500/15588:
total_loss: 0.13031426072120667, mse: 0.0, ce: 0.13031426072120667


Epoch 10/10:  88%|████████▊ | 13750/15588 [27:49<03:43,  8.22it/s]

Batch 13750/15588:
total_loss: 0.10324268043041229, mse: 0.0, ce: 0.10324268043041229


Epoch 10/10:  90%|████████▉ | 14000/15588 [28:19<03:12,  8.24it/s]

Batch 14000/15588:
total_loss: 0.12516017258167267, mse: 0.0, ce: 0.12516017258167267


Epoch 10/10:  91%|█████████▏| 14250/15588 [28:50<02:42,  8.25it/s]

Batch 14250/15588:
total_loss: 0.1261335015296936, mse: 0.0, ce: 0.1261335015296936


Epoch 10/10:  93%|█████████▎| 14500/15588 [29:20<02:12,  8.23it/s]

Batch 14500/15588:
total_loss: 0.13813115656375885, mse: 0.0, ce: 0.13813115656375885


Epoch 10/10:  95%|█████████▍| 14750/15588 [29:51<01:41,  8.24it/s]

Batch 14750/15588:
total_loss: 0.07661715894937515, mse: 0.0, ce: 0.07661715894937515


Epoch 10/10:  96%|█████████▌| 15000/15588 [30:21<01:11,  8.25it/s]

Batch 15000/15588:
total_loss: 0.25447309017181396, mse: 0.0, ce: 0.25447309017181396


Epoch 10/10:  98%|█████████▊| 15250/15588 [30:51<00:41,  8.24it/s]

Batch 15250/15588:
total_loss: 0.22570624947547913, mse: 0.0, ce: 0.22570624947547913


Epoch 10/10:  99%|█████████▉| 15500/15588 [31:22<00:10,  8.24it/s]

Batch 15500/15588:
total_loss: 0.03681360185146332, mse: 0.0, ce: 0.03681360185146332


Epoch 10/10: 100%|██████████| 15588/15588 [31:32<00:00,  8.24it/s]
Validating: 100%|██████████| 31/31 [00:01<00:00, 23.72it/s]



Epoch Summary:
Train Total Loss: 0.0993 (MSE: 0.0000, CE: 0.0993)
Val Total Loss: 0.0744 (MSE: 0.0000, CE: 0.0744)
Learning Rate: 0.000100
New best model saved with val loss: 0.0744
--------------------------------------------------
