In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/config.txt
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_5_4_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_4_3_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_2_1_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_1_0_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_3_2_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Test/0_5_4_17062021_094123.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Te

In [2]:
import torch
import json
from torch.utils.data import Dataset
import re
import numpy as np
import random
from scipy.optimize import minimize
import math
import matplotlib.pyplot as plt


def generateDataStrEq(
    eq, n_points=2, n_vars=3, decimals=4, supportPoints=None, min_x=0, max_x=3
):
    X = []
    Y = []
    # TODO: Need to make this faster
    for p in range(n_points):
        if supportPoints is None:
            if type(min_x) == list:
                x = []
                for _ in range(n_vars):
                    idx = np.random.randint(len(min_x))
                    x += list(
                        np.round(np.random.uniform(min_x[idx], max_x[idx], 1), decimals)
                    )
            else:
                x = list(np.round(np.random.uniform(min_x, max_x, n_vars), decimals))
            assert (
                len(x) != 0
            ), "For some reason, we didn't generate the points correctly!"
        else:
            x = supportPoints[p]

        tmpEq = eq + ""
        for nVID in range(n_vars):
            tmpEq = tmpEq.replace("x{}".format(nVID + 1), str(x[nVID]))
        y = float(np.round(eval(tmpEq), decimals))
        X.append(x)
        Y.append(y)
    return X, Y


# def processDataFiles(files):
#     text = ""
#     for f in tqdm(files):
#         with open(f, 'r') as h:
#             lines = h.read() # don't worry we won't run out of file handles
#             if lines[-1]==-1:
#                 lines = lines[:-1]
#             #text += lines #json.loads(line)
#             text = ''.join([lines,text])
#     return text


def processDataFiles(files):
    text = ""
    for f in files:
        with open(f, "r") as h:
            lines = h.read()  # don't worry we won't run out of file handles
            if lines[-1] == -1:
                lines = lines[:-1]
            # text += lines #json.loads(line)
            text = "".join([lines, text])
    return text


def tokenize_equation(eq):
    token_spec = [
        (r"\*\*"),  # exponentiation
        (r"exp"),  # exp function
        (r"[+\-*/=()]"),  # operators and parentheses
        (r"sin"),  # sin function
        (r"cos"),  # cos function
        (r"log"),  # log function
        (r"x\d+"),  # variables like x1, x23, etc.
        (r"C"),  # constants placeholder
        (r"-?\d+\.\d+"),  # decimal numbers
        (r"-?\d+"),  # integers
        (r"_"),  # padding token
    ]
    token_regex = "|".join(f"({pattern})" for pattern in token_spec)
    matches = re.finditer(token_regex, eq)
    return [match.group(0) for match in matches]


class CharDataset(Dataset):
    def __init__(
        self,
        data,
        block_size,
        tokens,
        numVars,
        numYs,
        numPoints,
        target="Skeleton",
        addVars=False,
        const_range=[-0.4, 0.4],
        xRange=[-3.0, 3.0],
        decimals=4,
        augment=False,
    ):

        data_size, vocab_size = len(data), len(tokens)
        print("data has %d examples, %d unique." % (data_size, vocab_size))

        self.stoi = {tok: i for i, tok in enumerate(tokens)}
        self.itos = {i: tok for i, tok in enumerate(tokens)}

        self.numVars = numVars
        self.numYs = numYs
        self.numPoints = numPoints

        # padding token
        self.paddingToken = "_"
        self.paddingID = self.stoi["_"]  # or another ID not already used
        self.stoi[self.paddingToken] = self.paddingID
        self.itos[self.paddingID] = self.paddingToken

        self.threshold = [-1000, 1000]

        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data  # it should be a list of examples
        self.target = target
        self.addVars = addVars

        self.const_range = const_range
        self.xRange = xRange
        self.decimals = decimals
        self.augment = augment

    def __len__(self):
        return len(self.data) - 1

    def __getitem__(self, idx):
        # grab an example from the data
        chunk = self.data[idx]  # sequence of tokens including x, y, eq, etc.

        try:
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary
        except Exception as e:
            print("Couldn't convert to json: {} \n error is: {}".format(chunk, e))
            # try the previous example
            idx = idx - 1
            idx = idx if idx >= 0 else 0
            chunk = self.data[idx]
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary

        # find the number of variables in the equation
        printInfoCondition = random.random() < 0.0000001
        eq = chunk[self.target]
        if printInfoCondition:
            print(f"\nEquation: {eq}")
        vars = re.finditer("x[\d]+", eq)
        numVars = 0
        for v in vars:
            v = v.group(0).strip("x")
            v = eval(v)
            v = int(v)
            if v > numVars:
                numVars = v

        if self.target == "Skeleton" and self.augment:
            threshold = 5000
            # randomly generate the constants
            cleanEqn = ""
            for chr in eq:
                if chr == "C":
                    # genereate a new random number
                    chr = "{}".format(
                        np.random.uniform(self.const_range[0], self.const_range[1])
                    )
                cleanEqn += chr

            # update the points
            nPoints = np.random.randint(
                *self.numPoints
            )  # if supportPoints is None else len(supportPoints)
            try:
                if printInfoCondition:
                    print("Org:", chunk["X"], chunk["Y"])

                X, y = generateDataStrEq(
                    cleanEqn,
                    n_points=nPoints,
                    n_vars=self.numVars,
                    decimals=self.decimals,
                    min_x=self.xRange[0],
                    max_x=self.xRange[1],
                )

                # replace out of threshold with maximum numbers
                y = [e if abs(e) < threshold else np.sign(e) * threshold for e in y]

                # check if there is nan/inf/very large numbers in the y
                conditions = (
                    (np.isnan(y).any() or np.isinf(y).any())
                    or len(y) == 0
                    or (abs(min(y)) > threshold or abs(max(y)) > threshold)
                )
                if not conditions:
                    chunk["X"], chunk["Y"] = X, y

                if printInfoCondition:
                    print("Evd:", chunk["X"], chunk["Y"])
            except Exception as e:
                # for different reason this might happend including but not limited to division by zero
                print(
                    "".join(
                        [
                            f"We just used the original equation and support points because of {e}. ",
                            f"The equation is {eq}, and we update the equation to {cleanEqn}",
                        ]
                    )
                )

        # encode every character in the equation to an integer
        # < is SOS, > is EOS
        if self.addVars:
            dix = [self.stoi[s] for s in "<" + str(numVars) + ":" + eq + ">"]
        else:
            eq_tokens = tokenize_equation(eq)
            if self.addVars:
                token_seq = ["<", str(numVars), ":", *eq_tokens, ">"]
            else:
                token_seq = ["<", *eq_tokens, ">"]
            dix = [self.stoi[tok] for tok in token_seq]

        inputs = dix[:-1]
        outputs = dix[1:]

        # add the padding to the equations
        paddingSize = max(self.block_size - len(inputs), 0)
        paddingList = [self.paddingID] * paddingSize
        inputs += paddingList
        outputs += paddingList

        # make sure it is not more than what should be
        inputs = inputs[: self.block_size]
        outputs = outputs[: self.block_size]

        points = torch.zeros(self.numVars + self.numYs, self.numPoints - 1)
        for idx, xy in enumerate(zip(chunk["X"], chunk["Y"])):

            if not isinstance(xy[0], list) or not isinstance(
                xy[1], (list, float, np.float64)
            ):
                print(f"Unexpected types: {type(xy[0])}, {type(xy[1])}")
                continue  # Skip if types are incorrect

            # don't let to exceed the maximum number of points
            if idx >= self.numPoints - 1:
                break

            x = xy[0]
            x = x + [0] * (max(self.numVars - len(x), 0))  # padding

            y = [xy[1]] if type(xy[1]) == float or type(xy[1]) == np.float64 else xy[1]

            y = y + [0] * (max(self.numYs - len(y), 0))  # padding
            p = x + y  # because it is only one point
            p = torch.tensor(p)
            # replace nan and inf
            p = torch.nan_to_num(
                p,
                nan=self.threshold[1],
                posinf=self.threshold[1],
                neginf=self.threshold[0],
            )

            points[:, idx] = p

        points = torch.nan_to_num(
            points,
            nan=self.threshold[1],
            posinf=self.threshold[1],
            neginf=self.threshold[0],
        )

        inputs = torch.tensor(inputs, dtype=torch.long)
        outputs = torch.tensor(outputs, dtype=torch.long)
        numVars = torch.tensor(numVars, dtype=torch.long)
        return inputs, outputs, points, numVars


# Relative Mean Square Error
def relativeErr(y, yHat, info=False, eps=1e-5):
    yHat = np.reshape(yHat, [1, -1])[0]
    y = np.reshape(y, [1, -1])[0]
    if len(y) > 0 and len(y) == len(yHat):
        err = ((yHat - y)) ** 2 / np.linalg.norm(y + eps)
        if info:
            for _ in range(5):
                i = np.random.randint(len(y))
                # print("yPR,yTrue:{},{}, Err:{}".format(yHat[i], y[i], err[i]))
    else:
        err = 100

    return np.mean(err)


def lossFunc(constants, eq, X, Y, eps=1e-5):
    err = 0
    eq = eq.replace("C", "{}").format(*constants)

    for x, y in zip(X, Y):
        eqTemp = eq + ""
        if type(x) == np.float32:
            x = [x]
        for i, e in enumerate(x):
            # make sure e is not a tensor
            if type(e) == torch.Tensor:
                e = e.item()
            eqTemp = eqTemp.replace("x{}".format(i + 1), str(e))
        try:
            yHat = eval(eqTemp)
        except:
            # print("Exception has been occured! EQ: {}, OR: {}".format(eqTemp, eq))
            yHat = 100
        try:
            # handle overflow
            err += relativeErr(y, yHat)  # (y-yHat)**2
        except:
            # print(
            #    "Exception has been occured! EQ: {}, OR: {}, y:{}-yHat:{}".format(
            #        eqTemp, eq, y, yHat
            #    )
            # )
            err += 10

    err /= len(Y)
    return err


def get_skeleton(tokens, train_dataset: CharDataset):
    skeleton = "".join([train_dataset.itos[int(idx)] for idx in tokens])
    skeleton = skeleton.strip(train_dataset.paddingToken).split(">")
    skeleton = skeleton[0] if len(skeleton[0]) >= 1 else skeleton[1]
    skeleton = skeleton.strip("<").strip(">")
    skeleton = skeleton.replace("Ce", "C*e")

    return skeleton


In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math


# from SymbolicGPT: https://github.com/mojivalipour/symbolicgpt/blob/master/models.py
class PointNetConfig:
    """base PointNet config"""

    def __init__(
        self,
        embeddingSize,
        numberofPoints,
        numberofVars,
        numberofYs,
        **kwargs,
    ):
        self.embeddingSize = embeddingSize
        self.numberofPoints = numberofPoints  # number of points
        self.numberofVars = numberofVars  # input dimension (Xs)
        self.numberofYs = numberofYs  # output dimension (Ys)

        for k, v in kwargs.items():
            setattr(self, k, v)


class tNet(nn.Module):
    """
    The PointNet structure in the orginal PointNet paper:
    PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation by Qi et. al. 2017
    """

    def __init__(self, config):
        super(tNet, self).__init__()

        self.activation_func = F.relu
        self.num_units = config.embeddingSize

        self.conv1 = nn.Conv1d(
            config.numberofVars + config.numberofYs, self.num_units, 1
        )
        self.conv2 = nn.Conv1d(self.num_units, 2 * self.num_units, 1)
        self.conv3 = nn.Conv1d(2 * self.num_units, 4 * self.num_units, 1)
        self.fc1 = nn.Linear(4 * self.num_units, 2 * self.num_units)
        self.fc2 = nn.Linear(2 * self.num_units, self.num_units)

        # self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(config.numberofVars + config.numberofYs)
        # self.input_layer_norm = nn.LayerNorm(config.numberofPoints)

        self.bn1 = nn.BatchNorm1d(self.num_units)
        self.bn2 = nn.BatchNorm1d(2 * self.num_units)
        self.bn3 = nn.BatchNorm1d(4 * self.num_units)
        self.bn4 = nn.BatchNorm1d(2 * self.num_units)
        self.bn5 = nn.BatchNorm1d(self.num_units)

    def forward(self, x):
        """
        :param x: [batch, #features, #points]
        :return:
            logit: [batch, embedding_size]
        """
        x = self.input_batch_norm(x)
        x = self.activation_func(self.bn1(self.conv1(x)))
        x = self.activation_func(self.bn2(self.conv2(x)))
        x = self.activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4 * self.num_units

        x = self.activation_func(self.bn4(self.fc1(x)))
        x = self.activation_func(self.bn5(self.fc2(x)))
        # x = self.fc2(x)

        return x


# from https://github.com/juho-lee/set_transformer/blob/master/modules.py
class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, "ln0", None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, "ln1", None) is None else self.ln1(O)
        return O


class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)


class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)


# from https://github.com/juho-lee/set_transformer/blob/master/models.py
class SetTransformer(nn.Module):
    def __init__(
        self,
        dim_input,
        num_outputs,
        dim_output,
        num_inds=32,
        dim_hidden=128,
        num_heads=4,
        ln=False,
    ):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
            ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
        )
        self.dec = nn.Sequential(
            PMA(dim_hidden, num_heads, num_outputs, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            nn.Linear(dim_hidden, dim_output),
        )

    def forward(self, X):
        return self.dec(self.enc(X)).squeeze(1)


class NoisePredictionTransformer(nn.Module):
    def __init__(self, n_embd, max_seq_len, n_layer=6, n_head=8, max_timesteps=1000):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, n_embd))
        self.time_emb = nn.Embedding(max_timesteps, n_embd)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)

    def forward(self, x_t, t, condition):
        _, L, _ = x_t.shape
        pos_emb = self.pos_emb[:, :L, :]  # [1, L, n_embd]
        time_emb = self.time_emb(t)
        if time_emb.dim() == 1:  # Scalar t case, [n_embd]
            time_emb = time_emb.unsqueeze(0)  # [1, n_embd]
        time_emb = time_emb.unsqueeze(1)  # [1, 1, n_embd]
        condition = condition.unsqueeze(1)  # [B, 1, n_embd]

        x = x_t + pos_emb + time_emb + condition
        return self.encoder(x)


# influenced by https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
class SymbolicGaussianDiffusion(nn.Module):
    def __init__(
        self,
        tnet_config: PointNetConfig,
        vocab_size,
        max_seq_len,
        padding_idx: int = 0,
        max_num_vars: int = 9,
        n_layer=6,
        n_head=8,
        n_embd=512,
        timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        set_transformer=True,
        ce_weight=1.0,  # Weight for CE loss relative to MSE
        p_uncond: float = 0.1,  # Probability of unconditioned sampling
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.padding_idx = padding_idx
        self.n_embd = n_embd
        self.timesteps = timesteps
        self.set_transformer = set_transformer
        self.ce_weight = ce_weight
        self.p_uncond = p_uncond

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.tok_emb = nn.Embedding(vocab_size, n_embd, padding_idx=self.padding_idx)
        self.vars_emb = nn.Embedding(max_num_vars, n_embd)

        self.decoder = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.weight = self.tok_emb.weight

        if set_transformer:
            dim_input = tnet_config.numberofVars + tnet_config.numberofYs
            self.tnet = SetTransformer(
                dim_input=dim_input,
                num_outputs=1,
                dim_output=tnet_config.embeddingSize,
                num_inds=tnet_config.numberofPoints,
                num_heads=4,
            )
        else:
            self.tnet = tNet(tnet_config)

        self.model = NoisePredictionTransformer(
            n_embd, max_seq_len, n_layer, n_head, timesteps
        )

        # Noise schedule
        self.register_buffer("beta", torch.linspace(beta_start, beta_end, timesteps))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

    def q_sample(self, x_start, t, noise=None):
        noise = torch.randn_like(x_start)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1)

        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
        return x_t

    def p_mean_variance(self, x, t, t_next, condition):
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        alpha_bar_t_next = self.alpha_bar[t_next]
        beta_t = self.beta[t]

        x_start_pred = self.model(x, t.long(), condition)

        coeff1 = torch.sqrt(alpha_bar_t_next) * beta_t / (1 - alpha_bar_t)
        coeff2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_next) / (1 - alpha_bar_t)
        mean = coeff1 * x_start_pred + coeff2 * x
        variance = (1 - alpha_bar_t_next) / (1 - alpha_bar_t) * beta_t
        return mean, variance

    @torch.no_grad()
    def p_sample(self, x, t, t_next, condition):
        mean, variance = self.p_mean_variance(x, t, t_next, condition)
        if torch.all(t_next == 0):
            return mean
        noise = torch.randn_like(x)
        return mean + torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(
        self,
        points,
        variables,
        train_dataset,
        batch_size=16,
        ddim_step=1,
        guidance_scale=1.0,
    ):
        if self.set_transformer:
            points = points.transpose(1, 2)

        B = batch_size

        condition = self.tnet(points) + self.vars_emb(variables)
        uncondition = torch.zeros_like(condition)
        condition_uncodition = torch.cat(
            [condition, uncondition], dim=0
        )  # [2B, 1, n_embd]

        shape = (B, self.max_seq_len, self.n_embd)
        x = torch.randn(shape, device=self.device)

        steps = torch.arange(self.timesteps - 1, -1, -1, device=self.device)

        for i in range(0, self.timesteps, ddim_step):
            t = steps[i]

            t_next = (
                steps[i + ddim_step]
                if i + ddim_step < self.timesteps
                else torch.tensor(0, device=self.device)
            )
            x = x.repeat(2, 1, 1)
            x = self.p_sample(x, t, t_next, condition_uncodition)

            x_condition = x[:B]  # [B, L, n_embd]
            x_uncondition = x[B:]
            x = x_uncondition + guidance_scale * (x_condition - x_uncondition)

        logits = self.decoder(x)  # [B, L, vocab_size]
        token_indices = torch.argmax(logits, dim=-1)  # [B, L]
        predicted_skeletons = []
        for j in range(batch_size):
            token_indices_j = token_indices[j]  # [L]
            predicted_skeleton = get_skeleton(token_indices_j, train_dataset)
            predicted_skeletons.append(predicted_skeleton)
        return predicted_skeletons

    def p_losses(
        self,
        x_start,
        points,
        tokens,
        variables,
        t,
    ):
        """Hybrid loss: MSE on embeddings + CE on tokens."""
        noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise)

        if self.set_transformer:
            points = points.transpose(1, 2)

        condition = self.tnet(points) + self.vars_emb(variables)

        # classifier free guidance
        mask = torch.rand(x_start.shape[0], device=self.device) < self.p_uncond

        condition = torch.where(mask.unsqueeze(1), condition, 0)
        x_start_pred = self.model(x_t, t.long(), condition)

        # CE loss on tokens
        logits = self.decoder(x_start_pred)  # [B, L, vocab_size]

        ce_loss = F.cross_entropy(
            logits.view(-1, self.vocab_size),  # [B*L, vocab_size]
            tokens.view(-1),  # [B*L]
            ignore_index=self.padding_idx,
            reduction="mean",
        )

        ce_loss = ce_loss * self.ce_weight

        return ce_loss

    def forward(self, points, tokens, variables, t):
        token_emb = self.tok_emb(tokens)
        ce_loss = self.p_losses(token_emb, points, tokens, variables, t)
        return ce_loss


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple

def train_epoch(
    model: SymbolicGaussianDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
        optimizer.zero_grad()

        total_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()


    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            total_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
    path=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        avg_train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f}"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f}"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)

In [5]:
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 5
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
maxNumFiles = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
dataDir = "/kaggle/input/1-5-var-dataset"
dataFolder = "1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points"

In [7]:
import numpy as np
import glob
import random

path = '{}/{}/Train/*.json'.format(dataDir, dataFolder)
files = glob.glob(path)[:maxNumFiles]
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 2262058 examples, 44 unique.
id:578695
outputs:C*exp(C*x4**5)+C>___________________
variables:4


In [8]:
path = '{}/{}/Val/*.json'.format(dataDir,dataFolder)
files = glob.glob(path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 195 examples, 44 unique.
tensor(-2.9966) tensor(2.9932)
id:149
outputs:C*cos(C*x3*x4+C*x3+C*x4+C)+C>_________
variables:4


In [9]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,  
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate,
    path="1_5_var_set_transformer_cfg"
)

Epoch 1/5:   1%|          | 251/35345 [00:26<59:59,  9.75it/s]  

Batch 250/35345:
total_loss: 1.8027149438858032


Epoch 1/5:   1%|▏         | 500/35345 [00:51<56:59, 10.19it/s]

Batch 500/35345:
total_loss: 1.4999831914901733


Epoch 1/5:   2%|▏         | 750/35345 [01:16<57:08, 10.09it/s]

Batch 750/35345:
total_loss: 1.211735486984253


Epoch 1/5:   3%|▎         | 1001/35345 [01:41<58:02,  9.86it/s]

Batch 1000/35345:
total_loss: 1.0646628141403198


Epoch 1/5:   4%|▎         | 1251/35345 [02:07<59:14,  9.59it/s]

Batch 1250/35345:
total_loss: 0.9269109964370728


Epoch 1/5:   4%|▍         | 1501/35345 [02:33<1:00:55,  9.26it/s]

Batch 1500/35345:
total_loss: 0.8156073093414307


Epoch 1/5:   5%|▍         | 1751/35345 [02:59<57:14,  9.78it/s]

Batch 1750/35345:
total_loss: 0.8554414510726929


Epoch 1/5:   6%|▌         | 2001/35345 [03:25<58:07,  9.56it/s]

Batch 2000/35345:
total_loss: 1.0244609117507935


Epoch 1/5:   6%|▋         | 2251/35345 [03:51<56:56,  9.69it/s]

Batch 2250/35345:
total_loss: 0.7088034749031067


Epoch 1/5:   7%|▋         | 2501/35345 [04:17<56:10,  9.74it/s]

Batch 2500/35345:
total_loss: 0.9111936688423157


Epoch 1/5:   8%|▊         | 2751/35345 [04:43<58:13,  9.33it/s]

Batch 2750/35345:
total_loss: 0.5746651887893677


Epoch 1/5:   8%|▊         | 3001/35345 [05:09<56:11,  9.59it/s]

Batch 3000/35345:
total_loss: 0.8704826831817627


Epoch 1/5:   9%|▉         | 3251/35345 [05:35<56:01,  9.55it/s]

Batch 3250/35345:
total_loss: 0.6062706112861633


Epoch 1/5:  10%|▉         | 3501/35345 [06:02<55:34,  9.55it/s]

Batch 3500/35345:
total_loss: 0.7253472208976746


Epoch 1/5:  11%|█         | 3751/35345 [06:28<54:50,  9.60it/s]

Batch 3750/35345:
total_loss: 0.6032565832138062


Epoch 1/5:  11%|█▏        | 4001/35345 [06:54<55:04,  9.49it/s]

Batch 4000/35345:
total_loss: 0.4637488126754761


Epoch 1/5:  12%|█▏        | 4251/35345 [07:20<53:55,  9.61it/s]

Batch 4250/35345:
total_loss: 0.6634306907653809


Epoch 1/5:  13%|█▎        | 4501/35345 [07:46<53:24,  9.62it/s]

Batch 4500/35345:
total_loss: 0.6022717356681824


Epoch 1/5:  13%|█▎        | 4751/35345 [08:12<53:26,  9.54it/s]

Batch 4750/35345:
total_loss: 0.7943505048751831


Epoch 1/5:  14%|█▍        | 5001/35345 [08:39<52:27,  9.64it/s]

Batch 5000/35345:
total_loss: 0.7897017002105713


Epoch 1/5:  15%|█▍        | 5251/35345 [09:05<53:14,  9.42it/s]

Batch 5250/35345:
total_loss: 0.6043593883514404


Epoch 1/5:  16%|█▌        | 5501/35345 [09:31<53:14,  9.34it/s]

Batch 5500/35345:
total_loss: 0.6636842489242554


Epoch 1/5:  16%|█▋        | 5751/35345 [09:57<51:38,  9.55it/s]

Batch 5750/35345:
total_loss: 0.7422832250595093


Epoch 1/5:  17%|█▋        | 6001/35345 [10:23<51:04,  9.58it/s]

Batch 6000/35345:
total_loss: 0.5581920742988586


Epoch 1/5:  18%|█▊        | 6251/35345 [10:49<50:24,  9.62it/s]

Batch 6250/35345:
total_loss: 0.44054874777793884


Epoch 1/5:  18%|█▊        | 6501/35345 [11:16<50:06,  9.59it/s]

Batch 6500/35345:
total_loss: 0.762383759021759


Epoch 1/5:  19%|█▉        | 6751/35345 [11:42<51:50,  9.19it/s]

Batch 6750/35345:
total_loss: 0.5923398733139038


Epoch 1/5:  20%|█▉        | 7001/35345 [12:08<49:02,  9.63it/s]

Batch 7000/35345:
total_loss: 0.46048498153686523


Epoch 1/5:  21%|██        | 7251/35345 [12:34<48:53,  9.58it/s]

Batch 7250/35345:
total_loss: 0.584358811378479


Epoch 1/5:  21%|██        | 7501/35345 [13:00<49:09,  9.44it/s]

Batch 7500/35345:
total_loss: 0.8989440202713013


Epoch 1/5:  22%|██▏       | 7751/35345 [13:27<48:06,  9.56it/s]

Batch 7750/35345:
total_loss: 0.26936084032058716


Epoch 1/5:  23%|██▎       | 8001/35345 [13:53<48:34,  9.38it/s]

Batch 8000/35345:
total_loss: 0.5681480169296265


Epoch 1/5:  23%|██▎       | 8251/35345 [14:19<46:28,  9.72it/s]

Batch 8250/35345:
total_loss: 0.5702928304672241


Epoch 1/5:  24%|██▍       | 8501/35345 [14:45<46:23,  9.64it/s]

Batch 8500/35345:
total_loss: 0.7391029000282288


Epoch 1/5:  25%|██▍       | 8751/35345 [15:11<46:15,  9.58it/s]

Batch 8750/35345:
total_loss: 0.4092826247215271


Epoch 1/5:  25%|██▌       | 9001/35345 [15:37<45:55,  9.56it/s]

Batch 9000/35345:
total_loss: 0.49424275755882263


Epoch 1/5:  26%|██▌       | 9251/35345 [16:04<46:39,  9.32it/s]

Batch 9250/35345:
total_loss: 0.4954473674297333


Epoch 1/5:  27%|██▋       | 9501/35345 [16:30<44:41,  9.64it/s]

Batch 9500/35345:
total_loss: 0.6372300386428833


Epoch 1/5:  28%|██▊       | 9751/35345 [16:56<44:28,  9.59it/s]

Batch 9750/35345:
total_loss: 0.470116525888443


Epoch 1/5:  28%|██▊       | 10001/35345 [17:22<44:01,  9.59it/s]

Batch 10000/35345:
total_loss: 0.6329379081726074


Epoch 1/5:  29%|██▉       | 10251/35345 [17:48<43:28,  9.62it/s]

Batch 10250/35345:
total_loss: 0.6197075247764587


Epoch 1/5:  30%|██▉       | 10501/35345 [18:15<43:19,  9.56it/s]

Batch 10500/35345:
total_loss: 0.6267568469047546


Epoch 1/5:  30%|███       | 10751/35345 [18:41<42:40,  9.61it/s]

Batch 10750/35345:
total_loss: 0.47721540927886963


Epoch 1/5:  31%|███       | 11001/35345 [19:07<42:23,  9.57it/s]

Batch 11000/35345:
total_loss: 0.741616427898407


Epoch 1/5:  32%|███▏      | 11251/35345 [19:33<41:56,  9.57it/s]

Batch 11250/35345:
total_loss: 0.26287803053855896


Epoch 1/5:  33%|███▎      | 11501/35345 [19:59<41:13,  9.64it/s]

Batch 11500/35345:
total_loss: 0.7664706707000732


Epoch 1/5:  33%|███▎      | 11751/35345 [20:26<41:09,  9.55it/s]

Batch 11750/35345:
total_loss: 0.5732376575469971


Epoch 1/5:  34%|███▍      | 12000/35345 [20:52<50:19,  7.73it/s]

Batch 12000/35345:
total_loss: 0.5179423689842224


Epoch 1/5:  35%|███▍      | 12251/35345 [21:18<40:10,  9.58it/s]

Batch 12250/35345:
total_loss: 0.4379156827926636


Epoch 1/5:  35%|███▌      | 12501/35345 [21:44<39:49,  9.56it/s]

Batch 12500/35345:
total_loss: 0.45414578914642334


Epoch 1/5:  36%|███▌      | 12751/35345 [22:11<39:25,  9.55it/s]

Batch 12750/35345:
total_loss: 0.6107162237167358


Epoch 1/5:  37%|███▋      | 13001/35345 [22:37<38:38,  9.64it/s]

Batch 13000/35345:
total_loss: 0.5961046814918518


Epoch 1/5:  37%|███▋      | 13251/35345 [23:04<39:12,  9.39it/s]

Batch 13250/35345:
total_loss: 0.5674253106117249


Epoch 1/5:  38%|███▊      | 13501/35345 [23:30<37:49,  9.63it/s]

Batch 13500/35345:
total_loss: 0.6587365865707397


Epoch 1/5:  39%|███▉      | 13751/35345 [23:56<37:39,  9.56it/s]

Batch 13750/35345:
total_loss: 0.5916584134101868


Epoch 1/5:  40%|███▉      | 14001/35345 [24:23<37:08,  9.58it/s]

Batch 14000/35345:
total_loss: 0.46168744564056396


Epoch 1/5:  40%|████      | 14251/35345 [24:49<36:35,  9.61it/s]

Batch 14250/35345:
total_loss: 0.5737860798835754


Epoch 1/5:  41%|████      | 14501/35345 [25:15<36:25,  9.54it/s]

Batch 14500/35345:
total_loss: 0.4443151354789734


Epoch 1/5:  42%|████▏     | 14751/35345 [25:41<36:28,  9.41it/s]

Batch 14750/35345:
total_loss: 0.3702394664287567


Epoch 1/5:  42%|████▏     | 15001/35345 [26:08<35:22,  9.58it/s]

Batch 15000/35345:
total_loss: 0.5254489779472351


Epoch 1/5:  43%|████▎     | 15251/35345 [26:34<34:50,  9.61it/s]

Batch 15250/35345:
total_loss: 0.4904455542564392


Epoch 1/5:  44%|████▍     | 15501/35345 [27:01<34:33,  9.57it/s]

Batch 15500/35345:
total_loss: 0.47921890020370483


Epoch 1/5:  45%|████▍     | 15751/35345 [27:27<33:48,  9.66it/s]

Batch 15750/35345:
total_loss: 0.4711994230747223


Epoch 1/5:  45%|████▌     | 16001/35345 [27:53<34:30,  9.34it/s]

Batch 16000/35345:
total_loss: 0.46863338351249695


Epoch 1/5:  46%|████▌     | 16251/35345 [28:19<33:07,  9.61it/s]

Batch 16250/35345:
total_loss: 0.5831630825996399


Epoch 1/5:  47%|████▋     | 16501/35345 [28:45<32:32,  9.65it/s]

Batch 16500/35345:
total_loss: 0.6108360886573792


Epoch 1/5:  47%|████▋     | 16751/35345 [29:11<32:19,  9.59it/s]

Batch 16750/35345:
total_loss: 0.5285031795501709


Epoch 1/5:  48%|████▊     | 17001/35345 [29:38<31:59,  9.56it/s]

Batch 17000/35345:
total_loss: 0.4236215353012085


Epoch 1/5:  49%|████▉     | 17251/35345 [30:04<31:40,  9.52it/s]

Batch 17250/35345:
total_loss: 0.682169497013092


Epoch 1/5:  50%|████▉     | 17501/35345 [30:30<31:09,  9.55it/s]

Batch 17500/35345:
total_loss: 0.4105146527290344


Epoch 1/5:  50%|█████     | 17751/35345 [30:56<30:37,  9.58it/s]

Batch 17750/35345:
total_loss: 0.6415276527404785


Epoch 1/5:  51%|█████     | 18001/35345 [31:23<29:49,  9.69it/s]

Batch 18000/35345:
total_loss: 0.4373925030231476


Epoch 1/5:  52%|█████▏    | 18251/35345 [31:49<29:33,  9.64it/s]

Batch 18250/35345:
total_loss: 0.4690210521221161


Epoch 1/5:  52%|█████▏    | 18501/35345 [32:15<29:08,  9.63it/s]

Batch 18500/35345:
total_loss: 0.6852198243141174


Epoch 1/5:  53%|█████▎    | 18751/35345 [32:41<28:27,  9.72it/s]

Batch 18750/35345:
total_loss: 0.4431312680244446


Epoch 1/5:  54%|█████▍    | 19001/35345 [33:07<28:13,  9.65it/s]

Batch 19000/35345:
total_loss: 0.5066320896148682


Epoch 1/5:  54%|█████▍    | 19251/35345 [33:33<27:53,  9.62it/s]

Batch 19250/35345:
total_loss: 0.6598734855651855


Epoch 1/5:  55%|█████▌    | 19501/35345 [34:00<27:42,  9.53it/s]

Batch 19500/35345:
total_loss: 0.4825492203235626


Epoch 1/5:  56%|█████▌    | 19751/35345 [34:26<27:02,  9.61it/s]

Batch 19750/35345:
total_loss: 0.4951431453227997


Epoch 1/5:  57%|█████▋    | 20001/35345 [34:52<28:04,  9.11it/s]

Batch 20000/35345:
total_loss: 0.4956195056438446


Epoch 1/5:  57%|█████▋    | 20251/35345 [35:18<26:41,  9.43it/s]

Batch 20250/35345:
total_loss: 0.5155373215675354


Epoch 1/5:  58%|█████▊    | 20501/35345 [35:45<26:11,  9.44it/s]

Batch 20500/35345:
total_loss: 0.36823806166648865


Epoch 1/5:  59%|█████▊    | 20751/35345 [36:11<25:22,  9.59it/s]

Batch 20750/35345:
total_loss: 0.6087230443954468


Epoch 1/5:  59%|█████▉    | 21001/35345 [36:37<25:01,  9.55it/s]

Batch 21000/35345:
total_loss: 0.40752410888671875


Epoch 1/5:  60%|██████    | 21251/35345 [37:04<25:26,  9.23it/s]

Batch 21250/35345:
total_loss: 0.5038299560546875


Epoch 1/5:  61%|██████    | 21501/35345 [37:30<24:28,  9.42it/s]

Batch 21500/35345:
total_loss: 0.40327373147010803


Epoch 1/5:  62%|██████▏   | 21751/35345 [37:56<23:50,  9.50it/s]

Batch 21750/35345:
total_loss: 0.23550495505332947


Epoch 1/5:  62%|██████▏   | 22001/35345 [38:22<23:02,  9.65it/s]

Batch 22000/35345:
total_loss: 0.32055023312568665


Epoch 1/5:  63%|██████▎   | 22251/35345 [38:48<22:49,  9.56it/s]

Batch 22250/35345:
total_loss: 0.41123074293136597


Epoch 1/5:  64%|██████▎   | 22501/35345 [39:15<22:55,  9.34it/s]

