In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/symbolic_diffusion_initial/pytorch/default/1/symbolic_diffusion_model.pth
/kaggle/input/1-var-dataset/1_var_test.json
/kaggle/input/1-var-dataset/1_var_val.json
/kaggle/input/1-var-dataset/1_var_train.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/config.txt
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_5_4_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_4_3_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_2_1_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_1_0_17062021_094043.json
/kaggle/input/1-5-var-dataset/1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points/Val/0_3_2_17062021_094043.json
/kaggle/input/

In [2]:
import torch
import json
from torch.utils.data import Dataset
import re
import numpy as np
import random
from scipy.optimize import minimize
import math
import matplotlib.pyplot as plt


def generateDataStrEq(
    eq, n_points=2, n_vars=3, decimals=4, supportPoints=None, min_x=0, max_x=3
):
    X = []
    Y = []
    # TODO: Need to make this faster
    for p in range(n_points):
        if supportPoints is None:
            if type(min_x) == list:
                x = []
                for _ in range(n_vars):
                    idx = np.random.randint(len(min_x))
                    x += list(
                        np.round(np.random.uniform(min_x[idx], max_x[idx], 1), decimals)
                    )
            else:
                x = list(np.round(np.random.uniform(min_x, max_x, n_vars), decimals))
            assert (
                len(x) != 0
            ), "For some reason, we didn't generate the points correctly!"
        else:
            x = supportPoints[p]

        tmpEq = eq + ""
        for nVID in range(n_vars):
            tmpEq = tmpEq.replace("x{}".format(nVID + 1), str(x[nVID]))
        y = float(np.round(eval(tmpEq), decimals))
        X.append(x)
        Y.append(y)
    return X, Y


# def processDataFiles(files):
#     text = ""
#     for f in tqdm(files):
#         with open(f, 'r') as h:
#             lines = h.read() # don't worry we won't run out of file handles
#             if lines[-1]==-1:
#                 lines = lines[:-1]
#             #text += lines #json.loads(line)
#             text = ''.join([lines,text])
#     return text


def processDataFiles(files):
    text = ""
    for f in files:
        with open(f, "r") as h:
            lines = h.read()  # don't worry we won't run out of file handles
            if lines[-1] == -1:
                lines = lines[:-1]
            # text += lines #json.loads(line)
            text = "".join([lines, text])
    return text


def tokenize_equation(eq):
    token_spec = [
        (r"\*\*"),  # exponentiation
        (r"exp"),  # exp function
        (r"[+\-*/=()]"),  # operators and parentheses
        (r"sin"),  # sin function
        (r"cos"),  # cos function
        (r"log"),  # log function
        (r"x\d+"),  # variables like x1, x23, etc.
        (r"C"),  # constants placeholder
        (r"-?\d+\.\d+"),  # decimal numbers
        (r"-?\d+"),  # integers
        (r"_"),  # padding token
    ]
    token_regex = "|".join(f"({pattern})" for pattern in token_spec)
    matches = re.finditer(token_regex, eq)
    return [match.group(0) for match in matches]


class CharDataset(Dataset):
    def __init__(
        self,
        data,
        block_size,
        tokens,
        numVars,
        numYs,
        numPoints,
        target="Skeleton",
        addVars=False,
        const_range=[-0.4, 0.4],
        xRange=[-3.0, 3.0],
        decimals=4,
        augment=False,
    ):

        data_size, vocab_size = len(data), len(tokens)
        print("data has %d examples, %d unique." % (data_size, vocab_size))

        self.stoi = {tok: i for i, tok in enumerate(tokens)}
        self.itos = {i: tok for i, tok in enumerate(tokens)}

        self.numVars = numVars
        self.numYs = numYs
        self.numPoints = numPoints

        # padding token
        self.paddingToken = "_"
        self.paddingID = self.stoi["_"]  # or another ID not already used
        self.stoi[self.paddingToken] = self.paddingID
        self.itos[self.paddingID] = self.paddingToken

        self.threshold = [-1000, 1000]

        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data  # it should be a list of examples
        self.target = target
        self.addVars = addVars

        self.const_range = const_range
        self.xRange = xRange
        self.decimals = decimals
        self.augment = augment

    def __len__(self):
        return len(self.data) - 1

    def __getitem__(self, idx):
        # grab an example from the data
        chunk = self.data[idx]  # sequence of tokens including x, y, eq, etc.

        try:
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary
        except Exception as e:
            print("Couldn't convert to json: {} \n error is: {}".format(chunk, e))
            # try the previous example
            idx = idx - 1
            idx = idx if idx >= 0 else 0
            chunk = self.data[idx]
            chunk = json.loads(chunk)  # convert the sequence tokens to a dictionary

        # find the number of variables in the equation
        printInfoCondition = random.random() < 0.0000001
        eq = chunk[self.target]
        if printInfoCondition:
            print(f"\nEquation: {eq}")
        vars = re.finditer("x[\d]+", eq)
        numVars = 0
        for v in vars:
            v = v.group(0).strip("x")
            v = eval(v)
            v = int(v)
            if v > numVars:
                numVars = v

        if self.target == "Skeleton" and self.augment:
            threshold = 5000
            # randomly generate the constants
            cleanEqn = ""
            for chr in eq:
                if chr == "C":
                    # genereate a new random number
                    chr = "{}".format(
                        np.random.uniform(self.const_range[0], self.const_range[1])
                    )
                cleanEqn += chr

            # update the points
            nPoints = np.random.randint(
                *self.numPoints
            )  # if supportPoints is None else len(supportPoints)
            try:
                if printInfoCondition:
                    print("Org:", chunk["X"], chunk["Y"])

                X, y = generateDataStrEq(
                    cleanEqn,
                    n_points=nPoints,
                    n_vars=self.numVars,
                    decimals=self.decimals,
                    min_x=self.xRange[0],
                    max_x=self.xRange[1],
                )

                # replace out of threshold with maximum numbers
                y = [e if abs(e) < threshold else np.sign(e) * threshold for e in y]

                # check if there is nan/inf/very large numbers in the y
                conditions = (
                    (np.isnan(y).any() or np.isinf(y).any())
                    or len(y) == 0
                    or (abs(min(y)) > threshold or abs(max(y)) > threshold)
                )
                if not conditions:
                    chunk["X"], chunk["Y"] = X, y

                if printInfoCondition:
                    print("Evd:", chunk["X"], chunk["Y"])
            except Exception as e:
                # for different reason this might happend including but not limited to division by zero
                print(
                    "".join(
                        [
                            f"We just used the original equation and support points because of {e}. ",
                            f"The equation is {eq}, and we update the equation to {cleanEqn}",
                        ]
                    )
                )

        # encode every character in the equation to an integer
        # < is SOS, > is EOS
        if self.addVars:
            dix = [self.stoi[s] for s in "<" + str(numVars) + ":" + eq + ">"]
        else:
            eq_tokens = tokenize_equation(eq)
            if self.addVars:
                token_seq = ["<", str(numVars), ":", *eq_tokens, ">"]
            else:
                token_seq = ["<", *eq_tokens, ">"]
            dix = [self.stoi[tok] for tok in token_seq]

        inputs = dix[:-1]
        outputs = dix[1:]

        # add the padding to the equations
        paddingSize = max(self.block_size - len(inputs), 0)
        paddingList = [self.paddingID] * paddingSize
        inputs += paddingList
        outputs += paddingList

        # make sure it is not more than what should be
        inputs = inputs[: self.block_size]
        outputs = outputs[: self.block_size]

        points = torch.zeros(self.numVars + self.numYs, self.numPoints - 1)
        for idx, xy in enumerate(zip(chunk["X"], chunk["Y"])):

            if not isinstance(xy[0], list) or not isinstance(
                xy[1], (list, float, np.float64)
            ):
                print(f"Unexpected types: {type(xy[0])}, {type(xy[1])}")
                continue  # Skip if types are incorrect

            # don't let to exceed the maximum number of points
            if idx >= self.numPoints - 1:
                break

            x = xy[0]
            x = x + [0] * (max(self.numVars - len(x), 0))  # padding

            y = [xy[1]] if type(xy[1]) == float or type(xy[1]) == np.float64 else xy[1]

            y = y + [0] * (max(self.numYs - len(y), 0))  # padding
            p = x + y  # because it is only one point
            p = torch.tensor(p)
            # replace nan and inf
            p = torch.nan_to_num(
                p,
                nan=self.threshold[1],
                posinf=self.threshold[1],
                neginf=self.threshold[0],
            )

            points[:, idx] = p

        points = torch.nan_to_num(
            points,
            nan=self.threshold[1],
            posinf=self.threshold[1],
            neginf=self.threshold[0],
        )

        inputs = torch.tensor(inputs, dtype=torch.long)
        outputs = torch.tensor(outputs, dtype=torch.long)
        numVars = torch.tensor(numVars, dtype=torch.long)
        return inputs, outputs, points, numVars


# Relative Mean Square Error
def relativeErr(y, yHat, info=False, eps=1e-5):
    yHat = np.reshape(yHat, [1, -1])[0]
    y = np.reshape(y, [1, -1])[0]
    if len(y) > 0 and len(y) == len(yHat):
        err = ((yHat - y)) ** 2 / np.linalg.norm(y + eps)
        if info:
            for _ in range(5):
                i = np.random.randint(len(y))
                # print("yPR,yTrue:{},{}, Err:{}".format(yHat[i], y[i], err[i]))
    else:
        err = 100

    return np.mean(err)


def lossFunc(constants, eq, X, Y, eps=1e-5):
    err = 0
    eq = eq.replace("C", "{}").format(*constants)

    for x, y in zip(X, Y):
        eqTemp = eq + ""
        if type(x) == np.float32:
            x = [x]
        for i, e in enumerate(x):
            # make sure e is not a tensor
            if type(e) == torch.Tensor:
                e = e.item()
            eqTemp = eqTemp.replace("x{}".format(i + 1), str(e))
        try:
            yHat = eval(eqTemp)
        except:
            # print("Exception has been occured! EQ: {}, OR: {}".format(eqTemp, eq))
            yHat = 100
        try:
            # handle overflow
            err += relativeErr(y, yHat)  # (y-yHat)**2
        except:
            # print(
            #    "Exception has been occured! EQ: {}, OR: {}, y:{}-yHat:{}".format(
            #        eqTemp, eq, y, yHat
            #    )
            # )
            err += 10

    err /= len(Y)
    return err


def get_skeleton(tokens, train_dataset: CharDataset):
    skeleton = "".join([train_dataset.itos[int(idx)] for idx in tokens])
    skeleton = skeleton.strip(train_dataset.paddingToken).split(">")
    skeleton = skeleton[0] if len(skeleton[0]) >= 1 else skeleton[1]
    skeleton = skeleton.strip("<").strip(">")
    skeleton = skeleton.replace("Ce", "C*e")

    return skeleton


In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math

# from SymbolicGPT: https://github.com/mojivalipour/symbolicgpt/blob/master/models.py
class PointNetConfig:
    """base PointNet config"""

    def __init__(
        self,
        embeddingSize,
        numberofPoints,
        numberofVars,
        numberofYs,
        **kwargs,
    ):
        self.embeddingSize = embeddingSize
        self.numberofPoints = numberofPoints  # number of points
        self.numberofVars = numberofVars  # input dimension (Xs)
        self.numberofYs = numberofYs  # output dimension (Ys)

        for k, v in kwargs.items():
            setattr(self, k, v)


class tNet(nn.Module):
    """
    The PointNet structure in the orginal PointNet paper:
    PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation by Qi et. al. 2017
    """

    def __init__(self, config):
        super(tNet, self).__init__()

        self.activation_func = F.relu
        self.num_units = config.embeddingSize

        self.conv1 = nn.Conv1d(
            config.numberofVars + config.numberofYs, self.num_units, 1
        )
        self.conv2 = nn.Conv1d(self.num_units, 2 * self.num_units, 1)
        self.conv3 = nn.Conv1d(2 * self.num_units, 4 * self.num_units, 1)
        self.fc1 = nn.Linear(4 * self.num_units, 2 * self.num_units)
        self.fc2 = nn.Linear(2 * self.num_units, self.num_units)

        # self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(config.numberofVars + config.numberofYs)
        # self.input_layer_norm = nn.LayerNorm(config.numberofPoints)

        self.bn1 = nn.BatchNorm1d(self.num_units)
        self.bn2 = nn.BatchNorm1d(2 * self.num_units)
        self.bn3 = nn.BatchNorm1d(4 * self.num_units)
        self.bn4 = nn.BatchNorm1d(2 * self.num_units)
        self.bn5 = nn.BatchNorm1d(self.num_units)

    def forward(self, x):
        """
        :param x: [batch, #features, #points]
        :return:
            logit: [batch, embedding_size]
        """
        x = self.input_batch_norm(x)
        x = self.activation_func(self.bn1(self.conv1(x)))
        x = self.activation_func(self.bn2(self.conv2(x)))
        x = self.activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4 * self.num_units

        x = self.activation_func(self.bn4(self.fc1(x)))
        x = self.activation_func(self.bn5(self.fc2(x)))
        # x = self.fc2(x)

        return x


# from https://github.com/juho-lee/set_transformer/blob/master/modules.py
class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, "ln0", None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, "ln1", None) is None else self.ln1(O)
        return O


class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)


class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)


# from https://github.com/juho-lee/set_transformer/blob/master/models.py
class SetTransformer(nn.Module):
    def __init__(
        self,
        dim_input,
        num_outputs,
        dim_output,
        num_inds=32,
        dim_hidden=128,
        num_heads=4,
        ln=False,
    ):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
            ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
        )
        self.dec = nn.Sequential(
            PMA(dim_hidden, num_heads, num_outputs, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            nn.Linear(dim_hidden, dim_output),
        )

    def forward(self, X):
        return self.dec(self.enc(X)).squeeze(1)


class NoisePredictionTransformer(nn.Module):
    def __init__(self, n_embd, max_seq_len, n_layer=6, n_head=8, max_timesteps=1000):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, n_embd))
        self.time_emb = nn.Embedding(max_timesteps, n_embd)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)

    def forward(self, x_t, t, condition):
        _, L, _ = x_t.shape
        pos_emb = self.pos_emb[:, :L, :]  # [1, L, n_embd]
        time_emb = self.time_emb(t)
        if time_emb.dim() == 1:  # Scalar t case, [n_embd]
            time_emb = time_emb.unsqueeze(0)  # [1, n_embd]
        time_emb = time_emb.unsqueeze(1)  # [1, 1, n_embd]
        condition = condition.unsqueeze(1)  # [B, 1, n_embd]

        x = x_t + pos_emb + time_emb + condition
        return self.encoder(x)


# influenced by https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
class SymbolicGaussianDiffusion(nn.Module):
    def __init__(
        self,
        tnet_config: PointNetConfig,
        vocab_size,
        max_seq_len,
        padding_idx: int = 0,
        max_num_vars: int = 9,
        n_layer=6,
        n_head=8,
        n_embd=512,
        timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        set_transformer=True,
        ce_weight=1.0,  
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.padding_idx = padding_idx
        self.n_embd = n_embd
        self.timesteps = timesteps
        self.set_transformer = set_transformer
        self.ce_weight = ce_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.tok_emb = nn.Embedding(vocab_size, n_embd, padding_idx=self.padding_idx)
        self.vars_emb = nn.Embedding(max_num_vars, n_embd)

        self.decoder = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.weight = self.tok_emb.weight

        if set_transformer:
            dim_input = tnet_config.numberofVars + tnet_config.numberofYs
            self.tnet = SetTransformer(
                dim_input=dim_input,
                num_outputs=1,
                dim_output=tnet_config.embeddingSize,
                num_inds=tnet_config.numberofPoints,
                num_heads=4,
            )
        else:
            self.tnet = tNet(tnet_config)

        self.model = NoisePredictionTransformer(
            n_embd, max_seq_len, n_layer, n_head, timesteps
        )

        # Noise schedule
        self.register_buffer("beta", torch.linspace(beta_start, beta_end, timesteps))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

    def q_sample(self, x_start, t, noise=None):
        noise = torch.randn_like(x_start)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1)

        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
        return x_t

    def p_mean_variance(self, x, t, t_next, condition):
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        alpha_bar_t_next = self.alpha_bar[t_next]
        beta_t = self.beta[t]

        x_start_pred = self.model(x, t.long(), condition)

        coeff1 = torch.sqrt(alpha_bar_t_next) * beta_t / (1 - alpha_bar_t)
        coeff2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_next) / (1 - alpha_bar_t)
        mean = coeff1 * x_start_pred + coeff2 * x
        variance = (1 - alpha_bar_t_next) / (1 - alpha_bar_t) * beta_t
        return mean, variance

    @torch.no_grad()
    def p_sample(self, x, t, t_next, condition):
        mean, variance = self.p_mean_variance(x, t, t_next, condition)
        if torch.all(t_next == 0):
            return mean
        noise = torch.randn_like(x)
        return mean + torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, points, variables, train_dataset, batch_size=16, ddim_step=1):
        if self.set_transformer:
            points = points.transpose(1, 2)

        condition = self.tnet(points) + self.vars_emb(variables)
        shape = (batch_size, self.max_seq_len, self.n_embd)
        x = torch.randn(shape, device=self.device)
        steps = torch.arange(self.timesteps - 1, -1, -1, device=self.device)

        for i in range(0, self.timesteps, ddim_step):
            t = steps[i]
            t_next = (
                steps[i + ddim_step]
                if i + ddim_step < self.timesteps
                else torch.tensor(0, device=self.device)
            )
            x = self.p_sample(x, t, t_next, condition)

            # Print prediction every 250 steps
            # if (i + 1) % 250 == 0:
            #    logits = self.decoder(x)  # [B, L, vocab_size]
            #    token_indices = torch.argmax(logits, dim=-1)  # [B, L]
            #    for j in range(batch_size):
            #       token_indices_j = token_indices[j]  # [L]
            #        predicted_skeleton = get_predicted_skeleton(
            #            token_indices_j, train_dataset
            #        )
            #        tqdm.write(f" sample {j}: predicted_skeleton: {predicted_skeleton}")

        logits = self.decoder(x)  # [B, L, vocab_size]
        token_indices = torch.argmax(logits, dim=-1)  # [B, L]
        predicted_skeletons = []
        for j in range(batch_size):
            token_indices_j = token_indices[j]  # [L]
            predicted_skeleton = get_predicted_skeleton(token_indices_j, train_dataset)
            predicted_skeletons.append(predicted_skeleton)
        return predicted_skeletons

    def p_losses(
        self,
        x_start,
        points,
        tokens,
        variables,
        t,
    ):
        noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise)

        if self.set_transformer:
            points = points.transpose(1, 2)

        condition = self.tnet(points) + self.vars_emb(variables)

        x_start_pred = self.model(x_t, t.long(), condition)

        # CE loss on tokens
        logits = self.decoder(x_start_pred)  # [B, L, vocab_size]

        ce_loss = F.cross_entropy(
            logits.view(-1, self.vocab_size),  # [B*L, vocab_size]
            tokens.view(-1),  # [B*L]
            ignore_index=self.padding_idx,
            reduction="mean",
        )

        ce_loss = ce_loss * self.ce_weight

        return ce_loss

    def forward(self, points, tokens, variables, t):
        token_emb = self.tok_emb(tokens)
        ce_loss = self.p_losses(token_emb, points, tokens, variables, t)
        return ce_loss

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple

def train_epoch(
    model: SymbolicGaussianDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
        optimizer.zero_grad()

        total_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()


    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            total_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
    path=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        avg_train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f}"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f}"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)

In [5]:
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 5
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
maxNumFiles = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
dataDir = "/kaggle/input/1-5-var-dataset"
dataFolder = "1-5Var_RandSupport_RandLength_-3to3_-5.0to-3.0-3.0to5.0_10to200Points"

In [7]:
import numpy as np
import glob
import random

path = '{}/{}/Train/*.json'.format(dataDir, dataFolder)
files = glob.glob(path)[:maxNumFiles]
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 2262058 examples, 44 unique.
id:330316
outputs:C*log(C*sin(C*x2+C)+C*cos(C*x2+C))+C>___
variables:2


In [8]:
path = '{}/{}/Val/*.json'.format(dataDir,dataFolder)
files = glob.glob(path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 195 examples, 44 unique.
tensor(-2.9703) tensor(2.9976)
id:10
outputs:C*cos(C*x2+C*x3+C/x2)+C>_____________
variables:3


In [9]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,  
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate,
    path="1_5_var_set_transformer"
)

Epoch 1/5:   1%|          | 250/35345 [00:25<1:08:54,  8.49it/s]

Batch 250/35345:
total_loss: 2.185614585876465


Epoch 1/5:   1%|▏         | 501/35345 [00:50<58:54,  9.86it/s]

Batch 500/35345:
total_loss: 1.4346182346343994


Epoch 1/5:   2%|▏         | 751/35345 [01:16<1:00:58,  9.46it/s]

Batch 750/35345:
total_loss: 0.980102002620697


Epoch 1/5:   3%|▎         | 1000/35345 [01:43<1:04:45,  8.84it/s]

Batch 1000/35345:
total_loss: 1.0374677181243896


Epoch 1/5:   4%|▎         | 1250/35345 [02:13<1:08:55,  8.24it/s]

Batch 1250/35345:
total_loss: 0.8646742701530457


Epoch 1/5:   4%|▍         | 1501/35345 [02:42<1:03:44,  8.85it/s]

Batch 1500/35345:
total_loss: 1.2908841371536255


Epoch 1/5:   5%|▍         | 1750/35345 [03:10<1:03:51,  8.77it/s]

Batch 1750/35345:
total_loss: 0.8237023949623108


Epoch 1/5:   6%|▌         | 2000/35345 [03:39<1:04:11,  8.66it/s]

Batch 2000/35345:
total_loss: 0.8282701373100281


Epoch 1/5:   6%|▋         | 2250/35345 [04:08<1:03:22,  8.70it/s]

Batch 2250/35345:
total_loss: 0.7722989320755005


Epoch 1/5:   7%|▋         | 2500/35345 [04:36<1:02:49,  8.71it/s]

Batch 2500/35345:
total_loss: 0.6709623336791992


Epoch 1/5:   8%|▊         | 2750/35345 [05:05<1:01:40,  8.81it/s]

Batch 2750/35345:
total_loss: 0.8416247367858887


Epoch 1/5:   8%|▊         | 3000/35345 [05:33<1:01:07,  8.82it/s]

Batch 3000/35345:
total_loss: 0.6055462956428528


Epoch 1/5:   9%|▉         | 3250/35345 [06:02<1:00:40,  8.82it/s]

Batch 3250/35345:
total_loss: 0.5399788022041321


Epoch 1/5:  10%|▉         | 3500/35345 [06:30<1:00:33,  8.76it/s]

Batch 3500/35345:
total_loss: 0.8589791059494019


Epoch 1/5:  11%|█         | 3750/35345 [06:59<59:39,  8.83it/s]

Batch 3750/35345:
total_loss: 0.7139004468917847


Epoch 1/5:  11%|█▏        | 4000/35345 [07:27<59:49,  8.73it/s]

Batch 4000/35345:
total_loss: 0.762412965297699


Epoch 1/5:  12%|█▏        | 4251/35345 [07:56<59:03,  8.78it/s]

Batch 4250/35345:
total_loss: 0.5367165803909302


Epoch 1/5:  13%|█▎        | 4500/35345 [08:24<59:36,  8.63it/s]

Batch 4500/35345:
total_loss: 0.7189046144485474


Epoch 1/5:  13%|█▎        | 4750/35345 [08:53<58:52,  8.66it/s]

Batch 4750/35345:
total_loss: 0.7103501558303833


Epoch 1/5:  14%|█▍        | 5000/35345 [09:22<57:24,  8.81it/s]

Batch 5000/35345:
total_loss: 0.4498799741268158


Epoch 1/5:  15%|█▍        | 5250/35345 [09:50<57:13,  8.77it/s]

Batch 5250/35345:
total_loss: 0.49469560384750366


Epoch 1/5:  16%|█▌        | 5500/35345 [10:19<56:31,  8.80it/s]

Batch 5500/35345:
total_loss: 0.5378633141517639


Epoch 1/5:  16%|█▋        | 5750/35345 [10:47<56:04,  8.80it/s]

Batch 5750/35345:
total_loss: 0.6361140608787537


Epoch 1/5:  17%|█▋        | 6000/35345 [11:16<55:41,  8.78it/s]

Batch 6000/35345:
total_loss: 0.7626102566719055


Epoch 1/5:  18%|█▊        | 6250/35345 [11:44<55:20,  8.76it/s]

Batch 6250/35345:
total_loss: 0.5773890018463135


Epoch 1/5:  18%|█▊        | 6500/35345 [12:13<54:31,  8.82it/s]

Batch 6500/35345:
total_loss: 0.4216547906398773


Epoch 1/5:  19%|█▉        | 6750/35345 [12:41<54:28,  8.75it/s]

Batch 6750/35345:
total_loss: 0.5306946039199829


Epoch 1/5:  20%|█▉        | 7000/35345 [13:10<54:00,  8.75it/s]

Batch 7000/35345:
total_loss: 0.5001635551452637


Epoch 1/5:  21%|██        | 7250/35345 [13:39<53:53,  8.69it/s]

Batch 7250/35345:
total_loss: 0.5651721358299255


Epoch 1/5:  21%|██        | 7500/35345 [14:08<52:52,  8.78it/s]

Batch 7500/35345:
total_loss: 0.4869869351387024


Epoch 1/5:  22%|██▏       | 7751/35345 [14:37<52:51,  8.70it/s]

Batch 7750/35345:
total_loss: 0.5204997062683105


Epoch 1/5:  23%|██▎       | 8000/35345 [15:05<51:57,  8.77it/s]

Batch 8000/35345:
total_loss: 0.3362918794155121


Epoch 1/5:  23%|██▎       | 8250/35345 [15:34<52:04,  8.67it/s]

Batch 8250/35345:
total_loss: 0.4749164581298828


Epoch 1/5:  24%|██▍       | 8500/35345 [16:03<51:01,  8.77it/s]

Batch 8500/35345:
total_loss: 0.28382614254951477


Epoch 1/5:  25%|██▍       | 8751/35345 [16:32<51:18,  8.64it/s]

Batch 8750/35345:
total_loss: 0.6510335803031921


Epoch 1/5:  25%|██▌       | 9000/35345 [17:01<50:25,  8.71it/s]

Batch 9000/35345:
total_loss: 0.5155853629112244


Epoch 1/5:  26%|██▌       | 9250/35345 [17:29<49:44,  8.74it/s]

Batch 9250/35345:
total_loss: 0.28897368907928467


Epoch 1/5:  27%|██▋       | 9500/35345 [17:58<49:33,  8.69it/s]

Batch 9500/35345:
total_loss: 0.4326290786266327


Epoch 1/5:  28%|██▊       | 9750/35345 [18:27<49:02,  8.70it/s]

Batch 9750/35345:
total_loss: 0.797929048538208


Epoch 1/5:  28%|██▊       | 10000/35345 [18:56<48:53,  8.64it/s]

Batch 10000/35345:
total_loss: 0.3404402434825897


Epoch 1/5:  29%|██▉       | 10250/35345 [19:25<48:17,  8.66it/s]

Batch 10250/35345:
total_loss: 0.5611764788627625


Epoch 1/5:  30%|██▉       | 10500/35345 [19:54<47:33,  8.71it/s]

Batch 10500/35345:
total_loss: 0.3537016212940216


Epoch 1/5:  30%|███       | 10750/35345 [20:23<46:57,  8.73it/s]

Batch 10750/35345:
total_loss: 0.5349417924880981


Epoch 1/5:  31%|███       | 11001/35345 [20:52<47:24,  8.56it/s]

Batch 11000/35345:
total_loss: 0.6443161368370056


Epoch 1/5:  32%|███▏      | 11250/35345 [21:21<46:14,  8.69it/s]

Batch 11250/35345:
total_loss: 0.38039037585258484


Epoch 1/5:  33%|███▎      | 11500/35345 [21:49<45:49,  8.67it/s]

Batch 11500/35345:
total_loss: 0.5742515325546265


Epoch 1/5:  33%|███▎      | 11750/35345 [22:18<45:07,  8.72it/s]

Batch 11750/35345:
total_loss: 0.36079567670822144


Epoch 1/5:  34%|███▍      | 12000/35345 [22:47<44:48,  8.68it/s]

Batch 12000/35345:
total_loss: 0.31545668840408325


Epoch 1/5:  35%|███▍      | 12250/35345 [23:16<44:21,  8.68it/s]

Batch 12250/35345:
total_loss: 0.4801715910434723


Epoch 1/5:  35%|███▌      | 12500/35345 [23:45<43:38,  8.72it/s]

Batch 12500/35345:
total_loss: 0.35330477356910706


Epoch 1/5:  36%|███▌      | 12750/35345 [24:13<43:24,  8.67it/s]

Batch 12750/35345:
total_loss: 0.6377407312393188


Epoch 1/5:  37%|███▋      | 13000/35345 [24:42<43:00,  8.66it/s]

Batch 13000/35345:
total_loss: 0.5154014229774475


Epoch 1/5:  37%|███▋      | 13250/35345 [25:11<42:27,  8.67it/s]

Batch 13250/35345:
total_loss: 0.49370187520980835


Epoch 1/5:  38%|███▊      | 13500/35345 [25:40<42:47,  8.51it/s]

Batch 13500/35345:
total_loss: 0.5019680261611938


Epoch 1/5:  39%|███▉      | 13750/35345 [26:09<42:12,  8.53it/s]

Batch 13750/35345:
total_loss: 0.5541069507598877


Epoch 1/5:  40%|███▉      | 14000/35345 [26:38<40:47,  8.72it/s]

Batch 14000/35345:
total_loss: 0.4396671950817108


Epoch 1/5:  40%|████      | 14250/35345 [27:06<40:30,  8.68it/s]

Batch 14250/35345:
total_loss: 0.4015865623950958


Epoch 1/5:  41%|████      | 14500/35345 [27:35<39:47,  8.73it/s]

Batch 14500/35345:
total_loss: 0.3875928819179535


Epoch 1/5:  42%|████▏     | 14750/35345 [28:04<39:05,  8.78it/s]

Batch 14750/35345:
total_loss: 0.3912469446659088


Epoch 1/5:  42%|████▏     | 15000/35345 [28:33<38:54,  8.72it/s]

Batch 15000/35345:
total_loss: 0.5332741737365723


Epoch 1/5:  43%|████▎     | 15250/35345 [29:01<38:31,  8.69it/s]

Batch 15250/35345:
total_loss: 0.4439820647239685


Epoch 1/5:  44%|████▍     | 15500/35345 [29:30<38:42,  8.55it/s]

Batch 15500/35345:
total_loss: 0.7912471294403076


Epoch 1/5:  45%|████▍     | 15751/35345 [30:00<38:07,  8.57it/s]

Batch 15750/35345:
total_loss: 0.3446178436279297


Epoch 1/5:  45%|████▌     | 16000/35345 [30:28<37:32,  8.59it/s]

Batch 16000/35345:
total_loss: 0.553234338760376


Epoch 1/5:  46%|████▌     | 16250/35345 [30:57<36:31,  8.71it/s]

Batch 16250/35345:
total_loss: 0.38948819041252136


Epoch 1/5:  47%|████▋     | 16500/35345 [31:26<36:23,  8.63it/s]

Batch 16500/35345:
total_loss: 0.3504689335823059


Epoch 1/5:  47%|████▋     | 16750/35345 [31:55<35:45,  8.67it/s]

Batch 16750/35345:
total_loss: 0.3852728605270386


Epoch 1/5:  48%|████▊     | 17000/35345 [32:24<35:05,  8.71it/s]

Batch 17000/35345:
total_loss: 0.41532424092292786


Epoch 1/5:  49%|████▉     | 17250/35345 [32:53<34:29,  8.74it/s]

Batch 17250/35345:
total_loss: 0.43428531289100647


Epoch 1/5:  50%|████▉     | 17500/35345 [33:21<34:20,  8.66it/s]

Batch 17500/35345:
total_loss: 0.3148570954799652


Epoch 1/5:  50%|█████     | 17750/35345 [33:50<33:26,  8.77it/s]

Batch 17750/35345:
total_loss: 0.34052902460098267


Epoch 1/5:  51%|█████     | 18001/35345 [34:19<33:06,  8.73it/s]

Batch 18000/35345:
total_loss: 0.4822077751159668


Epoch 1/5:  52%|█████▏    | 18250/35345 [34:47<33:26,  8.52it/s]

Batch 18250/35345:
total_loss: 0.5488578081130981


Epoch 1/5:  52%|█████▏    | 18500/35345 [35:16<32:14,  8.71it/s]

Batch 18500/35345:
total_loss: 0.48118916153907776


Epoch 1/5:  53%|█████▎    | 18750/35345 [35:45<31:39,  8.74it/s]

Batch 18750/35345:
total_loss: 0.4632233679294586


Epoch 1/5:  54%|█████▍    | 19000/35345 [36:14<31:15,  8.71it/s]

Batch 19000/35345:
total_loss: 0.4006049931049347


Epoch 1/5:  54%|█████▍    | 19250/35345 [36:43<31:16,  8.58it/s]

Batch 19250/35345:
total_loss: 0.4393700063228607


Epoch 1/5:  55%|█████▌    | 19500/35345 [37:12<30:53,  8.55it/s]

Batch 19500/35345:
total_loss: 0.6964597105979919


Epoch 1/5:  56%|█████▌    | 19751/35345 [37:41<30:22,  8.56it/s]

Batch 19750/35345:
total_loss: 0.5394505858421326


Epoch 1/5:  57%|█████▋    | 20000/35345 [38:10<29:39,  8.62it/s]

Batch 20000/35345:
total_loss: 0.29681870341300964


Epoch 1/5:  57%|█████▋    | 20250/35345 [38:39<29:11,  8.62it/s]

Batch 20250/35345:
total_loss: 0.4081485867500305


Epoch 1/5:  58%|█████▊    | 20501/35345 [39:08<28:33,  8.66it/s]

Batch 20500/35345:
total_loss: 0.4332698881626129


Epoch 1/5:  59%|█████▊    | 20750/35345 [39:37<28:14,  8.61it/s]

Batch 20750/35345:
total_loss: 0.2742215394973755


Epoch 1/5:  59%|█████▉    | 21000/35345 [40:06<27:31,  8.68it/s]

Batch 21000/35345:
total_loss: 0.32336825132369995


Epoch 1/5:  60%|██████    | 21250/35345 [40:35<27:23,  8.58it/s]

Batch 21250/35345:
total_loss: 0.45391881465911865


Epoch 1/5:  61%|██████    | 21500/35345 [41:04<26:32,  8.69it/s]

Batch 21500/35345:
total_loss: 0.5121743679046631


Epoch 1/5:  62%|██████▏   | 21750/35345 [41:33<26:13,  8.64it/s]

Batch 21750/35345:
total_loss: 0.29714271426200867


Epoch 1/5:  62%|██████▏   | 22000/35345 [42:01<25:41,  8.66it/s]

Batch 22000/35345:
total_loss: 0.1687062531709671


Epoch 1/5:  63%|██████▎   | 22250/35345 [42:30<25:13,  8.65it/s]

Batch 22250/35345:
total_loss: 0.4417828321456909


