<a href="https://colab.research.google.com/github/carolynw898/STAT946Proj/blob/main/stat946-proj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple
from models import SymbolicGaussianDiffusion, PointNetConfig
from utils import CharDataset, processDataFiles, tokenize_equation

def train_epoch(
    model: SymbolicGaussianDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
        optimizer.zero_grad()

        total_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()


    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            total_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
    path=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        avg_train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f}"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f}"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)

In [3]:
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 1
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
maxNumFiles = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
dataDir = "/content/drive/MyDrive/Colab/STAT946_proj/data"
dataFolder = "1_var_dataset"

In [5]:
from torch.utils.data import DataLoader
import numpy as np
import glob
from utils import processDataFiles, CharDataset, tokenize_equation
import random
import json

path = '{}/{}/Train/*.json'.format(dataDir, dataFolder)
files = glob.glob(path)[:maxNumFiles]
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 498795 examples, 27 unique.
id:405137
outputs:C*cos(C*sin(C*x1)+C/x1)+C>____________
variables:1


In [6]:
path = '{}/{}/Val/*.json'.format(dataDir,dataFolder)
files = glob.glob(path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 972 examples, 27 unique.
tensor(-52.6893) tensor(5.2545)
id:803
outputs:C*x1/(C*x1+C)+C>__________________
variables:1


In [13]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
    set_transformer=True,
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate,
    path="1_var_set_transformer_validity.pth"
)

Epoch 1/5:   3%|▎         | 251/7794 [00:14<07:05, 17.74it/s]

Batch 250/7794:
total_loss: 1.760127305984497


Epoch 1/5:   6%|▋         | 501/7794 [00:28<07:00, 17.36it/s]

Batch 500/7794:
total_loss: 1.2598085403442383


Epoch 1/5:  10%|▉         | 751/7794 [00:42<06:40, 17.57it/s]

Batch 750/7794:
total_loss: 1.1648240089416504


Epoch 1/5:  13%|█▎        | 1001/7794 [00:56<06:24, 17.68it/s]

Batch 1000/7794:
total_loss: 0.9402586817741394


Epoch 1/5:  16%|█▌        | 1251/7794 [01:10<06:04, 17.94it/s]

Batch 1250/7794:
total_loss: 0.7823335528373718


Epoch 1/5:  19%|█▉        | 1501/7794 [01:24<05:56, 17.66it/s]

Batch 1500/7794:
total_loss: 0.5902435183525085


Epoch 1/5:  22%|██▏       | 1751/7794 [01:38<05:38, 17.86it/s]

Batch 1750/7794:
total_loss: 0.6718902587890625


Epoch 1/5:  26%|██▌       | 2001/7794 [01:52<05:27, 17.71it/s]

Batch 2000/7794:
total_loss: 0.600426197052002


Epoch 1/5:  29%|██▉       | 2251/7794 [02:07<05:15, 17.57it/s]

Batch 2250/7794:
total_loss: 0.6319841146469116


Epoch 1/5:  32%|███▏      | 2501/7794 [02:21<04:59, 17.69it/s]

Batch 2500/7794:
total_loss: 0.6263789534568787


Epoch 1/5:  35%|███▌      | 2751/7794 [02:35<04:47, 17.57it/s]

Batch 2750/7794:
total_loss: 0.8388538956642151


Epoch 1/5:  39%|███▊      | 3001/7794 [02:49<04:30, 17.70it/s]

Batch 3000/7794:
total_loss: 0.642079770565033


Epoch 1/5:  42%|████▏     | 3251/7794 [03:03<04:15, 17.79it/s]

Batch 3250/7794:
total_loss: 0.6665394306182861


Epoch 1/5:  45%|████▍     | 3501/7794 [03:17<04:03, 17.66it/s]

Batch 3500/7794:
total_loss: 0.7430440187454224


Epoch 1/5:  48%|████▊     | 3751/7794 [03:31<03:48, 17.70it/s]

Batch 3750/7794:
total_loss: 0.3497365415096283


Epoch 1/5:  51%|█████▏    | 4001/7794 [03:45<03:34, 17.66it/s]

Batch 4000/7794:
total_loss: 0.3692811131477356


Epoch 1/5:  55%|█████▍    | 4251/7794 [03:59<03:20, 17.63it/s]