Batch 22500/35345:
total_loss: 0.46040594577789307


Epoch 1/5:  64%|██████▍   | 22751/35345 [39:41<22:28,  9.34it/s]

Batch 22750/35345:
total_loss: 0.22574681043624878


Epoch 1/5:  65%|██████▌   | 23001/35345 [40:07<21:29,  9.57it/s]

Batch 23000/35345:
total_loss: 0.5056415796279907


Epoch 1/5:  66%|██████▌   | 23251/35345 [40:34<21:10,  9.52it/s]

Batch 23250/35345:
total_loss: 0.4712775647640228


Epoch 1/5:  66%|██████▋   | 23501/35345 [41:00<20:23,  9.68it/s]

Batch 23500/35345:
total_loss: 0.5281290411949158


Epoch 1/5:  67%|██████▋   | 23751/35345 [41:26<20:05,  9.62it/s]

Batch 23750/35345:
total_loss: 0.32694312930107117


Epoch 1/5:  68%|██████▊   | 24000/35345 [41:52<20:29,  9.22it/s]

Batch 24000/35345:
total_loss: 0.5322115421295166


Epoch 1/5:  69%|██████▊   | 24251/35345 [42:19<19:16,  9.59it/s]

Batch 24250/35345:
total_loss: 0.3966643214225769


Epoch 1/5:  69%|██████▉   | 24501/35345 [42:45<18:43,  9.65it/s]

Batch 24500/35345:
total_loss: 0.31565383076667786


Epoch 1/5:  70%|███████   | 24751/35345 [43:12<18:17,  9.65it/s]

Batch 24750/35345:
total_loss: 0.4255695343017578


Epoch 1/5:  71%|███████   | 25001/35345 [43:38<18:10,  9.49it/s]

Batch 25000/35345:
total_loss: 0.5857208967208862


Epoch 1/5:  71%|███████▏  | 25251/35345 [44:05<17:37,  9.54it/s]

Batch 25250/35345:
total_loss: 0.4732074439525604


Epoch 1/5:  72%|███████▏  | 25501/35345 [44:31<16:50,  9.74it/s]

Batch 25500/35345:
total_loss: 0.30186727643013


Epoch 1/5:  73%|███████▎  | 25751/35345 [44:57<16:35,  9.64it/s]

Batch 25750/35345:
total_loss: 0.4261530935764313


Epoch 1/5:  74%|███████▎  | 26001/35345 [45:24<16:30,  9.43it/s]

Batch 26000/35345:
total_loss: 0.46451419591903687


Epoch 1/5:  74%|███████▍  | 26251/35345 [45:50<15:49,  9.58it/s]

Batch 26250/35345:
total_loss: 0.22994396090507507


Epoch 1/5:  75%|███████▍  | 26501/35345 [46:16<17:41,  8.33it/s]

Batch 26500/35345:
total_loss: 0.47799623012542725


Epoch 1/5:  76%|███████▌  | 26751/35345 [46:42<15:20,  9.33it/s]

Batch 26750/35345:
total_loss: 0.3726099729537964


Epoch 1/5:  76%|███████▋  | 27001/35345 [47:09<14:29,  9.60it/s]

Batch 27000/35345:
total_loss: 0.3890466094017029


Epoch 1/5:  77%|███████▋  | 27251/35345 [47:35<13:59,  9.64it/s]

Batch 27250/35345:
total_loss: 0.6401770114898682


Epoch 1/5:  78%|███████▊  | 27501/35345 [48:01<13:35,  9.62it/s]

Batch 27500/35345:
total_loss: 0.5176008343696594


Epoch 1/5:  79%|███████▊  | 27751/35345 [48:27<13:23,  9.45it/s]

Batch 27750/35345:
total_loss: 0.4482760727405548


Epoch 1/5:  79%|███████▉  | 28001/35345 [48:53<12:47,  9.57it/s]

Batch 28000/35345:
total_loss: 0.2754921317100525


Epoch 1/5:  80%|███████▉  | 28251/35345 [49:19<12:20,  9.58it/s]

Batch 28250/35345:
total_loss: 0.4101526439189911


Epoch 1/5:  81%|████████  | 28501/35345 [49:46<11:45,  9.70it/s]

Batch 28500/35345:
total_loss: 0.42141321301460266


Epoch 1/5:  81%|████████▏ | 28751/35345 [50:12<11:26,  9.60it/s]

Batch 28750/35345:
total_loss: 0.4022425413131714


Epoch 1/5:  82%|████████▏ | 29001/35345 [50:38<10:57,  9.65it/s]

Batch 29000/35345:
total_loss: 0.5893812775611877


Epoch 1/5:  83%|████████▎ | 29251/35345 [51:04<10:56,  9.29it/s]

Batch 29250/35345:
total_loss: 0.31341832876205444


Epoch 1/5:  83%|████████▎ | 29501/35345 [51:30<10:06,  9.63it/s]

Batch 29500/35345:
total_loss: 0.5071956515312195


Epoch 1/5:  84%|████████▍ | 29751/35345 [51:56<09:40,  9.64it/s]

Batch 29750/35345:
total_loss: 0.3533589839935303


Epoch 1/5:  85%|████████▍ | 30001/35345 [52:22<09:23,  9.49it/s]

Batch 30000/35345:
total_loss: 0.40756332874298096


Epoch 1/5:  86%|████████▌ | 30251/35345 [52:49<08:48,  9.64it/s]

Batch 30250/35345:
total_loss: 0.3955462872982025


Epoch 1/5:  86%|████████▋ | 30501/35345 [53:15<08:27,  9.55it/s]

Batch 30500/35345:
total_loss: 0.5256631970405579


Epoch 1/5:  87%|████████▋ | 30751/35345 [53:41<07:57,  9.62it/s]

Batch 30750/35345:
total_loss: 0.2954269349575043


Epoch 1/5:  88%|████████▊ | 31001/35345 [54:08<07:30,  9.64it/s]

Batch 31000/35345:
total_loss: 0.4855959713459015


Epoch 1/5:  88%|████████▊ | 31251/35345 [54:34<07:05,  9.61it/s]

Batch 31250/35345:
total_loss: 0.3051591217517853


Epoch 1/5:  89%|████████▉ | 31501/35345 [55:00<06:39,  9.63it/s]

Batch 31500/35345:
total_loss: 0.29262441396713257


Epoch 1/5:  90%|████████▉ | 31751/35345 [55:26<06:14,  9.59it/s]

Batch 31750/35345:
total_loss: 0.43094801902770996


Epoch 1/5:  91%|█████████ | 32001/35345 [55:52<05:59,  9.29it/s]

Batch 32000/35345:
total_loss: 0.27268707752227783


Epoch 1/5:  91%|█████████ | 32251/35345 [56:19<05:29,  9.40it/s]

Batch 32250/35345:
total_loss: 0.35425621271133423


Epoch 1/5:  92%|█████████▏| 32501/35345 [56:45<04:55,  9.61it/s]

Batch 32500/35345:
total_loss: 0.2884497046470642


Epoch 1/5:  93%|█████████▎| 32751/35345 [57:11<04:31,  9.55it/s]

Batch 32750/35345:
total_loss: 0.5273545980453491


Epoch 1/5:  93%|█████████▎| 33001/35345 [57:38<04:02,  9.65it/s]

Batch 33000/35345:
total_loss: 0.2960396409034729


Epoch 1/5:  94%|█████████▍| 33251/35345 [58:05<03:39,  9.53it/s]

Batch 33250/35345:
total_loss: 0.443998247385025


Epoch 1/5:  95%|█████████▍| 33501/35345 [58:31<03:13,  9.54it/s]

Batch 33500/35345:
total_loss: 0.47666722536087036


Epoch 1/5:  95%|█████████▌| 33751/35345 [58:58<02:46,  9.58it/s]

Batch 33750/35345:
total_loss: 0.3696814477443695


Epoch 1/5:  96%|█████████▌| 34001/35345 [59:24<02:20,  9.55it/s]

Batch 34000/35345:
total_loss: 0.259086936712265


Epoch 1/5:  97%|█████████▋| 34251/35345 [59:51<01:55,  9.49it/s]

Batch 34250/35345:
total_loss: 0.33316007256507874


Epoch 1/5:  98%|█████████▊| 34501/35345 [1:00:17<01:28,  9.49it/s]

Batch 34500/35345:
total_loss: 0.4679219722747803


Epoch 1/5:  98%|█████████▊| 34751/35345 [1:00:44<01:03,  9.29it/s]

Batch 34750/35345:
total_loss: 0.3378826379776001


Epoch 1/5:  99%|█████████▉| 35001/35345 [1:01:10<00:35,  9.61it/s]

Batch 35000/35345:
total_loss: 0.2687230110168457


Epoch 1/5: 100%|█████████▉| 35251/35345 [1:01:37<00:09,  9.47it/s]

Batch 35250/35345:
total_loss: 0.25776365399360657


Epoch 1/5: 100%|██████████| 35345/35345 [1:01:48<00:00,  9.53it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  2.95it/s]


Epoch Summary:
Train Total Loss: 0.5497
Val Total Loss: 0.3439
Learning Rate: 0.000100
New best model saved with val loss: 0.3439
--------------------------------------------------



Epoch 2/5:   1%|          | 251/35345 [00:26<1:01:57,  9.44it/s]

Batch 250/35345:
total_loss: 0.23203979432582855


Epoch 2/5:   1%|▏         | 501/35345 [00:53<1:00:18,  9.63it/s]

Batch 500/35345:
total_loss: 0.2817099988460541


Epoch 2/5:   2%|▏         | 751/35345 [01:19<1:01:15,  9.41it/s]

Batch 750/35345:
total_loss: 0.4479381740093231


Epoch 2/5:   3%|▎         | 1001/35345 [01:46<59:24,  9.63it/s]

Batch 1000/35345:
total_loss: 0.3160417377948761


Epoch 2/5:   4%|▎         | 1251/35345 [02:12<59:30,  9.55it/s]

Batch 1250/35345:
total_loss: 0.3096984028816223


Epoch 2/5:   4%|▍         | 1501/35345 [02:38<59:15,  9.52it/s]

Batch 1500/35345:
total_loss: 0.4870072901248932


Epoch 2/5:   5%|▍         | 1751/35345 [03:05<1:04:30,  8.68it/s]

Batch 1750/35345:
total_loss: 0.28920069336891174


Epoch 2/5:   6%|▌         | 2001/35345 [03:31<58:03,  9.57it/s]

Batch 2000/35345:
total_loss: 0.32783421874046326


Epoch 2/5:   6%|▋         | 2251/35345 [03:57<57:23,  9.61it/s]

Batch 2250/35345:
total_loss: 0.48665666580200195


Epoch 2/5:   7%|▋         | 2501/35345 [04:24<57:08,  9.58it/s]

Batch 2500/35345:
total_loss: 0.4843777120113373


Epoch 2/5:   8%|▊         | 2751/35345 [04:51<57:04,  9.52it/s]

Batch 2750/35345:
total_loss: 0.5081274509429932


Epoch 2/5:   8%|▊         | 3001/35345 [05:17<56:03,  9.62it/s]

Batch 3000/35345:
total_loss: 0.3781970739364624


Epoch 2/5:   9%|▉         | 3251/35345 [05:43<55:51,  9.58it/s]

Batch 3250/35345:
total_loss: 0.2756538391113281


Epoch 2/5:  10%|▉         | 3501/35345 [06:10<55:25,  9.57it/s]

Batch 3500/35345:
total_loss: 0.4742562472820282


Epoch 2/5:  11%|█         | 3751/35345 [06:37<54:47,  9.61it/s]

Batch 3750/35345:
total_loss: 0.24878746271133423


Epoch 2/5:  11%|█▏        | 4001/35345 [07:03<54:56,  9.51it/s]

Batch 4000/35345:
total_loss: 0.3061128258705139


Epoch 2/5:  12%|█▏        | 4251/35345 [07:29<53:35,  9.67it/s]

Batch 4250/35345:
total_loss: 0.190854012966156


Epoch 2/5:  13%|█▎        | 4501/35345 [07:56<53:45,  9.56it/s]

Batch 4500/35345:
total_loss: 0.5260723829269409


Epoch 2/5:  13%|█▎        | 4751/35345 [08:22<52:47,  9.66it/s]

Batch 4750/35345:
total_loss: 0.3308979272842407


Epoch 2/5:  14%|█▍        | 5001/35345 [08:48<53:32,  9.45it/s]

Batch 5000/35345:
total_loss: 0.4482291340827942


Epoch 2/5:  15%|█▍        | 5251/35345 [09:15<52:12,  9.61it/s]

Batch 5250/35345:
total_loss: 0.36763063073158264


Epoch 2/5:  16%|█▌        | 5501/35345 [09:41<55:06,  9.03it/s]

Batch 5500/35345:
total_loss: 0.3017757833003998


Epoch 2/5:  16%|█▋        | 5751/35345 [10:08<51:41,  9.54it/s]

Batch 5750/35345:
total_loss: 0.37267759442329407


Epoch 2/5:  17%|█▋        | 6001/35345 [10:34<53:34,  9.13it/s]

Batch 6000/35345:
total_loss: 0.40474244952201843


Epoch 2/5:  18%|█▊        | 6251/35345 [11:01<50:55,  9.52it/s]

Batch 6250/35345:
total_loss: 0.3799101710319519


Epoch 2/5:  18%|█▊        | 6501/35345 [11:27<50:09,  9.58it/s]

Batch 6500/35345:
total_loss: 0.4197917580604553


Epoch 2/5:  19%|█▉        | 6751/35345 [11:54<49:58,  9.54it/s]

Batch 6750/35345:
total_loss: 0.4503942131996155


Epoch 2/5:  20%|█▉        | 7001/35345 [12:20<54:08,  8.72it/s]

Batch 7000/35345:
total_loss: 0.3499585688114166


Epoch 2/5:  21%|██        | 7251/35345 [12:46<49:17,  9.50it/s]

Batch 7250/35345:
total_loss: 0.5948711037635803


Epoch 2/5:  21%|██        | 7501/35345 [13:13<48:16,  9.61it/s]

Batch 7500/35345:
total_loss: 0.3452635109424591


Epoch 2/5:  22%|██▏       | 7751/35345 [13:40<47:35,  9.66it/s]

Batch 7750/35345:
total_loss: 0.42343562841415405


Epoch 2/5:  23%|██▎       | 8001/35345 [14:06<52:57,  8.61it/s]

Batch 8000/35345:
total_loss: 0.22007234394550323


Epoch 2/5:  23%|██▎       | 8251/35345 [14:32<48:56,  9.23it/s]

Batch 8250/35345:
total_loss: 0.48754337430000305


Epoch 2/5:  24%|██▍       | 8501/35345 [14:59<46:34,  9.60it/s]

Batch 8500/35345:
total_loss: 0.3464258909225464


Epoch 2/5:  25%|██▍       | 8751/35345 [15:25<46:19,  9.57it/s]

Batch 8750/35345:
total_loss: 0.26380622386932373


Epoch 2/5:  25%|██▌       | 9001/35345 [15:51<45:56,  9.56it/s]

Batch 9000/35345:
total_loss: 0.23344729840755463


Epoch 2/5:  26%|██▌       | 9251/35345 [16:17<45:18,  9.60it/s]

Batch 9250/35345:
total_loss: 0.24501922726631165


Epoch 2/5:  27%|██▋       | 9501/35345 [16:44<45:02,  9.56it/s]

Batch 9500/35345:
total_loss: 0.45744284987449646


Epoch 2/5:  28%|██▊       | 9751/35345 [17:10<44:17,  9.63it/s]

Batch 9750/35345:
total_loss: 0.4545100927352905


Epoch 2/5:  28%|██▊       | 10001/35345 [17:36<44:36,  9.47it/s]

Batch 10000/35345:
total_loss: 0.414695680141449


Epoch 2/5:  29%|██▉       | 10251/35345 [18:03<43:31,  9.61it/s]

Batch 10250/35345:
total_loss: 0.27304932475090027


Epoch 2/5:  30%|██▉       | 10501/35345 [18:30<43:03,  9.62it/s]

Batch 10500/35345:
total_loss: 0.4661630094051361


Epoch 2/5:  30%|███       | 10751/35345 [18:57<42:28,  9.65it/s]

Batch 10750/35345:
total_loss: 0.4467254877090454


Epoch 2/5:  31%|███       | 11000/35345 [19:23<49:12,  8.25it/s]

Batch 11000/35345:
total_loss: 0.27765393257141113


Epoch 2/5:  32%|███▏      | 11251/35345 [19:50<41:44,  9.62it/s]

Batch 11250/35345:
total_loss: 0.38124462962150574


Epoch 2/5:  33%|███▎      | 11501/35345 [20:16<41:11,  9.65it/s]

Batch 11500/35345:
total_loss: 0.30800363421440125


Epoch 2/5:  33%|███▎      | 11751/35345 [20:43<40:51,  9.62it/s]

Batch 11750/35345:
total_loss: 0.29685917496681213


Epoch 2/5:  34%|███▍      | 12001/35345 [21:09<40:54,  9.51it/s]

Batch 12000/35345:
total_loss: 0.35846948623657227


Epoch 2/5:  35%|███▍      | 12251/35345 [21:36<40:16,  9.56it/s]

Batch 12250/35345:
total_loss: 0.18237078189849854


Epoch 2/5:  35%|███▌      | 12501/35345 [22:02<39:20,  9.68it/s]

Batch 12500/35345:
total_loss: 0.2392653226852417


Epoch 2/5:  36%|███▌      | 12751/35345 [22:28<39:11,  9.61it/s]

Batch 12750/35345:
total_loss: 0.4837948679924011


