<a href="https://colab.research.google.com/github/carolynw898/STAT946Proj/blob/main/stat946-proj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple
from models import SymbolicGaussianDiffusion, PointNetConfig
from utils import CharDataset, processDataFiles, tokenize_equation

def train_epoch(
    model: SymbolicGaussianDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
        optimizer.zero_grad()

        total_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()


    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            total_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
    path=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        avg_train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f}"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f}"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)

In [8]:
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 2
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
maxNumFiles = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
dataDir = "/content/drive/MyDrive/Colab/STAT946_proj/data"
dataFolder = "2_var_dataset"

In [10]:
from torch.utils.data import DataLoader
import numpy as np
import glob
from utils import processDataFiles, CharDataset, tokenize_equation
import random
import json

path = '{}/{}/Train/*.json'.format(dataDir, dataFolder)
files = glob.glob(path)[:maxNumFiles]
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 499035 examples, 30 unique.
id:390191
outputs:C*exp(C*x2**4)+C*exp(C*x1)+C>__________
variables:2


In [11]:
path = '{}/{}/Val/*.json'.format(dataDir,dataFolder)
files = glob.glob(path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 949 examples, 30 unique.
tensor(-8.8065) tensor(6.2440)
id:438
outputs:C*x1**4+C*x1**3+C*x1**2+C*x1+C>________
variables:1


In [12]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
    set_transformer=True,
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate,
    path="/content/drive/MyDrive/Colab/STAT946_proj/models/stable_diffusym/2_var_set_transformer_sd.pth"
)

Epoch 1/5:   3%|▎         | 249/7798 [00:21<10:11, 12.34it/s]

Batch 250/7798:
total_loss: 2.007009267807007


Epoch 1/5:   6%|▋         | 501/7798 [00:42<09:46, 12.43it/s]

Batch 500/7798:
total_loss: 1.6813679933547974


Epoch 1/5:  10%|▉         | 749/7798 [01:03<08:53, 13.21it/s]

Batch 750/7798:
total_loss: 1.3449336290359497


Epoch 1/5:  13%|█▎        | 999/7798 [01:24<08:45, 12.93it/s]

Batch 1000/7798:
total_loss: 0.8695368766784668


Epoch 1/5:  16%|█▌        | 1251/7798 [01:45<08:23, 12.99it/s]

Batch 1250/7798:
total_loss: 0.7282642126083374


Epoch 1/5:  19%|█▉        | 1501/7798 [02:06<09:06, 11.53it/s]

Batch 1500/7798:
total_loss: 0.8693933486938477


Epoch 1/5:  22%|██▏       | 1751/7798 [02:26<08:36, 11.71it/s]

Batch 1750/7798:
total_loss: 0.7199123501777649


Epoch 1/5:  26%|██▌       | 2001/7798 [02:47<08:19, 11.62it/s]

Batch 2000/7798:
total_loss: 0.794024646282196


Epoch 1/5:  29%|██▉       | 2251/7798 [03:07<08:19, 11.11it/s]

Batch 2250/7798:
total_loss: 0.6710669994354248


Epoch 1/5:  32%|███▏      | 2501/7798 [03:28<06:51, 12.87it/s]

Batch 2500/7798:
total_loss: 0.5889523029327393


Epoch 1/5:  35%|███▌      | 2751/7798 [03:49<08:16, 10.17it/s]

Batch 2750/7798:
total_loss: 0.6715309023857117


Epoch 1/5:  38%|███▊      | 3001/7798 [04:09<06:16, 12.73it/s]

Batch 3000/7798:
total_loss: 0.4075583517551422


Epoch 1/5:  42%|████▏     | 3251/7798 [04:30<05:58, 12.68it/s]

Batch 3250/7798:
total_loss: 0.7180449366569519


Epoch 1/5:  45%|████▍     | 3501/7798 [04:51<06:19, 11.33it/s]

Batch 3500/7798:
total_loss: 0.4229906499385834


Epoch 1/5:  48%|████▊     | 3751/7798 [05:12<05:51, 11.50it/s]

Batch 3750/7798:
total_loss: 0.4974105656147003


Epoch 1/5:  51%|█████▏    | 4001/7798 [05:32<04:55, 12.84it/s]

Batch 4000/7798:
total_loss: 0.5277514457702637


