<a href="https://colab.research.google.com/github/carolynw898/STAT946Proj/blob/main/stat946-proj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import tqdm
from typing import Tuple
from models import SymbolicGaussianDiffusion, PointNetConfig
from utils import CharDataset, processDataFiles, tokenize_equation

def train_epoch(
    model: SymbolicGaussianDiffusion,
    train_loader: DataLoader,
    optimizer: Adam,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.train()
    total_train_loss = 0

    for i, (_, tokens, points, variables) in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{num_epochs}",
    ):
        points, tokens, variables = (
            points.to(device),
            tokens.to(device),
            variables.to(device),
        )
        t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
        optimizer.zero_grad()

        total_loss = model(points, tokens, variables, t)

        if (i + 1) % 250 == 0:
            print(f"Batch {i + 1}/{len(train_loader)}:")
            print(f"total_loss: {total_loss}")

        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()


    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss


def val_epoch(
    model: SymbolicGaussianDiffusion,
    val_loader: DataLoader,
    train_dataset: CharDataset,
    timesteps: int,
    device: torch.device,
    epoch: int,
    num_epochs: int,
) -> Tuple[float, float, float]:
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for _, tokens, points, variables in tqdm.tqdm(
            val_loader, total=len(val_loader), desc="Validating"
        ):
            points, tokens, variables = (
                points.to(device),
                tokens.to(device),
                variables.to(device),
            )
            t = torch.randint(0, timesteps, (tokens.shape[0],), device=device)
            total_loss = model(points, tokens, variables, t)

            total_val_loss += total_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss


def train_single_gpu(
    model: SymbolicGaussianDiffusion,
    train_dataset: CharDataset,
    val_dataset: CharDataset,
    num_epochs=10,
    save_every=2,
    batch_size=32,
    timesteps=1000,
    learning_rate=1e-3,
    path=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=4,
    )

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        avg_train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            train_dataset,
            timesteps,
            device,
            epoch,
            num_epochs,
        )

        avg_val_loss = val_epoch(
            model, val_loader, train_dataset, timesteps, device, epoch, num_epochs
        )

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        print("\nEpoch Summary:")
        print(
            f"Train Total Loss: {avg_train_loss:.4f}"
        )
        print(
            f"Val Total Loss: {avg_val_loss:.4f}"
        )
        print(f"Learning Rate: {current_lr:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            state_dict = model.state_dict()
            torch.save(state_dict, path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")

        print("-" * 50)

In [15]:
n_embd = 512
timesteps = 1000
batch_size = 64
learning_rate = 1e-4
num_epochs = 5
blockSize = 32
numVars = 3
numYs = 1
numPoints = 250
target = 'Skeleton'
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
maxNumFiles = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
dataDir = "/content/drive/MyDrive/Colab/STAT946_proj/data"
dataFolder = "3_var_dataset"

In [17]:
from torch.utils.data import DataLoader
import numpy as np
import glob
from utils import processDataFiles, CharDataset, tokenize_equation
import random
import json

path = '{}/{}/Train/*.json'.format(dataDir, dataFolder)
files = glob.glob(path)[:maxNumFiles]
text = processDataFiles(files)
text = text.split('\n') # convert the raw text to a set of examples
# skeletons = []
skeletons = [json.loads(item)['Skeleton'] for item in text if item.strip()]
all_tokens = set()
for eq in skeletons:
    all_tokens.update(tokenize_equation(eq))
integers = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
all_tokens.update(integers)  # add all integers to the token set
tokens = sorted(list(all_tokens) + ['_', 'T', '<', '>', ':'])  # special tokens
trainText = text[:-1] if len(text[-1]) == 0 else text
random.shuffle(trainText) # shuffle the dataset, it's important specailly for the combined number of variables experiment
train_dataset = CharDataset(trainText, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

idx = np.random.randint(train_dataset.__len__())
inputs, outputs, points, variables = train_dataset.__getitem__(idx)
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 498101 examples, 30 unique.
id:391541
outputs:C*cos(C*x2**5)+C>___________________
variables:2


In [18]:
path = '{}/{}/Val/*.json'.format(dataDir,dataFolder)
files = glob.glob(path)
textVal = processDataFiles([files[0]])
textVal = textVal.split('\n') # convert the raw text to a set of examples
val_dataset = CharDataset(textVal, blockSize, tokens=tokens, numVars=numVars,
                        numYs=numYs, numPoints=numPoints, target=target, addVars=addVars,
                        const_range=const_range, xRange=trainRange, decimals=decimals)

# print a random sample
idx = np.random.randint(val_dataset.__len__())
inputs, outputs, points, variables = val_dataset.__getitem__(idx)
print(points.min(), points.max())
inputs = ''.join([train_dataset.itos[int(i)] for i in inputs])
outputs = ''.join([train_dataset.itos[int(i)] for i in outputs])
print('id:{}\noutputs:{}\nvariables:{}'.format(idx,outputs,variables))

data has 948 examples, 30 unique.
tensor(-308.6018) tensor(16.9162)
id:62
outputs:C*x1*C*sin(C*x2+C)*C*cos(C*x1+C)/x2+C*x2+
variables:3


In [19]:
pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=numPoints,
    numberofVars=numVars,
    numberofYs=numYs,
)

model = SymbolicGaussianDiffusion(
    tnet_config=pconfig,
    vocab_size=train_dataset.vocab_size,
    max_seq_len=blockSize,
    padding_idx=train_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
    set_transformer=True,
)

train_single_gpu(
    model,
    train_dataset,
    val_dataset,
    num_epochs=num_epochs,
    save_every=2,
    batch_size=batch_size,
    timesteps=timesteps,
    learning_rate=learning_rate,
    path="/content/drive/MyDrive/Colab/STAT946_proj/models/stable_diffusym/3_var_set_transformer_sd.pth"
)

Epoch 1/5:   3%|▎         | 251/7783 [00:27<12:48,  9.81it/s]

Batch 250/7783:
total_loss: 2.487004280090332


Epoch 1/5:   6%|▋         | 500/7783 [00:54<12:11,  9.96it/s]

Batch 500/7783:
total_loss: 1.139195442199707


Epoch 1/5:  10%|▉         | 752/7783 [01:21<11:10, 10.49it/s]

Batch 750/7783:
total_loss: 1.156992793083191


Epoch 1/5:  13%|█▎        | 1001/7783 [01:47<11:41,  9.67it/s]

Batch 1000/7783:
total_loss: 1.1037955284118652


Epoch 1/5:  16%|█▌        | 1251/7783 [02:14<10:33, 10.31it/s]

Batch 1250/7783:
total_loss: 0.7133010625839233


Epoch 1/5:  19%|█▉        | 1499/7783 [02:40<10:05, 10.37it/s]

Batch 1500/7783:
total_loss: 0.8165925145149231


Epoch 1/5:  22%|██▏       | 1751/7783 [03:07<09:49, 10.23it/s]

Batch 1750/7783:
total_loss: 0.8534173965454102


Epoch 1/5:  26%|██▌       | 2001/7783 [03:34<09:38,  9.99it/s]

Batch 2000/7783:
total_loss: 0.6771409511566162


Epoch 1/5:  29%|██▉       | 2250/7783 [04:00<09:20,  9.87it/s]

Batch 2250/7783:
total_loss: 0.7251788973808289


Epoch 1/5:  32%|███▏      | 2501/7783 [04:27<08:46, 10.04it/s]

Batch 2500/7783:
total_loss: 0.4326946437358856


Epoch 1/5:  35%|███▌      | 2751/7783 [04:54<09:22,  8.94it/s]

Batch 2750/7783:
total_loss: 0.6554166078567505


Epoch 1/5:  39%|███▊      | 3000/7783 [05:20<07:46, 10.26it/s]

Batch 3000/7783:
total_loss: 0.45467954874038696


Epoch 1/5:  42%|████▏     | 3251/7783 [05:48<08:32,  8.84it/s]

Batch 3250/7783:
total_loss: 0.7153379917144775


Epoch 1/5:  45%|████▍     | 3499/7783 [06:14<07:31,  9.49it/s]

Batch 3500/7783:
total_loss: 0.538017749786377


Epoch 1/5:  48%|████▊     | 3751/7783 [06:41<06:25, 10.47it/s]

Batch 3750/7783:
total_loss: 0.46045127511024475


Epoch 1/5:  51%|█████▏    | 3999/7783 [07:08<06:46,  9.32it/s]

Batch 4000/7783:
total_loss: 0.7207744121551514


Epoch 1/5:  55%|█████▍    | 4252/7783 [07:34<05:44, 10.24it/s]

Batch 4250/7783:
total_loss: 0.5130143761634827


Epoch 1/5:  58%|█████▊    | 4501/7783 [08:01<05:22, 10.19it/s]

Batch 4500/7783:
total_loss: 0.4884134829044342


Epoch 1/5:  61%|██████    | 4749/7783 [08:28<04:57, 10.19it/s]

Batch 4750/7783:
total_loss: 0.6816918849945068


Epoch 1/5:  64%|██████▍   | 5001/7783 [08:55<04:27, 10.41it/s]

Batch 5000/7783:
total_loss: 0.5933822393417358


Epoch 1/5:  67%|██████▋   | 5250/7783 [09:21<04:08, 10.20it/s]

Batch 5250/7783:
total_loss: 0.5761051177978516


Epoch 1/5:  71%|███████   | 5501/7783 [09:48<03:43, 10.21it/s]

Batch 5500/7783:
total_loss: 0.2744663655757904


Epoch 1/5:  74%|███████▍  | 5749/7783 [10:14<03:21, 10.11it/s]

Batch 5750/7783:
total_loss: 0.4002000093460083


Epoch 1/5:  77%|███████▋  | 5999/7783 [10:42<02:55, 10.14it/s]

Batch 6000/7783:
total_loss: 0.41747236251831055


Epoch 1/5:  80%|████████  | 6251/7783 [11:09<02:27, 10.38it/s]

Batch 6250/7783:
total_loss: 0.8041149377822876


Epoch 1/5:  84%|████████▎ | 6502/7783 [11:36<02:01, 10.51it/s]

Batch 6500/7783:
total_loss: 0.7349252700805664


Epoch 1/5:  87%|████████▋ | 6750/7783 [12:02<01:47,  9.57it/s]

Batch 6750/7783:
total_loss: 0.7058141231536865


Epoch 1/5:  90%|████████▉ | 7002/7783 [12:29<01:15, 10.31it/s]

Batch 7000/7783:
total_loss: 0.6545202732086182


Epoch 1/5:  93%|█████████▎| 7250/7783 [12:56<00:51, 10.33it/s]

Batch 7250/7783:
total_loss: 0.5732424259185791


Epoch 1/5:  96%|█████████▋| 7502/7783 [13:22<00:27, 10.34it/s]

Batch 7500/7783:
total_loss: 0.516656219959259


Epoch 1/5: 100%|█████████▉| 7750/7783 [13:49<00:03, 10.42it/s]

Batch 7750/7783:
total_loss: 0.39046618342399597


Epoch 1/5: 100%|██████████| 7783/7783 [13:53<00:00,  9.33it/s]
Validating: 100%|██████████| 15/15 [00:02<00:00,  5.07it/s]


Epoch Summary:
Train Total Loss: 0.7783
Val Total Loss: 0.5273
Learning Rate: 0.000100
New best model saved with val loss: 0.5273
--------------------------------------------------



Epoch 2/5:   3%|▎         | 252/7783 [00:27<12:16, 10.23it/s]

Batch 250/7783:
total_loss: 0.6174187660217285


Epoch 2/5:   6%|▋         | 500/7783 [00:54<13:22,  9.08it/s]

Batch 500/7783:
total_loss: 0.5130387544631958


Epoch 2/5:  10%|▉         | 752/7783 [01:21<11:18, 10.36it/s]

Batch 750/7783:
total_loss: 0.5422541499137878


Epoch 2/5:  13%|█▎        | 1000/7783 [01:48<11:06, 10.18it/s]

Batch 1000/7783:
total_loss: 0.5579790472984314


Epoch 2/5:  16%|█▌        | 1251/7783 [02:15<12:27,  8.74it/s]

Batch 1250/7783:
total_loss: 0.49765831232070923


Epoch 2/5:  19%|█▉        | 1501/7783 [02:41<09:58, 10.49it/s]

Batch 1500/7783:
total_loss: 0.460420161485672


Epoch 2/5:  23%|██▎       | 1752/7783 [03:08<10:04,  9.97it/s]

Batch 1750/7783:
total_loss: 0.4868209958076477


Epoch 2/5:  26%|██▌       | 2000/7783 [03:35<09:18, 10.35it/s]

Batch 2000/7783:
total_loss: 0.7166166305541992


Epoch 2/5:  29%|██▉       | 2252/7783 [04:02<08:51, 10.41it/s]

Batch 2250/7783:
total_loss: 0.6488161683082581


Epoch 2/5:  32%|███▏      | 2500/7783 [04:29<08:34, 10.27it/s]

Batch 2500/7783:
total_loss: 0.48197680711746216


Epoch 2/5:  35%|███▌      | 2752/7783 [04:56<08:13, 10.20it/s]

Batch 2750/7783:
total_loss: 0.6194902062416077


Epoch 2/5:  39%|███▊      | 3000/7783 [05:23<07:51, 10.14it/s]

Batch 3000/7783:
total_loss: 0.5945083498954773


Epoch 2/5:  42%|████▏     | 3251/7783 [05:50<07:28, 10.10it/s]

Batch 3250/7783:
total_loss: 0.7003620862960815


Epoch 2/5:  45%|████▍     | 3500/7783 [06:17<07:30,  9.50it/s]

Batch 3500/7783:
total_loss: 0.37887275218963623


Epoch 2/5:  48%|████▊     | 3752/7783 [06:44<06:29, 10.34it/s]

Batch 3750/7783:
total_loss: 0.630586564540863


Epoch 2/5:  51%|█████▏    | 4002/7783 [07:11<06:15, 10.07it/s]

Batch 4000/7783:
total_loss: 0.7384699583053589


Epoch 2/5:  55%|█████▍    | 4251/7783 [07:37<07:09,  8.23it/s]

Batch 4250/7783:
total_loss: 0.5160409808158875


Epoch 2/5:  58%|█████▊    | 4501/7783 [08:05<05:24, 10.13it/s]

Batch 4500/7783:
total_loss: 0.6033084392547607


Epoch 2/5:  61%|██████    | 4751/7783 [08:32<05:40,  8.91it/s]

Batch 4750/7783:
total_loss: 0.6495813131332397


Epoch 2/5:  64%|██████▍   | 5000/7783 [08:58<04:43,  9.82it/s]

Batch 5000/7783:
total_loss: 0.23567599058151245


Epoch 2/5:  67%|██████▋   | 5252/7783 [09:26<04:07, 10.23it/s]

Batch 5250/7783:
total_loss: 0.40380895137786865


Epoch 2/5:  71%|███████   | 5501/7783 [09:53<03:39, 10.37it/s]

Batch 5500/7783:
total_loss: 0.626427412033081


Epoch 2/5:  74%|███████▍  | 5752/7783 [10:20<03:14, 10.42it/s]

Batch 5750/7783:
total_loss: 0.47742366790771484


Epoch 2/5:  77%|███████▋  | 6002/7783 [10:47<03:05,  9.58it/s]

Batch 6000/7783:
total_loss: 0.2708381116390228


Epoch 2/5:  80%|████████  | 6251/7783 [11:14<02:29, 10.22it/s]

Batch 6250/7783:
total_loss: 0.3851410448551178


Epoch 2/5:  84%|████████▎ | 6499/7783 [11:41<02:05, 10.24it/s]

Batch 6500/7783:
total_loss: 0.39663615822792053


Epoch 2/5:  87%|████████▋ | 6749/7783 [12:08<01:42, 10.09it/s]

Batch 6750/7783:
total_loss: 0.42378655076026917


Epoch 2/5:  90%|████████▉ | 7001/7783 [12:36<01:18,  9.94it/s]

Batch 7000/7783:
total_loss: 0.309518426656723


Epoch 2/5:  93%|█████████▎| 7249/7783 [13:02<00:53, 10.04it/s]

Batch 7250/7783:
total_loss: 0.325260728597641


Epoch 2/5:  96%|█████████▋| 7501/7783 [13:29<00:27, 10.30it/s]

Batch 7500/7783:
total_loss: 0.5203643441200256


Epoch 2/5: 100%|█████████▉| 7749/7783 [13:56<00:03, 10.03it/s]

Batch 7750/7783:
total_loss: 0.5733766555786133


Epoch 2/5: 100%|██████████| 7783/7783 [14:00<00:00,  9.26it/s]
Validating: 100%|██████████| 15/15 [00:02<00:00,  5.00it/s]


Epoch Summary:
Train Total Loss: 0.5154
Val Total Loss: 0.4498
Learning Rate: 0.000100
New best model saved with val loss: 0.4498
--------------------------------------------------



Epoch 3/5:   3%|▎         | 252/7783 [00:27<12:09, 10.33it/s]

Batch 250/7783:
total_loss: 0.4803950786590576


Epoch 3/5:   6%|▋         | 500/7783 [00:53<12:01, 10.10it/s]

Batch 500/7783:
total_loss: 0.6825218796730042


Epoch 3/5:  10%|▉         | 752/7783 [01:21<11:36, 10.09it/s]

Batch 750/7783:
total_loss: 0.6157584190368652


Epoch 3/5:  13%|█▎        | 1000/7783 [01:47<11:25,  9.89it/s]

Batch 1000/7783:
total_loss: 0.517510175704956


Epoch 3/5:  16%|█▌        | 1252/7783 [02:15<10:45, 10.12it/s]

Batch 1250/7783:
total_loss: 0.3378576636314392


Epoch 3/5:  19%|█▉        | 1500/7783 [02:41<10:07, 10.35it/s]

Batch 1500/7783:
total_loss: 0.41662636399269104


Epoch 3/5:  22%|██▏       | 1751/7783 [03:08<09:57, 10.10it/s]

Batch 1750/7783:
total_loss: 0.4612868130207062


Epoch 3/5:  26%|██▌       | 2001/7783 [03:36<09:38, 10.00it/s]

Batch 2000/7783:
total_loss: 0.6557658314704895


Epoch 3/5:  29%|██▉       | 2249/7783 [04:03<09:12, 10.02it/s]

Batch 2250/7783:
total_loss: 0.8163413405418396


Epoch 3/5:  32%|███▏      | 2501/7783 [04:29<08:43, 10.08it/s]

Batch 2500/7783:
total_loss: 0.556919276714325


Epoch 3/5:  35%|███▌      | 2751/7783 [04:56<09:49,  8.53it/s]

Batch 2750/7783:
total_loss: 0.5780665278434753


Epoch 3/5:  39%|███▊      | 3001/7783 [05:24<08:08,  9.79it/s]

Batch 3000/7783:
total_loss: 0.6945744156837463


Epoch 3/5:  42%|████▏     | 3251/7783 [05:51<08:45,  8.63it/s]

Batch 3250/7783:
total_loss: 0.4697644114494324


Epoch 3/5:  45%|████▍     | 3502/7783 [06:18<07:01, 10.16it/s]

Batch 3500/7783:
total_loss: 0.45015716552734375


Epoch 3/5:  48%|████▊     | 3752/7783 [06:45<06:39, 10.09it/s]

Batch 3750/7783:
total_loss: 0.3002169728279114


Epoch 3/5:  51%|█████▏    | 4001/7783 [07:13<06:15, 10.07it/s]

Batch 4000/7783:
total_loss: 0.5114618539810181


Epoch 3/5:  55%|█████▍    | 4252/7783 [07:40<05:47, 10.17it/s]

Batch 4250/7783:
total_loss: 0.36421725153923035


Epoch 3/5:  58%|█████▊    | 4500/7783 [08:07<05:23, 10.16it/s]

Batch 4500/7783:
total_loss: 0.4007834196090698


Epoch 3/5:  61%|██████    | 4752/7783 [08:34<05:12,  9.70it/s]

Batch 4750/7783:
total_loss: 0.4172327518463135


Epoch 3/5:  64%|██████▍   | 5000/7783 [09:01<04:28, 10.35it/s]

Batch 5000/7783:
total_loss: 0.555371880531311


Epoch 3/5:  67%|██████▋   | 5252/7783 [09:28<04:21,  9.67it/s]

Batch 5250/7783:
total_loss: 0.5435057878494263


Epoch 3/5:  71%|███████   | 5500/7783 [09:56<04:07,  9.21it/s]

Batch 5500/7783:
total_loss: 0.5160278677940369


Epoch 3/5:  74%|███████▍  | 5752/7783 [10:23<03:15, 10.37it/s]

Batch 5750/7783:
total_loss: 0.6242573857307434


Epoch 3/5:  77%|███████▋  | 6000/7783 [10:49<03:10,  9.36it/s]

Batch 6000/7783:
total_loss: 0.3696134388446808


Epoch 3/5:  80%|████████  | 6250/7783 [11:16<02:30, 10.18it/s]

Batch 6250/7783:
total_loss: 0.35211122035980225


Epoch 3/5:  84%|████████▎ | 6502/7783 [11:43<02:03, 10.39it/s]

Batch 6500/7783:
total_loss: 0.4562491178512573


Epoch 3/5:  87%|████████▋ | 6750/7783 [12:09<01:44,  9.85it/s]

Batch 6750/7783:
total_loss: 0.3596145510673523


Epoch 3/5:  90%|████████▉ | 7002/7783 [12:36<01:14, 10.45it/s]

Batch 7000/7783:
total_loss: 0.41959348320961


Epoch 3/5:  93%|█████████▎| 7250/7783 [13:03<00:52, 10.12it/s]

Batch 7250/7783:
total_loss: 0.42669105529785156


Epoch 3/5:  96%|█████████▋| 7500/7783 [13:30<00:27, 10.20it/s]

Batch 7500/7783:
total_loss: 0.38575175404548645


Epoch 3/5: 100%|█████████▉| 7750/7783 [13:57<00:03, 10.07it/s]

Batch 7750/7783:
total_loss: 0.4041978716850281


Epoch 3/5: 100%|██████████| 7783/7783 [14:01<00:00,  9.25it/s]
Validating: 100%|██████████| 15/15 [00:02<00:00,  5.46it/s]


Epoch Summary:
Train Total Loss: 0.4515
Val Total Loss: 0.3829
Learning Rate: 0.000100
New best model saved with val loss: 0.3829
--------------------------------------------------



Epoch 4/5:   3%|▎         | 250/7783 [00:27<13:03,  9.62it/s]

Batch 250/7783:
total_loss: 0.5813769102096558


Epoch 4/5:   6%|▋         | 499/7783 [00:54<11:48, 10.28it/s]

Batch 500/7783:
total_loss: 0.4501926898956299


Epoch 4/5:  10%|▉         | 751/7783 [01:21<11:26, 10.25it/s]

Batch 750/7783:
total_loss: 0.37848398089408875


Epoch 4/5:  13%|█▎        | 999/7783 [01:48<10:56, 10.34it/s]

Batch 1000/7783:
total_loss: 0.5006594061851501


Epoch 4/5:  16%|█▌        | 1252/7783 [02:15<10:45, 10.11it/s]

Batch 1250/7783:
total_loss: 0.3449082672595978


Epoch 4/5:  19%|█▉        | 1500/7783 [02:42<10:27, 10.01it/s]

Batch 1500/7783:
total_loss: 0.3793310821056366


Epoch 4/5:  23%|██▎       | 1752/7783 [03:09<10:06,  9.94it/s]

Batch 1750/7783:
total_loss: 0.21813395619392395


Epoch 4/5:  26%|██▌       | 2000/7783 [03:36<09:18, 10.35it/s]

Batch 2000/7783:
total_loss: 0.4977359473705292


Epoch 4/5:  29%|██▉       | 2252/7783 [04:03<08:55, 10.32it/s]

Batch 2250/7783:
total_loss: 0.24689166247844696


Epoch 4/5:  32%|███▏      | 2500/7783 [04:30<08:43, 10.09it/s]

Batch 2500/7783:
total_loss: 0.3125508427619934


Epoch 4/5:  35%|███▌      | 2752/7783 [04:57<08:05, 10.37it/s]

Batch 2750/7783:
total_loss: 0.30993103981018066


Epoch 4/5:  39%|███▊      | 3000/7783 [05:24<07:39, 10.40it/s]

Batch 3000/7783:
total_loss: 0.2908690571784973


Epoch 4/5:  42%|████▏     | 3252/7783 [05:51<07:55,  9.54it/s]

Batch 3250/7783:
total_loss: 0.40808212757110596


Epoch 4/5:  45%|████▍     | 3500/7783 [06:18<07:03, 10.12it/s]

Batch 3500/7783:
total_loss: 0.37431609630584717


Epoch 4/5:  48%|████▊     | 3752/7783 [06:45<06:45,  9.93it/s]

Batch 3750/7783:
total_loss: 0.42086029052734375


Epoch 4/5:  51%|█████▏    | 4000/7783 [07:12<06:04, 10.39it/s]

Batch 4000/7783:
total_loss: 0.4888419210910797


Epoch 4/5:  55%|█████▍    | 4249/7783 [07:39<05:47, 10.18it/s]

Batch 4250/7783:
total_loss: 0.5705952048301697


Epoch 4/5:  58%|█████▊    | 4501/7783 [08:06<05:24, 10.11it/s]

Batch 4500/7783:
total_loss: 0.4607667624950409


Epoch 4/5:  61%|██████    | 4749/7783 [08:33<05:02, 10.04it/s]

Batch 4750/7783:
total_loss: 0.42896702885627747


Epoch 4/5:  64%|██████▍   | 5001/7783 [09:00<04:30, 10.27it/s]

Batch 5000/7783:
total_loss: 0.32409995794296265


Epoch 4/5:  67%|██████▋   | 5249/7783 [09:27<04:23,  9.61it/s]

Batch 5250/7783:
total_loss: 0.542652428150177


Epoch 4/5:  71%|███████   | 5502/7783 [09:54<03:42, 10.27it/s]

Batch 5500/7783:
total_loss: 0.43653494119644165


Epoch 4/5:  74%|███████▍  | 5750/7783 [10:21<03:25,  9.88it/s]

Batch 5750/7783:
total_loss: 0.441150963306427


Epoch 4/5:  77%|███████▋  | 6001/7783 [10:48<02:53, 10.29it/s]

Batch 6000/7783:
total_loss: 0.3121704161167145


Epoch 4/5:  80%|████████  | 6250/7783 [11:15<02:30, 10.17it/s]

Batch 6250/7783:
total_loss: 0.36377546191215515


Epoch 4/5:  84%|████████▎ | 6502/7783 [11:42<02:04, 10.25it/s]

Batch 6500/7783:
total_loss: 0.4202072024345398


Epoch 4/5:  87%|████████▋ | 6750/7783 [12:09<01:43, 10.00it/s]

Batch 6750/7783:
total_loss: 0.3390291631221771


Epoch 4/5:  90%|████████▉ | 7001/7783 [12:36<01:29,  8.73it/s]

Batch 7000/7783:
total_loss: 0.42732465267181396


Epoch 4/5:  93%|█████████▎| 7249/7783 [13:03<00:51, 10.36it/s]

Batch 7250/7783:
total_loss: 0.4019427001476288


Epoch 4/5:  96%|█████████▋| 7502/7783 [13:30<00:27, 10.25it/s]

Batch 7500/7783:
total_loss: 0.535003662109375


Epoch 4/5: 100%|█████████▉| 7750/7783 [13:56<00:03, 10.22it/s]

Batch 7750/7783:
total_loss: 0.4573497772216797


Epoch 4/5: 100%|██████████| 7783/7783 [14:00<00:00,  9.26it/s]
Validating: 100%|██████████| 15/15 [00:02<00:00,  5.54it/s]


Epoch Summary:
Train Total Loss: 0.4067
Val Total Loss: 0.4036
Learning Rate: 0.000100
--------------------------------------------------



Epoch 5/5:   3%|▎         | 251/7783 [00:27<12:12, 10.29it/s]

Batch 250/7783:
total_loss: 0.3088120222091675


Epoch 5/5:   6%|▋         | 499/7783 [00:53<11:38, 10.43it/s]

Batch 500/7783:
total_loss: 0.5017609000205994


Epoch 5/5:  10%|▉         | 751/7783 [01:20<11:39, 10.05it/s]

Batch 750/7783:
total_loss: 0.25263720750808716


Epoch 5/5:  13%|█▎        | 999/7783 [01:47<10:49, 10.45it/s]

Batch 1000/7783:
total_loss: 0.3914138674736023


Epoch 5/5:  16%|█▌        | 1249/7783 [02:13<10:38, 10.24it/s]

Batch 1250/7783:
total_loss: 0.47338658571243286


Epoch 5/5:  19%|█▉        | 1501/7783 [02:40<10:17, 10.17it/s]

Batch 1500/7783:
total_loss: 0.32836511731147766


Epoch 5/5:  22%|██▏       | 1751/7783 [03:08<10:13,  9.84it/s]

Batch 1750/7783:
total_loss: 0.3774169385433197


Epoch 5/5:  26%|██▌       | 2001/7783 [03:35<11:06,  8.67it/s]

Batch 2000/7783:
total_loss: 0.2512301802635193


Epoch 5/5:  29%|██▉       | 2250/7783 [04:01<09:23,  9.82it/s]

Batch 2250/7783:
total_loss: 0.2536848187446594


Epoch 5/5:  32%|███▏      | 2500/7783 [04:28<08:47, 10.01it/s]

Batch 2500/7783:
total_loss: 0.40198156237602234


Epoch 5/5:  35%|███▌      | 2752/7783 [04:55<08:07, 10.31it/s]

Batch 2750/7783:
total_loss: 0.387783944606781


Epoch 5/5:  39%|███▊      | 3002/7783 [05:22<08:05,  9.85it/s]

Batch 3000/7783:
total_loss: 0.3073616027832031


Epoch 5/5:  42%|████▏     | 3250/7783 [05:49<07:18, 10.33it/s]

Batch 3250/7783:
total_loss: 0.29058846831321716


Epoch 5/5:  45%|████▍     | 3502/7783 [06:16<07:06, 10.04it/s]

Batch 3500/7783:
total_loss: 0.40234777331352234


Epoch 5/5:  48%|████▊     | 3750/7783 [06:43<07:23,  9.08it/s]

Batch 3750/7783:
total_loss: 0.5115523934364319


Epoch 5/5:  51%|█████▏    | 4002/7783 [07:10<06:43,  9.38it/s]

Batch 4000/7783:
total_loss: 0.33954551815986633


Epoch 5/5:  55%|█████▍    | 4250/7783 [07:37<06:06,  9.65it/s]

Batch 4250/7783:
total_loss: 0.26551955938339233


Epoch 5/5:  58%|█████▊    | 4501/7783 [08:04<05:24, 10.10it/s]

Batch 4500/7783:
total_loss: 0.2739455997943878


Epoch 5/5:  61%|██████    | 4752/7783 [08:31<04:51, 10.40it/s]

Batch 4750/7783:
total_loss: 0.33426621556282043


Epoch 5/5:  64%|██████▍   | 5000/7783 [08:58<04:35, 10.12it/s]

Batch 5000/7783:
total_loss: 0.498263418674469


Epoch 5/5:  67%|██████▋   | 5252/7783 [09:25<04:13,  9.97it/s]

Batch 5250/7783:
total_loss: 0.21309711039066315


Epoch 5/5:  71%|███████   | 5500/7783 [09:52<03:38, 10.46it/s]

Batch 5500/7783:
total_loss: 0.42148464918136597


Epoch 5/5:  74%|███████▍  | 5752/7783 [10:19<03:25,  9.89it/s]

Batch 5750/7783:
total_loss: 0.5086964964866638


Epoch 5/5:  77%|███████▋  | 6000/7783 [10:46<02:52, 10.31it/s]

Batch 6000/7783:
total_loss: 0.43739643692970276


Epoch 5/5:  80%|████████  | 6252/7783 [11:13<02:31, 10.09it/s]

Batch 6250/7783:
total_loss: 0.25562742352485657


Epoch 5/5:  84%|████████▎ | 6500/7783 [11:40<02:06, 10.16it/s]

Batch 6500/7783:
total_loss: 0.37235236167907715


Epoch 5/5:  87%|████████▋ | 6749/7783 [12:06<01:49,  9.41it/s]

Batch 6750/7783:
total_loss: 0.29373326897621155


Epoch 5/5:  90%|████████▉ | 7000/7783 [12:33<01:18,  9.96it/s]

Batch 7000/7783:
total_loss: 0.23771759867668152


Epoch 5/5:  93%|█████████▎| 7250/7783 [13:00<00:52, 10.16it/s]

Batch 7250/7783:
total_loss: 0.42935481667518616


Epoch 5/5:  96%|█████████▋| 7502/7783 [13:27<00:28,  9.95it/s]

Batch 7500/7783:
total_loss: 0.1438828557729721


Epoch 5/5: 100%|█████████▉| 7751/7783 [13:54<00:03,  8.75it/s]

Batch 7750/7783:
total_loss: 0.2663438618183136


Epoch 5/5: 100%|██████████| 7783/7783 [13:58<00:00,  9.28it/s]
Validating: 100%|██████████| 15/15 [00:02<00:00,  5.47it/s]


Epoch Summary:
Train Total Loss: 0.3688
Val Total Loss: 0.3259
Learning Rate: 0.000100
New best model saved with val loss: 0.3259
--------------------------------------------------