Batch 4250/7794:
total_loss: 0.6714485287666321


Epoch 1/5:  58%|█████▊    | 4501/7794 [04:14<03:05, 17.77it/s]

Batch 4500/7794:
total_loss: 0.5681069493293762


Epoch 1/5:  61%|██████    | 4751/7794 [04:28<02:52, 17.69it/s]

Batch 4750/7794:
total_loss: 0.43899431824684143


Epoch 1/5:  64%|██████▍   | 5001/7794 [04:42<02:36, 17.82it/s]

Batch 5000/7794:
total_loss: 0.37503665685653687


Epoch 1/5:  67%|██████▋   | 5251/7794 [04:56<02:24, 17.59it/s]

Batch 5250/7794:
total_loss: 0.49030330777168274


Epoch 1/5:  71%|███████   | 5501/7794 [05:10<02:09, 17.72it/s]

Batch 5500/7794:
total_loss: 0.2985667884349823


Epoch 1/5:  74%|███████▍  | 5751/7794 [05:24<01:55, 17.76it/s]

Batch 5750/7794:
total_loss: 0.6765999794006348


Epoch 1/5:  77%|███████▋  | 6001/7794 [05:38<01:41, 17.69it/s]

Batch 6000/7794:
total_loss: 0.6968816518783569


Epoch 1/5:  80%|████████  | 6251/7794 [05:52<01:27, 17.58it/s]

Batch 6250/7794:
total_loss: 0.46786585450172424


Epoch 1/5:  83%|████████▎ | 6501/7794 [06:06<01:13, 17.62it/s]

Batch 6500/7794:
total_loss: 0.6111153960227966


Epoch 1/5:  87%|████████▋ | 6751/7794 [06:21<00:59, 17.45it/s]

Batch 6750/7794:
total_loss: 0.7107164263725281


Epoch 1/5:  90%|████████▉ | 7001/7794 [06:35<00:45, 17.61it/s]

Batch 7000/7794:
total_loss: 0.5944525003433228


Epoch 1/5:  93%|█████████▎| 7251/7794 [06:49<00:30, 17.54it/s]

Batch 7250/7794:
total_loss: 0.2973841428756714


Epoch 1/5:  96%|█████████▌| 7501/7794 [07:03<00:16, 17.59it/s]

Batch 7500/7794:
total_loss: 0.41381147503852844


Epoch 1/5:  99%|█████████▉| 7751/7794 [07:17<00:02, 17.62it/s]

Batch 7750/7794:
total_loss: 0.6039680242538452


Epoch 1/5: 100%|██████████| 7794/7794 [07:20<00:00, 17.71it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 31.33it/s]


Epoch Summary:
Train Total Loss: 0.7195
Val Total Loss: 0.4258
Learning Rate: 0.000100
New best model saved with val loss: 0.4258
--------------------------------------------------



Epoch 2/5:   3%|▎         | 251/7794 [00:14<07:08, 17.59it/s]

Batch 250/7794:
total_loss: 0.49523064494132996


Epoch 2/5:   6%|▋         | 501/7794 [00:28<06:51, 17.74it/s]

Batch 500/7794:
total_loss: 0.7479625344276428


Epoch 2/5:  10%|▉         | 751/7794 [00:42<06:38, 17.69it/s]

Batch 750/7794:
total_loss: 0.5059036612510681


Epoch 2/5:  13%|█▎        | 1001/7794 [00:56<06:23, 17.71it/s]

Batch 1000/7794:
total_loss: 0.5702353715896606


Epoch 2/5:  16%|█▌        | 1251/7794 [01:10<06:09, 17.72it/s]

Batch 1250/7794:
total_loss: 0.5608868598937988


Epoch 2/5:  19%|█▉        | 1501/7794 [01:24<05:56, 17.66it/s]

Batch 1500/7794:
total_loss: 0.6750186085700989


Epoch 2/5:  22%|██▏       | 1751/7794 [01:38<05:41, 17.72it/s]

Batch 1750/7794:
total_loss: 0.4750761389732361


Epoch 2/5:  26%|██▌       | 2001/7794 [01:52<05:26, 17.75it/s]

Batch 2000/7794:
total_loss: 0.4640514850616455


Epoch 2/5:  29%|██▉       | 2251/7794 [02:07<05:12, 17.74it/s]