Epoch 1/5:  55%|█████▍    | 4251/7798 [05:53<04:30, 13.11it/s]

Batch 4250/7798:
total_loss: 0.6356792449951172


Epoch 1/5:  58%|█████▊    | 4501/7798 [06:14<04:42, 11.67it/s]

Batch 4500/7798:
total_loss: 0.6408905982971191


Epoch 1/5:  61%|██████    | 4749/7798 [06:34<03:57, 12.85it/s]

Batch 4750/7798:
total_loss: 0.7295777201652527


Epoch 1/5:  64%|██████▍   | 5001/7798 [06:55<03:45, 12.42it/s]

Batch 5000/7798:
total_loss: 0.5725485682487488


Epoch 1/5:  67%|██████▋   | 5251/7798 [07:15<03:19, 12.76it/s]

Batch 5250/7798:
total_loss: 0.5295022130012512


Epoch 1/5:  71%|███████   | 5501/7798 [07:36<02:57, 12.94it/s]

Batch 5500/7798:
total_loss: 0.5209551453590393


Epoch 1/5:  74%|███████▎  | 5751/7798 [07:56<02:34, 13.22it/s]

Batch 5750/7798:
total_loss: 0.37016788125038147


Epoch 1/5:  77%|███████▋  | 5999/7798 [08:17<02:22, 12.63it/s]

Batch 6000/7798:
total_loss: 0.5949460864067078


Epoch 1/5:  80%|████████  | 6251/7798 [08:38<02:00, 12.86it/s]

Batch 6250/7798:
total_loss: 0.5793258547782898


Epoch 1/5:  83%|████████▎ | 6501/7798 [08:59<01:49, 11.89it/s]

Batch 6500/7798:
total_loss: 0.5592092275619507


Epoch 1/5:  87%|████████▋ | 6751/7798 [09:19<01:21, 12.86it/s]

Batch 6750/7798:
total_loss: 0.7993831038475037


Epoch 1/5:  90%|████████▉ | 7001/7798 [09:40<01:20,  9.94it/s]

Batch 7000/7798:
total_loss: 0.4668397307395935


Epoch 1/5:  93%|█████████▎| 7251/7798 [10:01<00:47, 11.41it/s]

Batch 7250/7798:
total_loss: 0.4647594392299652


Epoch 1/5:  96%|█████████▌| 7501/7798 [10:21<00:22, 13.16it/s]

Batch 7500/7798:
total_loss: 0.4721720814704895


Epoch 1/5:  99%|█████████▉| 7751/7798 [10:42<00:04, 11.45it/s]

Batch 7750/7798:
total_loss: 0.6352137923240662


Epoch 1/5: 100%|██████████| 7798/7798 [10:46<00:00, 12.06it/s]
Validating: 100%|██████████| 15/15 [00:01<00:00,  9.01it/s]



Epoch Summary:
Train Total Loss: 0.7509
Val Total Loss: 0.4900
Learning Rate: 0.000100
New best model saved with val loss: 0.4900
--------------------------------------------------


Epoch 2/5:   3%|▎         | 251/7798 [00:20<09:54, 12.70it/s]

Batch 250/7798:
total_loss: 0.6493979096412659


Epoch 2/5:   6%|▋         | 499/7798 [00:41<09:57, 12.22it/s]

Batch 500/7798:
total_loss: 0.5412514209747314


Epoch 2/5:  10%|▉         | 751/7798 [01:02<09:15, 12.68it/s]

Batch 750/7798:
total_loss: 0.2742714285850525


Epoch 2/5:  13%|█▎        | 1001/7798 [01:22<09:29, 11.93it/s]

Batch 1000/7798:
total_loss: 0.5198659300804138


Epoch 2/5:  16%|█▌        | 1251/7798 [01:44<08:29, 12.85it/s]

Batch 1250/7798:
total_loss: 0.5552994608879089


Epoch 2/5:  19%|█▉        | 1501/7798 [02:04<08:52, 11.83it/s]

Batch 1500/7798:
total_loss: 0.4589557945728302


Epoch 2/5:  22%|██▏       | 1751/7798 [02:25<08:36, 11.70it/s]

Batch 1750/7798:
total_loss: 0.5189166069030762


Epoch 2/5:  26%|██▌       | 2001/7798 [02:45<07:23, 13.06it/s]

Batch 2000/7798:
total_loss: 0.44335922598838806