Epoch 2/5:  37%|███▋      | 13001/35345 [22:54<38:36,  9.65it/s]

Batch 13000/35345:
total_loss: 0.4235680401325226


Epoch 2/5:  37%|███▋      | 13251/35345 [23:21<38:01,  9.69it/s]

Batch 13250/35345:
total_loss: 0.2142820805311203


Epoch 2/5:  38%|███▊      | 13501/35345 [23:47<37:50,  9.62it/s]

Batch 13500/35345:
total_loss: 0.4857197403907776


Epoch 2/5:  39%|███▉      | 13751/35345 [24:14<41:25,  8.69it/s]

Batch 13750/35345:
total_loss: 0.2600705623626709


Epoch 2/5:  40%|███▉      | 14001/35345 [24:40<36:57,  9.63it/s]

Batch 14000/35345:
total_loss: 0.43554240465164185


Epoch 2/5:  40%|████      | 14251/35345 [25:07<36:55,  9.52it/s]

Batch 14250/35345:
total_loss: 0.36432939767837524


Epoch 2/5:  41%|████      | 14501/35345 [25:34<36:24,  9.54it/s]

Batch 14500/35345:
total_loss: 0.23115754127502441


Epoch 2/5:  42%|████▏     | 14751/35345 [26:00<38:05,  9.01it/s]

Batch 14750/35345:
total_loss: 0.4910087287425995


Epoch 2/5:  42%|████▏     | 15001/35345 [26:26<35:22,  9.58it/s]

Batch 15000/35345:
total_loss: 0.36836302280426025


Epoch 2/5:  43%|████▎     | 15251/35345 [26:53<34:48,  9.62it/s]

Batch 15250/35345:
total_loss: 0.24972841143608093


Epoch 2/5:  44%|████▍     | 15501/35345 [27:20<34:48,  9.50it/s]

Batch 15500/35345:
total_loss: 0.3147754371166229


Epoch 2/5:  45%|████▍     | 15751/35345 [27:46<34:10,  9.55it/s]

Batch 15750/35345:
total_loss: 0.214797705411911


Epoch 2/5:  45%|████▌     | 16001/35345 [28:12<33:44,  9.55it/s]

Batch 16000/35345:
total_loss: 0.3079574704170227


Epoch 2/5:  46%|████▌     | 16251/35345 [28:39<33:03,  9.63it/s]

Batch 16250/35345:
total_loss: 0.27176761627197266


Epoch 2/5:  47%|████▋     | 16501/35345 [29:06<33:07,  9.48it/s]

Batch 16500/35345:
total_loss: 0.3196493089199066


Epoch 2/5:  47%|████▋     | 16751/35345 [29:32<32:25,  9.56it/s]

Batch 16750/35345:
total_loss: 0.2273491621017456


Epoch 2/5:  48%|████▊     | 17001/35345 [29:59<31:38,  9.66it/s]

Batch 17000/35345:
total_loss: 0.34750327467918396


Epoch 2/5:  49%|████▉     | 17251/35345 [30:25<31:21,  9.62it/s]

Batch 17250/35345:
total_loss: 0.337258517742157


Epoch 2/5:  50%|████▉     | 17501/35345 [30:52<31:04,  9.57it/s]

Batch 17500/35345:
total_loss: 0.2686620056629181


Epoch 2/5:  50%|█████     | 17751/35345 [31:18<30:26,  9.63it/s]

Batch 17750/35345:
total_loss: 0.3980637490749359


Epoch 2/5:  51%|█████     | 18001/35345 [31:45<30:15,  9.55it/s]

Batch 18000/35345:
total_loss: 0.2669616639614105


Epoch 2/5:  52%|█████▏    | 18251/35345 [32:11<29:37,  9.62it/s]

Batch 18250/35345:
total_loss: 0.2765498459339142


Epoch 2/5:  52%|█████▏    | 18501/35345 [32:37<29:55,  9.38it/s]

Batch 18500/35345:
total_loss: 0.336688756942749


Epoch 2/5:  53%|█████▎    | 18751/35345 [33:04<28:58,  9.54it/s]

Batch 18750/35345:
total_loss: 0.3335287570953369


Epoch 2/5:  54%|█████▍    | 19001/35345 [33:31<28:32,  9.55it/s]

Batch 19000/35345:
total_loss: 0.4433160424232483


Epoch 2/5:  54%|█████▍    | 19251/35345 [33:58<28:28,  9.42it/s]

Batch 19250/35345:
total_loss: 0.28168126940727234


Epoch 2/5:  55%|█████▌    | 19501/35345 [34:24<27:43,  9.52it/s]

Batch 19500/35345:
total_loss: 0.3813449442386627


Epoch 2/5:  56%|█████▌    | 19751/35345 [34:51<27:03,  9.61it/s]

Batch 19750/35345:
total_loss: 0.48589345812797546


Epoch 2/5:  57%|█████▋    | 20001/35345 [35:18<26:33,  9.63it/s]

Batch 20000/35345:
total_loss: 0.20583263039588928


Epoch 2/5:  57%|█████▋    | 20251/35345 [35:44<32:15,  7.80it/s]

Batch 20250/35345:
total_loss: 0.19299593567848206


Epoch 2/5:  58%|█████▊    | 20501/35345 [36:11<25:43,  9.62it/s]

Batch 20500/35345:
total_loss: 0.3804525136947632


Epoch 2/5:  59%|█████▊    | 20751/35345 [36:37<29:27,  8.25it/s]

Batch 20750/35345:
total_loss: 0.31032294034957886


Epoch 2/5:  59%|█████▉    | 21001/35345 [37:04<24:52,  9.61it/s]

Batch 21000/35345:
total_loss: 0.3076488971710205


Epoch 2/5:  60%|██████    | 21251/35345 [37:30<24:25,  9.62it/s]

Batch 21250/35345:
total_loss: 0.23332415521144867


Epoch 2/5:  61%|██████    | 21501/35345 [37:57<23:51,  9.67it/s]

Batch 21500/35345:
total_loss: 0.3662266433238983


Epoch 2/5:  62%|██████▏   | 21751/35345 [38:24<23:27,  9.66it/s]

Batch 21750/35345:
total_loss: 0.13694195449352264


Epoch 2/5:  62%|██████▏   | 22001/35345 [38:51<24:19,  9.14it/s]

Batch 22000/35345:
total_loss: 0.27342256903648376


Epoch 2/5:  63%|██████▎   | 22251/35345 [39:17<24:03,  9.07it/s]

Batch 22250/35345:
total_loss: 0.2740055024623871


Epoch 2/5:  64%|██████▎   | 22501/35345 [39:44<22:07,  9.67it/s]

Batch 22500/35345:
total_loss: 0.3008790910243988


Epoch 2/5:  64%|██████▍   | 22751/35345 [40:10<21:39,  9.69it/s]

Batch 22750/35345:
total_loss: 0.2126980870962143


Epoch 2/5:  65%|██████▌   | 23001/35345 [40:37<21:47,  9.44it/s]

Batch 23000/35345:
total_loss: 0.25320199131965637


Epoch 2/5:  66%|██████▌   | 23251/35345 [41:03<21:20,  9.44it/s]

Batch 23250/35345:
total_loss: 0.3147207498550415


Epoch 2/5:  66%|██████▋   | 23501/35345 [41:29<20:42,  9.53it/s]

Batch 23500/35345:
total_loss: 0.2924109399318695


Epoch 2/5:  67%|██████▋   | 23751/35345 [41:56<19:53,  9.71it/s]

Batch 23750/35345:
total_loss: 0.23246373236179352


Epoch 2/5:  68%|██████▊   | 24001/35345 [42:22<19:40,  9.61it/s]

Batch 24000/35345:
total_loss: 0.39852219820022583


Epoch 2/5:  69%|██████▊   | 24251/35345 [42:49<19:33,  9.46it/s]

Batch 24250/35345:
total_loss: 0.2584221363067627


Epoch 2/5:  69%|██████▉   | 24501/35345 [43:15<18:42,  9.66it/s]

Batch 24500/35345:
total_loss: 0.3393940031528473


Epoch 2/5:  70%|███████   | 24751/35345 [43:41<18:24,  9.60it/s]

Batch 24750/35345:
total_loss: 0.3081529438495636


Epoch 2/5:  71%|███████   | 25001/35345 [44:07<17:59,  9.58it/s]

Batch 25000/35345:
total_loss: 0.27606427669525146


Epoch 2/5:  71%|███████▏  | 25251/35345 [44:33<17:27,  9.64it/s]

Batch 25250/35345:
total_loss: 0.36695972084999084


Epoch 2/5:  72%|███████▏  | 25501/35345 [44:59<17:24,  9.43it/s]

Batch 25500/35345:
total_loss: 0.34424281120300293


Epoch 2/5:  73%|███████▎  | 25751/35345 [45:26<16:41,  9.58it/s]

Batch 25750/35345:
total_loss: 0.2834092974662781


Epoch 2/5:  74%|███████▎  | 26001/35345 [45:52<16:08,  9.65it/s]

Batch 26000/35345:
total_loss: 0.2322937250137329


Epoch 2/5:  74%|███████▍  | 26251/35345 [46:19<16:14,  9.33it/s]

Batch 26250/35345:
total_loss: 0.21391794085502625


Epoch 2/5:  75%|███████▍  | 26501/35345 [46:45<15:21,  9.59it/s]

Batch 26500/35345:
total_loss: 0.13453234732151031


Epoch 2/5:  76%|███████▌  | 26751/35345 [47:11<14:51,  9.64it/s]

Batch 26750/35345:
total_loss: 0.3611655533313751


Epoch 2/5:  76%|███████▋  | 27001/35345 [47:38<14:19,  9.70it/s]

Batch 27000/35345:
total_loss: 0.29624292254447937


Epoch 2/5:  77%|███████▋  | 27251/35345 [48:04<14:08,  9.54it/s]

Batch 27250/35345:
total_loss: 0.23370420932769775


Epoch 2/5:  78%|███████▊  | 27501/35345 [48:31<13:46,  9.50it/s]

Batch 27500/35345:
total_loss: 0.2823128402233124


Epoch 2/5:  79%|███████▊  | 27751/35345 [48:56<13:12,  9.58it/s]

Batch 27750/35345:
total_loss: 0.3527529835700989


Epoch 2/5:  79%|███████▉  | 28001/35345 [49:23<12:40,  9.66it/s]

Batch 28000/35345:
total_loss: 0.2660616338253021


Epoch 2/5:  80%|███████▉  | 28251/35345 [49:50<12:11,  9.69it/s]

Batch 28250/35345:
total_loss: 0.18723765015602112


Epoch 2/5:  81%|████████  | 28501/35345 [50:16<11:50,  9.64it/s]

Batch 28500/35345:
total_loss: 0.37569770216941833


Epoch 2/5:  81%|████████▏ | 28751/35345 [50:43<11:35,  9.48it/s]

Batch 28750/35345:
total_loss: 0.34253013134002686


Epoch 2/5:  82%|████████▏ | 29001/35345 [51:09<11:34,  9.14it/s]

Batch 29000/35345:
total_loss: 0.09848526120185852


Epoch 2/5:  83%|████████▎ | 29251/35345 [51:35<10:29,  9.68it/s]

Batch 29250/35345:
total_loss: 0.20527340471744537


Epoch 2/5:  83%|████████▎ | 29501/35345 [52:02<10:08,  9.60it/s]

Batch 29500/35345:
total_loss: 0.21771523356437683


Epoch 2/5:  84%|████████▍ | 29751/35345 [52:28<09:38,  9.67it/s]

Batch 29750/35345:
total_loss: 0.4061836004257202


Epoch 2/5:  85%|████████▍ | 30001/35345 [52:54<09:48,  9.08it/s]

Batch 30000/35345:
total_loss: 0.3594833314418793


Epoch 2/5:  86%|████████▌ | 30251/35345 [53:20<08:52,  9.56it/s]

Batch 30250/35345:
total_loss: 0.33820584416389465


Epoch 2/5:  86%|████████▋ | 30501/35345 [53:46<08:21,  9.65it/s]

Batch 30500/35345:
total_loss: 0.36370691657066345


Epoch 2/5:  87%|████████▋ | 30751/35345 [54:13<07:55,  9.66it/s]

Batch 30750/35345:
total_loss: 0.30863890051841736


Epoch 2/5:  88%|████████▊ | 31001/35345 [54:39<07:31,  9.63it/s]

Batch 31000/35345:
total_loss: 0.2318686842918396


Epoch 2/5:  88%|████████▊ | 31251/35345 [55:05<07:09,  9.53it/s]

Batch 31250/35345:
total_loss: 0.29185348749160767


Epoch 2/5:  89%|████████▉ | 31501/35345 [55:32<06:42,  9.55it/s]

Batch 31500/35345:
total_loss: 0.2899448275566101


Epoch 2/5:  90%|████████▉ | 31751/35345 [55:58<06:12,  9.66it/s]

Batch 31750/35345:
total_loss: 0.3241116404533386


Epoch 2/5:  91%|█████████ | 32001/35345 [56:25<05:44,  9.70it/s]

Batch 32000/35345:
total_loss: 0.3095822036266327


Epoch 2/5:  91%|█████████ | 32251/35345 [56:52<05:22,  9.60it/s]

Batch 32250/35345:
total_loss: 0.18993030488491058


Epoch 2/5:  92%|█████████▏| 32501/35345 [57:18<04:56,  9.59it/s]

Batch 32500/35345:
total_loss: 0.2551872730255127


Epoch 2/5:  93%|█████████▎| 32751/35345 [57:44<04:30,  9.59it/s]

Batch 32750/35345:
total_loss: 0.2511644959449768


Epoch 2/5:  93%|█████████▎| 33001/35345 [58:11<04:08,  9.45it/s]

Batch 33000/35345:
total_loss: 0.1614418476819992


Epoch 2/5:  94%|█████████▍| 33251/35345 [58:37<04:20,  8.03it/s]

Batch 33250/35345:
total_loss: 0.218337282538414


Epoch 2/5:  95%|█████████▍| 33501/35345 [59:04<03:11,  9.63it/s]

Batch 33500/35345:
total_loss: 0.19084949791431427


Epoch 2/5:  95%|█████████▌| 33751/35345 [59:30<02:47,  9.53it/s]

Batch 33750/35345:
total_loss: 0.2769787013530731


Epoch 2/5:  96%|█████████▌| 34000/35345 [59:57<02:19,  9.63it/s]

Batch 34000/35345:
total_loss: 0.1171543225646019


Epoch 2/5:  97%|█████████▋| 34251/35345 [1:00:23<01:54,  9.51it/s]

Batch 34250/35345:
total_loss: 0.27132681012153625


Epoch 2/5:  98%|█████████▊| 34500/35345 [1:00:49<01:32,  9.13it/s]

Batch 34500/35345:
total_loss: 0.27862659096717834


Epoch 2/5:  98%|█████████▊| 34751/35345 [1:01:16<01:01,  9.67it/s]

Batch 34750/35345:
total_loss: 0.29696834087371826


Epoch 2/5:  99%|█████████▉| 35001/35345 [1:01:42<00:35,  9.60it/s]

Batch 35000/35345:
total_loss: 0.31677430868148804


Epoch 2/5: 100%|█████████▉| 35251/35345 [1:02:09<00:09,  9.55it/s]

Batch 35250/35345:
total_loss: 0.5380210876464844


Epoch 2/5: 100%|██████████| 35345/35345 [1:02:20<00:00,  9.45it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.62it/s]


Epoch Summary:
Train Total Loss: 0.3207
Val Total Loss: 0.1865
Learning Rate: 0.000100
New best model saved with val loss: 0.1865
--------------------------------------------------



Epoch 3/5:   1%|          | 251/35345 [00:27<1:01:14,  9.55it/s]

Batch 250/35345:
total_loss: 0.23180115222930908


Epoch 3/5:   1%|▏         | 501/35345 [00:53<1:00:51,  9.54it/s]

Batch 500/35345:
total_loss: 0.22936907410621643


Epoch 3/5:   2%|▏         | 751/35345 [01:20<59:56,  9.62it/s]  

Batch 750/35345:
total_loss: 0.38528063893318176


Epoch 3/5:   3%|▎         | 1001/35345 [01:46<59:14,  9.66it/s]

Batch 1000/35345:
total_loss: 0.16745850443840027


Epoch 3/5:   4%|▎         | 1251/35345 [02:13<58:53,  9.65it/s]

Batch 1250/35345:
total_loss: 0.2533724009990692


Epoch 3/5:   4%|▍         | 1501/35345 [02:39<58:34,  9.63it/s]

Batch 1500/35345:
total_loss: 0.32840844988822937


Epoch 3/5:   5%|▍         | 1751/35345 [03:06<1:01:00,  9.18it/s]

Batch 1750/35345:
total_loss: 0.3615659773349762


Epoch 3/5:   6%|▌         | 2001/35345 [03:32<58:13,  9.55it/s]

Batch 2000/35345:
total_loss: 0.32527536153793335


Epoch 3/5:   6%|▋         | 2251/35345 [03:58<57:44,  9.55it/s]

Batch 2250/35345:
total_loss: 0.09576872736215591


Epoch 3/5:   7%|▋         | 2501/35345 [04:25<57:23,  9.54it/s]

Batch 2500/35345:
total_loss: 0.28408628702163696


Epoch 3/5:   8%|▊         | 2751/35345 [04:51<57:25,  9.46it/s]

Batch 2750/35345:
total_loss: 0.509041965007782


Epoch 3/5:   8%|▊         | 3001/35345 [05:17<56:49,  9.49it/s]

Batch 3000/35345:
total_loss: 0.2945655882358551


Epoch 3/5:   9%|▉         | 3251/35345 [05:44<56:15,  9.51it/s]