Epoch 1/5:  64%|██████▎   | 22500/35345 [42:59<24:52,  8.61it/s]

Batch 22500/35345:
total_loss: 0.3367003798484802


Epoch 1/5:  64%|██████▍   | 22750/35345 [43:28<24:13,  8.67it/s]

Batch 22750/35345:
total_loss: 0.3718493580818176


Epoch 1/5:  65%|██████▌   | 23000/35345 [43:57<23:24,  8.79it/s]

Batch 23000/35345:
total_loss: 0.520052433013916


Epoch 1/5:  66%|██████▌   | 23250/35345 [44:26<23:28,  8.58it/s]

Batch 23250/35345:
total_loss: 0.5135177969932556


Epoch 1/5:  66%|██████▋   | 23500/35345 [44:55<23:01,  8.57it/s]

Batch 23500/35345:
total_loss: 0.5118731260299683


Epoch 1/5:  67%|██████▋   | 23750/35345 [45:24<22:19,  8.65it/s]

Batch 23750/35345:
total_loss: 0.4549266993999481


Epoch 1/5:  68%|██████▊   | 24000/35345 [45:52<21:30,  8.79it/s]

Batch 24000/35345:
total_loss: 0.2979355454444885


Epoch 1/5:  69%|██████▊   | 24250/35345 [46:21<21:06,  8.76it/s]

Batch 24250/35345:
total_loss: 0.3497985601425171


Epoch 1/5:  69%|██████▉   | 24500/35345 [46:50<21:09,  8.54it/s]

Batch 24500/35345:
total_loss: 0.32332292199134827


Epoch 1/5:  70%|███████   | 24750/35345 [47:19<20:01,  8.82it/s]

Batch 24750/35345:
total_loss: 0.304090678691864


Epoch 1/5:  71%|███████   | 25000/35345 [47:47<19:47,  8.71it/s]

Batch 25000/35345:
total_loss: 0.5792734026908875


Epoch 1/5:  71%|███████▏  | 25250/35345 [48:16<19:50,  8.48it/s]

Batch 25250/35345:
total_loss: 0.4245140850543976


Epoch 1/5:  72%|███████▏  | 25500/35345 [48:45<18:49,  8.72it/s]

Batch 25500/35345:
total_loss: 0.40692150592803955


Epoch 1/5:  73%|███████▎  | 25750/35345 [49:13<18:31,  8.63it/s]

Batch 25750/35345:
total_loss: 0.33495157957077026


Epoch 1/5:  74%|███████▎  | 26000/35345 [49:42<17:46,  8.77it/s]

Batch 26000/35345:
total_loss: 0.628962516784668


Epoch 1/5:  74%|███████▍  | 26250/35345 [50:11<17:21,  8.73it/s]

Batch 26250/35345:
total_loss: 0.3371160328388214


Epoch 1/5:  75%|███████▍  | 26500/35345 [50:40<16:48,  8.77it/s]

Batch 26500/35345:
total_loss: 0.4499507248401642


Epoch 1/5:  76%|███████▌  | 26750/35345 [51:09<16:29,  8.68it/s]

Batch 26750/35345:
total_loss: 0.32314541935920715


Epoch 1/5:  76%|███████▋  | 27000/35345 [51:37<16:04,  8.65it/s]

Batch 27000/35345:
total_loss: 0.4146150052547455


Epoch 1/5:  77%|███████▋  | 27250/35345 [52:06<15:40,  8.61it/s]

Batch 27250/35345:
total_loss: 0.35982370376586914


Epoch 1/5:  78%|███████▊  | 27501/35345 [52:35<15:06,  8.65it/s]

Batch 27500/35345:
total_loss: 0.5150923728942871


Epoch 1/5:  79%|███████▊  | 27750/35345 [53:04<14:39,  8.64it/s]

Batch 27750/35345:
total_loss: 0.43536895513534546


Epoch 1/5:  79%|███████▉  | 28000/35345 [53:33<14:02,  8.71it/s]

Batch 28000/35345:
total_loss: 0.36235713958740234


Epoch 1/5:  80%|███████▉  | 28250/35345 [54:02<13:36,  8.69it/s]

Batch 28250/35345:
total_loss: 0.32352685928344727


Epoch 1/5:  81%|████████  | 28501/35345 [54:31<13:11,  8.64it/s]

Batch 28500/35345:
total_loss: 0.441256046295166


Epoch 1/5:  81%|████████▏ | 28750/35345 [55:00<12:48,  8.58it/s]

Batch 28750/35345:
total_loss: 0.3425244688987732


Epoch 1/5:  82%|████████▏ | 29000/35345 [55:29<12:13,  8.64it/s]

Batch 29000/35345:
total_loss: 0.33695176243782043


Epoch 1/5:  83%|████████▎ | 29250/35345 [55:57<11:41,  8.69it/s]

Batch 29250/35345:
total_loss: 0.5018501281738281


Epoch 1/5:  83%|████████▎ | 29500/35345 [56:26<11:15,  8.65it/s]

Batch 29500/35345:
total_loss: 0.38342908024787903


Epoch 1/5:  84%|████████▍ | 29750/35345 [56:55<10:46,  8.66it/s]

Batch 29750/35345:
total_loss: 0.26289865374565125


Epoch 1/5:  85%|████████▍ | 30000/35345 [57:24<10:20,  8.62it/s]

Batch 30000/35345:
total_loss: 0.5283681154251099


Epoch 1/5:  86%|████████▌ | 30250/35345 [57:53<09:42,  8.75it/s]

Batch 30250/35345:
total_loss: 0.6086551547050476


Epoch 1/5:  86%|████████▋ | 30500/35345 [58:22<09:15,  8.71it/s]

Batch 30500/35345:
total_loss: 0.39650848507881165


Epoch 1/5:  87%|████████▋ | 30750/35345 [58:50<08:46,  8.74it/s]

Batch 30750/35345:
total_loss: 0.38900548219680786


Epoch 1/5:  88%|████████▊ | 31000/35345 [59:19<08:12,  8.82it/s]

Batch 31000/35345:
total_loss: 0.5439746379852295


Epoch 1/5:  88%|████████▊ | 31250/35345 [59:48<07:45,  8.80it/s]

Batch 31250/35345:
total_loss: 0.26604557037353516


Epoch 1/5:  89%|████████▉ | 31500/35345 [1:00:16<07:16,  8.80it/s]

Batch 31500/35345:
total_loss: 0.35725095868110657


Epoch 1/5:  90%|████████▉ | 31750/35345 [1:00:45<06:52,  8.72it/s]

Batch 31750/35345:
total_loss: 0.3540562093257904


Epoch 1/5:  91%|█████████ | 32000/35345 [1:01:14<06:24,  8.71it/s]

Batch 32000/35345:
total_loss: 0.3406241834163666


Epoch 1/5:  91%|█████████ | 32250/35345 [1:01:43<06:28,  7.97it/s]

Batch 32250/35345:
total_loss: 0.40293923020362854


Epoch 1/5:  92%|█████████▏| 32500/35345 [1:02:11<05:27,  8.69it/s]

Batch 32500/35345:
total_loss: 0.3431962728500366


Epoch 1/5:  93%|█████████▎| 32750/35345 [1:02:40<04:57,  8.74it/s]

Batch 32750/35345:
total_loss: 0.3916983902454376


Epoch 1/5:  93%|█████████▎| 33000/35345 [1:03:09<04:27,  8.78it/s]

Batch 33000/35345:
total_loss: 0.40927231311798096


Epoch 1/5:  94%|█████████▍| 33250/35345 [1:03:38<04:02,  8.64it/s]

Batch 33250/35345:
total_loss: 0.44851982593536377


Epoch 1/5:  95%|█████████▍| 33500/35345 [1:04:06<03:33,  8.63it/s]

Batch 33500/35345:
total_loss: 0.2504175305366516


Epoch 1/5:  95%|█████████▌| 33751/35345 [1:04:35<03:03,  8.67it/s]

Batch 33750/35345:
total_loss: 0.45741748809814453


Epoch 1/5:  96%|█████████▌| 34000/35345 [1:05:04<02:34,  8.69it/s]

Batch 34000/35345:
total_loss: 0.23832057416439056


Epoch 1/5:  97%|█████████▋| 34250/35345 [1:05:33<02:05,  8.72it/s]

Batch 34250/35345:
total_loss: 0.5682705044746399


Epoch 1/5:  98%|█████████▊| 34501/35345 [1:06:01<01:38,  8.54it/s]

Batch 34500/35345:
total_loss: 0.3193194568157196


Epoch 1/5:  98%|█████████▊| 34750/35345 [1:06:30<01:09,  8.61it/s]

Batch 34750/35345:
total_loss: 0.45371538400650024


Epoch 1/5:  99%|█████████▉| 35000/35345 [1:06:59<00:39,  8.69it/s]

Batch 35000/35345:
total_loss: 0.41046589612960815


Epoch 1/5: 100%|█████████▉| 35250/35345 [1:07:28<00:10,  8.73it/s]

Batch 35250/35345:
total_loss: 0.27960872650146484


Epoch 1/5: 100%|██████████| 35345/35345 [1:07:40<00:00,  8.70it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]


Epoch Summary:
Train Total Loss: 0.5192
Val Total Loss: 0.2533
Learning Rate: 0.000100
New best model saved with val loss: 0.2533
--------------------------------------------------



Epoch 2/5:   1%|          | 250/35345 [00:29<1:08:48,  8.50it/s]

Batch 250/35345:
total_loss: 0.2915315330028534


Epoch 2/5:   1%|▏         | 500/35345 [00:58<1:07:19,  8.63it/s]

Batch 500/35345:
total_loss: 0.3305414915084839


Epoch 2/5:   2%|▏         | 750/35345 [01:26<1:06:36,  8.66it/s]

Batch 750/35345:
total_loss: 0.3456326723098755


Epoch 2/5:   3%|▎         | 1000/35345 [01:55<1:04:40,  8.85it/s]

Batch 1000/35345:
total_loss: 0.25384315848350525


Epoch 2/5:   4%|▎         | 1250/35345 [02:23<1:05:00,  8.74it/s]

Batch 1250/35345:
total_loss: 0.32047390937805176


Epoch 2/5:   4%|▍         | 1500/35345 [02:52<1:04:41,  8.72it/s]

Batch 1500/35345:
total_loss: 0.4527130722999573


Epoch 2/5:   5%|▍         | 1750/35345 [03:20<1:04:10,  8.72it/s]

Batch 1750/35345:
total_loss: 0.33546528220176697


Epoch 2/5:   6%|▌         | 2000/35345 [03:50<1:03:55,  8.69it/s]

Batch 2000/35345:
total_loss: 0.42650488018989563


Epoch 2/5:   6%|▋         | 2251/35345 [04:19<1:03:31,  8.68it/s]

Batch 2250/35345:
total_loss: 0.30946221947669983


Epoch 2/5:   7%|▋         | 2500/35345 [04:47<1:02:54,  8.70it/s]

Batch 2500/35345:
total_loss: 0.45018985867500305


Epoch 2/5:   8%|▊         | 2750/35345 [05:16<1:02:31,  8.69it/s]

Batch 2750/35345:
total_loss: 0.3247793912887573


Epoch 2/5:   8%|▊         | 3000/35345 [05:45<1:02:11,  8.67it/s]

Batch 3000/35345:
total_loss: 0.29222169518470764


Epoch 2/5:   9%|▉         | 3250/35345 [06:14<1:01:41,  8.67it/s]

Batch 3250/35345:
total_loss: 0.3140348494052887


Epoch 2/5:  10%|▉         | 3500/35345 [06:43<1:01:38,  8.61it/s]

Batch 3500/35345:
total_loss: 0.21801084280014038


Epoch 2/5:  11%|█         | 3750/35345 [07:12<1:00:55,  8.64it/s]

Batch 3750/35345:
total_loss: 0.3049382269382477


Epoch 2/5:  11%|█▏        | 4001/35345 [07:40<1:00:41,  8.61it/s]

Batch 4000/35345:
total_loss: 0.38083234429359436


Epoch 2/5:  12%|█▏        | 4250/35345 [08:09<59:48,  8.67it/s]

Batch 4250/35345:
total_loss: 0.3700900673866272


Epoch 2/5:  13%|█▎        | 4500/35345 [08:38<59:18,  8.67it/s]

Batch 4500/35345:
total_loss: 0.2375403493642807


Epoch 2/5:  13%|█▎        | 4750/35345 [09:07<59:11,  8.61it/s]

Batch 4750/35345:
total_loss: 0.25102558732032776


Epoch 2/5:  14%|█▍        | 5000/35345 [09:35<57:53,  8.73it/s]

Batch 5000/35345:
total_loss: 0.3466677665710449


Epoch 2/5:  15%|█▍        | 5250/35345 [10:04<57:59,  8.65it/s]

Batch 5250/35345:
total_loss: 0.2521390914916992


Epoch 2/5:  16%|█▌        | 5500/35345 [10:33<56:59,  8.73it/s]

Batch 5500/35345:
total_loss: 0.3789118826389313


Epoch 2/5:  16%|█▋        | 5750/35345 [11:02<57:38,  8.56it/s]

Batch 5750/35345:
total_loss: 0.2533651292324066


Epoch 2/5:  17%|█▋        | 6000/35345 [11:31<55:58,  8.74it/s]

Batch 6000/35345:
total_loss: 0.23821447789669037


Epoch 2/5:  18%|█▊        | 6250/35345 [12:00<56:05,  8.65it/s]

Batch 6250/35345:
total_loss: 0.18325044214725494


Epoch 2/5:  18%|█▊        | 6500/35345 [12:29<55:01,  8.74it/s]

Batch 6500/35345:
total_loss: 0.27543485164642334


Epoch 2/5:  19%|█▉        | 6750/35345 [12:57<54:09,  8.80it/s]

Batch 6750/35345:
total_loss: 0.3708803653717041


Epoch 2/5:  20%|█▉        | 7000/35345 [13:26<54:29,  8.67it/s]

Batch 7000/35345:
total_loss: 0.42300015687942505


Epoch 2/5:  21%|██        | 7251/35345 [13:55<56:00,  8.36it/s]

Batch 7250/35345:
total_loss: 0.29140686988830566


Epoch 2/5:  21%|██        | 7500/35345 [14:23<53:27,  8.68it/s]

Batch 7500/35345:
total_loss: 0.44690990447998047


Epoch 2/5:  22%|██▏       | 7750/35345 [14:52<53:00,  8.68it/s]

Batch 7750/35345:
total_loss: 0.33931878209114075


Epoch 2/5:  23%|██▎       | 8000/35345 [15:21<52:16,  8.72it/s]

Batch 8000/35345:
total_loss: 0.23095497488975525


Epoch 2/5:  23%|██▎       | 8250/35345 [15:50<52:32,  8.59it/s]

Batch 8250/35345:
total_loss: 0.37828007340431213


Epoch 2/5:  24%|██▍       | 8500/35345 [16:19<52:13,  8.57it/s]

Batch 8500/35345:
total_loss: 0.16734758019447327


Epoch 2/5:  25%|██▍       | 8750/35345 [16:48<51:16,  8.64it/s]

Batch 8750/35345:
total_loss: 0.38540172576904297


Epoch 2/5:  25%|██▌       | 9000/35345 [17:17<50:26,  8.70it/s]

Batch 9000/35345:
total_loss: 0.3518279790878296


Epoch 2/5:  26%|██▌       | 9250/35345 [17:46<49:55,  8.71it/s]

Batch 9250/35345:
total_loss: 0.28624722361564636


Epoch 2/5:  27%|██▋       | 9501/35345 [18:15<50:29,  8.53it/s]

Batch 9500/35345:
total_loss: 0.4173251688480377


Epoch 2/5:  28%|██▊       | 9751/35345 [18:44<48:52,  8.73it/s]

Batch 9750/35345:
total_loss: 0.27635321021080017


Epoch 2/5:  28%|██▊       | 10000/35345 [19:13<49:00,  8.62it/s]

Batch 10000/35345:
total_loss: 0.2806066572666168


Epoch 2/5:  29%|██▉       | 10250/35345 [19:42<48:00,  8.71it/s]

Batch 10250/35345:
total_loss: 0.3959089517593384


Epoch 2/5:  30%|██▉       | 10500/35345 [20:11<47:31,  8.71it/s]

Batch 10500/35345:
total_loss: 0.4361545145511627


Epoch 2/5:  30%|███       | 10750/35345 [20:40<47:02,  8.71it/s]

Batch 10750/35345:
total_loss: 0.3741002380847931


Epoch 2/5:  31%|███       | 11001/35345 [21:09<46:24,  8.74it/s]

Batch 11000/35345:
total_loss: 0.32708871364593506


Epoch 2/5:  32%|███▏      | 11250/35345 [21:37<46:07,  8.71it/s]

Batch 11250/35345:
total_loss: 0.33949533104896545


Epoch 2/5:  33%|███▎      | 11500/35345 [22:06<45:36,  8.71it/s]

Batch 11500/35345:
total_loss: 0.3400016129016876


Epoch 2/5:  33%|███▎      | 11750/35345 [22:35<45:17,  8.68it/s]

Batch 11750/35345:
total_loss: 0.24951843917369843


Epoch 2/5:  34%|███▍      | 12000/35345 [23:04<44:46,  8.69it/s]

Batch 12000/35345:
total_loss: 0.28425195813179016


Epoch 2/5:  35%|███▍      | 12250/35345 [23:33<44:04,  8.73it/s]

Batch 12250/35345:
total_loss: 0.16851717233657837


Epoch 2/5:  35%|███▌      | 12500/35345 [24:02<44:39,  8.52it/s]

Batch 12500/35345:
total_loss: 0.25577235221862793


Epoch 2/5:  36%|███▌      | 12750/35345 [24:31<43:20,  8.69it/s]