Epoch 2/5:  29%|██▉       | 2251/7798 [03:06<07:10, 12.88it/s]

Batch 2250/7798:
total_loss: 0.5639389157295227


Epoch 2/5:  32%|███▏      | 2499/7798 [03:26<07:12, 12.24it/s]

Batch 2500/7798:
total_loss: 0.43846702575683594


Epoch 2/5:  35%|███▌      | 2751/7798 [03:47<06:55, 12.14it/s]

Batch 2750/7798:
total_loss: 0.36503660678863525


Epoch 2/5:  38%|███▊      | 3001/7798 [04:08<06:30, 12.29it/s]

Batch 3000/7798:
total_loss: 0.5643869042396545


Epoch 2/5:  42%|████▏     | 3251/7798 [04:29<05:44, 13.21it/s]

Batch 3250/7798:
total_loss: 0.38421398401260376


Epoch 2/5:  45%|████▍     | 3501/7798 [04:49<06:06, 11.72it/s]

Batch 3500/7798:
total_loss: 0.46523499488830566


Epoch 2/5:  48%|████▊     | 3751/7798 [05:10<05:21, 12.57it/s]

Batch 3750/7798:
total_loss: 0.8041387796401978


Epoch 2/5:  51%|█████▏    | 4001/7798 [05:31<05:22, 11.76it/s]

Batch 4000/7798:
total_loss: 0.35513293743133545


Epoch 2/5:  55%|█████▍    | 4251/7798 [05:51<04:55, 11.98it/s]

Batch 4250/7798:
total_loss: 0.4928465783596039


Epoch 2/5:  58%|█████▊    | 4501/7798 [06:12<04:30, 12.18it/s]

Batch 4500/7798:
total_loss: 0.35577213764190674


Epoch 2/5:  61%|██████    | 4751/7798 [06:33<04:23, 11.57it/s]

Batch 4750/7798:
total_loss: 0.48274946212768555


Epoch 2/5:  64%|██████▍   | 5001/7798 [06:53<03:56, 11.83it/s]

Batch 5000/7798:
total_loss: 0.4413211941719055


Epoch 2/5:  67%|██████▋   | 5251/7798 [07:15<03:13, 13.15it/s]

Batch 5250/7798:
total_loss: 0.4874500036239624


Epoch 2/5:  71%|███████   | 5499/7798 [07:35<03:14, 11.80it/s]

Batch 5500/7798:
total_loss: 0.3696240186691284


Epoch 2/5:  74%|███████▎  | 5751/7798 [07:56<02:34, 13.29it/s]

Batch 5750/7798:
total_loss: 0.6158024668693542


Epoch 2/5:  77%|███████▋  | 6001/7798 [08:17<02:33, 11.73it/s]

Batch 6000/7798:
total_loss: 0.47296345233917236


Epoch 2/5:  80%|████████  | 6251/7798 [08:37<02:15, 11.42it/s]

Batch 6250/7798:
total_loss: 0.44948798418045044


Epoch 2/5:  83%|████████▎ | 6501/7798 [08:58<01:47, 12.12it/s]

Batch 6500/7798:
total_loss: 0.5451480150222778


Epoch 2/5:  87%|████████▋ | 6751/7798 [09:18<01:28, 11.82it/s]

Batch 6750/7798:
total_loss: 0.5495024919509888


Epoch 2/5:  90%|████████▉ | 7001/7798 [09:38<01:07, 11.83it/s]

Batch 7000/7798:
total_loss: 0.4139281213283539


Epoch 2/5:  93%|█████████▎| 7251/7798 [09:59<00:45, 11.98it/s]

Batch 7250/7798:
total_loss: 0.5818496942520142


Epoch 2/5:  96%|█████████▌| 7501/7798 [10:19<00:22, 13.17it/s]

Batch 7500/7798:
total_loss: 0.5885727405548096


Epoch 2/5:  99%|█████████▉| 7751/7798 [10:40<00:03, 12.65it/s]

Batch 7750/7798:
total_loss: 0.5387837290763855


Epoch 2/5: 100%|██████████| 7798/7798 [10:44<00:00, 12.10it/s]
Validating: 100%|██████████| 15/15 [00:01<00:00,  9.20it/s]



Epoch Summary:
Train Total Loss: 0.5005
Val Total Loss: 0.4157
Learning Rate: 0.000100
New best model saved with val loss: 0.4157
--------------------------------------------------