Batch 2250/7794:
total_loss: 0.6608189940452576


Epoch 2/5:  32%|███▏      | 2501/7794 [02:21<04:58, 17.73it/s]

Batch 2500/7794:
total_loss: 0.5352658033370972


Epoch 2/5:  35%|███▌      | 2751/7794 [02:35<04:43, 17.80it/s]

Batch 2750/7794:
total_loss: 0.4523034989833832


Epoch 2/5:  39%|███▊      | 3001/7794 [02:49<04:30, 17.69it/s]

Batch 3000/7794:
total_loss: 0.4280036389827728


Epoch 2/5:  42%|████▏     | 3251/7794 [03:03<04:18, 17.60it/s]

Batch 3250/7794:
total_loss: 0.4339623749256134


Epoch 2/5:  45%|████▍     | 3501/7794 [03:17<04:03, 17.66it/s]

Batch 3500/7794:
total_loss: 0.6780082583427429


Epoch 2/5:  48%|████▊     | 3751/7794 [03:31<03:49, 17.65it/s]

Batch 3750/7794:
total_loss: 0.4959571659564972


Epoch 2/5:  51%|█████▏    | 4001/7794 [03:45<03:34, 17.72it/s]

Batch 4000/7794:
total_loss: 0.4241112470626831


Epoch 2/5:  55%|█████▍    | 4251/7794 [04:00<03:19, 17.76it/s]

Batch 4250/7794:
total_loss: 0.502284586429596


Epoch 2/5:  58%|█████▊    | 4501/7794 [04:14<03:06, 17.63it/s]

Batch 4500/7794:
total_loss: 0.5120351314544678


Epoch 2/5:  61%|██████    | 4751/7794 [04:28<02:52, 17.67it/s]

Batch 4750/7794:
total_loss: 0.37465614080429077


Epoch 2/5:  64%|██████▍   | 5001/7794 [04:42<02:36, 17.81it/s]

Batch 5000/7794:
total_loss: 0.4926343262195587


Epoch 2/5:  67%|██████▋   | 5251/7794 [04:56<02:23, 17.72it/s]

Batch 5250/7794:
total_loss: 0.40980419516563416


Epoch 2/5:  71%|███████   | 5501/7794 [05:10<02:10, 17.62it/s]

Batch 5500/7794:
total_loss: 0.4092652499675751


Epoch 2/5:  74%|███████▍  | 5751/7794 [05:24<01:56, 17.59it/s]

Batch 5750/7794:
total_loss: 0.5252406597137451


Epoch 2/5:  77%|███████▋  | 6001/7794 [05:39<01:41, 17.72it/s]

Batch 6000/7794:
total_loss: 0.5693560838699341


Epoch 2/5:  80%|████████  | 6251/7794 [05:53<01:27, 17.63it/s]

Batch 6250/7794:
total_loss: 0.4397949278354645


Epoch 2/5:  83%|████████▎ | 6501/7794 [06:07<01:13, 17.71it/s]

Batch 6500/7794:
total_loss: 0.4274633228778839


Epoch 2/5:  87%|████████▋ | 6751/7794 [06:21<00:59, 17.54it/s]

Batch 6750/7794:
total_loss: 0.7094650864601135


Epoch 2/5:  90%|████████▉ | 7001/7794 [06:35<00:44, 17.76it/s]

Batch 7000/7794:
total_loss: 0.5777230858802795


Epoch 2/5:  93%|█████████▎| 7251/7794 [06:49<00:30, 17.62it/s]

Batch 7250/7794:
total_loss: 0.5101160407066345


Epoch 2/5:  96%|█████████▌| 7501/7794 [07:03<00:16, 17.65it/s]

Batch 7500/7794:
total_loss: 0.4266366958618164


Epoch 2/5:  99%|█████████▉| 7751/7794 [07:18<00:02, 17.67it/s]

Batch 7750/7794:
total_loss: 0.3235415518283844


Epoch 2/5: 100%|██████████| 7794/7794 [07:20<00:00, 17.69it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 33.23it/s]



Epoch Summary:
Train Total Loss: 0.4620
Val Total Loss: 0.3669
Learning Rate: 0.000100
New best model saved with val loss: 0.3669
--------------------------------------------------