Batch 12750/35345:
total_loss: 0.29987531900405884


Epoch 2/5:  37%|███▋      | 13000/35345 [24:59<43:02,  8.65it/s]

Batch 13000/35345:
total_loss: 0.42988401651382446


Epoch 2/5:  37%|███▋      | 13250/35345 [25:29<42:21,  8.69it/s]

Batch 13250/35345:
total_loss: 0.22770178318023682


Epoch 2/5:  38%|███▊      | 13500/35345 [25:58<41:23,  8.80it/s]

Batch 13500/35345:
total_loss: 0.2286805361509323


Epoch 2/5:  39%|███▉      | 13750/35345 [26:27<41:43,  8.63it/s]

Batch 13750/35345:
total_loss: 0.1984795480966568


Epoch 2/5:  40%|███▉      | 14000/35345 [26:55<40:55,  8.69it/s]

Batch 14000/35345:
total_loss: 0.44798287749290466


Epoch 2/5:  40%|████      | 14250/35345 [27:24<40:25,  8.70it/s]

Batch 14250/35345:
total_loss: 0.3591800928115845


Epoch 2/5:  41%|████      | 14501/35345 [27:53<40:32,  8.57it/s]

Batch 14500/35345:
total_loss: 0.3408309817314148


Epoch 2/5:  42%|████▏     | 14750/35345 [28:22<39:31,  8.68it/s]

Batch 14750/35345:
total_loss: 0.24907566606998444


Epoch 2/5:  42%|████▏     | 15000/35345 [28:51<39:03,  8.68it/s]

Batch 15000/35345:
total_loss: 0.3100699186325073


Epoch 2/5:  43%|████▎     | 15251/35345 [29:20<38:20,  8.74it/s]

Batch 15250/35345:
total_loss: 0.292671263217926


Epoch 2/5:  44%|████▍     | 15500/35345 [29:49<37:43,  8.77it/s]

Batch 15500/35345:
total_loss: 0.2441597431898117


Epoch 2/5:  45%|████▍     | 15750/35345 [30:17<37:44,  8.65it/s]

Batch 15750/35345:
total_loss: 0.22674044966697693


Epoch 2/5:  45%|████▌     | 16000/35345 [30:46<37:22,  8.63it/s]

Batch 16000/35345:
total_loss: 0.2758038341999054


Epoch 2/5:  46%|████▌     | 16250/35345 [31:15<36:22,  8.75it/s]

Batch 16250/35345:
total_loss: 0.31284019351005554


Epoch 2/5:  47%|████▋     | 16500/35345 [31:44<36:18,  8.65it/s]

Batch 16500/35345:
total_loss: 0.19202759861946106


Epoch 2/5:  47%|████▋     | 16750/35345 [32:13<36:01,  8.60it/s]

Batch 16750/35345:
total_loss: 0.12393329292535782


Epoch 2/5:  48%|████▊     | 17000/35345 [32:41<34:55,  8.76it/s]

Batch 17000/35345:
total_loss: 0.3764950931072235


Epoch 2/5:  49%|████▉     | 17250/35345 [33:10<34:19,  8.79it/s]

Batch 17250/35345:
total_loss: 0.19388440251350403


Epoch 2/5:  50%|████▉     | 17500/35345 [33:38<34:08,  8.71it/s]

Batch 17500/35345:
total_loss: 0.16998682916164398


Epoch 2/5:  50%|█████     | 17751/35345 [34:07<34:05,  8.60it/s]

Batch 17750/35345:
total_loss: 0.1621275246143341


Epoch 2/5:  51%|█████     | 18000/35345 [34:36<33:12,  8.71it/s]

Batch 18000/35345:
total_loss: 0.31658563017845154


Epoch 2/5:  52%|█████▏    | 18250/35345 [35:04<32:37,  8.73it/s]

Batch 18250/35345:
total_loss: 0.31501510739326477


Epoch 2/5:  52%|█████▏    | 18500/35345 [35:33<32:24,  8.66it/s]

Batch 18500/35345:
total_loss: 0.14337126910686493


Epoch 2/5:  53%|█████▎    | 18750/35345 [36:02<31:34,  8.76it/s]

Batch 18750/35345:
total_loss: 0.19648151099681854


Epoch 2/5:  54%|█████▍    | 19001/35345 [36:31<31:54,  8.54it/s]

Batch 19000/35345:
total_loss: 0.30688267946243286


Epoch 2/5:  54%|█████▍    | 19250/35345 [36:59<30:26,  8.81it/s]

Batch 19250/35345:
total_loss: 0.28059151768684387


Epoch 2/5:  55%|█████▌    | 19500/35345 [37:28<30:09,  8.76it/s]

Batch 19500/35345:
total_loss: 0.3817252516746521


Epoch 2/5:  56%|█████▌    | 19750/35345 [37:57<29:47,  8.72it/s]

Batch 19750/35345:
total_loss: 0.24235688149929047


Epoch 2/5:  57%|█████▋    | 20000/35345 [38:26<29:24,  8.70it/s]

Batch 20000/35345:
total_loss: 0.23222528398036957


Epoch 2/5:  57%|█████▋    | 20250/35345 [38:54<28:46,  8.74it/s]

Batch 20250/35345:
total_loss: 0.21184778213500977


Epoch 2/5:  58%|█████▊    | 20500/35345 [39:23<28:35,  8.65it/s]

Batch 20500/35345:
total_loss: 0.4276920258998871


Epoch 2/5:  59%|█████▊    | 20750/35345 [39:52<27:53,  8.72it/s]

Batch 20750/35345:
total_loss: 0.366703599691391


Epoch 2/5:  59%|█████▉    | 21000/35345 [40:21<27:36,  8.66it/s]

Batch 21000/35345:
total_loss: 0.20559512078762054


Epoch 2/5:  60%|██████    | 21250/35345 [40:49<27:10,  8.65it/s]

Batch 21250/35345:
total_loss: 0.3018095791339874


Epoch 2/5:  61%|██████    | 21500/35345 [41:18<26:32,  8.69it/s]

Batch 21500/35345:
total_loss: 0.2764597237110138


Epoch 2/5:  62%|██████▏   | 21750/35345 [41:47<25:50,  8.77it/s]

Batch 21750/35345:
total_loss: 0.12204007804393768


Epoch 2/5:  62%|██████▏   | 22000/35345 [42:16<25:41,  8.66it/s]

Batch 22000/35345:
total_loss: 0.31182947754859924


Epoch 2/5:  63%|██████▎   | 22250/35345 [42:45<25:05,  8.70it/s]

Batch 22250/35345:
total_loss: 0.256146639585495


Epoch 2/5:  64%|██████▎   | 22501/35345 [43:14<25:01,  8.56it/s]

Batch 22500/35345:
total_loss: 0.28176575899124146


Epoch 2/5:  64%|██████▍   | 22750/35345 [43:42<24:13,  8.66it/s]

Batch 22750/35345:
total_loss: 0.18342077732086182


Epoch 2/5:  65%|██████▌   | 23000/35345 [44:11<23:45,  8.66it/s]

Batch 23000/35345:
total_loss: 0.34504789113998413


Epoch 2/5:  66%|██████▌   | 23250/35345 [44:40<23:21,  8.63it/s]

Batch 23250/35345:
total_loss: 0.16455073654651642


Epoch 2/5:  66%|██████▋   | 23500/35345 [45:09<23:28,  8.41it/s]

Batch 23500/35345:
total_loss: 0.2864595949649811


Epoch 2/5:  67%|██████▋   | 23750/35345 [45:38<22:29,  8.59it/s]

Batch 23750/35345:
total_loss: 0.3070562779903412


Epoch 2/5:  68%|██████▊   | 24000/35345 [46:07<21:55,  8.62it/s]

Batch 24000/35345:
total_loss: 0.20880770683288574


Epoch 2/5:  69%|██████▊   | 24250/35345 [46:36<21:25,  8.63it/s]

Batch 24250/35345:
total_loss: 0.1774384081363678


Epoch 2/5:  69%|██████▉   | 24500/35345 [47:05<20:39,  8.75it/s]

Batch 24500/35345:
total_loss: 0.2992947995662689


Epoch 2/5:  70%|███████   | 24750/35345 [47:34<20:16,  8.71it/s]

Batch 24750/35345:
total_loss: 0.26378756761550903


Epoch 2/5:  71%|███████   | 25000/35345 [48:02<19:51,  8.68it/s]

Batch 25000/35345:
total_loss: 0.2863904535770416


Epoch 2/5:  71%|███████▏  | 25250/35345 [48:31<19:22,  8.68it/s]

Batch 25250/35345:
total_loss: 0.28041312098503113


Epoch 2/5:  72%|███████▏  | 25500/35345 [49:00<19:08,  8.58it/s]

Batch 25500/35345:
total_loss: 0.31292298436164856


Epoch 2/5:  73%|███████▎  | 25750/35345 [49:29<18:17,  8.74it/s]

Batch 25750/35345:
total_loss: 0.2696428894996643


Epoch 2/5:  74%|███████▎  | 26001/35345 [49:57<18:00,  8.65it/s]

Batch 26000/35345:
total_loss: 0.26757973432540894


Epoch 2/5:  74%|███████▍  | 26250/35345 [50:26<17:17,  8.76it/s]

Batch 26250/35345:
total_loss: 0.3374488055706024


Epoch 2/5:  75%|███████▍  | 26500/35345 [50:55<17:02,  8.65it/s]

Batch 26500/35345:
total_loss: 0.2461402267217636


Epoch 2/5:  76%|███████▌  | 26750/35345 [51:24<16:40,  8.59it/s]

Batch 26750/35345:
total_loss: 0.1130780503153801


Epoch 2/5:  76%|███████▋  | 27000/35345 [51:52<16:01,  8.68it/s]

Batch 27000/35345:
total_loss: 0.21125978231430054


Epoch 2/5:  77%|███████▋  | 27250/35345 [52:21<15:29,  8.71it/s]

Batch 27250/35345:
total_loss: 0.11760004609823227


Epoch 2/5:  78%|███████▊  | 27500/35345 [52:50<15:00,  8.71it/s]

Batch 27500/35345:
total_loss: 0.32400059700012207


Epoch 2/5:  79%|███████▊  | 27750/35345 [53:19<14:45,  8.57it/s]

Batch 27750/35345:
total_loss: 0.1788526028394699


Epoch 2/5:  79%|███████▉  | 28000/35345 [53:48<14:07,  8.66it/s]

Batch 28000/35345:
total_loss: 0.24773308634757996


Epoch 2/5:  80%|███████▉  | 28250/35345 [54:17<13:44,  8.61it/s]

Batch 28250/35345:
total_loss: 0.17946946620941162


Epoch 2/5:  81%|████████  | 28500/35345 [54:46<13:17,  8.58it/s]

Batch 28500/35345:
total_loss: 0.33225828409194946


Epoch 2/5:  81%|████████▏ | 28750/35345 [55:15<12:52,  8.54it/s]

Batch 28750/35345:
total_loss: 0.1768617480993271


Epoch 2/5:  82%|████████▏ | 29000/35345 [55:44<12:08,  8.71it/s]

Batch 29000/35345:
total_loss: 0.20206166803836823


Epoch 2/5:  83%|████████▎ | 29250/35345 [56:13<11:49,  8.59it/s]

Batch 29250/35345:
total_loss: 0.33203190565109253


Epoch 2/5:  83%|████████▎ | 29500/35345 [56:42<11:18,  8.62it/s]

Batch 29500/35345:
total_loss: 0.2809961438179016


Epoch 2/5:  84%|████████▍ | 29750/35345 [57:11<10:40,  8.74it/s]

Batch 29750/35345:
total_loss: 0.2272346019744873


Epoch 2/5:  85%|████████▍ | 30000/35345 [57:40<10:14,  8.70it/s]

Batch 30000/35345:
total_loss: 0.22899332642555237


Epoch 2/5:  86%|████████▌ | 30250/35345 [58:09<09:54,  8.56it/s]

Batch 30250/35345:
total_loss: 0.34236347675323486


Epoch 2/5:  86%|████████▋ | 30500/35345 [58:37<09:17,  8.69it/s]

Batch 30500/35345:
total_loss: 0.15375977754592896


Epoch 2/5:  87%|████████▋ | 30751/35345 [59:06<08:50,  8.67it/s]

Batch 30750/35345:
total_loss: 0.19616921246051788


Epoch 2/5:  88%|████████▊ | 31000/35345 [59:35<08:21,  8.66it/s]

Batch 31000/35345:
total_loss: 0.3293546140193939


Epoch 2/5:  88%|████████▊ | 31250/35345 [1:00:04<07:54,  8.62it/s]

Batch 31250/35345:
total_loss: 0.2964771091938019


Epoch 2/5:  89%|████████▉ | 31500/35345 [1:00:33<07:22,  8.69it/s]

Batch 31500/35345:
total_loss: 0.32833993434906006


Epoch 2/5:  90%|████████▉ | 31750/35345 [1:01:02<06:55,  8.66it/s]

Batch 31750/35345:
total_loss: 0.1926048994064331


Epoch 2/5:  91%|█████████ | 32001/35345 [1:01:31<06:28,  8.60it/s]

Batch 32000/35345:
total_loss: 0.19776269793510437


Epoch 2/5:  91%|█████████ | 32250/35345 [1:01:59<05:58,  8.64it/s]

Batch 32250/35345:
total_loss: 0.22972962260246277


Epoch 2/5:  92%|█████████▏| 32500/35345 [1:02:28<05:31,  8.58it/s]

Batch 32500/35345:
total_loss: 0.3466702103614807


Epoch 2/5:  93%|█████████▎| 32750/35345 [1:02:57<04:57,  8.72it/s]

Batch 32750/35345:
total_loss: 0.12242483347654343


Epoch 2/5:  93%|█████████▎| 33000/35345 [1:03:26<04:31,  8.64it/s]

Batch 33000/35345:
total_loss: 0.23866987228393555


Epoch 2/5:  94%|█████████▍| 33250/35345 [1:03:55<04:00,  8.72it/s]

Batch 33250/35345:
total_loss: 0.15671730041503906


Epoch 2/5:  95%|█████████▍| 33500/35345 [1:04:23<03:32,  8.69it/s]

Batch 33500/35345:
total_loss: 0.2871347963809967


Epoch 2/5:  95%|█████████▌| 33751/35345 [1:04:52<03:04,  8.65it/s]

Batch 33750/35345:
total_loss: 0.40761473774909973


Epoch 2/5:  96%|█████████▌| 34000/35345 [1:05:21<02:33,  8.75it/s]

Batch 34000/35345:
total_loss: 0.2504594027996063


Epoch 2/5:  97%|█████████▋| 34250/35345 [1:05:50<02:06,  8.65it/s]

Batch 34250/35345:
total_loss: 0.1743326336145401


Epoch 2/5:  98%|█████████▊| 34500/35345 [1:06:19<01:37,  8.70it/s]

Batch 34500/35345:
total_loss: 0.3197016716003418


Epoch 2/5:  98%|█████████▊| 34750/35345 [1:06:48<01:07,  8.78it/s]

Batch 34750/35345:
total_loss: 0.18793100118637085


Epoch 2/5:  99%|█████████▉| 35000/35345 [1:07:17<00:39,  8.72it/s]

Batch 35000/35345:
total_loss: 0.09664501994848251


Epoch 2/5: 100%|█████████▉| 35250/35345 [1:07:45<00:10,  8.67it/s]

Batch 35250/35345:
total_loss: 0.07575275003910065


Epoch 2/5: 100%|██████████| 35345/35345 [1:07:57<00:00,  8.67it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.26it/s]


Epoch Summary:
Train Total Loss: 0.2795
Val Total Loss: 0.1776
Learning Rate: 0.000100
New best model saved with val loss: 0.1776
--------------------------------------------------



Epoch 3/5:   1%|          | 250/35345 [00:29<1:08:04,  8.59it/s]

Batch 250/35345:
total_loss: 0.2724069058895111


Epoch 3/5:   1%|▏         | 500/35345 [00:58<1:06:56,  8.67it/s]

Batch 500/35345:
total_loss: 0.20035706460475922


Epoch 3/5:   2%|▏         | 751/35345 [01:27<1:06:04,  8.73it/s]

Batch 750/35345:
total_loss: 0.2426607608795166


Epoch 3/5:   3%|▎         | 1000/35345 [01:56<1:06:39,  8.59it/s]

Batch 1000/35345:
total_loss: 0.13474544882774353


Epoch 3/5:   4%|▎         | 1251/35345 [02:25<1:07:15,  8.45it/s]

Batch 1250/35345:
total_loss: 0.20791147649288177


Epoch 3/5:   4%|▍         | 1500/35345 [02:54<1:05:26,  8.62it/s]

Batch 1500/35345:
total_loss: 0.17109546065330505


Epoch 3/5:   5%|▍         | 1750/35345 [03:23<1:04:45,  8.65it/s]

Batch 1750/35345:
total_loss: 0.20792201161384583


Epoch 3/5:   6%|▌         | 2000/35345 [03:52<1:04:13,  8.65it/s]

Batch 2000/35345:
total_loss: 0.16744066774845123


Epoch 3/5:   6%|▋         | 2250/35345 [04:21<1:03:42,  8.66it/s]

Batch 2250/35345:
total_loss: 0.2539869248867035


Epoch 3/5:   7%|▋         | 2501/35345 [04:50<1:03:40,  8.60it/s]

Batch 2500/35345:
total_loss: 0.1636509895324707


Epoch 3/5:   8%|▊         | 2750/35345 [05:19<1:02:40,  8.67it/s]

Batch 2750/35345:
total_loss: 0.23335634171962738


Epoch 3/5:   8%|▊         | 3000/35345 [05:48<1:02:05,  8.68it/s]

Batch 3000/35345:
total_loss: 0.19146308302879333