Epoch 3/5:   3%|▎         | 251/7798 [00:20<09:50, 12.78it/s]

Batch 250/7798:
total_loss: 0.7195716500282288


Epoch 3/5:   6%|▋         | 501/7798 [00:41<10:38, 11.43it/s]

Batch 500/7798:
total_loss: 0.3836797773838043


Epoch 3/5:  10%|▉         | 751/7798 [01:02<09:06, 12.90it/s]

Batch 750/7798:
total_loss: 0.5034289360046387


Epoch 3/5:  13%|█▎        | 1001/7798 [01:23<09:28, 11.96it/s]

Batch 1000/7798:
total_loss: 0.30346670746803284


Epoch 3/5:  16%|█▌        | 1251/7798 [01:44<09:13, 11.83it/s]

Batch 1250/7798:
total_loss: 0.5056633353233337


Epoch 3/5:  19%|█▉        | 1501/7798 [02:05<10:15, 10.23it/s]

Batch 1500/7798:
total_loss: 0.5332229137420654


Epoch 3/5:  22%|██▏       | 1751/7798 [02:26<08:30, 11.85it/s]

Batch 1750/7798:
total_loss: 0.2852718234062195


Epoch 3/5:  26%|██▌       | 2001/7798 [02:47<07:34, 12.75it/s]

Batch 2000/7798:
total_loss: 0.36887744069099426


Epoch 3/5:  29%|██▉       | 2251/7798 [03:08<08:04, 11.45it/s]

Batch 2250/7798:
total_loss: 0.41847172379493713


Epoch 3/5:  32%|███▏      | 2501/7798 [03:29<07:22, 11.98it/s]

Batch 2500/7798:
total_loss: 0.6198106408119202


Epoch 3/5:  35%|███▌      | 2751/7798 [03:49<06:31, 12.88it/s]

Batch 2750/7798:
total_loss: 0.5337623953819275


Epoch 3/5:  38%|███▊      | 3001/7798 [04:10<06:55, 11.55it/s]

Batch 3000/7798:
total_loss: 0.32645300030708313


Epoch 3/5:  42%|████▏     | 3251/7798 [04:31<05:57, 12.72it/s]

Batch 3250/7798:
total_loss: 0.43925362825393677


Epoch 3/5:  45%|████▍     | 3501/7798 [04:52<06:12, 11.55it/s]

Batch 3500/7798:
total_loss: 0.5312628149986267


Epoch 3/5:  48%|████▊     | 3751/7798 [05:13<05:14, 12.86it/s]

Batch 3750/7798:
total_loss: 0.6287570595741272


Epoch 3/5:  51%|█████▏    | 4001/7798 [05:34<05:17, 11.95it/s]

Batch 4000/7798:
total_loss: 0.5067151784896851


Epoch 3/5:  55%|█████▍    | 4251/7798 [05:54<05:05, 11.60it/s]

Batch 4250/7798:
total_loss: 0.3596287667751312


Epoch 3/5:  58%|█████▊    | 4501/7798 [06:15<04:39, 11.78it/s]

Batch 4500/7798:
total_loss: 0.44527262449264526


Epoch 3/5:  61%|██████    | 4751/7798 [06:36<04:05, 12.43it/s]

Batch 4750/7798:
total_loss: 0.4232020676136017


Epoch 3/5:  64%|██████▍   | 5001/7798 [06:57<04:19, 10.79it/s]

Batch 5000/7798:
total_loss: 0.31816738843917847


Epoch 3/5:  67%|██████▋   | 5251/7798 [07:18<03:30, 12.07it/s]

Batch 5250/7798:
total_loss: 0.28041958808898926


Epoch 3/5:  71%|███████   | 5501/7798 [07:39<03:00, 12.76it/s]

Batch 5500/7798:
total_loss: 0.5002937912940979


Epoch 3/5:  74%|███████▎  | 5749/7798 [07:59<02:34, 13.24it/s]

Batch 5750/7798:
total_loss: 0.38144218921661377


Epoch 3/5:  77%|███████▋  | 6001/7798 [08:20<02:31, 11.85it/s]

Batch 6000/7798:
total_loss: 0.3483096957206726


Epoch 3/5:  80%|████████  | 6251/7798 [08:40<02:06, 12.21it/s]