Epoch 3/5:   3%|▎         | 251/7794 [00:14<07:06, 17.68it/s]

Batch 250/7794:
total_loss: 0.4311765134334564


Epoch 3/5:   6%|▋         | 501/7794 [00:28<06:52, 17.67it/s]

Batch 500/7794:
total_loss: 0.3070276081562042


Epoch 3/5:  10%|▉         | 751/7794 [00:42<06:35, 17.81it/s]

Batch 750/7794:
total_loss: 0.5010106563568115


Epoch 3/5:  13%|█▎        | 1001/7794 [00:56<06:22, 17.74it/s]

Batch 1000/7794:
total_loss: 0.6703079342842102


Epoch 3/5:  16%|█▌        | 1251/7794 [01:10<06:10, 17.67it/s]

Batch 1250/7794:
total_loss: 0.42082729935646057


Epoch 3/5:  19%|█▉        | 1501/7794 [01:24<05:54, 17.75it/s]

Batch 1500/7794:
total_loss: 0.385206937789917


Epoch 3/5:  22%|██▏       | 1751/7794 [01:38<05:39, 17.78it/s]

Batch 1750/7794:
total_loss: 0.2857153117656708


Epoch 3/5:  26%|██▌       | 2001/7794 [01:52<05:25, 17.77it/s]

Batch 2000/7794:
total_loss: 0.643294632434845


Epoch 3/5:  29%|██▉       | 2251/7794 [02:07<05:13, 17.70it/s]

Batch 2250/7794:
total_loss: 0.4485008418560028


Epoch 3/5:  32%|███▏      | 2501/7794 [02:21<05:01, 17.58it/s]

Batch 2500/7794:
total_loss: 0.610541045665741


Epoch 3/5:  35%|███▌      | 2751/7794 [02:35<04:46, 17.58it/s]

Batch 2750/7794:
total_loss: 0.40324923396110535


Epoch 3/5:  39%|███▊      | 3001/7794 [02:49<04:31, 17.68it/s]

Batch 3000/7794:
total_loss: 0.4330770671367645


Epoch 3/5:  42%|████▏     | 3251/7794 [03:03<04:16, 17.68it/s]

Batch 3250/7794:
total_loss: 0.2668856680393219


Epoch 3/5:  45%|████▍     | 3501/7794 [03:17<04:03, 17.61it/s]

Batch 3500/7794:
total_loss: 0.564409077167511


Epoch 3/5:  48%|████▊     | 3751/7794 [03:31<03:47, 17.80it/s]

Batch 3750/7794:
total_loss: 0.5253268480300903


Epoch 3/5:  51%|█████▏    | 4001/7794 [03:46<03:34, 17.67it/s]

Batch 4000/7794:
total_loss: 0.5704588294029236


Epoch 3/5:  55%|█████▍    | 4251/7794 [04:00<03:20, 17.63it/s]

Batch 4250/7794:
total_loss: 0.47228285670280457


Epoch 3/5:  58%|█████▊    | 4501/7794 [04:14<03:06, 17.65it/s]

Batch 4500/7794:
total_loss: 0.48994702100753784


Epoch 3/5:  61%|██████    | 4751/7794 [04:28<02:52, 17.68it/s]

Batch 4750/7794:
total_loss: 0.4265270233154297


Epoch 3/5:  64%|██████▍   | 5001/7794 [04:42<02:37, 17.68it/s]

Batch 5000/7794:
total_loss: 0.3683379888534546


Epoch 3/5:  67%|██████▋   | 5251/7794 [04:56<02:23, 17.67it/s]

Batch 5250/7794:
total_loss: 0.3322710394859314


Epoch 3/5:  71%|███████   | 5501/7794 [05:10<02:10, 17.62it/s]

Batch 5500/7794:
total_loss: 0.35403677821159363


Epoch 3/5:  74%|███████▍  | 5751/7794 [05:24<01:55, 17.64it/s]

Batch 5750/7794:
total_loss: 0.5693978667259216


Epoch 3/5:  77%|███████▋  | 6001/7794 [05:39<01:41, 17.59it/s]

Batch 6000/7794:
total_loss: 0.49416542053222656


Epoch 3/5:  80%|████████  | 6251/7794 [05:53<01:28, 17.44it/s]