Epoch 3/5:   9%|▉         | 3250/35345 [06:17<1:02:33,  8.55it/s]

Batch 3250/35345:
total_loss: 0.11450322717428207


Epoch 3/5:  10%|▉         | 3500/35345 [06:46<1:02:19,  8.52it/s]

Batch 3500/35345:
total_loss: 0.12484023720026016


Epoch 3/5:  11%|█         | 3750/35345 [07:15<1:05:06,  8.09it/s]

Batch 3750/35345:
total_loss: 0.20723651349544525


Epoch 3/5:  11%|█▏        | 4000/35345 [07:44<1:00:55,  8.57it/s]

Batch 4000/35345:
total_loss: 0.15536373853683472


Epoch 3/5:  12%|█▏        | 4250/35345 [08:13<59:40,  8.69it/s]

Batch 4250/35345:
total_loss: 0.31899386644363403


Epoch 3/5:  13%|█▎        | 4500/35345 [08:43<59:48,  8.60it/s]

Batch 4500/35345:
total_loss: 0.14090940356254578


Epoch 3/5:  13%|█▎        | 4750/35345 [09:12<59:01,  8.64it/s]

Batch 4750/35345:
total_loss: 0.2518308162689209


Epoch 3/5:  14%|█▍        | 5000/35345 [09:40<58:12,  8.69it/s]

Batch 5000/35345:
total_loss: 0.18200179934501648


Epoch 3/5:  15%|█▍        | 5250/35345 [10:09<57:46,  8.68it/s]

Batch 5250/35345:
total_loss: 0.14905431866645813


Epoch 3/5:  16%|█▌        | 5500/35345 [10:38<58:05,  8.56it/s]

Batch 5500/35345:
total_loss: 0.17456716299057007


Epoch 3/5:  16%|█▋        | 5750/35345 [11:07<56:40,  8.70it/s]

Batch 5750/35345:
total_loss: 0.2492297738790512


Epoch 3/5:  17%|█▋        | 6000/35345 [11:36<56:05,  8.72it/s]

Batch 6000/35345:
total_loss: 0.2873014807701111


Epoch 3/5:  18%|█▊        | 6250/35345 [12:04<55:30,  8.74it/s]

Batch 6250/35345:
total_loss: 0.2528197467327118


Epoch 3/5:  18%|█▊        | 6500/35345 [12:33<55:24,  8.68it/s]

Batch 6500/35345:
total_loss: 0.2534494698047638


Epoch 3/5:  19%|█▉        | 6750/35345 [13:02<54:46,  8.70it/s]

Batch 6750/35345:
total_loss: 0.27452990412712097


Epoch 3/5:  20%|█▉        | 7000/35345 [13:31<54:09,  8.72it/s]

Batch 7000/35345:
total_loss: 0.23455065488815308


Epoch 3/5:  21%|██        | 7250/35345 [14:00<53:47,  8.71it/s]

Batch 7250/35345:
total_loss: 0.18761655688285828


Epoch 3/5:  21%|██        | 7501/35345 [14:29<53:04,  8.74it/s]

Batch 7500/35345:
total_loss: 0.26819300651550293


Epoch 3/5:  22%|██▏       | 7750/35345 [14:57<53:14,  8.64it/s]

Batch 7750/35345:
total_loss: 0.2721019387245178


Epoch 3/5:  23%|██▎       | 8000/35345 [15:26<52:26,  8.69it/s]

Batch 8000/35345:
total_loss: 0.14419396221637726


Epoch 3/5:  23%|██▎       | 8250/35345 [15:55<52:12,  8.65it/s]

Batch 8250/35345:
total_loss: 0.2245546579360962


Epoch 3/5:  24%|██▍       | 8500/35345 [16:24<1:22:38,  5.41it/s]

Batch 8500/35345:
total_loss: 0.19565416872501373


Epoch 3/5:  25%|██▍       | 8750/35345 [16:53<51:16,  8.64it/s]

Batch 8750/35345:
total_loss: 0.23072269558906555


Epoch 3/5:  25%|██▌       | 9000/35345 [17:22<51:09,  8.58it/s]

Batch 9000/35345:
total_loss: 0.31577688455581665


Epoch 3/5:  26%|██▌       | 9250/35345 [17:51<50:35,  8.60it/s]

Batch 9250/35345:
total_loss: 0.08221776783466339


Epoch 3/5:  27%|██▋       | 9500/35345 [18:19<50:06,  8.60it/s]

Batch 9500/35345:
total_loss: 0.29410219192504883


Epoch 3/5:  28%|██▊       | 9750/35345 [18:48<48:57,  8.71it/s]

Batch 9750/35345:
total_loss: 0.3086695075035095


Epoch 3/5:  28%|██▊       | 10000/35345 [19:17<48:47,  8.66it/s]

Batch 10000/35345:
total_loss: 0.14174197614192963


Epoch 3/5:  29%|██▉       | 10250/35345 [19:46<48:01,  8.71it/s]

Batch 10250/35345:
total_loss: 0.17192897200584412


Epoch 3/5:  30%|██▉       | 10500/35345 [20:15<47:29,  8.72it/s]

Batch 10500/35345:
total_loss: 0.23929192125797272


Epoch 3/5:  30%|███       | 10750/35345 [20:44<47:33,  8.62it/s]

Batch 10750/35345:
total_loss: 0.13835503160953522


Epoch 3/5:  31%|███       | 11000/35345 [21:13<47:11,  8.60it/s]

Batch 11000/35345:
total_loss: 0.1896415501832962


Epoch 3/5:  32%|███▏      | 11250/35345 [21:42<46:52,  8.57it/s]

Batch 11250/35345:
total_loss: 0.14875838160514832


Epoch 3/5:  33%|███▎      | 11500/35345 [22:11<46:28,  8.55it/s]

Batch 11500/35345:
total_loss: 0.10533978790044785


Epoch 3/5:  33%|███▎      | 11750/35345 [22:40<45:22,  8.67it/s]

Batch 11750/35345:
total_loss: 0.13926377892494202


Epoch 3/5:  34%|███▍      | 12000/35345 [23:09<44:52,  8.67it/s]

Batch 12000/35345:
total_loss: 0.09664889425039291


Epoch 3/5:  35%|███▍      | 12251/35345 [23:38<44:06,  8.73it/s]

Batch 12250/35345:
total_loss: 0.1392734795808792


Epoch 3/5:  35%|███▌      | 12500/35345 [24:07<44:06,  8.63it/s]

Batch 12500/35345:
total_loss: 0.21124474704265594


Epoch 3/5:  36%|███▌      | 12750/35345 [24:37<43:46,  8.60it/s]

Batch 12750/35345:
total_loss: 0.17196224629878998


Epoch 3/5:  37%|███▋      | 13000/35345 [25:06<43:53,  8.49it/s]

Batch 13000/35345:
total_loss: 0.2602313160896301


Epoch 3/5:  37%|███▋      | 13250/35345 [25:35<42:19,  8.70it/s]

Batch 13250/35345:
total_loss: 0.3799281120300293


Epoch 3/5:  38%|███▊      | 13500/35345 [26:04<42:39,  8.53it/s]

Batch 13500/35345:
total_loss: 0.1439221352338791


Epoch 3/5:  39%|███▉      | 13750/35345 [26:33<41:46,  8.61it/s]

Batch 13750/35345:
total_loss: 0.17331989109516144


Epoch 3/5:  40%|███▉      | 14000/35345 [27:02<40:55,  8.69it/s]

Batch 14000/35345:
total_loss: 0.28958097100257874


Epoch 3/5:  40%|████      | 14250/35345 [27:31<40:48,  8.61it/s]

Batch 14250/35345:
total_loss: 0.1985456347465515


Epoch 3/5:  41%|████      | 14501/35345 [28:00<40:02,  8.67it/s]

Batch 14500/35345:
total_loss: 0.14405958354473114


Epoch 3/5:  42%|████▏     | 14750/35345 [28:29<39:50,  8.62it/s]

Batch 14750/35345:
total_loss: 0.23470638692378998


Epoch 3/5:  42%|████▏     | 15000/35345 [28:58<39:00,  8.69it/s]

Batch 15000/35345:
total_loss: 0.25804415345191956


Epoch 3/5:  43%|████▎     | 15250/35345 [29:27<38:31,  8.69it/s]

Batch 15250/35345:
total_loss: 0.42011716961860657


Epoch 3/5:  44%|████▍     | 15500/35345 [29:56<38:21,  8.62it/s]

Batch 15500/35345:
total_loss: 0.2956583797931671


Epoch 3/5:  45%|████▍     | 15750/35345 [30:25<37:49,  8.63it/s]

Batch 15750/35345:
total_loss: 0.23042485117912292


Epoch 3/5:  45%|████▌     | 16000/35345 [30:53<36:54,  8.73it/s]

Batch 16000/35345:
total_loss: 0.0882086306810379


Epoch 3/5:  46%|████▌     | 16250/35345 [31:22<36:19,  8.76it/s]

Batch 16250/35345:
total_loss: 0.1381002515554428


Epoch 3/5:  47%|████▋     | 16500/35345 [31:51<36:00,  8.72it/s]

Batch 16500/35345:
total_loss: 0.194815531373024


Epoch 3/5:  47%|████▋     | 16750/35345 [32:20<35:30,  8.73it/s]

Batch 16750/35345:
total_loss: 0.13412156701087952


Epoch 3/5:  48%|████▊     | 17000/35345 [32:48<34:59,  8.74it/s]

Batch 17000/35345:
total_loss: 0.17637382447719574


Epoch 3/5:  49%|████▉     | 17250/35345 [33:17<34:46,  8.67it/s]

Batch 17250/35345:
total_loss: 0.14361561834812164


Epoch 3/5:  50%|████▉     | 17500/35345 [33:46<34:03,  8.73it/s]

Batch 17500/35345:
total_loss: 0.2122785747051239


Epoch 3/5:  50%|█████     | 17750/35345 [34:14<33:17,  8.81it/s]

Batch 17750/35345:
total_loss: 0.14327460527420044


Epoch 3/5:  51%|█████     | 18000/35345 [34:43<33:03,  8.74it/s]

Batch 18000/35345:
total_loss: 0.10461012274026871


Epoch 3/5:  52%|█████▏    | 18250/35345 [35:12<33:24,  8.53it/s]

Batch 18250/35345:
total_loss: 0.15304654836654663


Epoch 3/5:  52%|█████▏    | 18500/35345 [35:41<32:24,  8.66it/s]

Batch 18500/35345:
total_loss: 0.1921895295381546


Epoch 3/5:  53%|█████▎    | 18751/35345 [36:10<31:26,  8.79it/s]

Batch 18750/35345:
total_loss: 0.1716439425945282


Epoch 3/5:  54%|█████▍    | 19000/35345 [36:38<31:17,  8.71it/s]

Batch 19000/35345:
total_loss: 0.19588962197303772


Epoch 3/5:  54%|█████▍    | 19250/35345 [37:07<30:59,  8.66it/s]

Batch 19250/35345:
total_loss: 0.10338188707828522


Epoch 3/5:  55%|█████▌    | 19500/35345 [37:36<30:21,  8.70it/s]

Batch 19500/35345:
total_loss: 0.22120317816734314


Epoch 3/5:  56%|█████▌    | 19750/35345 [38:05<29:48,  8.72it/s]

Batch 19750/35345:
total_loss: 0.14864259958267212


Epoch 3/5:  57%|█████▋    | 20001/35345 [38:34<29:33,  8.65it/s]

Batch 20000/35345:
total_loss: 0.119456447660923


Epoch 3/5:  57%|█████▋    | 20250/35345 [39:02<28:37,  8.79it/s]

Batch 20250/35345:
total_loss: 0.2619079351425171


Epoch 3/5:  58%|█████▊    | 20501/35345 [39:31<28:49,  8.58it/s]

Batch 20500/35345:
total_loss: 0.18592938780784607


Epoch 3/5:  59%|█████▊    | 20750/35345 [40:00<28:01,  8.68it/s]

Batch 20750/35345:
total_loss: 0.21316379308700562


Epoch 3/5:  59%|█████▉    | 21000/35345 [40:29<27:28,  8.70it/s]

Batch 21000/35345:
total_loss: 0.1958429515361786


Epoch 3/5:  60%|██████    | 21251/35345 [40:57<26:59,  8.70it/s]

Batch 21250/35345:
total_loss: 0.15123164653778076


Epoch 3/5:  61%|██████    | 21500/35345 [41:26<26:39,  8.66it/s]

Batch 21500/35345:
total_loss: 0.17360809445381165


Epoch 3/5:  62%|██████▏   | 21750/35345 [41:55<26:01,  8.71it/s]

Batch 21750/35345:
total_loss: 0.23791907727718353


Epoch 3/5:  62%|██████▏   | 22000/35345 [42:24<25:29,  8.73it/s]

Batch 22000/35345:
total_loss: 0.24801364541053772


Epoch 3/5:  63%|██████▎   | 22250/35345 [42:52<25:00,  8.72it/s]

Batch 22250/35345:
total_loss: 0.2104354053735733


Epoch 3/5:  64%|██████▎   | 22501/35345 [43:22<24:36,  8.70it/s]

Batch 22500/35345:
total_loss: 0.06506110727787018


Epoch 3/5:  64%|██████▍   | 22750/35345 [43:50<23:50,  8.81it/s]

Batch 22750/35345:
total_loss: 0.17065457999706268


Epoch 3/5:  65%|██████▌   | 23000/35345 [44:19<23:28,  8.76it/s]

Batch 23000/35345:
total_loss: 0.15767326951026917


Epoch 3/5:  66%|██████▌   | 23250/35345 [44:48<23:05,  8.73it/s]

Batch 23250/35345:
total_loss: 0.20855669677257538


Epoch 3/5:  66%|██████▋   | 23500/35345 [45:16<22:55,  8.61it/s]

Batch 23500/35345:
total_loss: 0.14467793703079224


Epoch 3/5:  67%|██████▋   | 23750/35345 [45:45<22:17,  8.67it/s]

Batch 23750/35345:
total_loss: 0.21748405694961548


Epoch 3/5:  68%|██████▊   | 24000/35345 [46:14<21:36,  8.75it/s]

Batch 24000/35345:
total_loss: 0.17406590282917023


Epoch 3/5:  69%|██████▊   | 24250/35345 [46:43<21:24,  8.64it/s]

Batch 24250/35345:
total_loss: 0.16009356081485748


Epoch 3/5:  69%|██████▉   | 24500/35345 [47:11<20:35,  8.78it/s]

Batch 24500/35345:
total_loss: 0.2898392975330353


Epoch 3/5:  70%|███████   | 24750/35345 [47:40<20:06,  8.78it/s]

Batch 24750/35345:
total_loss: 0.05740617960691452


Epoch 3/5:  71%|███████   | 25001/35345 [48:09<19:48,  8.70it/s]

Batch 25000/35345:
total_loss: 0.19250470399856567


Epoch 3/5:  71%|███████▏  | 25250/35345 [48:38<19:22,  8.68it/s]

Batch 25250/35345:
total_loss: 0.22762411832809448


Epoch 3/5:  72%|███████▏  | 25500/35345 [49:06<18:56,  8.66it/s]

Batch 25500/35345:
total_loss: 0.12625837326049805


Epoch 3/5:  73%|███████▎  | 25750/35345 [49:35<18:21,  8.71it/s]

Batch 25750/35345:
total_loss: 0.2312401980161667


Epoch 3/5:  74%|███████▎  | 26000/35345 [50:04<17:51,  8.72it/s]

Batch 26000/35345:
total_loss: 0.1214708462357521


Epoch 3/5:  74%|███████▍  | 26250/35345 [50:33<17:18,  8.76it/s]

Batch 26250/35345:
total_loss: 0.15647917985916138


Epoch 3/5:  75%|███████▍  | 26500/35345 [51:01<16:46,  8.79it/s]

Batch 26500/35345:
total_loss: 0.15844205021858215


Epoch 3/5:  76%|███████▌  | 26750/35345 [51:30<16:46,  8.54it/s]

Batch 26750/35345:
total_loss: 0.2042236626148224


Epoch 3/5:  76%|███████▋  | 27000/35345 [51:59<15:51,  8.77it/s]

Batch 27000/35345:
total_loss: 0.3014007806777954


Epoch 3/5:  77%|███████▋  | 27250/35345 [52:27<15:24,  8.75it/s]

Batch 27250/35345:
total_loss: 0.1910601556301117


Epoch 3/5:  78%|███████▊  | 27501/35345 [52:56<15:06,  8.66it/s]

Batch 27500/35345:
total_loss: 0.20494484901428223


Epoch 3/5:  79%|███████▊  | 27750/35345 [53:25<14:33,  8.69it/s]

Batch 27750/35345:
total_loss: 0.08263251185417175


Epoch 3/5:  79%|███████▉  | 28001/35345 [53:54<14:11,  8.63it/s]

Batch 28000/35345:
total_loss: 0.10602555423974991


Epoch 3/5:  80%|███████▉  | 28250/35345 [54:23<13:41,  8.64it/s]

Batch 28250/35345:
total_loss: 0.3341173827648163


Epoch 3/5:  81%|████████  | 28500/35345 [54:52<13:01,  8.75it/s]

Batch 28500/35345:
total_loss: 0.15596894919872284


Epoch 3/5:  81%|████████▏ | 28750/35345 [55:21<12:35,  8.73it/s]

Batch 28750/35345:
total_loss: 0.10891053825616837


Epoch 3/5:  82%|████████▏ | 29000/35345 [55:50<12:01,  8.80it/s]

Batch 29000/35345:
total_loss: 0.07739847153425217


Epoch 3/5:  83%|████████▎ | 29250/35345 [56:18<11:44,  8.65it/s]

Batch 29250/35345:
total_loss: 0.15945136547088623