Batch 3250/35345:
total_loss: 0.22840771079063416


Epoch 3/5:  10%|▉         | 3501/35345 [06:10<54:48,  9.68it/s]

Batch 3500/35345:
total_loss: 0.2617906630039215


Epoch 3/5:  11%|█         | 3751/35345 [06:37<55:01,  9.57it/s]

Batch 3750/35345:
total_loss: 0.159208744764328


Epoch 3/5:  11%|█▏        | 4001/35345 [07:03<54:19,  9.62it/s]

Batch 4000/35345:
total_loss: 0.20673921704292297


Epoch 3/5:  12%|█▏        | 4251/35345 [07:29<53:32,  9.68it/s]

Batch 4250/35345:
total_loss: 0.36043447256088257


Epoch 3/5:  13%|█▎        | 4501/35345 [07:57<53:41,  9.58it/s]

Batch 4500/35345:
total_loss: 0.32632049918174744


Epoch 3/5:  13%|█▎        | 4751/35345 [08:23<53:23,  9.55it/s]

Batch 4750/35345:
total_loss: 0.23129501938819885


Epoch 3/5:  14%|█▍        | 5001/35345 [08:49<58:59,  8.57it/s]  

Batch 5000/35345:
total_loss: 0.30295613408088684


Epoch 3/5:  15%|█▍        | 5251/35345 [09:16<53:08,  9.44it/s]

Batch 5250/35345:
total_loss: 0.2570227384567261


Epoch 3/5:  16%|█▌        | 5501/35345 [09:42<52:26,  9.48it/s]

Batch 5500/35345:
total_loss: 0.1813923716545105


Epoch 3/5:  16%|█▋        | 5751/35345 [10:09<51:43,  9.54it/s]

Batch 5750/35345:
total_loss: 0.18266652524471283


Epoch 3/5:  17%|█▋        | 6001/35345 [10:35<51:08,  9.56it/s]

Batch 6000/35345:
total_loss: 0.1253420114517212


Epoch 3/5:  18%|█▊        | 6251/35345 [11:01<51:04,  9.49it/s]

Batch 6250/35345:
total_loss: 0.4316577613353729


Epoch 3/5:  18%|█▊        | 6501/35345 [11:28<50:26,  9.53it/s]

Batch 6500/35345:
total_loss: 0.23286762833595276


Epoch 3/5:  19%|█▉        | 6751/35345 [11:54<49:21,  9.65it/s]

Batch 6750/35345:
total_loss: 0.1778329610824585


Epoch 3/5:  20%|█▉        | 7001/35345 [12:21<50:08,  9.42it/s]

Batch 7000/35345:
total_loss: 0.33176112174987793


Epoch 3/5:  21%|██        | 7251/35345 [12:47<48:47,  9.60it/s]

Batch 7250/35345:
total_loss: 0.1748771369457245


Epoch 3/5:  21%|██        | 7501/35345 [13:13<48:26,  9.58it/s]

Batch 7500/35345:
total_loss: 0.3407248854637146


Epoch 3/5:  22%|██▏       | 7751/35345 [13:40<48:08,  9.55it/s]

Batch 7750/35345:
total_loss: 0.2642499506473541


Epoch 3/5:  23%|██▎       | 8001/35345 [14:06<49:16,  9.25it/s]

Batch 8000/35345:
total_loss: 0.16221551597118378


Epoch 3/5:  23%|██▎       | 8251/35345 [14:33<48:12,  9.37it/s]

Batch 8250/35345:
total_loss: 0.2770310938358307


Epoch 3/5:  24%|██▍       | 8501/35345 [15:00<46:20,  9.65it/s]

Batch 8500/35345:
total_loss: 0.13724902272224426


Epoch 3/5:  25%|██▍       | 8751/35345 [15:26<46:10,  9.60it/s]

Batch 8750/35345:
total_loss: 0.2958134412765503


Epoch 3/5:  25%|██▌       | 9001/35345 [15:53<45:45,  9.59it/s]

Batch 9000/35345:
total_loss: 0.17891770601272583


Epoch 3/5:  26%|██▌       | 9251/35345 [16:19<45:15,  9.61it/s]

Batch 9250/35345:
total_loss: 0.2649571895599365


Epoch 3/5:  27%|██▋       | 9501/35345 [16:45<46:19,  9.30it/s]

Batch 9500/35345:
total_loss: 0.3196297287940979


Epoch 3/5:  28%|██▊       | 9751/35345 [17:12<44:39,  9.55it/s]

Batch 9750/35345:
total_loss: 0.28994446992874146


Epoch 3/5:  28%|██▊       | 10001/35345 [17:39<43:37,  9.68it/s]

Batch 10000/35345:
total_loss: 0.27860864996910095


Epoch 3/5:  29%|██▉       | 10251/35345 [18:05<48:33,  8.61it/s]

Batch 10250/35345:
total_loss: 0.2652706205844879


Epoch 3/5:  30%|██▉       | 10501/35345 [18:32<43:17,  9.56it/s]

Batch 10500/35345:
total_loss: 0.17943815886974335


Epoch 3/5:  30%|███       | 10751/35345 [18:58<43:14,  9.48it/s]

Batch 10750/35345:
total_loss: 0.12867577373981476


Epoch 3/5:  31%|███       | 11001/35345 [19:25<42:24,  9.57it/s]

Batch 11000/35345:
total_loss: 0.15673334896564484


Epoch 3/5:  32%|███▏      | 11251/35345 [19:51<42:16,  9.50it/s]

Batch 11250/35345:
total_loss: 0.1986091136932373


Epoch 3/5:  33%|███▎      | 11501/35345 [20:18<40:56,  9.71it/s]

Batch 11500/35345:
total_loss: 0.17329925298690796


Epoch 3/5:  33%|███▎      | 11751/35345 [20:45<41:11,  9.55it/s]

Batch 11750/35345:
total_loss: 0.1384458988904953


Epoch 3/5:  34%|███▍      | 12001/35345 [21:11<41:31,  9.37it/s]

Batch 12000/35345:
total_loss: 0.260783851146698


Epoch 3/5:  35%|███▍      | 12251/35345 [21:38<53:11,  7.24it/s]

Batch 12250/35345:
total_loss: 0.35707399249076843


Epoch 3/5:  35%|███▌      | 12501/35345 [22:05<39:46,  9.57it/s]

Batch 12500/35345:
total_loss: 0.17768417298793793


Epoch 3/5:  36%|███▌      | 12751/35345 [22:31<39:31,  9.53it/s]

Batch 12750/35345:
total_loss: 0.23281525075435638


Epoch 3/5:  37%|███▋      | 13001/35345 [22:58<38:48,  9.60it/s]

Batch 13000/35345:
total_loss: 0.13080449402332306


Epoch 3/5:  37%|███▋      | 13251/35345 [23:25<38:50,  9.48it/s]

Batch 13250/35345:
total_loss: 0.13816742599010468


Epoch 3/5:  38%|███▊      | 13501/35345 [23:51<38:52,  9.36it/s]

Batch 13500/35345:
total_loss: 0.12810036540031433


Epoch 3/5:  39%|███▉      | 13751/35345 [24:18<38:04,  9.45it/s]

Batch 13750/35345:
total_loss: 0.20347359776496887


Epoch 3/5:  40%|███▉      | 14001/35345 [24:44<37:04,  9.60it/s]

Batch 14000/35345:
total_loss: 0.24882689118385315


Epoch 3/5:  40%|████      | 14251/35345 [25:11<38:47,  9.06it/s]

Batch 14250/35345:
total_loss: 0.19697211682796478


Epoch 3/5:  41%|████      | 14501/35345 [25:38<35:57,  9.66it/s]

Batch 14500/35345:
total_loss: 0.1512041538953781


Epoch 3/5:  42%|████▏     | 14751/35345 [26:04<35:36,  9.64it/s]

Batch 14750/35345:
total_loss: 0.2210673838853836


Epoch 3/5:  42%|████▏     | 15001/35345 [26:31<35:32,  9.54it/s]

Batch 15000/35345:
total_loss: 0.3083557188510895


Epoch 3/5:  43%|████▎     | 15251/35345 [26:57<34:37,  9.67it/s]

Batch 15250/35345:
total_loss: 0.13879930973052979


Epoch 3/5:  44%|████▍     | 15501/35345 [27:24<34:30,  9.58it/s]

Batch 15500/35345:
total_loss: 0.2523709833621979


Epoch 3/5:  45%|████▍     | 15751/35345 [27:51<34:13,  9.54it/s]

Batch 15750/35345:
total_loss: 0.18657980859279633


Epoch 3/5:  45%|████▌     | 16001/35345 [28:17<33:28,  9.63it/s]

Batch 16000/35345:
total_loss: 0.2188311219215393


Epoch 3/5:  46%|████▌     | 16251/35345 [28:44<36:25,  8.74it/s]

Batch 16250/35345:
total_loss: 0.26911789178848267


Epoch 3/5:  47%|████▋     | 16501/35345 [29:11<33:30,  9.37it/s]

Batch 16500/35345:
total_loss: 0.2659549415111542


Epoch 3/5:  47%|████▋     | 16751/35345 [29:37<32:21,  9.58it/s]

Batch 16750/35345:
total_loss: 0.10261540114879608


Epoch 3/5:  48%|████▊     | 17001/35345 [30:04<31:51,  9.60it/s]

Batch 17000/35345:
total_loss: 0.11265905201435089


Epoch 3/5:  49%|████▉     | 17251/35345 [30:30<31:08,  9.68it/s]

Batch 17250/35345:
total_loss: 0.23289790749549866


Epoch 3/5:  50%|████▉     | 17501/35345 [30:56<31:06,  9.56it/s]

Batch 17500/35345:
total_loss: 0.24317483603954315


Epoch 3/5:  50%|█████     | 17751/35345 [31:23<30:28,  9.62it/s]

Batch 17750/35345:
total_loss: 0.11224992573261261


Epoch 3/5:  50%|█████     | 17787/35345 [31:27<30:37,  9.55it/s]


Equation: C*log(C*x1)+C

Epoch 3/5:  50%|█████     | 17788/35345 [31:27<30:27,  9.61it/s]




Epoch 3/5:  51%|█████     | 18001/35345 [31:49<29:57,  9.65it/s]

Batch 18000/35345:
total_loss: 0.14869530498981476


Epoch 3/5:  52%|█████▏    | 18251/35345 [32:16<36:52,  7.73it/s]

Batch 18250/35345:
total_loss: 0.18475425243377686


Epoch 3/5:  52%|█████▏    | 18501/35345 [32:42<29:12,  9.61it/s]

Batch 18500/35345:
total_loss: 0.1635267436504364


Epoch 3/5:  53%|█████▎    | 18751/35345 [33:08<28:38,  9.66it/s]

Batch 18750/35345:
total_loss: 0.1437264233827591


Epoch 3/5:  54%|█████▍    | 19001/35345 [33:36<28:12,  9.66it/s]

Batch 19000/35345:
total_loss: 0.13178929686546326


Epoch 3/5:  54%|█████▍    | 19251/35345 [34:02<29:03,  9.23it/s]

Batch 19250/35345:
total_loss: 0.09050727635622025


Epoch 3/5:  55%|█████▌    | 19501/35345 [34:28<27:16,  9.68it/s]

Batch 19500/35345:
total_loss: 0.46363839507102966


Epoch 3/5:  56%|█████▌    | 19751/35345 [34:55<26:54,  9.66it/s]

Batch 19750/35345:
total_loss: 0.24638798832893372


Epoch 3/5:  57%|█████▋    | 20001/35345 [35:22<26:46,  9.55it/s]

Batch 20000/35345:
total_loss: 0.35599640011787415


Epoch 3/5:  57%|█████▋    | 20251/35345 [35:49<27:54,  9.01it/s]

Batch 20250/35345:
total_loss: 0.2691531181335449


Epoch 3/5:  58%|█████▊    | 20501/35345 [36:15<25:47,  9.59it/s]

Batch 20500/35345:
total_loss: 0.23768101632595062


Epoch 3/5:  59%|█████▊    | 20750/35345 [36:41<27:17,  8.91it/s]

Batch 20750/35345:
total_loss: 0.24826323986053467


Epoch 3/5:  59%|█████▉    | 21001/35345 [37:08<24:53,  9.61it/s]

Batch 21000/35345:
total_loss: 0.2109934538602829


Epoch 3/5:  60%|██████    | 21251/35345 [37:35<24:25,  9.62it/s]

Batch 21250/35345:
total_loss: 0.19191522896289825


Epoch 3/5:  61%|██████    | 21501/35345 [38:01<24:05,  9.57it/s]

Batch 21500/35345:
total_loss: 0.21839790046215057


Epoch 3/5:  62%|██████▏   | 21751/35345 [38:28<23:36,  9.60it/s]

Batch 21750/35345:
total_loss: 0.16939276456832886


Epoch 3/5:  62%|██████▏   | 22001/35345 [38:54<23:14,  9.57it/s]

Batch 22000/35345:
total_loss: 0.1199757382273674


Epoch 3/5:  63%|██████▎   | 22251/35345 [39:20<24:44,  8.82it/s]

Batch 22250/35345:
total_loss: 0.1505720466375351


Epoch 3/5:  64%|██████▎   | 22501/35345 [39:47<22:14,  9.62it/s]

Batch 22500/35345:
total_loss: 0.17838844656944275


Epoch 3/5:  64%|██████▍   | 22751/35345 [40:13<21:54,  9.58it/s]

Batch 22750/35345:
total_loss: 0.12689721584320068


Epoch 3/5:  65%|██████▌   | 23001/35345 [40:40<21:41,  9.48it/s]

Batch 23000/35345:
total_loss: 0.1371184140443802


Epoch 3/5:  66%|██████▌   | 23251/35345 [41:06<21:02,  9.58it/s]

Batch 23250/35345:
total_loss: 0.11134745925664902


Epoch 3/5:  66%|██████▋   | 23501/35345 [41:32<21:28,  9.19it/s]

Batch 23500/35345:
total_loss: 0.19311924278736115


Epoch 3/5:  67%|██████▋   | 23751/35345 [41:59<20:05,  9.62it/s]

Batch 23750/35345:
total_loss: 0.24585875868797302


Epoch 3/5:  68%|██████▊   | 24001/35345 [42:25<19:37,  9.63it/s]

Batch 24000/35345:
total_loss: 0.12485271692276001


Epoch 3/5:  69%|██████▊   | 24251/35345 [42:52<19:30,  9.48it/s]

Batch 24250/35345:
total_loss: 0.25411030650138855


Epoch 3/5:  69%|██████▉   | 24501/35345 [43:18<18:51,  9.58it/s]

Batch 24500/35345:
total_loss: 0.14826931059360504


Epoch 3/5:  70%|███████   | 24751/35345 [43:45<18:37,  9.48it/s]

Batch 24750/35345:
total_loss: 0.11536576598882675


Epoch 3/5:  71%|███████   | 25001/35345 [44:11<18:21,  9.39it/s]

Batch 25000/35345:
total_loss: 0.11354225873947144


Epoch 3/5:  71%|███████▏  | 25251/35345 [44:38<17:28,  9.63it/s]

Batch 25250/35345:
total_loss: 0.13963644206523895


Epoch 3/5:  72%|███████▏  | 25501/35345 [45:04<16:57,  9.67it/s]

Batch 25500/35345:
total_loss: 0.11956597864627838


Epoch 3/5:  73%|███████▎  | 25751/35345 [45:31<16:49,  9.50it/s]

Batch 25750/35345:
total_loss: 0.24603040516376495


Epoch 3/5:  74%|███████▎  | 26001/35345 [45:58<16:14,  9.59it/s]

Batch 26000/35345:
total_loss: 0.3639426827430725


Epoch 3/5:  74%|███████▍  | 26251/35345 [46:24<16:19,  9.29it/s]

Batch 26250/35345:
total_loss: 0.0582868717610836


Epoch 3/5:  75%|███████▍  | 26501/35345 [46:50<15:20,  9.61it/s]

Batch 26500/35345:
total_loss: 0.11170448362827301


Epoch 3/5:  76%|███████▌  | 26751/35345 [47:16<14:53,  9.62it/s]

Batch 26750/35345:
total_loss: 0.19766929745674133


Epoch 3/5:  76%|███████▋  | 27001/35345 [47:43<15:40,  8.87it/s]

Batch 27000/35345:
total_loss: 0.18209625780582428


Epoch 3/5:  77%|███████▋  | 27251/35345 [48:09<14:02,  9.60it/s]

Batch 27250/35345:
total_loss: 0.25377607345581055


Epoch 3/5:  78%|███████▊  | 27501/35345 [48:35<13:44,  9.52it/s]

Batch 27500/35345:
total_loss: 0.24415691196918488


Epoch 3/5:  79%|███████▊  | 27751/35345 [49:02<13:37,  9.29it/s]

Batch 27750/35345:
total_loss: 0.07355312258005142


Epoch 3/5:  79%|███████▉  | 28001/35345 [49:28<12:37,  9.69it/s]

Batch 28000/35345:
total_loss: 0.19012925028800964


Epoch 3/5:  80%|███████▉  | 28251/35345 [49:54<12:16,  9.63it/s]

Batch 28250/35345:
total_loss: 0.14946751296520233


Epoch 3/5:  81%|████████  | 28501/35345 [50:22<11:48,  9.65it/s]

Batch 28500/35345:
total_loss: 0.10283154249191284


Epoch 3/5:  81%|████████▏ | 28751/35345 [50:48<11:29,  9.56it/s]

Batch 28750/35345:
total_loss: 0.36097341775894165


Epoch 3/5:  82%|████████▏ | 29001/35345 [51:14<14:27,  7.31it/s]