Batch 6250/7794:
total_loss: 0.2775140106678009


Epoch 3/5:  83%|████████▎ | 6501/7794 [06:07<01:13, 17.66it/s]

Batch 6500/7794:
total_loss: 0.523048996925354


Epoch 3/5:  87%|████████▋ | 6751/7794 [06:21<00:59, 17.49it/s]

Batch 6750/7794:
total_loss: 0.432698518037796


Epoch 3/5:  90%|████████▉ | 7001/7794 [06:35<00:45, 17.56it/s]

Batch 7000/7794:
total_loss: 0.27505019307136536


Epoch 3/5:  93%|█████████▎| 7251/7794 [06:49<00:30, 17.63it/s]

Batch 7250/7794:
total_loss: 0.3851536214351654


Epoch 3/5:  96%|█████████▌| 7501/7794 [07:04<00:16, 17.63it/s]

Batch 7500/7794:
total_loss: 0.28292545676231384


Epoch 3/5:  99%|█████████▉| 7751/7794 [07:18<00:02, 17.71it/s]

Batch 7750/7794:
total_loss: 0.37029948830604553


Epoch 3/5: 100%|██████████| 7794/7794 [07:20<00:00, 17.69it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 33.52it/s]



Epoch Summary:
Train Total Loss: 0.4025
Val Total Loss: 0.2994
Learning Rate: 0.000100
New best model saved with val loss: 0.2994
--------------------------------------------------


Epoch 4/5:   3%|▎         | 251/7794 [00:14<07:06, 17.68it/s]

Batch 250/7794:
total_loss: 0.33545979857444763


Epoch 4/5:   6%|▋         | 501/7794 [00:28<06:54, 17.60it/s]

Batch 500/7794:
total_loss: 0.2773694396018982


Epoch 4/5:  10%|▉         | 751/7794 [00:42<06:38, 17.68it/s]

Batch 750/7794:
total_loss: 0.41336575150489807


Epoch 4/5:  13%|█▎        | 1001/7794 [00:56<06:25, 17.64it/s]

Batch 1000/7794:
total_loss: 0.4911518692970276


Epoch 4/5:  16%|█▌        | 1251/7794 [01:10<06:12, 17.55it/s]

Batch 1250/7794:
total_loss: 0.4201255738735199


Epoch 4/5:  19%|█▉        | 1501/7794 [01:25<05:54, 17.74it/s]

Batch 1500/7794:
total_loss: 0.37399476766586304


Epoch 4/5:  22%|██▏       | 1751/7794 [01:39<05:42, 17.64it/s]

Batch 1750/7794:
total_loss: 0.26680415868759155


Epoch 4/5:  26%|██▌       | 2001/7794 [01:53<05:27, 17.68it/s]

Batch 2000/7794:
total_loss: 0.36984983086586


Epoch 4/5:  29%|██▉       | 2251/7794 [02:07<05:14, 17.62it/s]

Batch 2250/7794:
total_loss: 0.2986784279346466


Epoch 4/5:  32%|███▏      | 2501/7794 [02:21<04:58, 17.70it/s]

Batch 2500/7794:
total_loss: 0.27087533473968506


Epoch 4/5:  35%|███▌      | 2751/7794 [02:35<04:46, 17.61it/s]

Batch 2750/7794:
total_loss: 0.23219624161720276


Epoch 4/5:  39%|███▊      | 3001/7794 [02:49<04:30, 17.69it/s]

Batch 3000/7794:
total_loss: 0.2211928367614746


Epoch 4/5:  42%|████▏     | 3251/7794 [03:03<04:18, 17.58it/s]

Batch 3250/7794:
total_loss: 0.4725821614265442


Epoch 4/5:  45%|████▍     | 3501/7794 [03:18<04:03, 17.61it/s]

Batch 3500/7794:
total_loss: 0.33953362703323364


Epoch 4/5:  48%|████▊     | 3751/7794 [03:32<03:50, 17.58it/s]

Batch 3750/7794:
total_loss: 0.22926954925060272


Epoch 4/5:  51%|█████▏    | 4001/7794 [03:46<03:37, 17.46it/s]

Batch 4000/7794:
total_loss: 0.22325992584228516


Epoch 4/5:  55%|█████▍    | 4251/7794 [04:00<03:19, 17.72it/s]