Batch 6250/7798:
total_loss: 0.39376407861709595


Epoch 3/5:  83%|████████▎ | 6501/7798 [09:01<01:47, 12.04it/s]

Batch 6500/7798:
total_loss: 0.5367264747619629


Epoch 3/5:  87%|████████▋ | 6751/7798 [09:22<01:21, 12.84it/s]

Batch 6750/7798:
total_loss: 0.6549822092056274


Epoch 3/5:  90%|████████▉ | 7001/7798 [09:42<01:08, 11.59it/s]

Batch 7000/7798:
total_loss: 0.12954555451869965


Epoch 3/5:  93%|█████████▎| 7251/7798 [10:03<00:43, 12.52it/s]

Batch 7250/7798:
total_loss: 0.3873438835144043


Epoch 3/5:  96%|█████████▌| 7501/7798 [10:23<00:25, 11.60it/s]

Batch 7500/7798:
total_loss: 0.25514471530914307


Epoch 3/5:  99%|█████████▉| 7751/7798 [10:44<00:03, 11.79it/s]

Batch 7750/7798:
total_loss: 0.26105058193206787


Epoch 3/5: 100%|██████████| 7798/7798 [10:48<00:00, 12.02it/s]
Validating: 100%|██████████| 15/15 [00:01<00:00,  9.22it/s]



Epoch Summary:
Train Total Loss: 0.4406
Val Total Loss: 0.3886
Learning Rate: 0.000100
New best model saved with val loss: 0.3886
--------------------------------------------------


Epoch 4/5:   3%|▎         | 251/7798 [00:20<09:34, 13.13it/s]

Batch 250/7798:
total_loss: 0.3164632022380829


Epoch 4/5:   6%|▋         | 501/7798 [00:41<09:32, 12.74it/s]

Batch 500/7798:
total_loss: 0.503399133682251


Epoch 4/5:  10%|▉         | 751/7798 [01:02<10:32, 11.14it/s]

Batch 750/7798:
total_loss: 0.4111625850200653


Epoch 4/5:  13%|█▎        | 1001/7798 [01:23<09:37, 11.77it/s]

Batch 1000/7798:
total_loss: 0.3159514367580414


Epoch 4/5:  16%|█▌        | 1251/7798 [01:43<09:41, 11.26it/s]

Batch 1250/7798:
total_loss: 0.44925519824028015


Epoch 4/5:  19%|█▉        | 1501/7798 [02:04<08:10, 12.83it/s]

Batch 1500/7798:
total_loss: 0.40042346715927124


Epoch 4/5:  22%|██▏       | 1751/7798 [02:25<07:45, 12.99it/s]

Batch 1750/7798:
total_loss: 0.5375398993492126


Epoch 4/5:  26%|██▌       | 2001/7798 [02:46<08:18, 11.63it/s]

Batch 2000/7798:
total_loss: 0.38462167978286743


Epoch 4/5:  29%|██▉       | 2251/7798 [03:06<07:18, 12.65it/s]

Batch 2250/7798:
total_loss: 0.518169105052948


Epoch 4/5:  32%|███▏      | 2501/7798 [03:27<06:39, 13.26it/s]

Batch 2500/7798:
total_loss: 0.3615046441555023


Epoch 4/5:  35%|███▌      | 2751/7798 [03:48<07:14, 11.61it/s]

Batch 2750/7798:
total_loss: 0.28980541229248047


Epoch 4/5:  38%|███▊      | 3001/7798 [04:08<06:38, 12.05it/s]

Batch 3000/7798:
total_loss: 0.33859145641326904


Epoch 4/5:  42%|████▏     | 3251/7798 [04:29<06:23, 11.86it/s]

Batch 3250/7798:
total_loss: 0.29727089405059814


Epoch 4/5:  45%|████▍     | 3501/7798 [04:50<05:32, 12.92it/s]

Batch 3500/7798:
total_loss: 0.4282858967781067


Epoch 4/5:  48%|████▊     | 3751/7798 [05:10<06:03, 11.13it/s]

Batch 3750/7798:
total_loss: 0.3330022394657135


Epoch 4/5:  51%|█████▏    | 4001/7798 [05:32<05:34, 11.34it/s]

Batch 4000/7798:
total_loss: 0.3983711898326874


Epoch 4/5:  55%|█████▍    | 4251/7798 [05:52<04:56, 11.97it/s]