Batch 29000/35345:
total_loss: 0.09513839334249496


Epoch 3/5:  83%|████████▎ | 29251/35345 [51:41<10:33,  9.61it/s]

Batch 29250/35345:
total_loss: 0.11317618191242218


Epoch 3/5:  83%|████████▎ | 29501/35345 [52:08<10:08,  9.61it/s]

Batch 29500/35345:
total_loss: 0.2640466094017029


Epoch 3/5:  84%|████████▍ | 29751/35345 [52:34<10:46,  8.65it/s]

Batch 29750/35345:
total_loss: 0.1747308224439621


Epoch 3/5:  85%|████████▍ | 30001/35345 [53:01<09:15,  9.62it/s]

Batch 30000/35345:
total_loss: 0.0898924320936203


Epoch 3/5:  86%|████████▌ | 30251/35345 [53:27<08:50,  9.60it/s]

Batch 30250/35345:
total_loss: 0.3251541256904602


Epoch 3/5:  86%|████████▋ | 30501/35345 [53:54<08:35,  9.40it/s]

Batch 30500/35345:
total_loss: 0.1031811460852623


Epoch 3/5:  87%|████████▋ | 30751/35345 [54:20<07:57,  9.62it/s]

Batch 30750/35345:
total_loss: 0.2088097631931305


Epoch 3/5:  88%|████████▊ | 31001/35345 [54:47<07:29,  9.66it/s]

Batch 31000/35345:
total_loss: 0.3080216646194458


Epoch 3/5:  88%|████████▊ | 31251/35345 [55:14<07:11,  9.49it/s]

Batch 31250/35345:
total_loss: 0.2704303562641144


Epoch 3/5:  89%|████████▉ | 31501/35345 [55:40<06:42,  9.56it/s]

Batch 31500/35345:
total_loss: 0.3288723826408386


Epoch 3/5:  90%|████████▉ | 31751/35345 [56:07<07:08,  8.38it/s]

Batch 31750/35345:
total_loss: 0.29464760422706604


Epoch 3/5:  91%|█████████ | 32001/35345 [56:34<05:45,  9.67it/s]

Batch 32000/35345:
total_loss: 0.24913817644119263


Epoch 3/5:  91%|█████████ | 32251/35345 [57:00<05:19,  9.68it/s]

Batch 32250/35345:
total_loss: 0.23278012871742249


Epoch 3/5:  92%|█████████▏| 32501/35345 [57:27<04:59,  9.49it/s]

Batch 32500/35345:
total_loss: 0.15109245479106903


Epoch 3/5:  93%|█████████▎| 32751/35345 [57:54<04:31,  9.57it/s]

Batch 32750/35345:
total_loss: 0.156524196267128


Epoch 3/5:  93%|█████████▎| 33001/35345 [58:20<04:03,  9.62it/s]

Batch 33000/35345:
total_loss: 0.15798340737819672


Epoch 3/5:  94%|█████████▍| 33251/35345 [58:47<04:01,  8.68it/s]

Batch 33250/35345:
total_loss: 0.20527009665966034


Epoch 3/5:  95%|█████████▍| 33501/35345 [59:13<03:11,  9.65it/s]

Batch 33500/35345:
total_loss: 0.10289563983678818


Epoch 3/5:  95%|█████████▌| 33751/35345 [59:39<02:45,  9.61it/s]

Batch 33750/35345:
total_loss: 0.21860703825950623


Epoch 3/5:  96%|█████████▌| 34001/35345 [1:00:07<02:20,  9.56it/s]

Batch 34000/35345:
total_loss: 0.30476438999176025


Epoch 3/5:  97%|█████████▋| 34251/35345 [1:00:33<01:54,  9.58it/s]

Batch 34250/35345:
total_loss: 0.12642782926559448


Epoch 3/5:  98%|█████████▊| 34501/35345 [1:01:00<01:30,  9.35it/s]

Batch 34500/35345:
total_loss: 0.146053284406662


Epoch 3/5:  98%|█████████▊| 34751/35345 [1:01:27<01:03,  9.40it/s]

Batch 34750/35345:
total_loss: 0.2615317404270172


Epoch 3/5:  99%|█████████▉| 35001/35345 [1:01:53<00:35,  9.57it/s]

Batch 35000/35345:
total_loss: 0.1690283715724945


Epoch 3/5: 100%|█████████▉| 35251/35345 [1:02:20<00:12,  7.56it/s]

Batch 35250/35345:
total_loss: 0.2338051199913025


Epoch 3/5: 100%|██████████| 35345/35345 [1:02:31<00:00,  9.42it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  2.88it/s]


Epoch Summary:
Train Total Loss: 0.2195
Val Total Loss: 0.2294
Learning Rate: 0.000100
--------------------------------------------------



Epoch 4/5:   1%|          | 251/35345 [00:26<1:02:18,  9.39it/s]

Batch 250/35345:
total_loss: 0.14406846463680267


Epoch 4/5:   1%|▏         | 500/35345 [00:53<1:05:02,  8.93it/s]

Batch 500/35345:
total_loss: 0.17272108793258667


Epoch 4/5:   2%|▏         | 751/35345 [01:20<1:00:04,  9.60it/s]

Batch 750/35345:
total_loss: 0.2005934864282608


Epoch 4/5:   3%|▎         | 1001/35345 [01:46<59:54,  9.55it/s]

Batch 1000/35345:
total_loss: 0.18091820180416107


Epoch 4/5:   4%|▎         | 1251/35345 [02:14<1:08:31,  8.29it/s]

Batch 1250/35345:
total_loss: 0.281623512506485


Epoch 4/5:   4%|▍         | 1501/35345 [02:40<1:00:51,  9.27it/s]

Batch 1500/35345:
total_loss: 0.23250481486320496


Epoch 4/5:   5%|▍         | 1751/35345 [03:06<58:20,  9.60it/s]

Batch 1750/35345:
total_loss: 0.11022298038005829


Epoch 4/5:   6%|▌         | 2001/35345 [03:33<58:12,  9.55it/s]

Batch 2000/35345:
total_loss: 0.16416826844215393


Epoch 4/5:   6%|▋         | 2251/35345 [03:59<56:59,  9.68it/s]

Batch 2250/35345:
total_loss: 0.15616557002067566


Epoch 4/5:   7%|▋         | 2501/35345 [04:26<57:11,  9.57it/s]

Batch 2500/35345:
total_loss: 0.12807807326316833


Epoch 4/5:   8%|▊         | 2751/35345 [04:53<56:27,  9.62it/s]

Batch 2750/35345:
total_loss: 0.16976259648799896


Epoch 4/5:   8%|▊         | 3001/35345 [05:19<56:05,  9.61it/s]

Batch 3000/35345:
total_loss: 0.12236981093883514


Epoch 4/5:   9%|▉         | 3251/35345 [05:45<56:33,  9.46it/s]

Batch 3250/35345:
total_loss: 0.17636141180992126


Epoch 4/5:  10%|▉         | 3501/35345 [06:12<54:55,  9.66it/s]

Batch 3500/35345:
total_loss: 0.29525941610336304


Epoch 4/5:  11%|█         | 3751/35345 [06:39<54:30,  9.66it/s]

Batch 3750/35345:
total_loss: 0.1926904171705246


Epoch 4/5:  11%|█▏        | 4001/35345 [07:06<55:51,  9.35it/s]

Batch 4000/35345:
total_loss: 0.27370062470436096


Epoch 4/5:  12%|█▏        | 4251/35345 [07:33<55:29,  9.34it/s]

Batch 4250/35345:
total_loss: 0.19969169795513153


Epoch 4/5:  13%|█▎        | 4501/35345 [07:59<53:18,  9.64it/s]

Batch 4500/35345:
total_loss: 0.1916472315788269


Epoch 4/5:  13%|█▎        | 4751/35345 [08:26<53:06,  9.60it/s]

Batch 4750/35345:
total_loss: 0.2491644024848938


Epoch 4/5:  14%|█▍        | 5001/35345 [08:52<52:48,  9.58it/s]

Batch 5000/35345:
total_loss: 0.21608905494213104


Epoch 4/5:  15%|█▍        | 5251/35345 [09:18<52:18,  9.59it/s]

Batch 5250/35345:
total_loss: 0.21124477684497833


Epoch 4/5:  16%|█▌        | 5501/35345 [09:46<52:09,  9.54it/s]

Batch 5500/35345:
total_loss: 0.12842589616775513


Epoch 4/5:  16%|█▋        | 5751/35345 [10:12<55:01,  8.96it/s]

Batch 5750/35345:
total_loss: 0.17673632502555847


Epoch 4/5:  17%|█▋        | 6001/35345 [10:38<51:21,  9.52it/s]

Batch 6000/35345:
total_loss: 0.1489485204219818


Epoch 4/5:  18%|█▊        | 6251/35345 [11:06<51:05,  9.49it/s]

Batch 6250/35345:
total_loss: 0.25439199805259705


Epoch 4/5:  18%|█▊        | 6501/35345 [11:32<50:01,  9.61it/s]

Batch 6500/35345:
total_loss: 0.2113015502691269


Epoch 4/5:  19%|█▉        | 6751/35345 [11:59<56:06,  8.49it/s]

Batch 6750/35345:
total_loss: 0.22691570222377777


Epoch 4/5:  20%|█▉        | 7001/35345 [12:26<49:31,  9.54it/s]

Batch 7000/35345:
total_loss: 0.15496553480625153


Epoch 4/5:  21%|██        | 7251/35345 [12:52<48:26,  9.67it/s]

Batch 7250/35345:
total_loss: 0.227733314037323


Epoch 4/5:  21%|██        | 7501/35345 [13:19<49:19,  9.41it/s]

Batch 7500/35345:
total_loss: 0.14129117131233215


Epoch 4/5:  22%|██▏       | 7751/35345 [13:46<47:27,  9.69it/s]

Batch 7750/35345:
total_loss: 0.16970986127853394


Epoch 4/5:  23%|██▎       | 8001/35345 [14:12<47:37,  9.57it/s]

Batch 8000/35345:
total_loss: 0.11935843527317047


Epoch 4/5:  23%|██▎       | 8251/35345 [14:40<47:08,  9.58it/s]

Batch 8250/35345:
total_loss: 0.18085841834545135


Epoch 4/5:  24%|██▍       | 8501/35345 [15:06<47:08,  9.49it/s]

Batch 8500/35345:
total_loss: 0.24178864061832428


Epoch 4/5:  25%|██▍       | 8751/35345 [15:32<46:08,  9.61it/s]

Batch 8750/35345:
total_loss: 0.2273273766040802


Epoch 4/5:  25%|██▌       | 9001/35345 [15:58<45:55,  9.56it/s]

Batch 9000/35345:
total_loss: 0.21025173366069794


Epoch 4/5:  26%|██▌       | 9251/35345 [16:24<45:18,  9.60it/s]

Batch 9250/35345:
total_loss: 0.1361558586359024


Epoch 4/5:  27%|██▋       | 9501/35345 [16:52<44:40,  9.64it/s]

Batch 9500/35345:
total_loss: 0.18638385832309723


Epoch 4/5:  28%|██▊       | 9751/35345 [17:18<45:19,  9.41it/s]

Batch 9750/35345:
total_loss: 0.09411804378032684


Epoch 4/5:  28%|██▊       | 10001/35345 [17:44<44:53,  9.41it/s]

Batch 10000/35345:
total_loss: 0.17057913541793823


Epoch 4/5:  29%|██▉       | 10251/35345 [18:10<43:37,  9.59it/s]

Batch 10250/35345:
total_loss: 0.18577581644058228


Epoch 4/5:  30%|██▉       | 10501/35345 [18:36<43:24,  9.54it/s]

Batch 10500/35345:
total_loss: 0.33845990896224976


Epoch 4/5:  30%|███       | 10751/35345 [19:04<43:14,  9.48it/s]

Batch 10750/35345:
total_loss: 0.2285139262676239


Epoch 4/5:  31%|███       | 11001/35345 [19:30<42:40,  9.51it/s]

Batch 11000/35345:
total_loss: 0.22753280401229858


Epoch 4/5:  32%|███▏      | 11251/35345 [19:56<42:58,  9.35it/s]

Batch 11250/35345:
total_loss: 0.2009255290031433


Epoch 4/5:  33%|███▎      | 11501/35345 [20:22<41:25,  9.59it/s]

Batch 11500/35345:
total_loss: 0.2439524233341217


Epoch 4/5:  33%|███▎      | 11751/35345 [20:48<41:01,  9.58it/s]

Batch 11750/35345:
total_loss: 0.1476125717163086


Epoch 4/5:  34%|███▍      | 12001/35345 [21:16<41:22,  9.40it/s]

Batch 12000/35345:
total_loss: 0.17182300984859467


Epoch 4/5:  35%|███▍      | 12251/35345 [21:42<40:04,  9.60it/s]

Batch 12250/35345:
total_loss: 0.15735575556755066


Epoch 4/5:  35%|███▌      | 12501/35345 [22:08<39:52,  9.55it/s]

Batch 12500/35345:
total_loss: 0.13057208061218262


Epoch 4/5:  36%|███▌      | 12751/35345 [22:34<39:13,  9.60it/s]

Batch 12750/35345:
total_loss: 0.16245140135288239


Epoch 4/5:  37%|███▋      | 13001/35345 [23:00<39:01,  9.54it/s]

Batch 13000/35345:
total_loss: 0.19689086079597473


Epoch 4/5:  37%|███▋      | 13251/35345 [23:28<38:45,  9.50it/s]

Batch 13250/35345:
total_loss: 0.13274624943733215


Epoch 4/5:  38%|███▊      | 13501/35345 [23:54<37:48,  9.63it/s]

Batch 13500/35345:
total_loss: 0.09193596988916397


Epoch 4/5:  39%|███▉      | 13751/35345 [24:20<37:36,  9.57it/s]

Batch 13750/35345:
total_loss: 0.1335003823041916


Epoch 4/5:  40%|███▉      | 14001/35345 [24:46<37:36,  9.46it/s]

Batch 14000/35345:
total_loss: 0.04764021933078766


Epoch 4/5:  40%|████      | 14251/35345 [25:12<36:51,  9.54it/s]

Batch 14250/35345:
total_loss: 0.11200029402971268


Epoch 4/5:  41%|████      | 14501/35345 [25:39<36:23,  9.55it/s]

Batch 14500/35345:
total_loss: 0.16278477013111115


Epoch 4/5:  42%|████▏     | 14751/35345 [26:06<35:49,  9.58it/s]

Batch 14750/35345:
total_loss: 0.17720958590507507


Epoch 4/5:  42%|████▏     | 15001/35345 [26:32<35:13,  9.63it/s]

Batch 15000/35345:
total_loss: 0.20210321247577667


Epoch 4/5:  43%|████▎     | 15251/35345 [26:59<35:09,  9.53it/s]

Batch 15250/35345:
total_loss: 0.14523647725582123


Epoch 4/5:  44%|████▍     | 15501/35345 [27:25<34:17,  9.64it/s]

Batch 15500/35345:
total_loss: 0.14348278939723969


Epoch 4/5:  45%|████▍     | 15751/35345 [27:51<36:16,  9.00it/s]

Batch 15750/35345:
total_loss: 0.15501444041728973


Epoch 4/5:  45%|████▌     | 16001/35345 [28:19<33:44,  9.55it/s]

Batch 16000/35345:
total_loss: 0.1069144457578659


Epoch 4/5:  46%|████▌     | 16251/35345 [28:45<33:13,  9.58it/s]

Batch 16250/35345:
total_loss: 0.13236269354820251


Epoch 4/5:  47%|████▋     | 16501/35345 [29:12<32:45,  9.59it/s]

Batch 16500/35345:
total_loss: 0.06501227617263794


Epoch 4/5:  47%|████▋     | 16751/35345 [29:38<33:00,  9.39it/s]

Batch 16750/35345:
total_loss: 0.18935470283031464


Epoch 4/5:  48%|████▊     | 17000/35345 [30:04<40:26,  7.56it/s]

Batch 17000/35345:
total_loss: 0.1258067786693573


Epoch 4/5:  49%|████▉     | 17251/35345 [30:32<31:36,  9.54it/s]

Batch 17250/35345:
total_loss: 0.08122275769710541


Epoch 4/5:  50%|████▉     | 17501/35345 [30:58<31:05,  9.56it/s]

Batch 17500/35345:
total_loss: 0.17065469920635223


Epoch 4/5:  50%|█████     | 17751/35345 [31:24<30:29,  9.62it/s]

Batch 17750/35345:
total_loss: 0.2523425221443176


Epoch 4/5:  51%|█████     | 18001/35345 [31:50<30:19,  9.53it/s]

Batch 18000/35345:
total_loss: 0.2558002769947052


Epoch 4/5:  52%|█████▏    | 18251/35345 [32:16<32:03,  8.89it/s]

Batch 18250/35345:
total_loss: 0.1264275461435318


Epoch 4/5:  52%|█████▏    | 18501/35345 [32:44<29:13,  9.60it/s]

Batch 18500/35345:
total_loss: 0.14536090195178986


Epoch 4/5:  53%|█████▎    | 18751/35345 [33:10<28:56,  9.55it/s]

Batch 18750/35345:
total_loss: 0.20822015404701233


Epoch 4/5:  54%|█████▍    | 19001/35345 [33:37<28:23,  9.59it/s]

Batch 19000/35345:
total_loss: 0.13036170601844788


Epoch 4/5:  54%|█████▍    | 19251/35345 [34:03<27:53,  9.61it/s]

Batch 19250/35345:
total_loss: 0.16188685595989227


Epoch 4/5:  55%|█████▌    | 19501/35345 [34:29<28:24,  9.30it/s]