Batch 4250/7794:
total_loss: 0.37449368834495544


Epoch 4/5:  58%|█████▊    | 4501/7794 [04:14<03:07, 17.61it/s]

Batch 4500/7794:
total_loss: 0.32956281304359436


Epoch 4/5:  61%|██████    | 4751/7794 [04:28<02:52, 17.65it/s]

Batch 4750/7794:
total_loss: 0.4278548061847687


Epoch 4/5:  64%|██████▍   | 5001/7794 [04:43<02:38, 17.62it/s]

Batch 5000/7794:
total_loss: 0.4098535478115082


Epoch 4/5:  67%|██████▋   | 5251/7794 [04:57<02:24, 17.62it/s]

Batch 5250/7794:
total_loss: 0.40436142683029175


Epoch 4/5:  71%|███████   | 5501/7794 [05:11<02:09, 17.69it/s]

Batch 5500/7794:
total_loss: 0.21059221029281616


Epoch 4/5:  74%|███████▍  | 5751/7794 [05:25<01:55, 17.73it/s]

Batch 5750/7794:
total_loss: 0.28528887033462524


Epoch 4/5:  77%|███████▋  | 6001/7794 [05:39<01:41, 17.58it/s]

Batch 6000/7794:
total_loss: 0.2939842641353607


Epoch 4/5:  80%|████████  | 6251/7794 [05:53<01:27, 17.71it/s]

Batch 6250/7794:
total_loss: 0.33765339851379395


Epoch 4/5:  83%|████████▎ | 6501/7794 [06:07<01:12, 17.73it/s]

Batch 6500/7794:
total_loss: 0.3082755208015442


Epoch 4/5:  87%|████████▋ | 6751/7794 [06:22<00:59, 17.66it/s]

Batch 6750/7794:
total_loss: 0.3715333640575409


Epoch 4/5:  90%|████████▉ | 7001/7794 [06:36<00:44, 17.62it/s]

Batch 7000/7794:
total_loss: 0.3587143123149872


Epoch 4/5:  93%|█████████▎| 7251/7794 [06:50<00:31, 17.49it/s]

Batch 7250/7794:
total_loss: 0.3646736741065979


Epoch 4/5:  96%|█████████▌| 7501/7794 [07:04<00:16, 17.56it/s]

Batch 7500/7794:
total_loss: 0.4415661096572876


Epoch 4/5:  99%|█████████▉| 7751/7794 [07:18<00:02, 17.56it/s]

Batch 7750/7794:
total_loss: 0.28538283705711365


Epoch 4/5: 100%|██████████| 7794/7794 [07:21<00:00, 17.66it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 32.87it/s]


Epoch Summary:
Train Total Loss: 0.3577
Val Total Loss: 0.3060
Learning Rate: 0.000100
--------------------------------------------------



Epoch 5/5:   3%|▎         | 251/7794 [00:14<07:06, 17.69it/s]

Batch 250/7794:
total_loss: 0.45593273639678955


Epoch 5/5:   6%|▋         | 501/7794 [00:28<06:49, 17.81it/s]

Batch 500/7794:
total_loss: 0.2640936076641083


Epoch 5/5:  10%|▉         | 751/7794 [00:42<06:38, 17.67it/s]

Batch 750/7794:
total_loss: 0.5390698313713074


Epoch 5/5:  13%|█▎        | 1001/7794 [00:56<06:26, 17.60it/s]

Batch 1000/7794:
total_loss: 0.18360091745853424


Epoch 5/5:  16%|█▌        | 1251/7794 [01:10<06:08, 17.75it/s]

Batch 1250/7794:
total_loss: 0.30243751406669617


Epoch 5/5:  19%|█▉        | 1501/7794 [01:24<05:55, 17.68it/s]

Batch 1500/7794:
total_loss: 0.30511561036109924


Epoch 5/5:  22%|██▏       | 1751/7794 [01:38<05:42, 17.64it/s]

Batch 1750/7794:
total_loss: 0.402237206697464


Epoch 5/5:  26%|██▌       | 2001/7794 [01:52<05:28, 17.61it/s]

Batch 2000/7794:
total_loss: 0.3117322623729706


Epoch 5/5:  29%|██▉       | 2251/7794 [02:07<05:12, 17.73it/s]