Epoch 3/5:  83%|████████▎ | 29500/35345 [56:47<11:14,  8.66it/s]

Batch 29500/35345:
total_loss: 0.26007020473480225


Epoch 3/5:  84%|████████▍ | 29751/35345 [57:16<10:48,  8.62it/s]

Batch 29750/35345:
total_loss: 0.17951512336730957


Epoch 3/5:  85%|████████▍ | 30000/35345 [57:44<10:15,  8.69it/s]

Batch 30000/35345:
total_loss: 0.20112933218479156


Epoch 3/5:  86%|████████▌ | 30250/35345 [58:13<09:44,  8.72it/s]

Batch 30250/35345:
total_loss: 0.1254984587430954


Epoch 3/5:  86%|████████▋ | 30500/35345 [58:42<09:20,  8.64it/s]

Batch 30500/35345:
total_loss: 0.275168776512146


Epoch 3/5:  87%|████████▋ | 30750/35345 [59:11<08:45,  8.74it/s]

Batch 30750/35345:
total_loss: 0.2749919593334198


Epoch 3/5:  88%|████████▊ | 31000/35345 [59:39<08:17,  8.74it/s]

Batch 31000/35345:
total_loss: 0.29024556279182434


Epoch 3/5:  88%|████████▊ | 31250/35345 [1:00:08<07:52,  8.66it/s]

Batch 31250/35345:
total_loss: 0.05405344441533089


Epoch 3/5:  89%|████████▉ | 31500/35345 [1:00:37<07:19,  8.76it/s]

Batch 31500/35345:
total_loss: 0.1608670949935913


Epoch 3/5:  90%|████████▉ | 31751/35345 [1:01:06<06:54,  8.68it/s]

Batch 31750/35345:
total_loss: 0.08079253137111664


Epoch 3/5:  91%|█████████ | 32000/35345 [1:01:35<06:23,  8.72it/s]

Batch 32000/35345:
total_loss: 0.18143662810325623


Epoch 3/5:  91%|█████████ | 32250/35345 [1:02:03<05:54,  8.72it/s]

Batch 32250/35345:
total_loss: 0.184737890958786


Epoch 3/5:  92%|█████████▏| 32500/35345 [1:02:32<05:27,  8.69it/s]

Batch 32500/35345:
total_loss: 0.14454883337020874


Epoch 3/5:  93%|█████████▎| 32750/35345 [1:03:01<04:55,  8.77it/s]

Batch 32750/35345:
total_loss: 0.12073580175638199


Epoch 3/5:  93%|█████████▎| 33001/35345 [1:03:29<04:32,  8.61it/s]

Batch 33000/35345:
total_loss: 0.23163974285125732


Epoch 3/5:  94%|█████████▍| 33250/35345 [1:03:58<03:59,  8.75it/s]

Batch 33250/35345:
total_loss: 0.2577945590019226


Epoch 3/5:  95%|█████████▍| 33500/35345 [1:04:27<03:31,  8.71it/s]

Batch 33500/35345:
total_loss: 0.07115046679973602


Epoch 3/5:  95%|█████████▌| 33750/35345 [1:04:56<03:00,  8.82it/s]

Batch 33750/35345:
total_loss: 0.1853562742471695


Epoch 3/5:  96%|█████████▌| 34000/35345 [1:05:25<02:34,  8.71it/s]

Batch 34000/35345:
total_loss: 0.15503273904323578


Epoch 3/5:  97%|█████████▋| 34251/35345 [1:05:54<02:06,  8.61it/s]

Batch 34250/35345:
total_loss: 0.12015627324581146


Epoch 3/5:  98%|█████████▊| 34500/35345 [1:06:22<01:37,  8.68it/s]

Batch 34500/35345:
total_loss: 0.09447314590215683


Epoch 3/5:  98%|█████████▊| 34750/35345 [1:06:51<01:08,  8.64it/s]

Batch 34750/35345:
total_loss: 0.08066671341657639


Epoch 3/5:  99%|█████████▉| 35000/35345 [1:07:20<00:39,  8.63it/s]

Batch 35000/35345:
total_loss: 0.1358250230550766


Epoch 3/5: 100%|█████████▉| 35250/35345 [1:07:49<00:10,  8.65it/s]

Batch 35250/35345:
total_loss: 0.24225115776062012


Epoch 3/5: 100%|██████████| 35345/35345 [1:08:01<00:00,  8.66it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.45it/s]


Epoch Summary:
Train Total Loss: 0.1854
Val Total Loss: 0.0838
Learning Rate: 0.000100
New best model saved with val loss: 0.0838
--------------------------------------------------



Epoch 4/5:   1%|          | 250/35345 [00:29<1:09:39,  8.40it/s]

Batch 250/35345:
total_loss: 0.18178583681583405


Epoch 4/5:   1%|▏         | 500/35345 [00:58<1:06:32,  8.73it/s]

Batch 500/35345:
total_loss: 0.13306133449077606


Epoch 4/5:   2%|▏         | 750/35345 [01:26<1:05:53,  8.75it/s]

Batch 750/35345:
total_loss: 0.09000534564256668


Epoch 4/5:   3%|▎         | 1000/35345 [01:55<1:05:51,  8.69it/s]

Batch 1000/35345:
total_loss: 0.21920326352119446


Epoch 4/5:   4%|▎         | 1250/35345 [02:24<1:05:50,  8.63it/s]

Batch 1250/35345:
total_loss: 0.12117370218038559


Epoch 4/5:   4%|▍         | 1500/35345 [02:53<1:05:12,  8.65it/s]

Batch 1500/35345:
total_loss: 0.08633583784103394


Epoch 4/5:   5%|▍         | 1750/35345 [03:22<1:05:26,  8.56it/s]

Batch 1750/35345:
total_loss: 0.097286157310009


Epoch 4/5:   6%|▌         | 2000/35345 [03:51<1:05:01,  8.55it/s]

Batch 2000/35345:
total_loss: 0.21319285035133362


Epoch 4/5:   6%|▋         | 2250/35345 [04:20<1:03:37,  8.67it/s]

Batch 2250/35345:
total_loss: 0.12428736686706543


Epoch 4/5:   7%|▋         | 2500/35345 [04:49<1:03:22,  8.64it/s]

Batch 2500/35345:
total_loss: 0.22055375576019287


Epoch 4/5:   8%|▊         | 2750/35345 [05:18<1:02:23,  8.71it/s]

Batch 2750/35345:
total_loss: 0.0948365107178688


Epoch 4/5:   8%|▊         | 2816/35345 [05:25<1:02:19,  8.70it/s]


Equation: C*x5/(C*x3+C)+C


Epoch 4/5:   8%|▊         | 3000/35345 [05:47<1:01:43,  8.73it/s]

Batch 3000/35345:
total_loss: 0.11168920993804932


Epoch 4/5:   9%|▉         | 3250/35345 [06:15<1:01:20,  8.72it/s]

Batch 3250/35345:
total_loss: 0.12781424820423126


Epoch 4/5:  10%|▉         | 3500/35345 [06:44<1:00:41,  8.74it/s]

Batch 3500/35345:
total_loss: 0.16386273503303528


Epoch 4/5:  11%|█         | 3750/35345 [07:13<1:01:24,  8.57it/s]

Batch 3750/35345:
total_loss: 0.09796541184186935


Epoch 4/5:  11%|█▏        | 4000/35345 [07:42<1:00:26,  8.64it/s]

Batch 4000/35345:
total_loss: 0.34496062994003296


Epoch 4/5:  12%|█▏        | 4250/35345 [08:11<1:00:02,  8.63it/s]

Batch 4250/35345:
total_loss: 0.12534192204475403


Epoch 4/5:  13%|█▎        | 4500/35345 [08:40<59:03,  8.70it/s]

Batch 4500/35345:
total_loss: 0.08467650413513184


Epoch 4/5:  13%|█▎        | 4750/35345 [09:09<59:27,  8.58it/s]

Batch 4750/35345:
total_loss: 0.18506339192390442


Epoch 4/5:  14%|█▍        | 5000/35345 [09:38<58:55,  8.58it/s]

Batch 5000/35345:
total_loss: 0.2256341129541397


Epoch 4/5:  15%|█▍        | 5250/35345 [10:07<57:57,  8.65it/s]

Batch 5250/35345:
total_loss: 0.07829352468252182


Epoch 4/5:  16%|█▌        | 5500/35345 [10:36<57:06,  8.71it/s]

Batch 5500/35345:
total_loss: 0.14555636048316956


Epoch 4/5:  16%|█▋        | 5750/35345 [11:05<56:24,  8.75it/s]

Batch 5750/35345:
total_loss: 0.18119126558303833


Epoch 4/5:  17%|█▋        | 6000/35345 [11:34<56:50,  8.60it/s]

Batch 6000/35345:
total_loss: 0.19081002473831177


Epoch 4/5:  18%|█▊        | 6250/35345 [12:03<56:07,  8.64it/s]

Batch 6250/35345:
total_loss: 0.07914742082357407


Epoch 4/5:  18%|█▊        | 6500/35345 [12:32<55:58,  8.59it/s]

Batch 6500/35345:
total_loss: 0.1941278725862503


Epoch 4/5:  19%|█▉        | 6750/35345 [13:01<54:55,  8.68it/s]

Batch 6750/35345:
total_loss: 0.1528046429157257


Epoch 4/5:  20%|█▉        | 7000/35345 [13:30<55:10,  8.56it/s]

Batch 7000/35345:
total_loss: 0.14195707440376282


Epoch 4/5:  21%|██        | 7250/35345 [13:59<53:55,  8.68it/s]

Batch 7250/35345:
total_loss: 0.22386467456817627


Epoch 4/5:  21%|██        | 7501/35345 [14:29<53:54,  8.61it/s]

Batch 7500/35345:
total_loss: 0.11082957684993744


Epoch 4/5:  22%|██▏       | 7751/35345 [14:57<53:02,  8.67it/s]

Batch 7750/35345:
total_loss: 0.20269370079040527


Epoch 4/5:  23%|██▎       | 8000/35345 [15:26<52:55,  8.61it/s]

Batch 8000/35345:
total_loss: 0.16354723274707794


Epoch 4/5:  23%|██▎       | 8250/35345 [15:55<52:11,  8.65it/s]

Batch 8250/35345:
total_loss: 0.20832817256450653


Epoch 4/5:  24%|██▍       | 8500/35345 [16:24<51:36,  8.67it/s]

Batch 8500/35345:
total_loss: 0.1628185361623764


Epoch 4/5:  25%|██▍       | 8750/35345 [16:53<50:57,  8.70it/s]

Batch 8750/35345:
total_loss: 0.09043203294277191


Epoch 4/5:  25%|██▌       | 9000/35345 [17:22<51:11,  8.58it/s]

Batch 9000/35345:
total_loss: 0.20139802992343903


Epoch 4/5:  26%|██▌       | 9250/35345 [17:51<49:52,  8.72it/s]

Batch 9250/35345:
total_loss: 0.1646346002817154


Epoch 4/5:  27%|██▋       | 9500/35345 [18:20<49:14,  8.75it/s]

Batch 9500/35345:
total_loss: 0.10078026354312897


Epoch 4/5:  28%|██▊       | 9750/35345 [18:49<48:57,  8.71it/s]

Batch 9750/35345:
total_loss: 0.20537486672401428


Epoch 4/5:  28%|██▊       | 10001/35345 [19:17<48:51,  8.65it/s]

Batch 10000/35345:
total_loss: 0.26507675647735596


Epoch 4/5:  29%|██▉       | 10251/35345 [19:46<48:24,  8.64it/s]

Batch 10250/35345:
total_loss: 0.07627232372760773


Epoch 4/5:  30%|██▉       | 10500/35345 [20:15<47:26,  8.73it/s]

Batch 10500/35345:
total_loss: 0.21888116002082825


Epoch 4/5:  30%|███       | 10750/35345 [20:44<47:28,  8.63it/s]

Batch 10750/35345:
total_loss: 0.11877613514661789


Epoch 4/5:  31%|███       | 11000/35345 [21:13<46:18,  8.76it/s]

Batch 11000/35345:
total_loss: 0.17990373075008392


Epoch 4/5:  32%|███▏      | 11250/35345 [21:41<45:54,  8.75it/s]

Batch 11250/35345:
total_loss: 0.2001056969165802


Epoch 4/5:  33%|███▎      | 11500/35345 [22:10<45:20,  8.77it/s]

Batch 11500/35345:
total_loss: 0.15340781211853027


Epoch 4/5:  33%|███▎      | 11750/35345 [22:39<44:29,  8.84it/s]

Batch 11750/35345:
total_loss: 0.10409484803676605


Epoch 4/5:  34%|███▍      | 12000/35345 [23:07<44:28,  8.75it/s]

Batch 12000/35345:
total_loss: 0.08202818036079407


Epoch 4/5:  35%|███▍      | 12250/35345 [23:36<49:12,  7.82it/s]

Batch 12250/35345:
total_loss: 0.0782189667224884


Epoch 4/5:  35%|███▌      | 12501/35345 [24:05<43:40,  8.72it/s]

Batch 12500/35345:
total_loss: 0.06914028525352478


Epoch 4/5:  36%|███▌      | 12750/35345 [24:34<42:59,  8.76it/s]

Batch 12750/35345:
total_loss: 0.20138631761074066


Epoch 4/5:  37%|███▋      | 13000/35345 [25:02<42:31,  8.76it/s]

Batch 13000/35345:
total_loss: 0.17507798969745636


Epoch 4/5:  37%|███▋      | 13250/35345 [25:31<42:27,  8.67it/s]

Batch 13250/35345:
total_loss: 0.2123616486787796


Epoch 4/5:  38%|███▊      | 13501/35345 [26:00<41:54,  8.69it/s]

Batch 13500/35345:
total_loss: 0.2139938473701477


Epoch 4/5:  39%|███▉      | 13750/35345 [26:29<41:02,  8.77it/s]

Batch 13750/35345:
total_loss: 0.1895061433315277


Epoch 4/5:  40%|███▉      | 14000/35345 [26:57<40:52,  8.70it/s]

Batch 14000/35345:
total_loss: 0.17539988458156586


Epoch 4/5:  40%|████      | 14250/35345 [27:26<40:32,  8.67it/s]

Batch 14250/35345:
total_loss: 0.17269307374954224


Epoch 4/5:  41%|████      | 14501/35345 [27:55<40:20,  8.61it/s]

Batch 14500/35345:
total_loss: 0.056147538125514984


Epoch 4/5:  42%|████▏     | 14751/35345 [28:24<39:55,  8.60it/s]

Batch 14750/35345:
total_loss: 0.15894992649555206


Epoch 4/5:  42%|████▏     | 15000/35345 [28:54<38:45,  8.75it/s]

Batch 15000/35345:
total_loss: 0.12417349219322205


Epoch 4/5:  43%|████▎     | 15250/35345 [29:23<38:17,  8.74it/s]

Batch 15250/35345:
total_loss: 0.09370002150535583


Epoch 4/5:  44%|████▍     | 15500/35345 [29:51<37:39,  8.78it/s]

Batch 15500/35345:
total_loss: 0.14481312036514282


Epoch 4/5:  45%|████▍     | 15750/35345 [30:20<37:03,  8.81it/s]

Batch 15750/35345:
total_loss: 0.15262135863304138


Epoch 4/5:  45%|████▌     | 16000/35345 [30:49<36:55,  8.73it/s]

Batch 16000/35345:
total_loss: 0.19008129835128784


Epoch 4/5:  46%|████▌     | 16250/35345 [31:17<36:26,  8.73it/s]

Batch 16250/35345:
total_loss: 0.13918602466583252


Epoch 4/5:  47%|████▋     | 16500/35345 [31:46<35:53,  8.75it/s]

Batch 16500/35345:
total_loss: 0.1048668697476387


Epoch 4/5:  47%|████▋     | 16750/35345 [32:15<35:51,  8.64it/s]

Batch 16750/35345:
total_loss: 0.2072250247001648


Epoch 4/5:  48%|████▊     | 17000/35345 [32:43<34:48,  8.78it/s]

Batch 17000/35345:
total_loss: 0.09082410484552383


Epoch 4/5:  49%|████▉     | 17250/35345 [33:12<34:31,  8.73it/s]

Batch 17250/35345:
total_loss: 0.15232232213020325


Epoch 4/5:  50%|████▉     | 17500/35345 [33:41<33:54,  8.77it/s]

Batch 17500/35345:
total_loss: 0.1489497870206833


Epoch 4/5:  50%|█████     | 17750/35345 [34:09<33:43,  8.69it/s]

Batch 17750/35345:
total_loss: 0.16633547842502594


Epoch 4/5:  51%|█████     | 18000/35345 [34:38<33:20,  8.67it/s]

Batch 18000/35345:
total_loss: 0.07212036103010178


Epoch 4/5:  52%|█████▏    | 18250/35345 [35:07<32:19,  8.81it/s]

Batch 18250/35345:
total_loss: 0.08947371691465378


Epoch 4/5:  52%|█████▏    | 18500/35345 [35:35<32:13,  8.71it/s]

Batch 18500/35345:
total_loss: 0.17661306262016296


Epoch 4/5:  53%|█████▎    | 18750/35345 [36:05<32:12,  8.59it/s]

Batch 18750/35345:
total_loss: 0.15051668882369995


Epoch 4/5:  54%|█████▍    | 19000/35345 [36:33<31:22,  8.68it/s]

Batch 19000/35345:
total_loss: 0.14034585654735565


Epoch 4/5:  54%|█████▍    | 19251/35345 [37:02<31:33,  8.50it/s]

Batch 19250/35345:
total_loss: 0.15037336945533752


Epoch 4/5:  55%|█████▌    | 19500/35345 [37:31<30:23,  8.69it/s]