Batch 4250/7798:
total_loss: 0.25864893198013306


Epoch 4/5:  58%|█████▊    | 4501/7798 [06:13<05:09, 10.67it/s]

Batch 4500/7798:
total_loss: 0.31555452942848206


Epoch 4/5:  61%|██████    | 4751/7798 [06:34<04:15, 11.95it/s]

Batch 4750/7798:
total_loss: 0.4042223393917084


Epoch 4/5:  64%|██████▍   | 5001/7798 [06:54<04:00, 11.61it/s]

Batch 5000/7798:
total_loss: 0.2615407407283783


Epoch 4/5:  67%|██████▋   | 5251/7798 [07:15<03:14, 13.11it/s]

Batch 5250/7798:
total_loss: 0.2569924592971802


Epoch 4/5:  71%|███████   | 5501/7798 [07:35<03:16, 11.70it/s]

Batch 5500/7798:
total_loss: 0.3690529763698578


Epoch 4/5:  74%|███████▎  | 5751/7798 [07:56<02:42, 12.62it/s]

Batch 5750/7798:
total_loss: 0.33103707432746887


Epoch 4/5:  77%|███████▋  | 6001/7798 [08:17<02:32, 11.80it/s]

Batch 6000/7798:
total_loss: 0.3591989576816559


Epoch 4/5:  80%|████████  | 6251/7798 [08:38<01:57, 13.20it/s]

Batch 6250/7798:
total_loss: 0.37536051869392395


Epoch 4/5:  83%|████████▎ | 6501/7798 [08:58<01:40, 12.88it/s]

Batch 6500/7798:
total_loss: 0.5423932075500488


Epoch 4/5:  87%|████████▋ | 6751/7798 [09:19<01:19, 13.14it/s]

Batch 6750/7798:
total_loss: 0.40898093581199646


Epoch 4/5:  90%|████████▉ | 7001/7798 [09:39<01:03, 12.47it/s]

Batch 7000/7798:
total_loss: 0.3485669791698456


Epoch 4/5:  93%|█████████▎| 7251/7798 [10:00<00:46, 11.76it/s]

Batch 7250/7798:
total_loss: 0.3706842362880707


Epoch 4/5:  96%|█████████▌| 7501/7798 [10:21<00:26, 11.38it/s]

Batch 7500/7798:
total_loss: 0.3227999210357666


Epoch 4/5:  99%|█████████▉| 7751/7798 [10:41<00:03, 12.87it/s]

Batch 7750/7798:
total_loss: 0.24286222457885742


Epoch 4/5: 100%|██████████| 7798/7798 [10:45<00:00, 12.07it/s]
Validating: 100%|██████████| 15/15 [00:01<00:00,  8.90it/s]



Epoch Summary:
Train Total Loss: 0.3938
Val Total Loss: 0.3371
Learning Rate: 0.000100
New best model saved with val loss: 0.3371
--------------------------------------------------


Epoch 5/5:   3%|▎         | 251/7798 [00:21<10:22, 12.12it/s]

Batch 250/7798:
total_loss: 0.4039788246154785


Epoch 5/5:   6%|▋         | 501/7798 [00:41<09:22, 12.96it/s]

Batch 500/7798:
total_loss: 0.2437777817249298


Epoch 5/5:  10%|▉         | 751/7798 [01:02<10:03, 11.68it/s]

Batch 750/7798:
total_loss: 0.4464111030101776


Epoch 5/5:  13%|█▎        | 1001/7798 [01:23<08:53, 12.74it/s]

Batch 1000/7798:
total_loss: 0.2626149356365204


Epoch 5/5:  16%|█▌        | 1251/7798 [01:44<08:24, 12.98it/s]

Batch 1250/7798:
total_loss: 0.3756188154220581


Epoch 5/5:  19%|█▉        | 1501/7798 [02:05<08:52, 11.82it/s]

Batch 1500/7798:
total_loss: 0.39256107807159424


Epoch 5/5:  22%|██▏       | 1751/7798 [02:25<07:46, 12.96it/s]

Batch 1750/7798:
total_loss: 0.2687416672706604


Epoch 5/5:  26%|██▌       | 2001/7798 [02:46<08:24, 11.50it/s]

Batch 2000/7798:
total_loss: 0.35877829790115356


Epoch 5/5:  29%|██▉       | 2251/7798 [03:06<07:13, 12.78it/s]