Batch 19500/35345:
total_loss: 0.13567429780960083


Epoch 4/5:  56%|█████▌    | 19751/35345 [34:57<27:07,  9.58it/s]

Batch 19750/35345:
total_loss: 0.21285755932331085


Epoch 4/5:  57%|█████▋    | 20001/35345 [35:23<26:38,  9.60it/s]

Batch 20000/35345:
total_loss: 0.1061554104089737


Epoch 4/5:  57%|█████▋    | 20251/35345 [35:49<26:07,  9.63it/s]

Batch 20250/35345:
total_loss: 0.18072359263896942


Epoch 4/5:  58%|█████▊    | 20501/35345 [36:15<25:56,  9.54it/s]

Batch 20500/35345:
total_loss: 0.09794788062572479


Epoch 4/5:  59%|█████▊    | 20751/35345 [36:41<25:31,  9.53it/s]

Batch 20750/35345:
total_loss: 0.18574275076389313


Epoch 4/5:  59%|█████▉    | 21001/35345 [37:09<25:25,  9.40it/s]

Batch 21000/35345:
total_loss: 0.2572907507419586


Epoch 4/5:  60%|██████    | 21251/35345 [37:35<24:28,  9.60it/s]

Batch 21250/35345:
total_loss: 0.23390047252178192


Epoch 4/5:  61%|██████    | 21501/35345 [38:01<23:45,  9.71it/s]

Batch 21500/35345:
total_loss: 0.172649085521698


Epoch 4/5:  62%|██████▏   | 21751/35345 [38:27<23:22,  9.69it/s]

Batch 21750/35345:
total_loss: 0.22603382170200348


Epoch 4/5:  62%|██████▏   | 22001/35345 [38:53<23:05,  9.63it/s]

Batch 22000/35345:
total_loss: 0.09869107604026794


Epoch 4/5:  63%|██████▎   | 22251/35345 [39:21<22:27,  9.71it/s]

Batch 22250/35345:
total_loss: 0.10065831989049911


Epoch 4/5:  64%|██████▎   | 22501/35345 [39:47<22:19,  9.59it/s]

Batch 22500/35345:
total_loss: 0.08880910277366638


Epoch 4/5:  64%|██████▍   | 22751/35345 [40:13<21:49,  9.62it/s]

Batch 22750/35345:
total_loss: 0.20032809674739838


Epoch 4/5:  65%|██████▌   | 23001/35345 [40:39<21:19,  9.65it/s]

Batch 23000/35345:
total_loss: 0.10502252727746964


Epoch 4/5:  66%|██████▌   | 23251/35345 [41:05<20:54,  9.64it/s]

Batch 23250/35345:
total_loss: 0.11083362996578217


Epoch 4/5:  66%|██████▋   | 23501/35345 [41:33<21:30,  9.18it/s]

Batch 23500/35345:
total_loss: 0.14972862601280212


Epoch 4/5:  67%|██████▋   | 23751/35345 [41:59<20:30,  9.42it/s]

Batch 23750/35345:
total_loss: 0.10187364369630814


Epoch 4/5:  68%|██████▊   | 24001/35345 [42:25<19:40,  9.61it/s]

Batch 24000/35345:
total_loss: 0.12501418590545654


Epoch 4/5:  69%|██████▊   | 24251/35345 [42:51<19:11,  9.64it/s]

Batch 24250/35345:
total_loss: 0.14380668103694916


Epoch 4/5:  69%|██████▉   | 24501/35345 [43:17<18:57,  9.53it/s]

Batch 24500/35345:
total_loss: 0.17842072248458862


Epoch 4/5:  70%|███████   | 24751/35345 [43:44<19:18,  9.14it/s]

Batch 24750/35345:
total_loss: 0.04873982071876526


Epoch 4/5:  71%|███████   | 25001/35345 [44:11<18:05,  9.53it/s]

Batch 25000/35345:
total_loss: 0.1438867300748825


Epoch 4/5:  71%|███████▏  | 25251/35345 [44:37<17:24,  9.67it/s]

Batch 25250/35345:
total_loss: 0.10761475563049316


Epoch 4/5:  72%|███████▏  | 25501/35345 [45:03<17:01,  9.64it/s]

Batch 25500/35345:
total_loss: 0.15348821878433228


Epoch 4/5:  73%|███████▎  | 25751/35345 [45:29<16:32,  9.66it/s]

Batch 25750/35345:
total_loss: 0.10934188216924667


Epoch 4/5:  74%|███████▎  | 26001/35345 [45:56<16:31,  9.42it/s]

Batch 26000/35345:
total_loss: 0.05944406986236572


Epoch 4/5:  74%|███████▍  | 26251/35345 [46:23<15:38,  9.69it/s]

Batch 26250/35345:
total_loss: 0.17068704962730408


Epoch 4/5:  75%|███████▍  | 26501/35345 [46:49<15:25,  9.55it/s]

Batch 26500/35345:
total_loss: 0.10844986885786057


Epoch 4/5:  76%|███████▌  | 26751/35345 [47:15<15:01,  9.53it/s]

Batch 26750/35345:
total_loss: 0.20215828716754913


Epoch 4/5:  76%|███████▋  | 27001/35345 [47:41<14:33,  9.55it/s]

Batch 27000/35345:
total_loss: 0.10425098240375519


Epoch 4/5:  77%|███████▋  | 27251/35345 [48:07<14:34,  9.25it/s]

Batch 27250/35345:
total_loss: 0.12123043090105057


Epoch 4/5:  78%|███████▊  | 27501/35345 [48:34<13:34,  9.63it/s]

Batch 27500/35345:
total_loss: 0.21908245980739594


Epoch 4/5:  79%|███████▊  | 27751/35345 [49:01<13:11,  9.60it/s]

Batch 27750/35345:
total_loss: 0.28778091073036194


Epoch 4/5:  79%|███████▉  | 28001/35345 [49:27<12:43,  9.62it/s]

Batch 28000/35345:
total_loss: 0.1041114404797554


Epoch 4/5:  80%|███████▉  | 28251/35345 [49:53<12:15,  9.64it/s]

Batch 28250/35345:
total_loss: 0.2192043662071228


Epoch 4/5:  81%|████████  | 28501/35345 [50:19<11:48,  9.65it/s]

Batch 28500/35345:
total_loss: 0.11011683940887451


Epoch 4/5:  81%|████████▏ | 28751/35345 [50:47<11:33,  9.50it/s]

Batch 28750/35345:
total_loss: 0.1524105966091156


Epoch 4/5:  82%|████████▏ | 29001/35345 [51:13<11:02,  9.57it/s]

Batch 29000/35345:
total_loss: 0.15295445919036865


Epoch 4/5:  83%|████████▎ | 29251/35345 [51:39<10:39,  9.52it/s]

Batch 29250/35345:
total_loss: 0.1126442551612854


Epoch 4/5:  83%|████████▎ | 29501/35345 [52:06<10:08,  9.60it/s]

Batch 29500/35345:
total_loss: 0.1273181140422821


Epoch 4/5:  84%|████████▍ | 29751/35345 [52:32<09:45,  9.56it/s]

Batch 29750/35345:
total_loss: 0.2200654298067093


Epoch 4/5:  85%|████████▍ | 30001/35345 [53:00<09:16,  9.60it/s]

Batch 30000/35345:
total_loss: 0.0977972149848938


Epoch 4/5:  86%|████████▌ | 30251/35345 [53:26<08:49,  9.63it/s]

Batch 30250/35345:
total_loss: 0.12950293719768524


Epoch 4/5:  86%|████████▋ | 30501/35345 [53:52<08:44,  9.24it/s]

Batch 30500/35345:
total_loss: 0.15696167945861816


Epoch 4/5:  87%|████████▋ | 30751/35345 [54:18<07:58,  9.59it/s]

Batch 30750/35345:
total_loss: 0.10147322714328766


Epoch 4/5:  88%|████████▊ | 31001/35345 [54:44<07:39,  9.46it/s]

Batch 31000/35345:
total_loss: 0.12579073011875153


Epoch 4/5:  88%|████████▊ | 31251/35345 [55:13<07:14,  9.42it/s]

Batch 31250/35345:
total_loss: 0.13835248351097107


Epoch 4/5:  89%|████████▉ | 31501/35345 [55:39<06:45,  9.49it/s]

Batch 31500/35345:
total_loss: 0.26495131850242615


Epoch 4/5:  90%|████████▉ | 31751/35345 [56:05<06:14,  9.60it/s]

Batch 31750/35345:
total_loss: 0.27492979168891907


Epoch 4/5:  91%|█████████ | 32001/35345 [56:31<05:53,  9.46it/s]

Batch 32000/35345:
total_loss: 0.10981902480125427


Epoch 4/5:  91%|█████████ | 32251/35345 [56:57<05:23,  9.56it/s]

Batch 32250/35345:
total_loss: 0.11745244264602661


Epoch 4/5:  92%|█████████▏| 32501/35345 [57:25<04:58,  9.54it/s]

Batch 32500/35345:
total_loss: 0.07323893904685974


Epoch 4/5:  93%|█████████▎| 32751/35345 [57:51<04:32,  9.52it/s]

Batch 32750/35345:
total_loss: 0.28011026978492737


Epoch 4/5:  93%|█████████▎| 33001/35345 [58:18<04:05,  9.56it/s]

Batch 33000/35345:
total_loss: 0.06280028074979782


Epoch 4/5:  94%|█████████▍| 33251/35345 [58:44<03:39,  9.53it/s]

Batch 33250/35345:
total_loss: 0.17453216016292572


Epoch 4/5:  95%|█████████▍| 33501/35345 [59:10<03:12,  9.56it/s]

Batch 33500/35345:
total_loss: 0.1704755574464798


Epoch 4/5:  95%|█████████▌| 33751/35345 [59:37<03:04,  8.65it/s]

Batch 33750/35345:
total_loss: 0.1902993619441986


Epoch 4/5:  96%|█████████▌| 34001/35345 [1:00:04<02:20,  9.56it/s]

Batch 34000/35345:
total_loss: 0.06046849861741066


Epoch 4/5:  97%|█████████▋| 34251/35345 [1:00:30<01:54,  9.59it/s]

Batch 34250/35345:
total_loss: 0.039558932185173035


Epoch 4/5:  98%|█████████▊| 34501/35345 [1:00:56<01:27,  9.68it/s]

Batch 34500/35345:
total_loss: 0.18325749039649963


Epoch 4/5:  98%|█████████▊| 34751/35345 [1:01:22<01:04,  9.23it/s]

Batch 34750/35345:
total_loss: 0.13054218888282776


Epoch 4/5:  99%|█████████▉| 35000/35345 [1:01:48<00:36,  9.34it/s]

Batch 35000/35345:
total_loss: 0.07430531829595566


Epoch 4/5: 100%|█████████▉| 35251/35345 [1:02:16<00:09,  9.58it/s]

Batch 35250/35345:
total_loss: 0.12317847460508347


Epoch 4/5: 100%|██████████| 35345/35345 [1:02:27<00:00,  9.43it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.46it/s]


Epoch Summary:
Train Total Loss: 0.1633
Val Total Loss: 0.1133
Learning Rate: 0.000100
New best model saved with val loss: 0.1133
--------------------------------------------------



Epoch 5/5:   1%|          | 251/35345 [00:26<1:01:22,  9.53it/s]

Batch 250/35345:
total_loss: 0.10212451964616776


Epoch 5/5:   1%|▏         | 501/35345 [00:52<1:00:01,  9.67it/s]

Batch 500/35345:
total_loss: 0.16792939603328705


Epoch 5/5:   2%|▏         | 751/35345 [01:18<59:51,  9.63it/s]

Batch 750/35345:
total_loss: 0.15504342317581177


Epoch 5/5:   3%|▎         | 1001/35345 [01:46<1:07:03,  8.54it/s]

Batch 1000/35345:
total_loss: 0.0729055404663086


Epoch 5/5:   4%|▎         | 1251/35345 [02:13<59:06,  9.61it/s]

Batch 1250/35345:
total_loss: 0.17816275358200073


Epoch 5/5:   4%|▍         | 1501/35345 [02:39<59:57,  9.41it/s]  

Batch 1500/35345:
total_loss: 0.23517471551895142


Epoch 5/5:   5%|▍         | 1751/35345 [03:05<58:21,  9.59it/s]

Batch 1750/35345:
total_loss: 0.18862184882164001


Epoch 5/5:   6%|▌         | 2001/35345 [03:31<58:02,  9.58it/s]

Batch 2000/35345:
total_loss: 0.2081206887960434


Epoch 5/5:   6%|▋         | 2250/35345 [03:57<1:00:49,  9.07it/s]

Batch 2250/35345:
total_loss: 0.17077048122882843


Epoch 5/5:   7%|▋         | 2501/35345 [04:25<57:24,  9.54it/s]

Batch 2500/35345:
total_loss: 0.06577865034341812


Epoch 5/5:   8%|▊         | 2751/35345 [04:51<57:15,  9.49it/s]

Batch 2750/35345:
total_loss: 0.0693797692656517


Epoch 5/5:   8%|▊         | 3001/35345 [05:17<56:01,  9.62it/s]

Batch 3000/35345:
total_loss: 0.08850640803575516


Epoch 5/5:   9%|▉         | 3251/35345 [05:43<55:56,  9.56it/s]

Batch 3250/35345:
total_loss: 0.14654287695884705


Epoch 5/5:  10%|▉         | 3501/35345 [06:10<55:46,  9.52it/s]

Batch 3500/35345:
total_loss: 0.2032860666513443


Epoch 5/5:  11%|█         | 3751/35345 [06:38<55:25,  9.50it/s]

Batch 3750/35345:
total_loss: 0.19995401799678802


Epoch 5/5:  11%|█▏        | 4001/35345 [07:04<54:45,  9.54it/s]

Batch 4000/35345:
total_loss: 0.1627168506383896


Epoch 5/5:  12%|█▏        | 4251/35345 [07:30<55:37,  9.32it/s]

Batch 4250/35345:
total_loss: 0.21808761358261108


Epoch 5/5:  13%|█▎        | 4501/35345 [07:56<53:17,  9.64it/s]

Batch 4500/35345:
total_loss: 0.18580476939678192


Epoch 5/5:  13%|█▎        | 4751/35345 [08:22<52:51,  9.65it/s]

Batch 4750/35345:
total_loss: 0.1456405073404312


Epoch 5/5:  14%|█▍        | 5001/35345 [08:50<53:16,  9.49it/s]

Batch 5000/35345:
total_loss: 0.21038639545440674


Epoch 5/5:  15%|█▍        | 5251/35345 [09:17<52:40,  9.52it/s]

Batch 5250/35345:
total_loss: 0.12172583490610123


Epoch 5/5:  16%|█▌        | 5501/35345 [09:43<52:00,  9.56it/s]

Batch 5500/35345:
total_loss: 0.1715480387210846


Epoch 5/5:  16%|█▋        | 5751/35345 [10:09<51:11,  9.63it/s]

Batch 5750/35345:
total_loss: 0.17939099669456482


Epoch 5/5:  17%|█▋        | 6001/35345 [10:35<51:22,  9.52it/s]

Batch 6000/35345:
total_loss: 0.18344353139400482


Epoch 5/5:  18%|█▊        | 6251/35345 [11:03<59:01,  8.21it/s]  

Batch 6250/35345:
total_loss: 0.05972372367978096


Epoch 5/5:  18%|█▊        | 6501/35345 [11:29<49:44,  9.67it/s]

Batch 6500/35345:
total_loss: 0.0813121646642685


Epoch 5/5:  19%|█▉        | 6751/35345 [11:55<49:30,  9.63it/s]

Batch 6750/35345:
total_loss: 0.17022961378097534


Epoch 5/5:  20%|█▉        | 7001/35345 [12:22<54:58,  8.59it/s]

Batch 7000/35345:
total_loss: 0.1215018555521965


Epoch 5/5:  21%|██        | 7251/35345 [12:48<48:09,  9.72it/s]

Batch 7250/35345:
total_loss: 0.20965063571929932


Epoch 5/5:  21%|██        | 7500/35345 [13:15<49:51,  9.31it/s]

Batch 7500/35345:
total_loss: 0.08852161467075348


Epoch 5/5:  22%|██▏       | 7751/35345 [13:42<47:27,  9.69it/s]

Batch 7750/35345:
total_loss: 0.1240602359175682


Epoch 5/5:  23%|██▎       | 8001/35345 [14:08<48:05,  9.48it/s]

Batch 8000/35345:
total_loss: 0.1350841522216797


Epoch 5/5:  23%|██▎       | 8251/35345 [14:35<47:08,  9.58it/s]

Batch 8250/35345:
total_loss: 0.18901272118091583


Epoch 5/5:  24%|██▍       | 8501/35345 [15:01<46:21,  9.65it/s]

Batch 8500/35345:
total_loss: 0.13203099370002747


Epoch 5/5:  25%|██▍       | 8751/35345 [15:27<48:46,  9.09it/s]

Batch 8750/35345:
total_loss: 0.0719250813126564


Epoch 5/5:  25%|██▌       | 9001/35345 [15:55<45:51,  9.57it/s]

Batch 9000/35345:
total_loss: 0.1547940969467163


Epoch 5/5:  26%|██▌       | 9251/35345 [16:21<45:25,  9.57it/s]

Batch 9250/35345:
total_loss: 0.14994807541370392


Epoch 5/5:  27%|██▋       | 9501/35345 [16:47<44:45,  9.62it/s]

Batch 9500/35345:
total_loss: 0.09101354330778122


Epoch 5/5:  28%|██▊       | 9751/35345 [17:14<45:07,  9.45it/s]

Batch 9750/35345:
total_loss: 0.1117265447974205