Batch 19500/35345:
total_loss: 0.1865118443965912


Epoch 4/5:  56%|█████▌    | 19750/35345 [38:00<29:48,  8.72it/s]

Batch 19750/35345:
total_loss: 0.11101249605417252


Epoch 4/5:  57%|█████▋    | 20000/35345 [38:30<29:28,  8.68it/s]

Batch 20000/35345:
total_loss: 0.22896654903888702


Epoch 4/5:  57%|█████▋    | 20250/35345 [38:58<29:08,  8.63it/s]

Batch 20250/35345:
total_loss: 0.07827162742614746


Epoch 4/5:  58%|█████▊    | 20500/35345 [39:27<28:44,  8.61it/s]

Batch 20500/35345:
total_loss: 0.12216497212648392


Epoch 4/5:  59%|█████▊    | 20750/35345 [39:56<28:09,  8.64it/s]

Batch 20750/35345:
total_loss: 0.1541125774383545


Epoch 4/5:  59%|█████▉    | 21000/35345 [40:25<27:51,  8.58it/s]

Batch 21000/35345:
total_loss: 0.11739745736122131


Epoch 4/5:  60%|██████    | 21250/35345 [40:54<26:59,  8.70it/s]

Batch 21250/35345:
total_loss: 0.1530533879995346


Epoch 4/5:  61%|██████    | 21500/35345 [41:23<26:37,  8.66it/s]

Batch 21500/35345:
total_loss: 0.2700955867767334


Epoch 4/5:  62%|██████▏   | 21750/35345 [41:52<26:07,  8.67it/s]

Batch 21750/35345:
total_loss: 0.18103963136672974


Epoch 4/5:  62%|██████▏   | 22000/35345 [42:21<25:31,  8.71it/s]

Batch 22000/35345:
total_loss: 0.09664411097764969


Epoch 4/5:  63%|██████▎   | 22250/35345 [42:50<24:57,  8.74it/s]

Batch 22250/35345:
total_loss: 0.028486182913184166


Epoch 4/5:  64%|██████▎   | 22500/35345 [43:18<24:34,  8.71it/s]

Batch 22500/35345:
total_loss: 0.06120553985238075


Epoch 4/5:  64%|██████▍   | 22750/35345 [43:47<24:09,  8.69it/s]

Batch 22750/35345:
total_loss: 0.13191843032836914


Epoch 4/5:  65%|██████▌   | 23000/35345 [44:16<23:24,  8.79it/s]

Batch 23000/35345:
total_loss: 0.09788063168525696


Epoch 4/5:  66%|██████▌   | 23250/35345 [44:44<23:15,  8.67it/s]

Batch 23250/35345:
total_loss: 0.05189535394310951


Epoch 4/5:  66%|██████▋   | 23500/35345 [45:13<22:49,  8.65it/s]

Batch 23500/35345:
total_loss: 0.050961460918188095


Epoch 4/5:  67%|██████▋   | 23750/35345 [45:43<22:10,  8.72it/s]

Batch 23750/35345:
total_loss: 0.12944330275058746


Epoch 4/5:  68%|██████▊   | 24000/35345 [46:11<21:52,  8.65it/s]

Batch 24000/35345:
total_loss: 0.17749430239200592


Epoch 4/5:  69%|██████▊   | 24250/35345 [46:40<21:13,  8.71it/s]

Batch 24250/35345:
total_loss: 0.18044544756412506


Epoch 4/5:  69%|██████▉   | 24501/35345 [47:09<20:52,  8.66it/s]

Batch 24500/35345:
total_loss: 0.15684391558170319


Epoch 4/5:  70%|███████   | 24750/35345 [47:38<20:30,  8.61it/s]

Batch 24750/35345:
total_loss: 0.07582925260066986


Epoch 4/5:  71%|███████   | 25000/35345 [48:07<20:02,  8.60it/s]

Batch 25000/35345:
total_loss: 0.24707551300525665


Epoch 4/5:  71%|███████▏  | 25250/35345 [48:36<19:15,  8.74it/s]

Batch 25250/35345:
total_loss: 0.20943069458007812


Epoch 4/5:  72%|███████▏  | 25500/35345 [49:04<18:48,  8.73it/s]

Batch 25500/35345:
total_loss: 0.07795654982328415


Epoch 4/5:  73%|███████▎  | 25750/35345 [49:33<18:48,  8.50it/s]

Batch 25750/35345:
total_loss: 0.27355852723121643


Epoch 4/5:  74%|███████▎  | 26000/35345 [50:02<17:54,  8.69it/s]

Batch 26000/35345:
total_loss: 0.08717107772827148


Epoch 4/5:  74%|███████▍  | 26250/35345 [50:31<17:32,  8.64it/s]

Batch 26250/35345:
total_loss: 0.08601151406764984


Epoch 4/5:  75%|███████▍  | 26500/35345 [51:00<17:03,  8.64it/s]

Batch 26500/35345:
total_loss: 0.1604679524898529


Epoch 4/5:  76%|███████▌  | 26750/35345 [51:29<16:39,  8.60it/s]

Batch 26750/35345:
total_loss: 0.09064256399869919


Epoch 4/5:  76%|███████▋  | 27000/35345 [51:58<16:00,  8.69it/s]

Batch 27000/35345:
total_loss: 0.268564909696579


Epoch 4/5:  77%|███████▋  | 27250/35345 [52:27<15:28,  8.72it/s]

Batch 27250/35345:
total_loss: 0.21850483119487762


Epoch 4/5:  78%|███████▊  | 27500/35345 [52:56<15:10,  8.61it/s]

Batch 27500/35345:
total_loss: 0.15937018394470215


Epoch 4/5:  79%|███████▊  | 27750/35345 [53:25<14:50,  8.53it/s]

Batch 27750/35345:
total_loss: 0.23687109351158142


Epoch 4/5:  79%|███████▉  | 28000/35345 [53:54<14:13,  8.61it/s]

Batch 28000/35345:
total_loss: 0.09853793680667877


Epoch 4/5:  80%|███████▉  | 28250/35345 [54:23<13:43,  8.62it/s]

Batch 28250/35345:
total_loss: 0.18832872807979584


Epoch 4/5:  81%|████████  | 28500/35345 [54:51<13:02,  8.75it/s]

Batch 28500/35345:
total_loss: 0.1493382751941681


Epoch 4/5:  81%|████████▏ | 28750/35345 [55:20<12:54,  8.52it/s]

Batch 28750/35345:
total_loss: 0.22261637449264526


Epoch 4/5:  82%|████████▏ | 29000/35345 [55:49<12:14,  8.64it/s]

Batch 29000/35345:
total_loss: 0.14026042819023132


Epoch 4/5:  83%|████████▎ | 29250/35345 [56:18<11:41,  8.69it/s]

Batch 29250/35345:
total_loss: 0.11103474348783493


Epoch 4/5:  83%|████████▎ | 29500/35345 [56:47<11:12,  8.69it/s]

Batch 29500/35345:
total_loss: 0.1171325147151947


Epoch 4/5:  84%|████████▍ | 29750/35345 [57:16<10:47,  8.64it/s]

Batch 29750/35345:
total_loss: 0.08364398777484894


Epoch 4/5:  85%|████████▍ | 30000/35345 [57:44<10:17,  8.66it/s]

Batch 30000/35345:
total_loss: 0.0831671953201294


Epoch 4/5:  86%|████████▌ | 30250/35345 [58:13<09:58,  8.51it/s]

Batch 30250/35345:
total_loss: 0.11779814213514328


Epoch 4/5:  86%|████████▋ | 30500/35345 [58:43<09:26,  8.55it/s]

Batch 30500/35345:
total_loss: 0.18253202736377716


Epoch 4/5:  87%|████████▋ | 30750/35345 [59:12<08:53,  8.62it/s]

Batch 30750/35345:
total_loss: 0.10979336500167847


Epoch 4/5:  88%|████████▊ | 31000/35345 [59:41<08:22,  8.65it/s]

Batch 31000/35345:
total_loss: 0.2234554886817932


Epoch 4/5:  88%|████████▊ | 31250/35345 [1:00:10<07:49,  8.72it/s]

Batch 31250/35345:
total_loss: 0.1642128974199295


Epoch 4/5:  89%|████████▉ | 31500/35345 [1:00:39<07:24,  8.64it/s]

Batch 31500/35345:
total_loss: 0.10813944041728973


Epoch 4/5:  90%|████████▉ | 31750/35345 [1:01:08<06:50,  8.75it/s]

Batch 31750/35345:
total_loss: 0.20754504203796387


Epoch 4/5:  91%|█████████ | 32000/35345 [1:01:36<06:26,  8.66it/s]

Batch 32000/35345:
total_loss: 0.08567081391811371


Epoch 4/5:  91%|█████████ | 32250/35345 [1:02:06<06:00,  8.59it/s]

Batch 32250/35345:
total_loss: 0.11429708451032639


Epoch 4/5:  92%|█████████▏| 32500/35345 [1:02:35<05:27,  8.68it/s]

Batch 32500/35345:
total_loss: 0.18832139670848846


Epoch 4/5:  93%|█████████▎| 32750/35345 [1:03:03<05:30,  7.86it/s]

Batch 32750/35345:
total_loss: 0.14471866190433502


Epoch 4/5:  93%|█████████▎| 33000/35345 [1:03:32<04:31,  8.63it/s]

Batch 33000/35345:
total_loss: 0.03954415023326874


Epoch 4/5:  94%|█████████▍| 33250/35345 [1:04:01<04:04,  8.56it/s]

Batch 33250/35345:
total_loss: 0.12070807814598083


Epoch 4/5:  95%|█████████▍| 33500/35345 [1:04:30<03:31,  8.71it/s]

Batch 33500/35345:
total_loss: 0.058415476232767105


Epoch 4/5:  95%|█████████▌| 33750/35345 [1:04:59<03:04,  8.64it/s]

Batch 33750/35345:
total_loss: 0.09529924392700195


Epoch 4/5:  96%|█████████▌| 34001/35345 [1:05:28<02:33,  8.75it/s]

Batch 34000/35345:
total_loss: 0.0987706184387207


Epoch 4/5:  97%|█████████▋| 34250/35345 [1:05:57<02:06,  8.67it/s]

Batch 34250/35345:
total_loss: 0.1311001032590866


Epoch 4/5:  98%|█████████▊| 34500/35345 [1:06:25<01:37,  8.64it/s]

Batch 34500/35345:
total_loss: 0.08322770893573761


Epoch 4/5:  98%|█████████▊| 34750/35345 [1:06:54<01:09,  8.60it/s]

Batch 34750/35345:
total_loss: 0.08037807792425156


Epoch 4/5:  99%|█████████▉| 35000/35345 [1:07:23<00:39,  8.75it/s]

Batch 35000/35345:
total_loss: 0.09601246565580368


Epoch 4/5: 100%|█████████▉| 35250/35345 [1:07:53<00:11,  8.54it/s]

Batch 35250/35345:
total_loss: 0.09143015742301941


Epoch 4/5: 100%|██████████| 35345/35345 [1:08:05<00:00,  8.65it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.42it/s]


Epoch Summary:
Train Total Loss: 0.1345
Val Total Loss: 0.0774
Learning Rate: 0.000100
New best model saved with val loss: 0.0774
--------------------------------------------------



Epoch 5/5:   1%|          | 250/35345 [00:29<1:09:01,  8.47it/s]

Batch 250/35345:
total_loss: 0.16984102129936218


Epoch 5/5:   1%|▏         | 500/35345 [00:58<1:08:06,  8.53it/s]

Batch 500/35345:
total_loss: 0.15773265063762665


Epoch 5/5:   2%|▏         | 750/35345 [01:27<1:06:56,  8.61it/s]

Batch 750/35345:
total_loss: 0.06299013644456863


Epoch 5/5:   3%|▎         | 1000/35345 [01:56<1:05:55,  8.68it/s]

Batch 1000/35345:
total_loss: 0.08082500845193863


Epoch 5/5:   4%|▎         | 1250/35345 [02:25<1:05:15,  8.71it/s]

Batch 1250/35345:
total_loss: 0.08919773995876312


Epoch 5/5:   4%|▍         | 1500/35345 [02:54<1:05:15,  8.64it/s]

Batch 1500/35345:
total_loss: 0.10965613275766373


Epoch 5/5:   5%|▍         | 1750/35345 [03:23<1:05:07,  8.60it/s]

Batch 1750/35345:
total_loss: 0.056708645075559616


Epoch 5/5:   6%|▌         | 2001/35345 [03:52<1:04:22,  8.63it/s]

Batch 2000/35345:
total_loss: 0.07317689061164856


Epoch 5/5:   6%|▋         | 2250/35345 [04:21<1:03:31,  8.68it/s]

Batch 2250/35345:
total_loss: 0.07964646816253662


Epoch 5/5:   7%|▋         | 2500/35345 [04:50<1:03:31,  8.62it/s]

Batch 2500/35345:
total_loss: 0.1052543967962265


Epoch 5/5:   8%|▊         | 2750/35345 [05:19<1:03:46,  8.52it/s]

Batch 2750/35345:
total_loss: 0.17300696671009064


Epoch 5/5:   8%|▊         | 3000/35345 [05:48<1:02:34,  8.62it/s]

Batch 3000/35345:
total_loss: 0.07106547057628632


Epoch 5/5:   9%|▉         | 3251/35345 [06:17<1:01:39,  8.68it/s]

Batch 3250/35345:
total_loss: 0.029682254418730736


Epoch 5/5:  10%|▉         | 3500/35345 [06:45<1:02:11,  8.53it/s]

Batch 3500/35345:
total_loss: 0.11643759906291962


Epoch 5/5:  11%|█         | 3750/35345 [07:14<1:00:35,  8.69it/s]

Batch 3750/35345:
total_loss: 0.09904549270868301


Epoch 5/5:  11%|█▏        | 4000/35345 [07:43<1:00:00,  8.71it/s]

Batch 4000/35345:
total_loss: 0.10707471519708633


Epoch 5/5:  12%|█▏        | 4250/35345 [08:12<59:58,  8.64it/s]  

Batch 4250/35345:
total_loss: 0.16582868993282318


Epoch 5/5:  13%|█▎        | 4501/35345 [08:41<59:33,  8.63it/s]

Batch 4500/35345:
total_loss: 0.05930206924676895


Epoch 5/5:  13%|█▎        | 4750/35345 [09:10<58:38,  8.70it/s]

Batch 4750/35345:
total_loss: 0.28302812576293945


Epoch 5/5:  14%|█▍        | 5000/35345 [09:39<58:37,  8.63it/s]

Batch 5000/35345:
total_loss: 0.12605655193328857


Epoch 5/5:  15%|█▍        | 5250/35345 [10:08<57:57,  8.65it/s]

Batch 5250/35345:
total_loss: 0.09077901393175125


Epoch 5/5:  16%|█▌        | 5500/35345 [10:37<58:18,  8.53it/s]

Batch 5500/35345:
total_loss: 0.11206191033124924


Epoch 5/5:  16%|█▋        | 5750/35345 [11:07<56:53,  8.67it/s]

Batch 5750/35345:
total_loss: 0.05184096097946167


Epoch 5/5:  17%|█▋        | 6000/35345 [11:36<56:51,  8.60it/s]

Batch 6000/35345:
total_loss: 0.11101612448692322


Epoch 5/5:  18%|█▊        | 6250/35345 [12:05<55:42,  8.70it/s]

Batch 6250/35345:
total_loss: 0.08885625004768372


Epoch 5/5:  18%|█▊        | 6500/35345 [12:33<55:32,  8.66it/s]

Batch 6500/35345:
total_loss: 0.08913195133209229


Epoch 5/5:  19%|█▉        | 6751/35345 [13:02<56:16,  8.47it/s]

Batch 6750/35345:
total_loss: 0.11345508694648743


Epoch 5/5:  20%|█▉        | 7000/35345 [13:31<53:06,  8.90it/s]

Batch 7000/35345:
total_loss: 0.17529720067977905


Epoch 5/5:  21%|██        | 7250/35345 [13:59<54:21,  8.62it/s]

Batch 7250/35345:
total_loss: 0.09005599468946457


Epoch 5/5:  21%|██        | 7500/35345 [14:29<54:19,  8.54it/s]

Batch 7500/35345:
total_loss: 0.1586053967475891


Epoch 5/5:  22%|██▏       | 7750/35345 [14:57<52:36,  8.74it/s]

Batch 7750/35345:
total_loss: 0.14561320841312408


Epoch 5/5:  23%|██▎       | 8000/35345 [15:26<52:52,  8.62it/s]

Batch 8000/35345:
total_loss: 0.15322257578372955


Epoch 5/5:  23%|██▎       | 8250/35345 [15:55<52:05,  8.67it/s]

Batch 8250/35345:
total_loss: 0.11681853979825974


Epoch 5/5:  24%|██▍       | 8500/35345 [16:24<51:24,  8.70it/s]

Batch 8500/35345:
total_loss: 0.06207403913140297


Epoch 5/5:  25%|██▍       | 8750/35345 [16:53<51:51,  8.55it/s]

Batch 8750/35345:
total_loss: 0.06757830083370209


Epoch 5/5:  25%|██▌       | 9000/35345 [17:22<51:01,  8.60it/s]

Batch 9000/35345:
total_loss: 0.06392349302768707


Epoch 5/5:  26%|██▌       | 9250/35345 [17:50<50:05,  8.68it/s]

Batch 9250/35345:
total_loss: 0.12323203682899475


Epoch 5/5:  27%|██▋       | 9500/35345 [18:20<50:03,  8.60it/s]

Batch 9500/35345:
total_loss: 0.09563902765512466


Epoch 5/5:  28%|██▊       | 9750/35345 [18:49<49:11,  8.67it/s]

Batch 9750/35345:
total_loss: 0.05469777435064316