Batch 2250/7794:
total_loss: 0.2203267514705658


Epoch 5/5:  32%|███▏      | 2501/7794 [02:21<05:00, 17.64it/s]

Batch 2500/7794:
total_loss: 0.18605409562587738


Epoch 5/5:  35%|███▌      | 2751/7794 [02:35<04:43, 17.76it/s]

Batch 2750/7794:
total_loss: 0.17528551816940308


Epoch 5/5:  39%|███▊      | 3001/7794 [02:49<04:30, 17.69it/s]

Batch 3000/7794:
total_loss: 0.28981468081474304


Epoch 5/5:  42%|████▏     | 3251/7794 [03:03<04:16, 17.74it/s]

Batch 3250/7794:
total_loss: 0.30644509196281433


Epoch 5/5:  45%|████▍     | 3501/7794 [03:17<04:03, 17.62it/s]

Batch 3500/7794:
total_loss: 0.33486202359199524


Epoch 5/5:  48%|████▊     | 3751/7794 [03:31<03:48, 17.67it/s]

Batch 3750/7794:
total_loss: 0.259114146232605


Epoch 5/5:  51%|█████▏    | 4001/7794 [03:45<03:33, 17.73it/s]

Batch 4000/7794:
total_loss: 0.2655238211154938


Epoch 5/5:  55%|█████▍    | 4251/7794 [03:59<03:20, 17.67it/s]

Batch 4250/7794:
total_loss: 0.3943725824356079


Epoch 5/5:  58%|█████▊    | 4501/7794 [04:14<03:06, 17.64it/s]

Batch 4500/7794:
total_loss: 0.38045376539230347


Epoch 5/5:  61%|██████    | 4751/7794 [04:28<02:51, 17.70it/s]

Batch 4750/7794:
total_loss: 0.28380244970321655


Epoch 5/5:  64%|██████▍   | 5001/7794 [04:42<02:39, 17.56it/s]

Batch 5000/7794:
total_loss: 0.18494218587875366


Epoch 5/5:  67%|██████▋   | 5251/7794 [04:56<02:23, 17.73it/s]

Batch 5250/7794:
total_loss: 0.21959492564201355


Epoch 5/5:  71%|███████   | 5501/7794 [05:10<02:10, 17.61it/s]

Batch 5500/7794:
total_loss: 0.3479990065097809


Epoch 5/5:  74%|███████▍  | 5751/7794 [05:24<01:55, 17.62it/s]

Batch 5750/7794:
total_loss: 0.33242133259773254


Epoch 5/5:  77%|███████▋  | 6001/7794 [05:38<01:41, 17.64it/s]

Batch 6000/7794:
total_loss: 0.4794304072856903


Epoch 5/5:  80%|████████  | 6251/7794 [05:52<01:27, 17.72it/s]

Batch 6250/7794:
total_loss: 0.38104018568992615


Epoch 5/5:  83%|████████▎ | 6501/7794 [06:07<01:13, 17.59it/s]

Batch 6500/7794:
total_loss: 0.36106765270233154


Epoch 5/5:  87%|████████▋ | 6751/7794 [06:21<00:59, 17.65it/s]

Batch 6750/7794:
total_loss: 0.3501778244972229


Epoch 5/5:  90%|████████▉ | 7001/7794 [06:35<00:44, 17.68it/s]

Batch 7000/7794:
total_loss: 0.2216043919324875


Epoch 5/5:  93%|█████████▎| 7251/7794 [06:49<00:30, 17.66it/s]

Batch 7250/7794:
total_loss: 0.23670180141925812


Epoch 5/5:  96%|█████████▌| 7501/7794 [07:03<00:16, 17.70it/s]

Batch 7500/7794:
total_loss: 0.30171626806259155


Epoch 5/5:  99%|█████████▉| 7751/7794 [07:17<00:02, 17.69it/s]

Batch 7750/7794:
total_loss: 0.3648933172225952


Epoch 5/5: 100%|██████████| 7794/7794 [07:20<00:00, 17.70it/s]
Validating: 100%|██████████| 16/16 [00:00<00:00, 33.53it/s]



Epoch Summary:
Train Total Loss: 0.3181
Val Total Loss: 0.2718
Learning Rate: 0.000100
New best model saved with val loss: 0.2718
--------------------------------------------------