Epoch 5/5:  28%|██▊       | 10001/35345 [17:40<44:32,  9.48it/s]

Batch 10000/35345:
total_loss: 0.11153628677129745


Epoch 5/5:  29%|██▉       | 10251/35345 [18:08<44:03,  9.49it/s]

Batch 10250/35345:
total_loss: 0.16789013147354126


Epoch 5/5:  30%|██▉       | 10501/35345 [18:34<43:13,  9.58it/s]

Batch 10500/35345:
total_loss: 0.09029577672481537


Epoch 5/5:  30%|███       | 10751/35345 [19:00<42:23,  9.67it/s]

Batch 10750/35345:
total_loss: 0.14524568617343903


Epoch 5/5:  31%|███       | 11001/35345 [19:26<42:06,  9.63it/s]

Batch 11000/35345:
total_loss: 0.12897686660289764


Epoch 5/5:  32%|███▏      | 11251/35345 [19:52<43:43,  9.18it/s]

Batch 11250/35345:
total_loss: 0.05448915809392929


Epoch 5/5:  33%|███▎      | 11501/35345 [20:20<41:24,  9.60it/s]

Batch 11500/35345:
total_loss: 0.15344293415546417


Epoch 5/5:  33%|███▎      | 11751/35345 [20:47<41:01,  9.59it/s]

Batch 11750/35345:
total_loss: 0.1554507613182068


Epoch 5/5:  34%|███▍      | 12001/35345 [21:13<40:34,  9.59it/s]

Batch 12000/35345:
total_loss: 0.24084027111530304


Epoch 5/5:  35%|███▍      | 12251/35345 [21:39<40:15,  9.56it/s]

Batch 12250/35345:
total_loss: 0.10970324277877808


Epoch 5/5:  35%|███▌      | 12501/35345 [22:05<39:42,  9.59it/s]

Batch 12500/35345:
total_loss: 0.09908037632703781


Epoch 5/5:  36%|███▌      | 12751/35345 [22:32<42:48,  8.80it/s]

Batch 12750/35345:
total_loss: 0.1543799638748169


Epoch 5/5:  37%|███▋      | 13001/35345 [22:59<38:34,  9.65it/s]

Batch 13000/35345:
total_loss: 0.060174934566020966


Epoch 5/5:  37%|███▋      | 13251/35345 [23:26<38:09,  9.65it/s]

Batch 13250/35345:
total_loss: 0.11725633591413498


Epoch 5/5:  38%|███▊      | 13501/35345 [23:52<37:52,  9.61it/s]

Batch 13500/35345:
total_loss: 0.11199337244033813


Epoch 5/5:  39%|███▉      | 13751/35345 [24:18<37:41,  9.55it/s]

Batch 13750/35345:
total_loss: 0.20193931460380554


Epoch 5/5:  40%|███▉      | 14001/35345 [24:45<48:31,  7.33it/s]

Batch 14000/35345:
total_loss: 0.11447013169527054


Epoch 5/5:  40%|████      | 14251/35345 [25:12<36:36,  9.60it/s]

Batch 14250/35345:
total_loss: 0.14637041091918945


Epoch 5/5:  41%|████      | 14501/35345 [25:38<36:24,  9.54it/s]

Batch 14500/35345:
total_loss: 0.13753318786621094


Epoch 5/5:  42%|████▏     | 14751/35345 [26:05<35:51,  9.57it/s]

Batch 14750/35345:
total_loss: 0.1999076008796692


Epoch 5/5:  42%|████▏     | 15001/35345 [26:31<35:15,  9.62it/s]

Batch 15000/35345:
total_loss: 0.1790524274110794


Epoch 5/5:  43%|████▎     | 15251/35345 [26:57<35:06,  9.54it/s]

Batch 15250/35345:
total_loss: 0.07164523005485535


Epoch 5/5:  44%|████▍     | 15501/35345 [27:25<35:32,  9.31it/s]

Batch 15500/35345:
total_loss: 0.14851762354373932


Epoch 5/5:  45%|████▍     | 15751/35345 [27:52<34:10,  9.56it/s]

Batch 15750/35345:
total_loss: 0.15541383624076843


Epoch 5/5:  45%|████▌     | 16001/35345 [28:18<33:39,  9.58it/s]

Batch 16000/35345:
total_loss: 0.14302411675453186


Epoch 5/5:  46%|████▌     | 16251/35345 [28:44<33:15,  9.57it/s]

Batch 16250/35345:
total_loss: 0.11942823976278305


Epoch 5/5:  47%|████▋     | 16501/35345 [29:10<32:34,  9.64it/s]

Batch 16500/35345:
total_loss: 0.09615100920200348


Epoch 5/5:  47%|████▋     | 16751/35345 [29:38<40:58,  7.56it/s]

Batch 16750/35345:
total_loss: 0.14658963680267334


Epoch 5/5:  48%|████▊     | 17001/35345 [30:05<31:53,  9.59it/s]

Batch 17000/35345:
total_loss: 0.07534370571374893


Epoch 5/5:  49%|████▉     | 17251/35345 [30:31<31:21,  9.61it/s]

Batch 17250/35345:
total_loss: 0.09664397686719894


Epoch 5/5:  50%|████▉     | 17501/35345 [30:57<31:06,  9.56it/s]

Batch 17500/35345:
total_loss: 0.06523517519235611


Epoch 5/5:  50%|█████     | 17751/35345 [31:23<30:17,  9.68it/s]

Batch 17750/35345:
total_loss: 0.09332986176013947


Epoch 5/5:  51%|█████     | 18001/35345 [31:51<32:57,  8.77it/s]

Batch 18000/35345:
total_loss: 0.10241185873746872


Epoch 5/5:  52%|█████▏    | 18251/35345 [32:18<34:37,  8.23it/s]

Batch 18250/35345:
total_loss: 0.1538507342338562


Epoch 5/5:  52%|█████▏    | 18501/35345 [32:44<29:01,  9.67it/s]

Batch 18500/35345:
total_loss: 0.14024507999420166


Epoch 5/5:  53%|█████▎    | 18751/35345 [33:10<28:34,  9.68it/s]

Batch 18750/35345:
total_loss: 0.1144162192940712


Epoch 5/5:  54%|█████▍    | 19001/35345 [33:36<28:12,  9.66it/s]

Batch 19000/35345:
total_loss: 0.1130119115114212


Epoch 5/5:  54%|█████▍    | 19251/35345 [34:03<42:09,  6.36it/s]

Batch 19250/35345:
total_loss: 0.14343342185020447


Epoch 5/5:  55%|█████▌    | 19501/35345 [34:30<27:42,  9.53it/s]

Batch 19500/35345:
total_loss: 0.0886492133140564


Epoch 5/5:  56%|█████▌    | 19751/35345 [34:56<26:49,  9.69it/s]

Batch 19750/35345:
total_loss: 0.20765766501426697


Epoch 5/5:  57%|█████▋    | 20001/35345 [35:22<26:50,  9.53it/s]

Batch 20000/35345:
total_loss: 0.03824830800294876


Epoch 5/5:  57%|█████▋    | 20251/35345 [35:48<26:26,  9.52it/s]

Batch 20250/35345:
total_loss: 0.08455944061279297


Epoch 5/5:  58%|█████▊    | 20501/35345 [36:15<26:02,  9.50it/s]

Batch 20500/35345:
total_loss: 0.15622098743915558


Epoch 5/5:  59%|█████▊    | 20751/35345 [36:42<25:22,  9.59it/s]

Batch 20750/35345:
total_loss: 0.06639357656240463


Epoch 5/5:  59%|█████▉    | 21001/35345 [37:09<26:04,  9.17it/s]

Batch 21000/35345:
total_loss: 0.14196564257144928


Epoch 5/5:  60%|██████    | 21251/35345 [37:35<24:34,  9.56it/s]

Batch 21250/35345:
total_loss: 0.11157107353210449


Epoch 5/5:  61%|██████    | 21501/35345 [38:01<24:06,  9.57it/s]

Batch 21500/35345:
total_loss: 0.17892961204051971


Epoch 5/5:  62%|██████▏   | 21751/35345 [38:27<23:38,  9.59it/s]

Batch 21750/35345:
total_loss: 0.24763047695159912


Epoch 5/5:  62%|██████▏   | 22001/35345 [38:55<24:22,  9.12it/s]

Batch 22000/35345:
total_loss: 0.20028160512447357


Epoch 5/5:  63%|██████▎   | 22251/35345 [39:21<22:34,  9.66it/s]

Batch 22250/35345:
total_loss: 0.1219339445233345


Epoch 5/5:  64%|██████▎   | 22501/35345 [39:47<22:12,  9.64it/s]

Batch 22500/35345:
total_loss: 0.15435972809791565


Epoch 5/5:  64%|██████▍   | 22751/35345 [40:13<21:44,  9.66it/s]

Batch 22750/35345:
total_loss: 0.08909202367067337


Epoch 5/5:  65%|██████▌   | 23001/35345 [40:39<21:09,  9.73it/s]

Batch 23000/35345:
total_loss: 0.09042023867368698


Epoch 5/5:  66%|██████▌   | 23251/35345 [41:06<23:24,  8.61it/s]

Batch 23250/35345:
total_loss: 0.039000943303108215


Epoch 5/5:  66%|██████▋   | 23501/35345 [41:33<20:36,  9.58it/s]

Batch 23500/35345:
total_loss: 0.04779808595776558


Epoch 5/5:  67%|██████▋   | 23750/35345 [41:59<20:49,  9.28it/s]

Batch 23750/35345:
total_loss: 0.10892461240291595


Epoch 5/5:  68%|██████▊   | 24001/35345 [42:25<19:42,  9.60it/s]

Batch 24000/35345:
total_loss: 0.08772953599691391


Epoch 5/5:  69%|██████▊   | 24251/35345 [42:51<19:08,  9.66it/s]

Batch 24250/35345:
total_loss: 0.17491799592971802


Epoch 5/5:  69%|██████▉   | 24501/35345 [43:17<19:02,  9.50it/s]

Batch 24500/35345:
total_loss: 0.1184326782822609


Epoch 5/5:  70%|███████   | 24751/35345 [43:45<18:07,  9.74it/s]

Batch 24750/35345:
total_loss: 0.019738508388400078


Epoch 5/5:  71%|███████   | 25001/35345 [44:11<17:52,  9.65it/s]

Batch 25000/35345:
total_loss: 0.09115845710039139


Epoch 5/5:  71%|███████▏  | 25251/35345 [44:37<17:47,  9.45it/s]

Batch 25250/35345:
total_loss: 0.12896867096424103


Epoch 5/5:  72%|███████▏  | 25501/35345 [45:03<17:03,  9.61it/s]

Batch 25500/35345:
total_loss: 0.1639946848154068


Epoch 5/5:  73%|███████▎  | 25751/35345 [45:30<16:39,  9.60it/s]

Batch 25750/35345:
total_loss: 0.06797702610492706


Epoch 5/5:  74%|███████▎  | 26001/35345 [45:57<16:39,  9.35it/s]

Batch 26000/35345:
total_loss: 0.03749606013298035


Epoch 5/5:  74%|███████▍  | 26251/35345 [46:24<15:42,  9.65it/s]

Batch 26250/35345:
total_loss: 0.06150996312499046


Epoch 5/5:  75%|███████▍  | 26501/35345 [46:50<15:56,  9.25it/s]

Batch 26500/35345:
total_loss: 0.07306258380413055


Epoch 5/5:  76%|███████▌  | 26751/35345 [47:16<14:58,  9.57it/s]

Batch 26750/35345:
total_loss: 0.09855680167675018


Epoch 5/5:  76%|███████▋  | 27001/35345 [47:42<14:33,  9.56it/s]

Batch 27000/35345:
total_loss: 0.12723740935325623


Epoch 5/5:  77%|███████▋  | 27250/35345 [48:09<14:31,  9.29it/s]

Batch 27250/35345:
total_loss: 0.15594026446342468


Epoch 5/5:  78%|███████▊  | 27501/35345 [48:37<13:36,  9.61it/s]

Batch 27500/35345:
total_loss: 0.12382436543703079


Epoch 5/5:  79%|███████▊  | 27751/35345 [49:03<13:15,  9.54it/s]

Batch 27750/35345:
total_loss: 0.15171460807323456


Epoch 5/5:  79%|███████▉  | 28001/35345 [49:29<12:44,  9.61it/s]

Batch 28000/35345:
total_loss: 0.0875350758433342


Epoch 5/5:  80%|███████▉  | 28251/35345 [49:55<12:20,  9.58it/s]

Batch 28250/35345:
total_loss: 0.113502636551857


Epoch 5/5:  81%|████████  | 28501/35345 [50:21<11:54,  9.58it/s]

Batch 28500/35345:
total_loss: 0.1899874359369278


Epoch 5/5:  81%|████████▏ | 28751/35345 [50:50<11:25,  9.62it/s]

Batch 28750/35345:
total_loss: 0.1236330047249794


Epoch 5/5:  82%|████████▏ | 29001/35345 [51:16<11:02,  9.58it/s]

Batch 29000/35345:
total_loss: 0.026468344032764435


Epoch 5/5:  83%|████████▎ | 29251/35345 [51:42<10:51,  9.35it/s]

Batch 29250/35345:
total_loss: 0.1545872688293457


Epoch 5/5:  83%|████████▎ | 29501/35345 [52:08<10:09,  9.59it/s]

Batch 29500/35345:
total_loss: 0.10359250009059906


Epoch 5/5:  84%|████████▍ | 29751/35345 [52:34<09:41,  9.61it/s]

Batch 29750/35345:
total_loss: 0.0903216078877449


Epoch 5/5:  85%|████████▍ | 30001/35345 [53:02<09:23,  9.48it/s]

Batch 30000/35345:
total_loss: 0.2003050595521927


Epoch 5/5:  86%|████████▌ | 30251/35345 [53:29<08:53,  9.55it/s]

Batch 30250/35345:
total_loss: 0.056588806211948395


Epoch 5/5:  86%|████████▋ | 30501/35345 [53:55<08:22,  9.63it/s]

Batch 30500/35345:
total_loss: 0.10812298208475113


Epoch 5/5:  87%|████████▋ | 30751/35345 [54:21<08:15,  9.28it/s]

Batch 30750/35345:
total_loss: 0.13129763305187225


Epoch 5/5:  88%|████████▊ | 31001/35345 [54:47<07:34,  9.56it/s]

Batch 31000/35345:
total_loss: 0.06340983510017395


Epoch 5/5:  88%|████████▊ | 31251/35345 [55:14<07:53,  8.64it/s]

Batch 31250/35345:
total_loss: 0.062204934656620026


Epoch 5/5:  89%|████████▉ | 31501/35345 [55:41<06:42,  9.55it/s]

Batch 31500/35345:
total_loss: 0.07155485451221466


Epoch 5/5:  90%|████████▉ | 31751/35345 [56:08<06:14,  9.60it/s]

Batch 31750/35345:
total_loss: 0.12680940330028534


Epoch 5/5:  91%|█████████ | 32001/35345 [56:34<05:50,  9.54it/s]

Batch 32000/35345:
total_loss: 0.18117092549800873


Epoch 5/5:  91%|█████████ | 32251/35345 [57:00<05:24,  9.52it/s]

Batch 32250/35345:
total_loss: 0.10235387086868286


Epoch 5/5:  92%|█████████▏| 32501/35345 [57:26<04:56,  9.58it/s]

Batch 32500/35345:
total_loss: 0.07223644107580185


Epoch 5/5:  93%|█████████▎| 32751/35345 [57:55<04:30,  9.61it/s]

Batch 32750/35345:
total_loss: 0.10000923275947571


Epoch 5/5:  93%|█████████▎| 33001/35345 [58:21<04:03,  9.63it/s]

Batch 33000/35345:
total_loss: 0.12841585278511047


Epoch 5/5:  94%|█████████▍| 33251/35345 [58:47<03:41,  9.46it/s]

Batch 33250/35345:
total_loss: 0.11113990098237991


Epoch 5/5:  95%|█████████▍| 33501/35345 [59:13<03:18,  9.27it/s]

Batch 33500/35345:
total_loss: 0.11723168939352036


Epoch 5/5:  95%|█████████▌| 33751/35345 [59:40<02:45,  9.64it/s]

Batch 33750/35345:
total_loss: 0.1061181128025055


Epoch 5/5:  96%|█████████▌| 34001/35345 [1:00:08<02:59,  7.50it/s]

Batch 34000/35345:
total_loss: 0.10814207792282104


Epoch 5/5:  97%|█████████▋| 34251/35345 [1:00:34<01:53,  9.66it/s]

Batch 34250/35345:
total_loss: 0.0886775329709053


Epoch 5/5:  98%|█████████▊| 34501/35345 [1:01:00<01:27,  9.60it/s]

Batch 34500/35345:
total_loss: 0.03885165974497795


Epoch 5/5:  98%|█████████▊| 34751/35345 [1:01:26<01:01,  9.62it/s]

Batch 34750/35345:
total_loss: 0.1713119000196457


Epoch 5/5:  99%|█████████▉| 35001/35345 [1:01:52<00:37,  9.12it/s]

Batch 35000/35345:
total_loss: 0.10065291821956635


Epoch 5/5: 100%|█████████▉| 35251/35345 [1:02:19<00:10,  9.24it/s]

Batch 35250/35345:
total_loss: 0.13127300143241882


Epoch 5/5: 100%|██████████| 35345/35345 [1:02:32<00:00,  9.42it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.43it/s]


Epoch Summary:
Train Total Loss: 0.1250
Val Total Loss: 0.1044
Learning Rate: 0.000100
New best model saved with val loss: 0.1044
--------------------------------------------------