Epoch 5/5:  28%|██▊       | 10000/35345 [19:18<48:59,  8.62it/s]

Batch 10000/35345:
total_loss: 0.08623380213975906


Epoch 5/5:  29%|██▉       | 10250/35345 [19:46<48:10,  8.68it/s]

Batch 10250/35345:
total_loss: 0.09561063349246979


Epoch 5/5:  30%|██▉       | 10500/35345 [20:16<47:29,  8.72it/s]

Batch 10500/35345:
total_loss: 0.11528313159942627


Epoch 5/5:  30%|███       | 10750/35345 [20:44<47:19,  8.66it/s]

Batch 10750/35345:
total_loss: 0.10817214101552963


Epoch 5/5:  31%|███       | 11000/35345 [21:13<46:28,  8.73it/s]

Batch 11000/35345:
total_loss: 0.20387327671051025


Epoch 5/5:  32%|███▏      | 11250/35345 [21:42<46:41,  8.60it/s]

Batch 11250/35345:
total_loss: 0.05395762249827385


Epoch 5/5:  33%|███▎      | 11500/35345 [22:11<46:02,  8.63it/s]

Batch 11500/35345:
total_loss: 0.10818614810705185


Epoch 5/5:  33%|███▎      | 11750/35345 [22:40<44:52,  8.76it/s]

Batch 11750/35345:
total_loss: 0.152547225356102


Epoch 5/5:  34%|███▍      | 12000/35345 [23:09<45:09,  8.61it/s]

Batch 12000/35345:
total_loss: 0.15982495248317719


Epoch 5/5:  35%|███▍      | 12250/35345 [23:38<45:03,  8.54it/s]

Batch 12250/35345:
total_loss: 0.14631307125091553


Epoch 5/5:  35%|███▌      | 12500/35345 [24:07<43:41,  8.71it/s]

Batch 12500/35345:
total_loss: 0.1117371991276741


Epoch 5/5:  36%|███▌      | 12750/35345 [24:35<43:10,  8.72it/s]

Batch 12750/35345:
total_loss: 0.14870385825634003


Epoch 5/5:  37%|███▋      | 13000/35345 [25:05<43:04,  8.65it/s]

Batch 13000/35345:
total_loss: 0.23911522328853607


Epoch 5/5:  37%|███▋      | 13250/35345 [25:34<43:03,  8.55it/s]

Batch 13250/35345:
total_loss: 0.1662146896123886


Epoch 5/5:  38%|███▊      | 13500/35345 [26:03<41:33,  8.76it/s]

Batch 13500/35345:
total_loss: 0.11176227033138275


Epoch 5/5:  39%|███▉      | 13750/35345 [26:32<41:16,  8.72it/s]

Batch 13750/35345:
total_loss: 0.13076981902122498


Epoch 5/5:  40%|███▉      | 14000/35345 [27:01<43:14,  8.23it/s]

Batch 14000/35345:
total_loss: 0.13508211076259613


Epoch 5/5:  40%|████      | 14250/35345 [27:30<41:01,  8.57it/s]

Batch 14250/35345:
total_loss: 0.11480264365673065


Epoch 5/5:  41%|████      | 14500/35345 [27:59<39:54,  8.71it/s]

Batch 14500/35345:
total_loss: 0.09142819792032242


Epoch 5/5:  42%|████▏     | 14750/35345 [28:28<39:57,  8.59it/s]

Batch 14750/35345:
total_loss: 0.09562301635742188


Epoch 5/5:  42%|████▏     | 15000/35345 [28:57<39:07,  8.67it/s]

Batch 15000/35345:
total_loss: 0.21199388802051544


Epoch 5/5:  43%|████▎     | 15250/35345 [29:26<38:50,  8.62it/s]

Batch 15250/35345:
total_loss: 0.0888240784406662


Epoch 5/5:  44%|████▍     | 15500/35345 [29:55<38:27,  8.60it/s]

Batch 15500/35345:
total_loss: 0.015066352672874928


Epoch 5/5:  45%|████▍     | 15750/35345 [30:23<37:51,  8.63it/s]

Batch 15750/35345:
total_loss: 0.13544130325317383


Epoch 5/5:  45%|████▌     | 16000/35345 [30:52<37:12,  8.67it/s]

Batch 16000/35345:
total_loss: 0.03129519894719124


Epoch 5/5:  46%|████▌     | 16250/35345 [31:21<36:51,  8.64it/s]

Batch 16250/35345:
total_loss: 0.12854310870170593


Epoch 5/5:  47%|████▋     | 16501/35345 [31:50<37:18,  8.42it/s]

Batch 16500/35345:
total_loss: 0.15843349695205688


Epoch 5/5:  47%|████▋     | 16750/35345 [32:19<35:32,  8.72it/s]

Batch 16750/35345:
total_loss: 0.19585373997688293


Epoch 5/5:  48%|████▊     | 17000/35345 [32:48<34:55,  8.75it/s]

Batch 17000/35345:
total_loss: 0.18214181065559387


Epoch 5/5:  49%|████▉     | 17250/35345 [33:17<35:25,  8.52it/s]

Batch 17250/35345:
total_loss: 0.09668697416782379


Epoch 5/5:  50%|████▉     | 17500/35345 [33:46<34:33,  8.61it/s]

Batch 17500/35345:
total_loss: 0.08406971395015717


Epoch 5/5:  50%|█████     | 17750/35345 [34:15<34:11,  8.58it/s]

Batch 17750/35345:
total_loss: 0.21773993968963623


Epoch 5/5:  51%|█████     | 18000/35345 [34:44<33:20,  8.67it/s]

Batch 18000/35345:
total_loss: 0.09140888601541519


Epoch 5/5:  52%|█████▏    | 18250/35345 [35:13<32:39,  8.72it/s]

Batch 18250/35345:
total_loss: 0.06554369628429413


Epoch 5/5:  52%|█████▏    | 18500/35345 [35:42<32:46,  8.57it/s]

Batch 18500/35345:
total_loss: 0.06456868350505829


Epoch 5/5:  53%|█████▎    | 18750/35345 [36:11<31:51,  8.68it/s]

Batch 18750/35345:
total_loss: 0.15534579753875732


Epoch 5/5:  54%|█████▍    | 19000/35345 [36:40<31:38,  8.61it/s]

Batch 19000/35345:
total_loss: 0.16909460723400116


Epoch 5/5:  54%|█████▍    | 19250/35345 [37:09<30:59,  8.65it/s]

Batch 19250/35345:
total_loss: 0.08437986671924591


Epoch 5/5:  55%|█████▌    | 19500/35345 [37:38<30:26,  8.67it/s]

Batch 19500/35345:
total_loss: 0.12966126203536987


Epoch 5/5:  56%|█████▌    | 19750/35345 [38:07<30:00,  8.66it/s]

Batch 19750/35345:
total_loss: 0.02343260496854782


Epoch 5/5:  57%|█████▋    | 20000/35345 [38:36<29:38,  8.63it/s]

Batch 20000/35345:
total_loss: 0.07071984559297562


Epoch 5/5:  57%|█████▋    | 20250/35345 [39:05<28:50,  8.72it/s]

Batch 20250/35345:
total_loss: 0.08700777590274811


Epoch 5/5:  58%|█████▊    | 20500/35345 [39:33<28:16,  8.75it/s]

Batch 20500/35345:
total_loss: 0.06484277546405792


Epoch 5/5:  59%|█████▊    | 20750/35345 [40:02<28:15,  8.61it/s]

Batch 20750/35345:
total_loss: 0.1381705105304718


Epoch 5/5:  59%|█████▉    | 21000/35345 [40:31<27:31,  8.69it/s]

Batch 21000/35345:
total_loss: 0.12345670163631439


Epoch 5/5:  60%|██████    | 21251/35345 [41:00<32:44,  7.17it/s]

Batch 21250/35345:
total_loss: 0.02808026596903801


Epoch 5/5:  61%|██████    | 21500/35345 [41:29<26:36,  8.67it/s]

Batch 21500/35345:
total_loss: 0.08844690769910812


Epoch 5/5:  62%|██████▏   | 21750/35345 [41:58<26:16,  8.62it/s]

Batch 21750/35345:
total_loss: 0.027162399142980576


Epoch 5/5:  62%|██████▏   | 22000/35345 [42:27<25:48,  8.62it/s]

Batch 22000/35345:
total_loss: 0.1578889787197113


Epoch 5/5:  63%|██████▎   | 22250/35345 [42:56<25:09,  8.67it/s]

Batch 22250/35345:
total_loss: 0.13514676690101624


Epoch 5/5:  64%|██████▎   | 22500/35345 [43:25<24:43,  8.66it/s]

Batch 22500/35345:
total_loss: 0.11765561997890472


Epoch 5/5:  64%|██████▍   | 22750/35345 [43:54<24:28,  8.58it/s]

Batch 22750/35345:
total_loss: 0.07656845450401306


Epoch 5/5:  65%|██████▌   | 23000/35345 [44:23<23:38,  8.71it/s]

Batch 23000/35345:
total_loss: 0.09467189013957977


Epoch 5/5:  66%|██████▌   | 23250/35345 [44:52<23:16,  8.66it/s]

Batch 23250/35345:
total_loss: 0.025728043168783188


Epoch 5/5:  66%|██████▋   | 23500/35345 [45:21<22:59,  8.59it/s]

Batch 23500/35345:
total_loss: 0.19586722552776337


Epoch 5/5:  67%|██████▋   | 23751/35345 [45:50<22:26,  8.61it/s]

Batch 23750/35345:
total_loss: 0.21702027320861816


Epoch 5/5:  68%|██████▊   | 24000/35345 [46:19<21:51,  8.65it/s]

Batch 24000/35345:
total_loss: 0.05507959797978401


Epoch 5/5:  69%|██████▊   | 24250/35345 [46:48<21:21,  8.66it/s]

Batch 24250/35345:
total_loss: 0.06437251716852188


Epoch 5/5:  69%|██████▉   | 24500/35345 [47:17<20:54,  8.64it/s]

Batch 24500/35345:
total_loss: 0.04281412065029144


Epoch 5/5:  70%|███████   | 24750/35345 [47:46<20:32,  8.60it/s]

Batch 24750/35345:
total_loss: 0.08572566509246826


Epoch 5/5:  71%|███████   | 25000/35345 [48:15<19:59,  8.62it/s]

Batch 25000/35345:
total_loss: 0.062449466437101364


Epoch 5/5:  71%|███████▏  | 25250/35345 [48:44<19:23,  8.68it/s]

Batch 25250/35345:
total_loss: 0.08182127773761749


Epoch 5/5:  72%|███████▏  | 25500/35345 [49:13<19:03,  8.61it/s]

Batch 25500/35345:
total_loss: 0.14521703124046326


Epoch 5/5:  73%|███████▎  | 25750/35345 [49:42<18:20,  8.72it/s]

Batch 25750/35345:
total_loss: 0.12229694426059723


Epoch 5/5:  74%|███████▎  | 26000/35345 [50:11<17:51,  8.72it/s]

Batch 26000/35345:
total_loss: 0.08103455603122711


Epoch 5/5:  74%|███████▍  | 26251/35345 [50:40<17:40,  8.57it/s]

Batch 26250/35345:
total_loss: 0.08676743507385254


Epoch 5/5:  75%|███████▍  | 26500/35345 [51:09<17:01,  8.66it/s]

Batch 26500/35345:
total_loss: 0.1079607680439949


Epoch 5/5:  76%|███████▌  | 26750/35345 [51:38<16:29,  8.69it/s]

Batch 26750/35345:
total_loss: 0.026753155514597893


Epoch 5/5:  76%|███████▋  | 27000/35345 [52:07<16:22,  8.49it/s]

Batch 27000/35345:
total_loss: 0.14621634781360626


Epoch 5/5:  77%|███████▋  | 27250/35345 [52:36<15:22,  8.78it/s]

Batch 27250/35345:
total_loss: 0.11169673502445221


Epoch 5/5:  78%|███████▊  | 27500/35345 [53:05<15:06,  8.66it/s]

Batch 27500/35345:
total_loss: 0.16062003374099731


Epoch 5/5:  79%|███████▊  | 27750/35345 [53:34<14:37,  8.65it/s]

Batch 27750/35345:
total_loss: 0.08342219889163971


Epoch 5/5:  79%|███████▉  | 28000/35345 [54:03<14:12,  8.62it/s]

Batch 28000/35345:
total_loss: 0.06156298145651817


Epoch 5/5:  80%|███████▉  | 28250/35345 [54:32<13:37,  8.68it/s]

Batch 28250/35345:
total_loss: 0.0599842295050621


Epoch 5/5:  81%|████████  | 28500/35345 [55:01<13:18,  8.57it/s]

Batch 28500/35345:
total_loss: 0.15420381724834442


Epoch 5/5:  81%|████████▏ | 28750/35345 [55:30<12:53,  8.52it/s]

Batch 28750/35345:
total_loss: 0.06347867101430893


Epoch 5/5:  82%|████████▏ | 29000/35345 [55:59<12:16,  8.61it/s]

Batch 29000/35345:
total_loss: 0.11324560642242432


Epoch 5/5:  83%|████████▎ | 29250/35345 [56:28<11:45,  8.63it/s]

Batch 29250/35345:
total_loss: 0.11991183459758759


Epoch 5/5:  83%|████████▎ | 29500/35345 [56:57<11:24,  8.54it/s]

Batch 29500/35345:
total_loss: 0.08172875642776489


Epoch 5/5:  84%|████████▍ | 29750/35345 [57:26<10:52,  8.57it/s]

Batch 29750/35345:
total_loss: 0.1578901708126068


Epoch 5/5:  85%|████████▍ | 30000/35345 [57:55<10:18,  8.64it/s]

Batch 30000/35345:
total_loss: 0.06508714705705643


Epoch 5/5:  86%|████████▌ | 30250/35345 [58:23<10:00,  8.49it/s]

Batch 30250/35345:
total_loss: 0.037047334015369415


Epoch 5/5:  86%|████████▋ | 30500/35345 [58:53<09:17,  8.69it/s]

Batch 30500/35345:
total_loss: 0.09508714824914932


Epoch 5/5:  87%|████████▋ | 30750/35345 [59:21<08:55,  8.59it/s]

Batch 30750/35345:
total_loss: 0.06261944025754929


Epoch 5/5:  88%|████████▊ | 31001/35345 [59:50<08:43,  8.29it/s]

Batch 31000/35345:
total_loss: 0.11920956522226334


Epoch 5/5:  88%|████████▊ | 31250/35345 [1:00:19<08:00,  8.52it/s]

Batch 31250/35345:
total_loss: 0.11443773657083511


Epoch 5/5:  89%|████████▉ | 31500/35345 [1:00:49<07:22,  8.68it/s]

Batch 31500/35345:
total_loss: 0.1210152730345726


Epoch 5/5:  90%|████████▉ | 31750/35345 [1:01:17<06:50,  8.76it/s]

Batch 31750/35345:
total_loss: 0.05735772103071213


Epoch 5/5:  91%|█████████ | 32000/35345 [1:01:46<06:26,  8.66it/s]

Batch 32000/35345:
total_loss: 0.06086983159184456


Epoch 5/5:  91%|█████████ | 32250/35345 [1:02:15<05:56,  8.68it/s]

Batch 32250/35345:
total_loss: 0.14058050513267517


Epoch 5/5:  92%|█████████▏| 32500/35345 [1:02:45<05:30,  8.61it/s]

Batch 32500/35345:
total_loss: 0.058618154376745224


Epoch 5/5:  93%|█████████▎| 32750/35345 [1:03:14<04:56,  8.75it/s]

Batch 32750/35345:
total_loss: 0.061470963060855865


Epoch 5/5:  93%|█████████▎| 33000/35345 [1:03:42<04:28,  8.73it/s]

Batch 33000/35345:
total_loss: 0.08086761087179184


Epoch 5/5:  94%|█████████▍| 33250/35345 [1:04:11<04:03,  8.60it/s]

Batch 33250/35345:
total_loss: 0.048014793545007706


Epoch 5/5:  95%|█████████▍| 33500/35345 [1:04:40<03:32,  8.69it/s]

Batch 33500/35345:
total_loss: 0.13959960639476776


Epoch 5/5:  95%|█████████▌| 33750/35345 [1:05:09<03:05,  8.58it/s]

Batch 33750/35345:
total_loss: 0.1156013086438179


Epoch 5/5:  96%|█████████▌| 34000/35345 [1:05:38<02:35,  8.67it/s]

Batch 34000/35345:
total_loss: 0.11061560362577438


Epoch 5/5:  97%|█████████▋| 34250/35345 [1:06:07<02:05,  8.69it/s]

Batch 34250/35345:
total_loss: 0.050433848053216934


Epoch 5/5:  98%|█████████▊| 34500/35345 [1:06:36<01:36,  8.76it/s]

Batch 34500/35345:
total_loss: 0.06458617001771927


Epoch 5/5:  98%|█████████▊| 34750/35345 [1:07:05<01:09,  8.62it/s]

Batch 34750/35345:
total_loss: 0.12457692623138428


Epoch 5/5:  99%|█████████▉| 35000/35345 [1:07:34<00:40,  8.60it/s]

Batch 35000/35345:
total_loss: 0.05625021085143089


Epoch 5/5: 100%|█████████▉| 35250/35345 [1:08:03<00:10,  8.69it/s]

Batch 35250/35345:
total_loss: 0.052457764744758606


Epoch 5/5: 100%|██████████| 35345/35345 [1:08:15<00:00,  8.63it/s]
Validating: 100%|██████████| 4/4 [00:01<00:00,  3.58it/s]


Epoch Summary:
Train Total Loss: 0.1014
Val Total Loss: 0.0673
Learning Rate: 0.000100
New best model saved with val loss: 0.0673
--------------------------------------------------