Batch 2250/7798:
total_loss: 0.46541598439216614


Epoch 5/5:  32%|███▏      | 2501/7798 [03:27<07:37, 11.58it/s]

Batch 2500/7798:
total_loss: 0.4096839427947998


Epoch 5/5:  35%|███▌      | 2751/7798 [03:48<07:28, 11.25it/s]

Batch 2750/7798:
total_loss: 0.3747316896915436


Epoch 5/5:  38%|███▊      | 3001/7798 [04:09<06:40, 11.99it/s]

Batch 3000/7798:
total_loss: 0.44609978795051575


Epoch 5/5:  42%|████▏     | 3251/7798 [04:30<06:40, 11.34it/s]

Batch 3250/7798:
total_loss: 0.5449820756912231


Epoch 5/5:  45%|████▍     | 3501/7798 [04:50<05:34, 12.85it/s]

Batch 3500/7798:
total_loss: 0.28832879662513733


Epoch 5/5:  48%|████▊     | 3751/7798 [05:11<05:52, 11.48it/s]

Batch 3750/7798:
total_loss: 0.28799310326576233


Epoch 5/5:  51%|█████▏    | 3999/7798 [05:31<04:56, 12.81it/s]

Batch 4000/7798:
total_loss: 0.33802154660224915


Epoch 5/5:  55%|█████▍    | 4251/7798 [05:52<04:30, 13.13it/s]

Batch 4250/7798:
total_loss: 0.3033615052700043


Epoch 5/5:  58%|█████▊    | 4501/7798 [06:13<04:33, 12.03it/s]

Batch 4500/7798:
total_loss: 0.4866296350955963


Epoch 5/5:  61%|██████    | 4751/7798 [06:33<03:59, 12.75it/s]

Batch 4750/7798:
total_loss: 0.25035226345062256


Epoch 5/5:  64%|██████▍   | 5001/7798 [06:54<04:07, 11.31it/s]

Batch 5000/7798:
total_loss: 0.4633203148841858


Epoch 5/5:  67%|██████▋   | 5251/7798 [07:15<03:16, 12.96it/s]

Batch 5250/7798:
total_loss: 0.2952815890312195


Epoch 5/5:  71%|███████   | 5499/7798 [07:35<02:55, 13.07it/s]

Batch 5500/7798:
total_loss: 0.3579320013523102


Epoch 5/5:  74%|███████▎  | 5751/7798 [07:56<02:51, 11.92it/s]

Batch 5750/7798:
total_loss: 0.4820408821105957


Epoch 5/5:  77%|███████▋  | 6001/7798 [08:17<02:21, 12.70it/s]

Batch 6000/7798:
total_loss: 0.32011184096336365


Epoch 5/5:  80%|████████  | 6251/7798 [08:38<02:33, 10.09it/s]

Batch 6250/7798:
total_loss: 0.3289816379547119


Epoch 5/5:  83%|████████▎ | 6501/7798 [08:59<01:41, 12.74it/s]

Batch 6500/7798:
total_loss: 0.3487662971019745


Epoch 5/5:  87%|████████▋ | 6751/7798 [09:19<01:33, 11.24it/s]

Batch 6750/7798:
total_loss: 0.24812376499176025


Epoch 5/5:  90%|████████▉ | 7001/7798 [09:40<01:03, 12.54it/s]

Batch 7000/7798:
total_loss: 0.22473903000354767


Epoch 5/5:  93%|█████████▎| 7249/7798 [10:00<00:43, 12.49it/s]

Batch 7250/7798:
total_loss: 0.318196564912796


Epoch 5/5:  96%|█████████▌| 7501/7798 [10:22<00:22, 13.27it/s]

Batch 7500/7798:
total_loss: 0.49084946513175964


Epoch 5/5:  99%|█████████▉| 7751/7798 [10:43<00:03, 11.93it/s]

Batch 7750/7798:
total_loss: 0.18818658590316772


Epoch 5/5: 100%|██████████| 7798/7798 [10:46<00:00, 12.05it/s]
Validating: 100%|██████████| 15/15 [00:01<00:00,  8.98it/s]



Epoch Summary:
Train Total Loss: 0.3588
Val Total Loss: 0.3314
Learning Rate: 0.000100
New best model saved with val loss: 0.3314
--------------------------------------------------
