In [13]:
#Starting code for PatRec project
#   Omada 2 -- Grokfast experiment

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!ls /content/drive/MyDrive/PatRec_Project_Shared_Folder/

 grokfast.py				 Grokking_mnist_v1.ipynb   __pycache__
 Groking_algo_v1.ipynb			 Grokking_qm9_v1.ipynb	   requirements.txt
'Grokking and how to avoid it.gslides'	 Grokking_qm9_v2.ipynb	   results


In [4]:
import sys
sys.path.append('/content/drive/MyDrive/PatRec_Project_Shared_Folder')

In [8]:
!pip install -r /content/drive/MyDrive/PatRec_Project_Shared_Folder/requirements.txt

Collecting torch_geometric (from -r /content/drive/MyDrive/PatRec_Project_Shared_Folder/requirements.txt (line 4))
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, global_add_pool

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

import tqdm
import numpy as np
import matplotlib.pyplot as plt

from argparse import ArgumentParser

In [10]:
from grokfast import gradfilter_ma, gradfilter_ema

In [12]:
 # THe Graph Neural Network (GNN) model that will be used in the qm9 dataset
class ExampleNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super().__init__()
        conv1_net = nn.Sequential(
            nn.Linear(num_edge_features, 32),
            nn.ReLU(),
            nn.Linear(32, num_node_features * 32))
        conv2_net = nn.Sequential(
            nn.Linear(num_edge_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32 * 16))
        self.conv1 = NNConv(num_node_features, 32, conv1_net)
        self.conv2 = NNConv(32, 16, conv2_net)
        self.fc_1 = nn.Linear(16, 32)
        self.out = nn.Linear(32, 1)

    def forward(self, data):
        batch, x, edge_index, edge_attr = (
            data.batch, data.x, data.edge_index, data.edge_attr)
        # First graph conv layer
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        # Second graph conv layer
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = global_add_pool(x,batch)
        x = F.relu(self.fc_1(x))
        output = self.out(x)
        return output



In [13]:
#  L2 norm
def L2(model):
    L2_ = 0.
    for p in model.parameters():
        L2_ += torch.sum(p**2)
    return L2_

# for rescaling the parameters
def rescale(model, alpha):
    for p in model.parameters():
        p.data = alpha * p.data



In [14]:
import os

# Specify the path to save in Google Drive
results_dir = "/content/drive/MyDrive/PatRec_Project_Shared_Folder/results"
os.makedirs(results_dir, exist_ok=True)

In [15]:
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    alpha = args.init_scale

    #size = 1000
    epochs = int(100 * 50000 / args.size)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the QM9 small molecule dataset
    dset = QM9('.')
    dset = dset[:args.size]

    train_set, test_set = random_split(dset, [int(args.size / 2), int(args.size / 2)])


    trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    testloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True)


    # initialize a network
    qm9_node_feats, qm9_edge_feats = 11, 4
    net = ExampleNet(qm9_node_feats, qm9_edge_feats)

    # initialize an optimizer with some reasonable parameters
    optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    target_idx = 1 # index position of the polarizability label
    net.to(device)

    rescale(net, alpha)
    L2_ = L2(net)

    train_best = 1e10
    test_best = 1e10

    train_losses, test_losses, train_avg_losses, test_avg_losses = [], [], [], []
    step = 0
    grads = None

    for total_epochs in tqdm.trange(epochs):
        epoch_loss = 0
        total_graphs_train = 0

        for batch in trainloader:
            net.train()
            batch.to(device)
            optimizer.zero_grad()
            output = net(batch)
            loss = F.mse_loss(output, batch.y[:, target_idx].unsqueeze(1))
            epoch_loss += loss.item() * batch.num_graphs
            total_graphs_train += batch.num_graphs

            loss.backward()

            #######

            trigger = False

            if args.filter == "none":
                pass
            elif args.filter == "ma":
                grads = gradfilter_ma(net, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger)
            elif args.filter == "ema":
                grads = gradfilter_ema(net, grads=grads, alpha=args.alpha, lamb=args.lamb)
            else:
                raise ValueError(f"Invalid gradient filter type `{args.filter}`")

            #######

            optimizer.step()

            train_losses.append(loss.item())

            step += 1

        train_avg_loss = epoch_loss / total_graphs_train
        if train_avg_loss < train_best:
            train_best = train_avg_loss
        train_avg_losses.append(train_avg_loss)

        #######

        test_loss = 0
        total_graphs_test = 0

        net.eval()

        for batch in testloader:
            batch.to(device)
            output = net(batch)
            loss = F.mse_loss(output, batch.y[:, target_idx].unsqueeze(1))
            test_loss += loss.item() * batch.num_graphs
            total_graphs_test += batch.num_graphs
            test_losses.append(loss.item())

        test_avg_loss = test_loss / total_graphs_test
        if test_avg_loss < test_best:
            test_best = test_avg_loss
        test_avg_losses.append(test_avg_loss)

        #######
        if (total_epochs + 1) % 100 == 0 or total_epochs == epochs - 1:
            tqdm.tqdm.write(f"Epochs: {total_epochs} | epoch avg. loss: {train_avg_loss:.3f} | "
                            f"test avg. loss: {test_avg_loss:.3f}")

        if (total_epochs + 1) % 100 == 0 or total_epochs == epochs - 1:

            plt.plot(np.arange(len(train_avg_losses)), train_avg_losses, label="train")
            plt.plot(np.arange(len(train_avg_losses)), test_avg_losses, label="val")
            plt.legend()
            plt.title("QM9 Molecule Isotropic Polarizability Prediction")
            plt.xlabel("Optimization Steps")
            plt.ylabel("MSE Loss")
            plt.yscale("log", base=10)
            plt.xscale("log", base=10)
            plt.ylim(1e-4, 100)
            plt.grid()
            plt.savefig(f"{results_dir}/qm9_loss_{args.label}.png", dpi=150)
            plt.close()

            torch.save({
                'its': np.arange(len(train_losses)),
                'its_avg': np.arange(len(train_avg_losses)),
                'train_acc': None,
                'train_loss': train_losses,
                'train_avg_loss': train_avg_losses,
                'val_acc': None,
                'val_loss': test_losses,
                'val_avg_loss': test_avg_losses,
                'train_best': train_best,
                'val_best': test_best,
            }, f"{results_dir}/qm9_{args.label}.pt")

    #######

    fig, ax = plt.subplots(1, 1, figsize=(4.2, 4.2))

    ax.plot((np.arange(len(test_losses))+1)[::20], np.mean(np.array(test_losses).reshape(-1, 20), axis=1), color='#ff7f0e')
    ax.plot((np.arange(len(train_losses))+1)[::20], np.mean(np.array(train_losses).reshape(-1, 20), axis=1), color='#1f77b4')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_ylim(1e-2, 1000)

    ax.set_ylabel("MSE", fontsize=15)
    ax.text(1, 0.003, r"$\alpha=3$", fontsize=15)
    ax.set_ylim(1e-3, 1e2)
    ax.grid()

    fig.savefig(f"{results_dir}/qm9_grok_{args.label}.pdf", bbox_inches="tight")
    plt.close()



In [16]:
# run for NONE

In [17]:
if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--label", default="")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--size", type=int, default=100)
    parser.add_argument("--init_scale", type=float, default=3.0) # init_scale 1.0 no grokking / init_scale 3.0 grokking

    # Grokfast
    parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none")
    parser.add_argument("--alpha", type=float, default=0.99)
    parser.add_argument("--window_size", type=int, default=100)
    parser.add_argument("--lamb", type=float, default=5.0)
    args, unknown = parser.parse_known_args()

    filter_str = ('_' if args.label != '' else '') + args.filter
    window_size_str = f'_w{args.window_size}'
    alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
    lamb_str = f'_l{args.lamb:.2f}'.replace('.', '')

    model_suffix = f'size{args.size}_alpha{args.init_scale:.4f}'

    if args.filter == 'none':
        filter_suffix = ''
    elif args.filter == 'ma':
        filter_suffix = window_size_str + lamb_str
    elif args.filter == 'ema':
        filter_suffix = alpha_str + lamb_str
    else:
        raise ValueError(f"Unrecognized filter type {args.filter}")

    optim_suffix = ''
    if args.weight_decay != 0:
        optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
    if args.lr != 1e-3:
        optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

    args.label = args.label + model_suffix + filter_str + filter_suffix + optim_suffix
    print(f'Experiment results saved under name: {args.label}')

    main(args)

Experiment results saved under name: size100_alpha3.0000none


Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting ./raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!
  0%|          | 99/50000 [00:03<24:49, 33.49it/s]

Epochs: 99 | epoch avg. loss: 30.857 | test avg. loss: 122.075


  0%|          | 198/50000 [00:08<17:58, 46.16it/s]

Epochs: 199 | epoch avg. loss: 18.473 | test avg. loss: 107.142


  1%|          | 298/50000 [00:10<17:34, 47.13it/s]

Epochs: 299 | epoch avg. loss: 14.770 | test avg. loss: 93.534


  1%|          | 399/50000 [00:13<18:30, 44.67it/s]

Epochs: 399 | epoch avg. loss: 12.232 | test avg. loss: 89.696


  1%|          | 498/50000 [00:16<25:26, 32.42it/s]

Epochs: 499 | epoch avg. loss: 9.990 | test avg. loss: 83.855


  1%|          | 599/50000 [00:20<17:33, 46.89it/s]

Epochs: 599 | epoch avg. loss: 8.390 | test avg. loss: 80.845


  1%|▏         | 699/50000 [00:23<18:38, 44.06it/s]

Epochs: 699 | epoch avg. loss: 7.133 | test avg. loss: 78.265


  2%|▏         | 799/50000 [00:25<17:12, 47.65it/s]

Epochs: 799 | epoch avg. loss: 5.604 | test avg. loss: 72.028


  2%|▏         | 899/50000 [00:28<17:46, 46.02it/s]

Epochs: 899 | epoch avg. loss: 4.809 | test avg. loss: 69.828


  2%|▏         | 996/50000 [00:32<25:53, 31.53it/s]

Epochs: 999 | epoch avg. loss: 4.489 | test avg. loss: 70.723


  2%|▏         | 1096/50000 [00:35<18:06, 44.99it/s]

Epochs: 1099 | epoch avg. loss: 3.420 | test avg. loss: 65.435


  2%|▏         | 1196/50000 [00:38<17:10, 47.36it/s]

Epochs: 1199 | epoch avg. loss: 4.543 | test avg. loss: 64.603


  3%|▎         | 1296/50000 [00:41<17:38, 46.00it/s]

Epochs: 1299 | epoch avg. loss: 2.765 | test avg. loss: 63.068


  3%|▎         | 1399/50000 [00:44<24:50, 32.61it/s]

Epochs: 1399 | epoch avg. loss: 1.987 | test avg. loss: 63.855


  3%|▎         | 1497/50000 [00:48<17:32, 46.09it/s]

Epochs: 1499 | epoch avg. loss: 1.551 | test avg. loss: 62.690


  3%|▎         | 1597/50000 [00:50<17:26, 46.25it/s]

Epochs: 1599 | epoch avg. loss: 2.201 | test avg. loss: 63.289


  3%|▎         | 1697/50000 [00:53<17:18, 46.51it/s]

Epochs: 1699 | epoch avg. loss: 1.980 | test avg. loss: 62.508


  4%|▎         | 1797/50000 [00:56<18:06, 44.38it/s]

Epochs: 1799 | epoch avg. loss: 3.835 | test avg. loss: 63.926


  4%|▍         | 1898/50000 [01:00<25:47, 31.08it/s]

Epochs: 1899 | epoch avg. loss: 2.085 | test avg. loss: 63.954


  4%|▍         | 1999/50000 [01:03<16:35, 48.23it/s]

Epochs: 1999 | epoch avg. loss: 0.896 | test avg. loss: 64.037


  4%|▍         | 2099/50000 [01:06<17:19, 46.06it/s]

Epochs: 2099 | epoch avg. loss: 1.373 | test avg. loss: 66.061


  4%|▍         | 2195/50000 [01:09<17:25, 45.75it/s]

Epochs: 2199 | epoch avg. loss: 1.428 | test avg. loss: 65.718


  5%|▍         | 2297/50000 [01:12<24:38, 32.27it/s]

Epochs: 2299 | epoch avg. loss: 1.886 | test avg. loss: 67.648


  5%|▍         | 2395/50000 [01:15<16:41, 47.54it/s]

Epochs: 2399 | epoch avg. loss: 1.827 | test avg. loss: 66.366


  5%|▍         | 2498/50000 [01:18<16:49, 47.06it/s]

Epochs: 2499 | epoch avg. loss: 6.655 | test avg. loss: 68.067


  5%|▌         | 2596/50000 [01:21<16:10, 48.86it/s]

Epochs: 2599 | epoch avg. loss: 1.748 | test avg. loss: 69.693


  5%|▌         | 2699/50000 [01:23<15:58, 49.34it/s]

Epochs: 2699 | epoch avg. loss: 0.805 | test avg. loss: 70.012


  6%|▌         | 2798/50000 [01:27<23:31, 33.45it/s]

Epochs: 2799 | epoch avg. loss: 0.772 | test avg. loss: 71.264


  6%|▌         | 2895/50000 [01:30<17:10, 45.71it/s]

Epochs: 2899 | epoch avg. loss: 0.767 | test avg. loss: 71.947


  6%|▌         | 2995/50000 [01:33<16:08, 48.53it/s]

Epochs: 2999 | epoch avg. loss: 1.391 | test avg. loss: 83.534


  6%|▌         | 3095/50000 [01:36<16:21, 47.78it/s]

Epochs: 3099 | epoch avg. loss: 2.733 | test avg. loss: 85.060


  6%|▋         | 3197/50000 [01:39<21:38, 36.03it/s]

Epochs: 3199 | epoch avg. loss: 0.718 | test avg. loss: 83.446


  7%|▋         | 3299/50000 [01:43<17:28, 44.54it/s]

Epochs: 3299 | epoch avg. loss: 0.266 | test avg. loss: 82.863


  7%|▋         | 3396/50000 [01:45<17:12, 45.14it/s]

Epochs: 3399 | epoch avg. loss: 0.367 | test avg. loss: 83.451


  7%|▋         | 3494/50000 [01:48<16:24, 47.23it/s]

Epochs: 3499 | epoch avg. loss: 0.665 | test avg. loss: 84.992


  7%|▋         | 3598/50000 [01:51<16:34, 46.66it/s]

Epochs: 3599 | epoch avg. loss: 16.837 | test avg. loss: 92.737


  7%|▋         | 3696/50000 [01:54<21:52, 35.28it/s]

Epochs: 3699 | epoch avg. loss: 0.618 | test avg. loss: 91.121


  8%|▊         | 3795/50000 [01:57<16:12, 47.51it/s]

Epochs: 3799 | epoch avg. loss: 0.144 | test avg. loss: 93.200


  8%|▊         | 3897/50000 [02:00<16:38, 46.19it/s]

Epochs: 3899 | epoch avg. loss: 0.547 | test avg. loss: 93.072


  8%|▊         | 3999/50000 [02:03<16:05, 47.63it/s]

Epochs: 3999 | epoch avg. loss: 0.338 | test avg. loss: 96.051


  8%|▊         | 4096/50000 [02:06<15:53, 48.15it/s]

Epochs: 4099 | epoch avg. loss: 0.100 | test avg. loss: 97.844


  8%|▊         | 4197/50000 [02:09<24:49, 30.75it/s]

Epochs: 4199 | epoch avg. loss: 0.523 | test avg. loss: 99.047


  9%|▊         | 4296/50000 [02:13<16:31, 46.10it/s]

Epochs: 4299 | epoch avg. loss: 0.114 | test avg. loss: 106.679


  9%|▉         | 4396/50000 [02:15<15:46, 48.17it/s]

Epochs: 4399 | epoch avg. loss: 0.095 | test avg. loss: 107.445


  9%|▉         | 4499/50000 [02:18<15:36, 48.57it/s]

Epochs: 4499 | epoch avg. loss: 1.868 | test avg. loss: 111.772


  9%|▉         | 4598/50000 [02:22<22:11, 34.09it/s]

Epochs: 4599 | epoch avg. loss: 0.137 | test avg. loss: 112.640


  9%|▉         | 4699/50000 [02:25<15:25, 48.93it/s]

Epochs: 4699 | epoch avg. loss: 30.878 | test avg. loss: 136.055


 10%|▉         | 4798/50000 [02:27<15:58, 47.17it/s]

Epochs: 4799 | epoch avg. loss: 0.218 | test avg. loss: 135.376


 10%|▉         | 4896/50000 [02:30<15:47, 47.58it/s]

Epochs: 4899 | epoch avg. loss: 0.498 | test avg. loss: 139.672


 10%|▉         | 4995/50000 [02:33<15:52, 47.25it/s]

Epochs: 4999 | epoch avg. loss: 0.088 | test avg. loss: 140.886


 10%|█         | 5096/50000 [02:37<22:45, 32.88it/s]

Epochs: 5099 | epoch avg. loss: 0.136 | test avg. loss: 145.576


 10%|█         | 5199/50000 [02:40<16:13, 46.02it/s]

Epochs: 5199 | epoch avg. loss: 0.266 | test avg. loss: 146.271


 11%|█         | 5297/50000 [02:42<15:26, 48.27it/s]

Epochs: 5299 | epoch avg. loss: 0.122 | test avg. loss: 153.732


 11%|█         | 5397/50000 [02:45<15:58, 46.52it/s]

Epochs: 5399 | epoch avg. loss: 0.525 | test avg. loss: 171.327


 11%|█         | 5497/50000 [02:48<21:28, 34.55it/s]

Epochs: 5499 | epoch avg. loss: 1.024 | test avg. loss: 178.341


 11%|█         | 5597/50000 [02:52<16:30, 44.83it/s]

Epochs: 5599 | epoch avg. loss: 0.142 | test avg. loss: 179.691


 11%|█▏        | 5698/50000 [02:55<15:29, 47.68it/s]

Epochs: 5699 | epoch avg. loss: 0.261 | test avg. loss: 180.277


 12%|█▏        | 5799/50000 [02:58<16:10, 45.56it/s]

Epochs: 5799 | epoch avg. loss: 0.609 | test avg. loss: 175.979


 12%|█▏        | 5899/50000 [03:01<15:30, 47.37it/s]

Epochs: 5899 | epoch avg. loss: 1.179 | test avg. loss: 179.797


 12%|█▏        | 5996/50000 [03:05<24:03, 30.49it/s]

Epochs: 5999 | epoch avg. loss: 0.205 | test avg. loss: 184.709


 12%|█▏        | 6097/50000 [03:08<15:15, 47.97it/s]

Epochs: 6099 | epoch avg. loss: 1.362 | test avg. loss: 191.662


 12%|█▏        | 6197/50000 [03:10<15:03, 48.50it/s]

Epochs: 6199 | epoch avg. loss: 0.087 | test avg. loss: 197.352


 13%|█▎        | 6295/50000 [03:13<15:18, 47.60it/s]

Epochs: 6299 | epoch avg. loss: 0.110 | test avg. loss: 194.804


 13%|█▎        | 6397/50000 [03:16<21:15, 34.19it/s]

Epochs: 6399 | epoch avg. loss: 0.097 | test avg. loss: 34.515


 13%|█▎        | 6499/50000 [03:20<15:49, 45.81it/s]

Epochs: 6499 | epoch avg. loss: 0.056 | test avg. loss: 37.217


 13%|█▎        | 6598/50000 [03:23<14:49, 48.79it/s]

Epochs: 6599 | epoch avg. loss: 0.131 | test avg. loss: 36.382


 13%|█▎        | 6695/50000 [03:25<15:05, 47.81it/s]

Epochs: 6699 | epoch avg. loss: 0.223 | test avg. loss: 31.989


 14%|█▎        | 6796/50000 [03:28<15:10, 47.45it/s]

Epochs: 6799 | epoch avg. loss: 0.170 | test avg. loss: 32.703


 14%|█▍        | 6897/50000 [03:32<22:10, 32.38it/s]

Epochs: 6899 | epoch avg. loss: 0.542 | test avg. loss: 32.629


 14%|█▍        | 6995/50000 [03:35<15:02, 47.64it/s]

Epochs: 6999 | epoch avg. loss: 0.101 | test avg. loss: 33.675


 14%|█▍        | 7099/50000 [03:38<14:47, 48.34it/s]

Epochs: 7099 | epoch avg. loss: 0.407 | test avg. loss: 32.296


 14%|█▍        | 7197/50000 [03:40<14:28, 49.27it/s]

Epochs: 7199 | epoch avg. loss: 0.199 | test avg. loss: 33.821


 15%|█▍        | 7296/50000 [03:43<19:13, 37.02it/s]

Epochs: 7299 | epoch avg. loss: 0.410 | test avg. loss: 33.008


 15%|█▍        | 7397/50000 [03:48<14:48, 47.96it/s]

Epochs: 7399 | epoch avg. loss: 0.250 | test avg. loss: 35.737


 15%|█▍        | 7497/50000 [03:51<15:34, 45.49it/s]

Epochs: 7499 | epoch avg. loss: 7.231 | test avg. loss: 18.696


 15%|█▌        | 7595/50000 [03:53<14:48, 47.73it/s]

Epochs: 7599 | epoch avg. loss: 0.318 | test avg. loss: 16.240


 15%|█▌        | 7698/50000 [03:56<15:20, 45.94it/s]

Epochs: 7699 | epoch avg. loss: 0.432 | test avg. loss: 15.473


 16%|█▌        | 7799/50000 [04:00<22:21, 31.46it/s]

Epochs: 7799 | epoch avg. loss: 0.204 | test avg. loss: 15.316


 16%|█▌        | 7895/50000 [04:03<14:28, 48.51it/s]

Epochs: 7899 | epoch avg. loss: 0.199 | test avg. loss: 15.303


 16%|█▌        | 7999/50000 [04:06<14:24, 48.56it/s]

Epochs: 7999 | epoch avg. loss: 0.247 | test avg. loss: 16.015


 16%|█▌        | 8095/50000 [04:08<14:21, 48.65it/s]

Epochs: 8099 | epoch avg. loss: 0.117 | test avg. loss: 15.539


 16%|█▋        | 8198/50000 [04:12<21:02, 33.12it/s]

Epochs: 8199 | epoch avg. loss: 0.477 | test avg. loss: 15.822


 17%|█▋        | 8298/50000 [04:15<14:39, 47.40it/s]

Epochs: 8299 | epoch avg. loss: 0.223 | test avg. loss: 16.085


 17%|█▋        | 8399/50000 [04:18<15:12, 45.61it/s]

Epochs: 8399 | epoch avg. loss: 0.104 | test avg. loss: 15.951


 17%|█▋        | 8495/50000 [04:21<14:27, 47.87it/s]

Epochs: 8499 | epoch avg. loss: 0.147 | test avg. loss: 15.949


 17%|█▋        | 8596/50000 [04:23<14:20, 48.12it/s]

Epochs: 8599 | epoch avg. loss: 0.121 | test avg. loss: 15.896


 17%|█▋        | 8697/50000 [04:27<21:50, 31.51it/s]

Epochs: 8699 | epoch avg. loss: 0.335 | test avg. loss: 15.714


 18%|█▊        | 8798/50000 [04:30<14:31, 47.28it/s]

Epochs: 8799 | epoch avg. loss: 0.522 | test avg. loss: 15.764


 18%|█▊        | 8899/50000 [04:33<14:10, 48.31it/s]

Epochs: 8899 | epoch avg. loss: 0.228 | test avg. loss: 15.264


 18%|█▊        | 8996/50000 [04:36<14:29, 47.17it/s]

Epochs: 8999 | epoch avg. loss: 0.292 | test avg. loss: 15.721


 18%|█▊        | 9099/50000 [04:39<19:12, 35.49it/s]

Epochs: 9099 | epoch avg. loss: 0.447 | test avg. loss: 16.567


 18%|█▊        | 9199/50000 [04:43<14:46, 46.00it/s]

Epochs: 9199 | epoch avg. loss: 0.094 | test avg. loss: 15.239


 19%|█▊        | 9297/50000 [04:45<14:41, 46.15it/s]

Epochs: 9299 | epoch avg. loss: 0.097 | test avg. loss: 15.484


 19%|█▉        | 9399/50000 [04:49<14:44, 45.89it/s]

Epochs: 9399 | epoch avg. loss: 0.379 | test avg. loss: 15.513


 19%|█▉        | 9495/50000 [04:52<14:19, 47.12it/s]

Epochs: 9499 | epoch avg. loss: 0.359 | test avg. loss: 15.979


 19%|█▉        | 9596/50000 [04:56<21:06, 31.89it/s]

Epochs: 9599 | epoch avg. loss: 0.096 | test avg. loss: 15.759


 19%|█▉        | 9697/50000 [04:59<14:12, 47.30it/s]

Epochs: 9699 | epoch avg. loss: 0.428 | test avg. loss: 15.941


 20%|█▉        | 9797/50000 [05:01<14:14, 47.05it/s]

Epochs: 9799 | epoch avg. loss: 0.061 | test avg. loss: 15.238


 20%|█▉        | 9898/50000 [05:04<13:56, 47.93it/s]

Epochs: 9899 | epoch avg. loss: 0.076 | test avg. loss: 15.332


 20%|█▉        | 9998/50000 [05:07<19:32, 34.10it/s]

Epochs: 9999 | epoch avg. loss: 0.322 | test avg. loss: 15.756


 20%|██        | 10098/50000 [05:11<14:23, 46.20it/s]

Epochs: 10099 | epoch avg. loss: 0.195 | test avg. loss: 15.345


 20%|██        | 10199/50000 [05:14<13:40, 48.53it/s]

Epochs: 10199 | epoch avg. loss: 0.078 | test avg. loss: 15.434


 21%|██        | 10296/50000 [05:17<13:54, 47.58it/s]

Epochs: 10299 | epoch avg. loss: 0.076 | test avg. loss: 15.205


 21%|██        | 10394/50000 [05:19<13:48, 47.80it/s]

Epochs: 10399 | epoch avg. loss: 0.510 | test avg. loss: 16.447


 21%|██        | 10496/50000 [05:23<21:43, 30.30it/s]

Epochs: 10499 | epoch avg. loss: 0.805 | test avg. loss: 15.631


 21%|██        | 10598/50000 [05:26<13:44, 47.78it/s]

Epochs: 10599 | epoch avg. loss: 0.365 | test avg. loss: 15.450


 21%|██▏       | 10696/50000 [05:29<13:45, 47.62it/s]

Epochs: 10699 | epoch avg. loss: 1.010 | test avg. loss: 15.739


 22%|██▏       | 10799/50000 [05:32<13:37, 47.94it/s]

Epochs: 10799 | epoch avg. loss: 0.054 | test avg. loss: 15.281


 22%|██▏       | 10896/50000 [05:35<19:27, 33.48it/s]

Epochs: 10899 | epoch avg. loss: 0.182 | test avg. loss: 15.393


 22%|██▏       | 10997/50000 [05:38<13:58, 46.51it/s]

Epochs: 10999 | epoch avg. loss: 0.114 | test avg. loss: 15.571


 22%|██▏       | 11098/50000 [05:41<13:31, 47.97it/s]

Epochs: 11099 | epoch avg. loss: 0.521 | test avg. loss: 15.494


 22%|██▏       | 11198/50000 [05:44<13:46, 46.92it/s]

Epochs: 11199 | epoch avg. loss: 0.083 | test avg. loss: 15.376


 23%|██▎       | 11299/50000 [05:47<14:41, 43.90it/s]

Epochs: 11299 | epoch avg. loss: 0.106 | test avg. loss: 15.377


 23%|██▎       | 11396/50000 [05:51<20:01, 32.14it/s]

Epochs: 11399 | epoch avg. loss: 0.126 | test avg. loss: 15.351


 23%|██▎       | 11497/50000 [05:54<13:19, 48.19it/s]

Epochs: 11499 | epoch avg. loss: 0.069 | test avg. loss: 15.491


 23%|██▎       | 11597/50000 [05:56<13:32, 47.28it/s]

Epochs: 11599 | epoch avg. loss: 0.392 | test avg. loss: 15.174


 23%|██▎       | 11697/50000 [06:00<13:54, 45.93it/s]

Epochs: 11699 | epoch avg. loss: 0.389 | test avg. loss: 15.192


 24%|██▎       | 11798/50000 [06:04<20:06, 31.66it/s]

Epochs: 11799 | epoch avg. loss: 0.044 | test avg. loss: 15.250


 24%|██▍       | 11897/50000 [06:07<13:37, 46.61it/s]

Epochs: 11899 | epoch avg. loss: 0.084 | test avg. loss: 15.390


 24%|██▍       | 11997/50000 [06:10<14:34, 43.47it/s]

Epochs: 11999 | epoch avg. loss: 0.070 | test avg. loss: 7.103


 24%|██▍       | 12097/50000 [06:13<13:12, 47.85it/s]

Epochs: 12099 | epoch avg. loss: 0.087 | test avg. loss: 7.258


 24%|██▍       | 12199/50000 [06:16<19:06, 32.98it/s]

Epochs: 12199 | epoch avg. loss: 0.093 | test avg. loss: 7.367


 25%|██▍       | 12298/50000 [06:20<13:00, 48.31it/s]

Epochs: 12299 | epoch avg. loss: 0.061 | test avg. loss: 7.243


 25%|██▍       | 12398/50000 [06:22<13:06, 47.78it/s]

Epochs: 12399 | epoch avg. loss: 0.146 | test avg. loss: 7.216


 25%|██▍       | 12496/50000 [06:25<13:23, 46.69it/s]

Epochs: 12499 | epoch avg. loss: 0.054 | test avg. loss: 7.203


 25%|██▌       | 12598/50000 [06:28<13:00, 47.94it/s]

Epochs: 12599 | epoch avg. loss: 0.556 | test avg. loss: 7.276


 25%|██▌       | 12699/50000 [06:32<19:16, 32.26it/s]

Epochs: 12699 | epoch avg. loss: 0.061 | test avg. loss: 7.342


 26%|██▌       | 12799/50000 [06:35<13:15, 46.76it/s]

Epochs: 12799 | epoch avg. loss: 0.103 | test avg. loss: 7.555


 26%|██▌       | 12896/50000 [06:38<12:57, 47.70it/s]

Epochs: 12899 | epoch avg. loss: 0.050 | test avg. loss: 7.455


 26%|██▌       | 12996/50000 [06:40<12:54, 47.79it/s]

Epochs: 12999 | epoch avg. loss: 0.070 | test avg. loss: 7.521


 26%|██▌       | 13099/50000 [06:44<18:12, 33.76it/s]

Epochs: 13099 | epoch avg. loss: 0.037 | test avg. loss: 7.771


 26%|██▋       | 13196/50000 [06:47<13:02, 47.02it/s]

Epochs: 13199 | epoch avg. loss: 0.111 | test avg. loss: 7.745


 27%|██▋       | 13298/50000 [06:50<13:09, 46.51it/s]

Epochs: 13299 | epoch avg. loss: 0.088 | test avg. loss: 8.025


 27%|██▋       | 13395/50000 [06:53<13:04, 46.68it/s]

Epochs: 13399 | epoch avg. loss: 0.085 | test avg. loss: 8.045


 27%|██▋       | 13499/50000 [06:56<12:15, 49.61it/s]

Epochs: 13499 | epoch avg. loss: 0.075 | test avg. loss: 8.082


 27%|██▋       | 13596/50000 [07:00<19:52, 30.52it/s]

Epochs: 13599 | epoch avg. loss: 0.027 | test avg. loss: 8.266


 27%|██▋       | 13698/50000 [07:02<13:13, 45.75it/s]

Epochs: 13699 | epoch avg. loss: 0.060 | test avg. loss: 8.600


 28%|██▊       | 13798/50000 [07:05<12:41, 47.56it/s]

Epochs: 13799 | epoch avg. loss: 0.261 | test avg. loss: 9.485


 28%|██▊       | 13898/50000 [07:08<13:42, 43.89it/s]

Epochs: 13899 | epoch avg. loss: 1.308 | test avg. loss: 9.445


 28%|██▊       | 13996/50000 [07:12<18:33, 32.34it/s]

Epochs: 13999 | epoch avg. loss: 0.167 | test avg. loss: 9.265


 28%|██▊       | 14097/50000 [07:15<12:48, 46.69it/s]

Epochs: 14099 | epoch avg. loss: 0.249 | test avg. loss: 9.129


 28%|██▊       | 14197/50000 [07:18<13:10, 45.31it/s]

Epochs: 14199 | epoch avg. loss: 0.088 | test avg. loss: 9.522


 29%|██▊       | 14297/50000 [07:21<12:24, 47.93it/s]

Epochs: 14299 | epoch avg. loss: 0.050 | test avg. loss: 9.097


 29%|██▉       | 14397/50000 [07:24<13:23, 44.33it/s]

Epochs: 14399 | epoch avg. loss: 0.031 | test avg. loss: 9.073


 29%|██▉       | 14495/50000 [07:28<16:42, 35.42it/s]

Epochs: 14499 | epoch avg. loss: 0.102 | test avg. loss: 9.001


 29%|██▉       | 14595/50000 [07:32<13:35, 43.44it/s]

Epochs: 14599 | epoch avg. loss: 0.238 | test avg. loss: 9.705


 29%|██▉       | 14695/50000 [07:35<12:39, 46.47it/s]

Epochs: 14699 | epoch avg. loss: 0.170 | test avg. loss: 9.350


 30%|██▉       | 14797/50000 [07:38<12:17, 47.74it/s]

Epochs: 14799 | epoch avg. loss: 0.163 | test avg. loss: 9.110


 30%|██▉       | 14895/50000 [07:42<17:18, 33.79it/s]

Epochs: 14899 | epoch avg. loss: 0.064 | test avg. loss: 8.851


 30%|██▉       | 14995/50000 [07:44<12:15, 47.59it/s]

Epochs: 14999 | epoch avg. loss: 0.072 | test avg. loss: 8.938


 30%|███       | 15097/50000 [07:47<12:15, 47.45it/s]

Epochs: 15099 | epoch avg. loss: 0.112 | test avg. loss: 8.770


 30%|███       | 15199/50000 [07:50<12:16, 47.28it/s]

Epochs: 15199 | epoch avg. loss: 0.039 | test avg. loss: 9.096


 31%|███       | 15297/50000 [07:54<16:40, 34.69it/s]

Epochs: 15299 | epoch avg. loss: 0.031 | test avg. loss: 8.741


 31%|███       | 15396/50000 [07:57<12:16, 46.98it/s]

Epochs: 15399 | epoch avg. loss: 0.048 | test avg. loss: 8.808


 31%|███       | 15496/50000 [08:00<12:25, 46.31it/s]

Epochs: 15499 | epoch avg. loss: 0.044 | test avg. loss: 8.198


 31%|███       | 15596/50000 [08:03<12:21, 46.41it/s]

Epochs: 15599 | epoch avg. loss: 0.030 | test avg. loss: 8.339


 31%|███▏      | 15699/50000 [08:06<15:36, 36.62it/s]

Epochs: 15699 | epoch avg. loss: 0.072 | test avg. loss: 8.379


 32%|███▏      | 15796/50000 [08:10<12:44, 44.76it/s]

Epochs: 15799 | epoch avg. loss: 0.045 | test avg. loss: 8.496


 32%|███▏      | 15895/50000 [08:13<11:56, 47.61it/s]

Epochs: 15899 | epoch avg. loss: 0.086 | test avg. loss: 8.529


 32%|███▏      | 15996/50000 [08:15<11:47, 48.04it/s]

Epochs: 15999 | epoch avg. loss: 0.043 | test avg. loss: 8.419


 32%|███▏      | 16097/50000 [08:18<12:00, 47.05it/s]

Epochs: 16099 | epoch avg. loss: 0.032 | test avg. loss: 8.667


 32%|███▏      | 16199/50000 [08:22<17:25, 32.32it/s]

Epochs: 16199 | epoch avg. loss: 0.068 | test avg. loss: 8.477


 33%|███▎      | 16299/50000 [08:25<11:43, 47.88it/s]

Epochs: 16299 | epoch avg. loss: 0.090 | test avg. loss: 8.591


 33%|███▎      | 16396/50000 [08:28<11:54, 47.04it/s]

Epochs: 16399 | epoch avg. loss: 0.050 | test avg. loss: 8.463


 33%|███▎      | 16496/50000 [08:31<12:04, 46.24it/s]

Epochs: 16499 | epoch avg. loss: 0.075 | test avg. loss: 8.281


 33%|███▎      | 16597/50000 [08:34<15:42, 35.44it/s]

Epochs: 16599 | epoch avg. loss: 0.493 | test avg. loss: 9.617


 33%|███▎      | 16698/50000 [08:38<12:21, 44.90it/s]

Epochs: 16699 | epoch avg. loss: 0.033 | test avg. loss: 8.957


 34%|███▎      | 16799/50000 [08:41<11:52, 46.62it/s]

Epochs: 16799 | epoch avg. loss: 0.141 | test avg. loss: 8.496


 34%|███▍      | 16897/50000 [08:44<11:33, 47.72it/s]

Epochs: 16899 | epoch avg. loss: 0.024 | test avg. loss: 8.695


 34%|███▍      | 16999/50000 [08:46<11:36, 47.39it/s]

Epochs: 16999 | epoch avg. loss: 0.260 | test avg. loss: 8.782


 34%|███▍      | 17096/50000 [08:50<17:27, 31.40it/s]

Epochs: 17099 | epoch avg. loss: 0.139 | test avg. loss: 8.332


 34%|███▍      | 17199/50000 [08:53<11:29, 47.57it/s]

Epochs: 17199 | epoch avg. loss: 0.122 | test avg. loss: 8.874


 35%|███▍      | 17299/50000 [08:56<11:37, 46.86it/s]

Epochs: 17299 | epoch avg. loss: 0.062 | test avg. loss: 8.953


 35%|███▍      | 17399/50000 [08:59<11:43, 46.33it/s]

Epochs: 17399 | epoch avg. loss: 0.017 | test avg. loss: 8.543


 35%|███▍      | 17496/50000 [09:03<16:04, 33.71it/s]

Epochs: 17499 | epoch avg. loss: 0.078 | test avg. loss: 8.483


 35%|███▌      | 17599/50000 [09:06<11:32, 46.76it/s]

Epochs: 17599 | epoch avg. loss: 0.076 | test avg. loss: 9.100


 35%|███▌      | 17695/50000 [09:09<11:57, 45.01it/s]

Epochs: 17699 | epoch avg. loss: 0.018 | test avg. loss: 8.547


 36%|███▌      | 17796/50000 [09:12<12:41, 42.27it/s]

Epochs: 17799 | epoch avg. loss: 0.040 | test avg. loss: 8.589


 36%|███▌      | 17896/50000 [09:15<13:38, 39.21it/s]

Epochs: 17899 | epoch avg. loss: 0.055 | test avg. loss: 8.460


 36%|███▌      | 17997/50000 [09:19<11:47, 45.26it/s]

Epochs: 17999 | epoch avg. loss: 0.023 | test avg. loss: 8.412


 36%|███▌      | 18097/50000 [09:23<11:23, 46.68it/s]

Epochs: 18099 | epoch avg. loss: 0.033 | test avg. loss: 8.320


 36%|███▋      | 18197/50000 [09:26<11:33, 45.84it/s]

Epochs: 18199 | epoch avg. loss: 0.257 | test avg. loss: 10.405


 37%|███▋      | 18298/50000 [09:29<15:13, 34.71it/s]

Epochs: 18299 | epoch avg. loss: 0.023 | test avg. loss: 8.538


 37%|███▋      | 18399/50000 [09:33<11:44, 44.87it/s]

Epochs: 18399 | epoch avg. loss: 0.030 | test avg. loss: 8.505


 37%|███▋      | 18495/50000 [09:36<11:15, 46.61it/s]

Epochs: 18499 | epoch avg. loss: 0.090 | test avg. loss: 8.471


 37%|███▋      | 18599/50000 [09:39<10:52, 48.12it/s]

Epochs: 18599 | epoch avg. loss: 0.136 | test avg. loss: 8.808


 37%|███▋      | 18695/50000 [09:42<11:48, 44.16it/s]

Epochs: 18699 | epoch avg. loss: 0.025 | test avg. loss: 8.500


 38%|███▊      | 18796/50000 [09:45<16:52, 30.82it/s]

Epochs: 18799 | epoch avg. loss: 0.065 | test avg. loss: 8.767


 38%|███▊      | 18896/50000 [09:48<11:06, 46.64it/s]

Epochs: 18899 | epoch avg. loss: 0.018 | test avg. loss: 8.552


 38%|███▊      | 18996/50000 [09:51<11:09, 46.28it/s]

Epochs: 18999 | epoch avg. loss: 0.019 | test avg. loss: 8.775


 38%|███▊      | 19096/50000 [09:54<11:01, 46.74it/s]

Epochs: 19099 | epoch avg. loss: 0.028 | test avg. loss: 8.516


 38%|███▊      | 19196/50000 [09:58<14:49, 34.63it/s]

Epochs: 19199 | epoch avg. loss: 0.194 | test avg. loss: 9.023


 39%|███▊      | 19299/50000 [10:01<10:54, 46.91it/s]

Epochs: 19299 | epoch avg. loss: 0.045 | test avg. loss: 8.447


 39%|███▉      | 19395/50000 [10:04<11:05, 45.98it/s]

Epochs: 19399 | epoch avg. loss: 0.032 | test avg. loss: 8.629


 39%|███▉      | 19497/50000 [10:07<10:39, 47.72it/s]

Epochs: 19499 | epoch avg. loss: 0.016 | test avg. loss: 8.666


 39%|███▉      | 19597/50000 [10:10<13:43, 36.93it/s]

Epochs: 19599 | epoch avg. loss: 0.248 | test avg. loss: 8.723


 39%|███▉      | 19697/50000 [10:14<11:38, 43.37it/s]

Epochs: 19699 | epoch avg. loss: 0.043 | test avg. loss: 8.719


 40%|███▉      | 19795/50000 [10:17<10:23, 48.45it/s]

Epochs: 19799 | epoch avg. loss: 0.080 | test avg. loss: 8.861


 40%|███▉      | 19897/50000 [10:20<10:20, 48.49it/s]

Epochs: 19899 | epoch avg. loss: 0.099 | test avg. loss: 8.933


 40%|███▉      | 19999/50000 [10:22<10:40, 46.81it/s]

Epochs: 19999 | epoch avg. loss: 0.013 | test avg. loss: 8.635


 40%|████      | 20097/50000 [10:26<14:53, 33.47it/s]

Epochs: 20099 | epoch avg. loss: 0.061 | test avg. loss: 8.500


 40%|████      | 20198/50000 [10:29<10:27, 47.46it/s]

Epochs: 20199 | epoch avg. loss: 0.051 | test avg. loss: 8.697


 41%|████      | 20298/50000 [10:32<10:32, 46.94it/s]

Epochs: 20299 | epoch avg. loss: 0.033 | test avg. loss: 8.703


 41%|████      | 20395/50000 [10:35<10:35, 46.56it/s]

Epochs: 20399 | epoch avg. loss: 0.014 | test avg. loss: 8.446


 41%|████      | 20496/50000 [10:38<13:56, 35.28it/s]

Epochs: 20499 | epoch avg. loss: 0.065 | test avg. loss: 8.446


 41%|████      | 20597/50000 [10:42<10:42, 45.80it/s]

Epochs: 20599 | epoch avg. loss: 0.079 | test avg. loss: 8.783


 41%|████▏     | 20697/50000 [10:45<09:58, 48.94it/s]

Epochs: 20699 | epoch avg. loss: 0.067 | test avg. loss: 8.427


 42%|████▏     | 20798/50000 [10:48<10:03, 48.42it/s]

Epochs: 20799 | epoch avg. loss: 0.029 | test avg. loss: 8.542


 42%|████▏     | 20895/50000 [10:50<10:10, 47.65it/s]

Epochs: 20899 | epoch avg. loss: 0.034 | test avg. loss: 8.450


 42%|████▏     | 20996/50000 [10:54<15:34, 31.02it/s]

Epochs: 20999 | epoch avg. loss: 0.010 | test avg. loss: 8.692


 42%|████▏     | 21097/50000 [10:57<09:49, 49.03it/s]

Epochs: 21099 | epoch avg. loss: 0.458 | test avg. loss: 8.328


 42%|████▏     | 21199/50000 [11:00<10:10, 47.21it/s]

Epochs: 21199 | epoch avg. loss: 0.017 | test avg. loss: 8.297


 43%|████▎     | 21295/50000 [11:03<10:20, 46.29it/s]

Epochs: 21299 | epoch avg. loss: 0.029 | test avg. loss: 8.295


 43%|████▎     | 21398/50000 [11:06<13:52, 34.38it/s]

Epochs: 21399 | epoch avg. loss: 0.132 | test avg. loss: 8.229


 43%|████▎     | 21495/50000 [11:10<10:18, 46.06it/s]

Epochs: 21499 | epoch avg. loss: 0.031 | test avg. loss: 8.258


 43%|████▎     | 21595/50000 [11:13<10:12, 46.35it/s]

Epochs: 21599 | epoch avg. loss: 0.023 | test avg. loss: 8.294


 43%|████▎     | 21696/50000 [11:16<09:48, 48.08it/s]

Epochs: 21699 | epoch avg. loss: 0.037 | test avg. loss: 8.393


 44%|████▎     | 21796/50000 [11:19<13:01, 36.10it/s]

Epochs: 21799 | epoch avg. loss: 0.019 | test avg. loss: 8.091


 44%|████▍     | 21898/50000 [11:23<10:15, 45.64it/s]

Epochs: 21899 | epoch avg. loss: 0.017 | test avg. loss: 8.213


 44%|████▍     | 21998/50000 [11:26<09:48, 47.60it/s]

Epochs: 21999 | epoch avg. loss: 0.033 | test avg. loss: 8.318


 44%|████▍     | 22095/50000 [11:29<09:41, 47.98it/s]

Epochs: 22099 | epoch avg. loss: 0.010 | test avg. loss: 8.318


 44%|████▍     | 22197/50000 [11:31<09:35, 48.29it/s]

Epochs: 22199 | epoch avg. loss: 0.087 | test avg. loss: 8.148


 45%|████▍     | 22298/50000 [11:37<10:28, 44.10it/s]

Epochs: 22299 | epoch avg. loss: 0.008 | test avg. loss: 8.437


 45%|████▍     | 22398/50000 [11:40<09:54, 46.41it/s]

Epochs: 22399 | epoch avg. loss: 0.038 | test avg. loss: 8.541


 45%|████▍     | 22499/50000 [11:43<09:44, 47.07it/s]

Epochs: 22499 | epoch avg. loss: 0.038 | test avg. loss: 7.929


 45%|████▌     | 22599/50000 [11:46<11:14, 40.60it/s]

Epochs: 22599 | epoch avg. loss: 0.013 | test avg. loss: 8.074


 45%|████▌     | 22696/50000 [11:50<09:46, 46.56it/s]

Epochs: 22699 | epoch avg. loss: 0.030 | test avg. loss: 7.880


 46%|████▌     | 22796/50000 [11:53<09:55, 45.68it/s]

Epochs: 22799 | epoch avg. loss: 0.014 | test avg. loss: 7.981


 46%|████▌     | 22898/50000 [11:55<09:33, 47.24it/s]

Epochs: 22899 | epoch avg. loss: 0.044 | test avg. loss: 8.004


 46%|████▌     | 22998/50000 [11:58<09:34, 46.96it/s]

Epochs: 22999 | epoch avg. loss: 0.025 | test avg. loss: 8.021


 46%|████▌     | 23099/50000 [12:02<13:28, 33.29it/s]

Epochs: 23099 | epoch avg. loss: 0.013 | test avg. loss: 8.033


 46%|████▋     | 23198/50000 [12:05<09:35, 46.58it/s]

Epochs: 23199 | epoch avg. loss: 0.014 | test avg. loss: 7.908


 47%|████▋     | 23298/50000 [12:08<09:38, 46.16it/s]

Epochs: 23299 | epoch avg. loss: 0.057 | test avg. loss: 8.109


 47%|████▋     | 23398/50000 [12:11<09:29, 46.67it/s]

Epochs: 23399 | epoch avg. loss: 0.010 | test avg. loss: 7.915


 47%|████▋     | 23497/50000 [12:15<12:34, 35.10it/s]

Epochs: 23499 | epoch avg. loss: 0.029 | test avg. loss: 8.156


 47%|████▋     | 23599/50000 [12:18<09:24, 46.79it/s]

Epochs: 23599 | epoch avg. loss: 0.084 | test avg. loss: 7.843


 47%|████▋     | 23699/50000 [12:21<09:32, 45.91it/s]

Epochs: 23699 | epoch avg. loss: 0.047 | test avg. loss: 7.584


 48%|████▊     | 23799/50000 [12:24<09:05, 48.00it/s]

Epochs: 23799 | epoch avg. loss: 0.009 | test avg. loss: 7.566


 48%|████▊     | 23896/50000 [12:27<12:58, 33.51it/s]

Epochs: 23899 | epoch avg. loss: 0.038 | test avg. loss: 7.728


 48%|████▊     | 23996/50000 [12:31<09:23, 46.14it/s]

Epochs: 23999 | epoch avg. loss: 0.008 | test avg. loss: 7.584


 48%|████▊     | 24095/50000 [12:34<09:02, 47.76it/s]

Epochs: 24099 | epoch avg. loss: 0.033 | test avg. loss: 7.547


 48%|████▊     | 24195/50000 [12:37<08:59, 47.86it/s]

Epochs: 24199 | epoch avg. loss: 0.024 | test avg. loss: 7.678


 49%|████▊     | 24295/50000 [12:40<08:52, 48.30it/s]

Epochs: 24299 | epoch avg. loss: 0.025 | test avg. loss: 7.622


 49%|████▉     | 24396/50000 [12:44<13:18, 32.05it/s]

Epochs: 24399 | epoch avg. loss: 0.052 | test avg. loss: 7.324


 49%|████▉     | 24495/50000 [12:47<09:20, 45.48it/s]

Epochs: 24499 | epoch avg. loss: 0.007 | test avg. loss: 7.407


 49%|████▉     | 24595/50000 [12:50<08:57, 47.26it/s]

Epochs: 24599 | epoch avg. loss: 0.152 | test avg. loss: 7.331


 49%|████▉     | 24695/50000 [12:53<09:38, 43.78it/s]

Epochs: 24699 | epoch avg. loss: 0.017 | test avg. loss: 7.289


 50%|████▉     | 24796/50000 [12:56<12:13, 34.35it/s]

Epochs: 24799 | epoch avg. loss: 0.023 | test avg. loss: 7.289


 50%|████▉     | 24895/50000 [12:59<09:01, 46.38it/s]

Epochs: 24899 | epoch avg. loss: 0.025 | test avg. loss: 7.440


 50%|████▉     | 24995/50000 [13:02<08:49, 47.23it/s]

Epochs: 24999 | epoch avg. loss: 0.016 | test avg. loss: 7.476


 50%|█████     | 25096/50000 [13:05<08:55, 46.50it/s]

Epochs: 25099 | epoch avg. loss: 0.129 | test avg. loss: 7.666


 50%|█████     | 25199/50000 [13:09<12:30, 33.07it/s]

Epochs: 25199 | epoch avg. loss: 0.137 | test avg. loss: 8.007


 51%|█████     | 25298/50000 [13:12<08:53, 46.28it/s]

Epochs: 25299 | epoch avg. loss: 0.017 | test avg. loss: 7.486


 51%|█████     | 25398/50000 [13:15<08:55, 45.93it/s]

Epochs: 25399 | epoch avg. loss: 0.008 | test avg. loss: 7.244


 51%|█████     | 25498/50000 [13:18<08:40, 47.04it/s]

Epochs: 25499 | epoch avg. loss: 0.011 | test avg. loss: 7.250


 51%|█████     | 25598/50000 [13:21<09:19, 43.65it/s]

Epochs: 25599 | epoch avg. loss: 0.020 | test avg. loss: 7.468


 51%|█████▏    | 25695/50000 [13:25<09:47, 41.36it/s]

Epochs: 25699 | epoch avg. loss: 0.011 | test avg. loss: 7.189


 52%|█████▏    | 25796/50000 [13:28<08:30, 47.37it/s]

Epochs: 25799 | epoch avg. loss: 0.011 | test avg. loss: 7.363


 52%|█████▏    | 25896/50000 [13:31<08:31, 47.13it/s]

Epochs: 25899 | epoch avg. loss: 0.009 | test avg. loss: 7.218


 52%|█████▏    | 25996/50000 [13:34<08:22, 47.80it/s]

Epochs: 25999 | epoch avg. loss: 0.050 | test avg. loss: 7.125


 52%|█████▏    | 26097/50000 [13:38<12:40, 31.43it/s]

Epochs: 26099 | epoch avg. loss: 0.042 | test avg. loss: 7.222


 52%|█████▏    | 26197/50000 [13:41<08:15, 48.04it/s]

Epochs: 26199 | epoch avg. loss: 0.008 | test avg. loss: 7.123


 53%|█████▎    | 26296/50000 [13:44<08:18, 47.55it/s]

Epochs: 26299 | epoch avg. loss: 0.035 | test avg. loss: 7.163


 53%|█████▎    | 26397/50000 [13:47<08:48, 44.64it/s]

Epochs: 26399 | epoch avg. loss: 0.020 | test avg. loss: 7.064


 53%|█████▎    | 26496/50000 [13:50<11:41, 33.48it/s]

Epochs: 26499 | epoch avg. loss: 0.015 | test avg. loss: 7.106


 53%|█████▎    | 26598/50000 [13:54<08:28, 46.03it/s]

Epochs: 26599 | epoch avg. loss: 0.123 | test avg. loss: 7.170


 53%|█████▎    | 26699/50000 [13:57<08:23, 46.30it/s]

Epochs: 26699 | epoch avg. loss: 0.016 | test avg. loss: 7.102


 54%|█████▎    | 26795/50000 [13:59<08:00, 48.29it/s]

Epochs: 26799 | epoch avg. loss: 0.045 | test avg. loss: 7.284


 54%|█████▍    | 26898/50000 [14:03<10:30, 36.64it/s]

Epochs: 26899 | epoch avg. loss: 0.026 | test avg. loss: 7.034


 54%|█████▍    | 26995/50000 [14:06<08:28, 45.27it/s]

Epochs: 26999 | epoch avg. loss: 0.031 | test avg. loss: 7.011


 54%|█████▍    | 27095/50000 [14:09<08:02, 47.43it/s]

Epochs: 27099 | epoch avg. loss: 0.035 | test avg. loss: 7.062


 54%|█████▍    | 27196/50000 [14:12<08:12, 46.26it/s]

Epochs: 27199 | epoch avg. loss: 0.010 | test avg. loss: 7.147


 55%|█████▍    | 27298/50000 [14:18<11:14, 33.63it/s]

Epochs: 27299 | epoch avg. loss: 0.046 | test avg. loss: 7.346


 55%|█████▍    | 27396/50000 [14:21<08:15, 45.65it/s]

Epochs: 27399 | epoch avg. loss: 0.012 | test avg. loss: 7.012


 55%|█████▍    | 27496/50000 [14:24<08:08, 46.03it/s]

Epochs: 27499 | epoch avg. loss: 0.067 | test avg. loss: 7.218


 55%|█████▌    | 27596/50000 [14:27<08:06, 46.07it/s]

Epochs: 27599 | epoch avg. loss: 0.034 | test avg. loss: 6.915


 55%|█████▌    | 27698/50000 [14:30<11:03, 33.64it/s]

Epochs: 27699 | epoch avg. loss: 0.051 | test avg. loss: 6.974


 56%|█████▌    | 27797/50000 [14:34<08:10, 45.27it/s]

Epochs: 27799 | epoch avg. loss: 0.013 | test avg. loss: 6.943


 56%|█████▌    | 27897/50000 [14:37<07:58, 46.23it/s]

Epochs: 27899 | epoch avg. loss: 0.066 | test avg. loss: 7.283


 56%|█████▌    | 27997/50000 [14:40<08:10, 44.88it/s]

Epochs: 27999 | epoch avg. loss: 0.028 | test avg. loss: 6.774


 56%|█████▌    | 28097/50000 [14:43<09:46, 37.32it/s]

Epochs: 28099 | epoch avg. loss: 0.022 | test avg. loss: 6.890


 56%|█████▋    | 28197/50000 [14:47<08:39, 41.98it/s]

Epochs: 28199 | epoch avg. loss: 0.035 | test avg. loss: 6.777


 57%|█████▋    | 28297/50000 [14:50<07:50, 46.09it/s]

Epochs: 28299 | epoch avg. loss: 0.012 | test avg. loss: 6.857


 57%|█████▋    | 28397/50000 [14:53<07:46, 46.28it/s]

Epochs: 28399 | epoch avg. loss: 0.058 | test avg. loss: 7.095


 57%|█████▋    | 28497/50000 [14:56<07:37, 46.98it/s]

Epochs: 28499 | epoch avg. loss: 0.029 | test avg. loss: 6.741


 57%|█████▋    | 28598/50000 [15:00<11:10, 31.92it/s]

Epochs: 28599 | epoch avg. loss: 0.018 | test avg. loss: 6.615


 57%|█████▋    | 28697/50000 [15:03<07:53, 44.99it/s]

Epochs: 28699 | epoch avg. loss: 0.136 | test avg. loss: 6.537


 58%|█████▊    | 28797/50000 [15:06<07:48, 45.24it/s]

Epochs: 28799 | epoch avg. loss: 0.106 | test avg. loss: 6.940


 58%|█████▊    | 28898/50000 [15:09<07:46, 45.20it/s]

Epochs: 28899 | epoch avg. loss: 0.007 | test avg. loss: 6.675


 58%|█████▊    | 28999/50000 [15:12<10:07, 34.57it/s]

Epochs: 28999 | epoch avg. loss: 0.013 | test avg. loss: 6.515


 58%|█████▊    | 29095/50000 [15:16<07:35, 45.87it/s]

Epochs: 29099 | epoch avg. loss: 0.996 | test avg. loss: 6.378


 58%|█████▊    | 29195/50000 [15:19<07:18, 47.48it/s]

Epochs: 29199 | epoch avg. loss: 0.010 | test avg. loss: 6.522


 59%|█████▊    | 29295/50000 [15:22<07:21, 46.95it/s]

Epochs: 29299 | epoch avg. loss: 0.006 | test avg. loss: 6.522


 59%|█████▉    | 29397/50000 [15:25<09:47, 35.08it/s]

Epochs: 29399 | epoch avg. loss: 0.016 | test avg. loss: 6.502


 59%|█████▉    | 29498/50000 [15:29<07:20, 46.56it/s]

Epochs: 29499 | epoch avg. loss: 0.013 | test avg. loss: 6.448


 59%|█████▉    | 29599/50000 [15:32<07:14, 46.94it/s]

Epochs: 29599 | epoch avg. loss: 0.300 | test avg. loss: 6.918


 59%|█████▉    | 29696/50000 [15:35<07:05, 47.77it/s]

Epochs: 29699 | epoch avg. loss: 0.032 | test avg. loss: 6.400


 60%|█████▉    | 29797/50000 [15:37<07:11, 46.83it/s]

Epochs: 29799 | epoch avg. loss: 0.058 | test avg. loss: 6.422


 60%|█████▉    | 29895/50000 [15:42<09:40, 34.62it/s]

Epochs: 29899 | epoch avg. loss: 0.077 | test avg. loss: 6.503


 60%|█████▉    | 29995/50000 [15:45<07:02, 47.31it/s]

Epochs: 29999 | epoch avg. loss: 0.091 | test avg. loss: 6.561


 60%|██████    | 30095/50000 [15:48<07:09, 46.36it/s]

Epochs: 30099 | epoch avg. loss: 0.042 | test avg. loss: 6.404


 60%|██████    | 30195/50000 [15:50<07:23, 44.68it/s]

Epochs: 30199 | epoch avg. loss: 0.028 | test avg. loss: 6.313


 61%|██████    | 30296/50000 [15:54<10:18, 31.84it/s]

Epochs: 30299 | epoch avg. loss: 0.007 | test avg. loss: 6.400


 61%|██████    | 30395/50000 [15:57<06:56, 47.05it/s]

Epochs: 30399 | epoch avg. loss: 0.011 | test avg. loss: 6.422


 61%|██████    | 30495/50000 [16:00<06:50, 47.46it/s]

Epochs: 30499 | epoch avg. loss: 0.025 | test avg. loss: 6.441


 61%|██████    | 30595/50000 [16:03<06:46, 47.71it/s]

Epochs: 30599 | epoch avg. loss: 0.054 | test avg. loss: 6.509


 61%|██████▏   | 30699/50000 [16:07<09:16, 34.71it/s]

Epochs: 30699 | epoch avg. loss: 0.335 | test avg. loss: 6.365


 62%|██████▏   | 30796/50000 [16:10<06:46, 47.24it/s]

Epochs: 30799 | epoch avg. loss: 0.017 | test avg. loss: 6.390


 62%|██████▏   | 30896/50000 [16:13<06:51, 46.40it/s]

Epochs: 30899 | epoch avg. loss: 0.004 | test avg. loss: 6.345


 62%|██████▏   | 30996/50000 [16:16<06:50, 46.24it/s]

Epochs: 30999 | epoch avg. loss: 0.018 | test avg. loss: 6.559


 62%|██████▏   | 31099/50000 [16:20<09:27, 33.29it/s]

Epochs: 31099 | epoch avg. loss: 0.017 | test avg. loss: 6.355


 62%|██████▏   | 31198/50000 [16:23<07:02, 44.54it/s]

Epochs: 31199 | epoch avg. loss: 0.019 | test avg. loss: 6.341


 63%|██████▎   | 31298/50000 [16:26<06:39, 46.77it/s]

Epochs: 31299 | epoch avg. loss: 0.046 | test avg. loss: 6.239


 63%|██████▎   | 31395/50000 [16:29<06:32, 47.39it/s]

Epochs: 31399 | epoch avg. loss: 0.008 | test avg. loss: 6.278


 63%|██████▎   | 31496/50000 [16:32<06:40, 46.22it/s]

Epochs: 31499 | epoch avg. loss: 0.034 | test avg. loss: 6.448


 63%|██████▎   | 31598/50000 [16:36<06:56, 44.14it/s]

Epochs: 31599 | epoch avg. loss: 0.022 | test avg. loss: 6.186


 63%|██████▎   | 31698/50000 [16:39<06:31, 46.71it/s]

Epochs: 31699 | epoch avg. loss: 0.005 | test avg. loss: 6.221


 64%|██████▎   | 31795/50000 [16:42<06:25, 47.22it/s]

Epochs: 31799 | epoch avg. loss: 0.131 | test avg. loss: 6.354


 64%|██████▍   | 31895/50000 [16:45<06:27, 46.70it/s]

Epochs: 31899 | epoch avg. loss: 0.006 | test avg. loss: 6.267


 64%|██████▍   | 31996/50000 [16:49<09:07, 32.91it/s]

Epochs: 31999 | epoch avg. loss: 0.017 | test avg. loss: 6.282


 64%|██████▍   | 32096/50000 [16:52<06:27, 46.18it/s]

Epochs: 32099 | epoch avg. loss: 0.006 | test avg. loss: 6.326


 64%|██████▍   | 32196/50000 [16:55<06:13, 47.63it/s]

Epochs: 32199 | epoch avg. loss: 0.014 | test avg. loss: 6.270


 65%|██████▍   | 32296/50000 [16:58<06:17, 46.84it/s]

Epochs: 32299 | epoch avg. loss: 0.015 | test avg. loss: 6.223


 65%|██████▍   | 32396/50000 [17:02<08:45, 33.49it/s]

Epochs: 32399 | epoch avg. loss: 0.007 | test avg. loss: 6.175


 65%|██████▍   | 32498/50000 [17:05<06:10, 47.22it/s]

Epochs: 32499 | epoch avg. loss: 0.019 | test avg. loss: 6.262


 65%|██████▌   | 32598/50000 [17:08<06:13, 46.56it/s]

Epochs: 32599 | epoch avg. loss: 0.354 | test avg. loss: 6.322


 65%|██████▌   | 32698/50000 [17:11<06:06, 47.27it/s]

Epochs: 32699 | epoch avg. loss: 0.063 | test avg. loss: 6.130


 66%|██████▌   | 32798/50000 [17:14<08:00, 35.80it/s]

Epochs: 32799 | epoch avg. loss: 0.026 | test avg. loss: 6.172


 66%|██████▌   | 32897/50000 [17:18<06:15, 45.61it/s]

Epochs: 32899 | epoch avg. loss: 0.190 | test avg. loss: 6.611


 66%|██████▌   | 32997/50000 [17:21<06:11, 45.77it/s]

Epochs: 32999 | epoch avg. loss: 0.011 | test avg. loss: 6.337


 66%|██████▌   | 33097/50000 [17:24<05:59, 47.02it/s]

Epochs: 33099 | epoch avg. loss: 0.009 | test avg. loss: 6.105


 66%|██████▋   | 33197/50000 [17:27<07:02, 39.76it/s]

Epochs: 33199 | epoch avg. loss: 0.005 | test avg. loss: 6.095


 67%|██████▋   | 33297/50000 [17:31<06:03, 45.96it/s]

Epochs: 33299 | epoch avg. loss: 0.055 | test avg. loss: 6.071


 67%|██████▋   | 33397/50000 [17:37<05:59, 46.19it/s]

Epochs: 33399 | epoch avg. loss: 0.020 | test avg. loss: 6.326


 67%|██████▋   | 33497/50000 [17:40<05:57, 46.21it/s]

Epochs: 33499 | epoch avg. loss: 0.007 | test avg. loss: 6.151


 67%|██████▋   | 33598/50000 [17:44<08:40, 31.51it/s]

Epochs: 33599 | epoch avg. loss: 0.239 | test avg. loss: 6.107


 67%|██████▋   | 33697/50000 [17:47<06:17, 43.21it/s]

Epochs: 33699 | epoch avg. loss: 0.017 | test avg. loss: 6.183


 68%|██████▊   | 33796/50000 [17:50<05:50, 46.17it/s]

Epochs: 33799 | epoch avg. loss: 0.013 | test avg. loss: 6.172


 68%|██████▊   | 33896/50000 [17:53<05:57, 45.05it/s]

Epochs: 33899 | epoch avg. loss: 0.109 | test avg. loss: 6.055


 68%|██████▊   | 33997/50000 [17:57<08:01, 33.22it/s]

Epochs: 33999 | epoch avg. loss: 0.011 | test avg. loss: 6.143


 68%|██████▊   | 34096/50000 [18:00<05:49, 45.51it/s]

Epochs: 34099 | epoch avg. loss: 0.044 | test avg. loss: 6.141


 68%|██████▊   | 34196/50000 [18:03<05:42, 46.14it/s]

Epochs: 34199 | epoch avg. loss: 0.013 | test avg. loss: 5.993


 69%|██████▊   | 34296/50000 [18:06<05:43, 45.77it/s]

Epochs: 34299 | epoch avg. loss: 0.045 | test avg. loss: 6.065


 69%|██████▉   | 34399/50000 [18:10<07:52, 33.01it/s]

Epochs: 34399 | epoch avg. loss: 0.016 | test avg. loss: 6.119


 69%|██████▉   | 34499/50000 [18:14<05:37, 45.96it/s]

Epochs: 34499 | epoch avg. loss: 0.008 | test avg. loss: 6.061


 69%|██████▉   | 34599/50000 [18:17<05:26, 47.14it/s]

Epochs: 34599 | epoch avg. loss: 0.182 | test avg. loss: 6.264


 69%|██████▉   | 34695/50000 [18:19<05:27, 46.70it/s]

Epochs: 34699 | epoch avg. loss: 0.005 | test avg. loss: 6.080


 70%|██████▉   | 34798/50000 [18:23<07:13, 35.09it/s]

Epochs: 34799 | epoch avg. loss: 0.031 | test avg. loss: 5.912


 70%|██████▉   | 34899/50000 [18:26<05:26, 46.27it/s]

Epochs: 34899 | epoch avg. loss: 0.227 | test avg. loss: 6.803


 70%|██████▉   | 34999/50000 [18:29<05:28, 45.62it/s]

Epochs: 34999 | epoch avg. loss: 0.025 | test avg. loss: 6.011


 70%|███████   | 35099/50000 [18:32<05:21, 46.38it/s]

Epochs: 35099 | epoch avg. loss: 0.006 | test avg. loss: 5.989


 70%|███████   | 35199/50000 [18:35<05:16, 46.81it/s]

Epochs: 35199 | epoch avg. loss: 0.004 | test avg. loss: 5.962


 71%|███████   | 35297/50000 [18:40<06:36, 37.05it/s]

Epochs: 35299 | epoch avg. loss: 0.020 | test avg. loss: 6.008


 71%|███████   | 35396/50000 [18:43<05:09, 47.21it/s]

Epochs: 35399 | epoch avg. loss: 0.039 | test avg. loss: 5.966


 71%|███████   | 35497/50000 [18:46<05:05, 47.42it/s]

Epochs: 35499 | epoch avg. loss: 0.020 | test avg. loss: 5.906


 71%|███████   | 35597/50000 [18:49<05:04, 47.33it/s]

Epochs: 35599 | epoch avg. loss: 0.028 | test avg. loss: 6.063


 71%|███████▏  | 35698/50000 [18:52<07:39, 31.15it/s]

Epochs: 35699 | epoch avg. loss: 0.021 | test avg. loss: 5.923


 72%|███████▏  | 35798/50000 [18:56<05:01, 47.11it/s]

Epochs: 35799 | epoch avg. loss: 0.071 | test avg. loss: 5.978


 72%|███████▏  | 35898/50000 [18:59<05:03, 46.41it/s]

Epochs: 35899 | epoch avg. loss: 0.010 | test avg. loss: 5.841


 72%|███████▏  | 35998/50000 [19:02<04:55, 47.43it/s]

Epochs: 35999 | epoch avg. loss: 0.009 | test avg. loss: 5.885


 72%|███████▏  | 36097/50000 [19:05<06:52, 33.72it/s]

Epochs: 36099 | epoch avg. loss: 0.297 | test avg. loss: 5.977


 72%|███████▏  | 36196/50000 [19:09<04:51, 47.29it/s]

Epochs: 36199 | epoch avg. loss: 0.019 | test avg. loss: 5.793


 73%|███████▎  | 36296/50000 [19:12<04:59, 45.81it/s]

Epochs: 36299 | epoch avg. loss: 0.037 | test avg. loss: 5.788


 73%|███████▎  | 36396/50000 [19:15<04:53, 46.36it/s]

Epochs: 36399 | epoch avg. loss: 0.017 | test avg. loss: 5.760


 73%|███████▎  | 36497/50000 [19:18<06:48, 33.05it/s]

Epochs: 36499 | epoch avg. loss: 0.006 | test avg. loss: 5.881


 73%|███████▎  | 36596/50000 [19:22<04:59, 44.83it/s]

Epochs: 36599 | epoch avg. loss: 0.006 | test avg. loss: 5.753


 73%|███████▎  | 36696/50000 [19:25<04:54, 45.13it/s]

Epochs: 36699 | epoch avg. loss: 0.050 | test avg. loss: 5.848


 74%|███████▎  | 36796/50000 [19:28<04:48, 45.81it/s]

Epochs: 36799 | epoch avg. loss: 0.447 | test avg. loss: 6.930


 74%|███████▍  | 36896/50000 [19:31<05:59, 36.43it/s]

Epochs: 36899 | epoch avg. loss: 0.004 | test avg. loss: 5.660


 74%|███████▍  | 36996/50000 [19:35<04:45, 45.51it/s]

Epochs: 36999 | epoch avg. loss: 0.020 | test avg. loss: 5.627


 74%|███████▍  | 37096/50000 [19:38<04:46, 44.99it/s]

Epochs: 37099 | epoch avg. loss: 0.009 | test avg. loss: 5.587


 74%|███████▍  | 37196/50000 [19:41<04:40, 45.72it/s]

Epochs: 37199 | epoch avg. loss: 0.046 | test avg. loss: 5.797


 75%|███████▍  | 37296/50000 [19:44<04:56, 42.83it/s]

Epochs: 37299 | epoch avg. loss: 0.010 | test avg. loss: 5.633


 75%|███████▍  | 37398/50000 [19:48<04:51, 43.18it/s]

Epochs: 37399 | epoch avg. loss: 0.129 | test avg. loss: 5.874


 75%|███████▍  | 37498/50000 [19:51<04:29, 46.33it/s]

Epochs: 37499 | epoch avg. loss: 0.007 | test avg. loss: 5.610


 75%|███████▌  | 37598/50000 [19:54<04:19, 47.75it/s]

Epochs: 37599 | epoch avg. loss: 0.025 | test avg. loss: 5.660


 75%|███████▌  | 37698/50000 [19:57<04:18, 47.67it/s]

Epochs: 37699 | epoch avg. loss: 0.262 | test avg. loss: 5.769


 76%|███████▌  | 37799/50000 [20:02<06:30, 31.28it/s]

Epochs: 37799 | epoch avg. loss: 0.024 | test avg. loss: 5.875


 76%|███████▌  | 37899/50000 [20:04<04:11, 48.10it/s]

Epochs: 37899 | epoch avg. loss: 0.017 | test avg. loss: 5.752


 76%|███████▌  | 37995/50000 [20:07<04:11, 47.82it/s]

Epochs: 37999 | epoch avg. loss: 0.057 | test avg. loss: 5.818


 76%|███████▌  | 38097/50000 [20:10<04:05, 48.40it/s]

Epochs: 38099 | epoch avg. loss: 0.011 | test avg. loss: 5.591


 76%|███████▋  | 38199/50000 [20:14<06:08, 32.00it/s]

Epochs: 38199 | epoch avg. loss: 0.016 | test avg. loss: 5.699


 77%|███████▋  | 38295/50000 [20:18<04:03, 48.12it/s]

Epochs: 38299 | epoch avg. loss: 0.223 | test avg. loss: 5.745


 77%|███████▋  | 38396/50000 [20:20<04:14, 45.60it/s]

Epochs: 38399 | epoch avg. loss: 0.009 | test avg. loss: 5.623


 77%|███████▋  | 38496/50000 [20:23<04:09, 46.18it/s]

Epochs: 38499 | epoch avg. loss: 0.009 | test avg. loss: 5.591


 77%|███████▋  | 38596/50000 [20:27<05:34, 34.06it/s]

Epochs: 38599 | epoch avg. loss: 0.022 | test avg. loss: 5.569


 77%|███████▋  | 38698/50000 [20:31<03:59, 47.26it/s]

Epochs: 38699 | epoch avg. loss: 0.082 | test avg. loss: 5.434


 78%|███████▊  | 38799/50000 [20:34<03:52, 48.09it/s]

Epochs: 38799 | epoch avg. loss: 0.117 | test avg. loss: 5.617


 78%|███████▊  | 38899/50000 [20:37<03:58, 46.64it/s]

Epochs: 38899 | epoch avg. loss: 0.009 | test avg. loss: 5.625


 78%|███████▊  | 38996/50000 [20:40<05:48, 31.55it/s]

Epochs: 38999 | epoch avg. loss: 0.031 | test avg. loss: 5.771


 78%|███████▊  | 39095/50000 [20:44<04:02, 44.88it/s]

Epochs: 39099 | epoch avg. loss: 0.010 | test avg. loss: 5.584


 78%|███████▊  | 39199/50000 [20:47<03:57, 45.41it/s]

Epochs: 39199 | epoch avg. loss: 0.010 | test avg. loss: 5.544


 79%|███████▊  | 39299/50000 [20:50<03:55, 45.45it/s]

Epochs: 39299 | epoch avg. loss: 0.015 | test avg. loss: 5.567


 79%|███████▉  | 39396/50000 [20:53<05:03, 34.93it/s]

Epochs: 39399 | epoch avg. loss: 0.064 | test avg. loss: 5.603


 79%|███████▉  | 39498/50000 [20:57<03:50, 45.58it/s]

Epochs: 39499 | epoch avg. loss: 0.014 | test avg. loss: 5.677


 79%|███████▉  | 39598/50000 [21:00<03:47, 45.68it/s]

Epochs: 39599 | epoch avg. loss: 0.019 | test avg. loss: 5.541


 79%|███████▉  | 39698/50000 [21:03<03:48, 45.03it/s]

Epochs: 39699 | epoch avg. loss: 0.078 | test avg. loss: 5.610


 80%|███████▉  | 39798/50000 [21:06<03:46, 45.10it/s]

Epochs: 39799 | epoch avg. loss: 0.008 | test avg. loss: 5.558


 80%|███████▉  | 39899/50000 [21:10<04:05, 41.11it/s]

Epochs: 39899 | epoch avg. loss: 0.014 | test avg. loss: 5.507


 80%|███████▉  | 39999/50000 [21:13<03:37, 45.90it/s]

Epochs: 39999 | epoch avg. loss: 0.017 | test avg. loss: 5.514


 80%|████████  | 40099/50000 [21:16<03:32, 46.65it/s]

Epochs: 40099 | epoch avg. loss: 0.028 | test avg. loss: 5.446


 80%|████████  | 40199/50000 [21:19<03:31, 46.39it/s]

Epochs: 40199 | epoch avg. loss: 0.012 | test avg. loss: 5.516


 81%|████████  | 40296/50000 [21:23<05:03, 31.98it/s]

Epochs: 40299 | epoch avg. loss: 0.037 | test avg. loss: 5.590


 81%|████████  | 40396/50000 [21:26<03:21, 47.63it/s]

Epochs: 40399 | epoch avg. loss: 0.012 | test avg. loss: 5.425


 81%|████████  | 40497/50000 [21:29<03:15, 48.66it/s]

Epochs: 40499 | epoch avg. loss: 0.288 | test avg. loss: 5.740


 81%|████████  | 40597/50000 [21:32<03:16, 47.93it/s]

Epochs: 40599 | epoch avg. loss: 0.009 | test avg. loss: 5.596


 81%|████████▏ | 40697/50000 [21:39<03:38, 42.49it/s]

Epochs: 40699 | epoch avg. loss: 0.008 | test avg. loss: 5.522


 82%|████████▏ | 40797/50000 [21:42<03:30, 43.72it/s]

Epochs: 40799 | epoch avg. loss: 0.010 | test avg. loss: 5.571


 82%|████████▏ | 40897/50000 [21:45<03:26, 44.14it/s]

Epochs: 40899 | epoch avg. loss: 0.021 | test avg. loss: 5.580


 82%|████████▏ | 40998/50000 [21:49<04:36, 32.57it/s]

Epochs: 40999 | epoch avg. loss: 0.019 | test avg. loss: 5.510


 82%|████████▏ | 41098/50000 [21:53<03:16, 45.26it/s]

Epochs: 41099 | epoch avg. loss: 0.023 | test avg. loss: 5.472


 82%|████████▏ | 41198/50000 [21:56<03:10, 46.14it/s]

Epochs: 41199 | epoch avg. loss: 0.014 | test avg. loss: 5.517


 83%|████████▎ | 41298/50000 [21:59<03:32, 40.98it/s]

Epochs: 41299 | epoch avg. loss: 0.030 | test avg. loss: 5.523


 83%|████████▎ | 41396/50000 [22:03<04:20, 32.98it/s]

Epochs: 41399 | epoch avg. loss: 0.102 | test avg. loss: 5.561


 83%|████████▎ | 41495/50000 [22:06<03:07, 45.35it/s]

Epochs: 41499 | epoch avg. loss: 0.040 | test avg. loss: 5.339


 83%|████████▎ | 41595/50000 [22:09<03:11, 43.90it/s]

Epochs: 41599 | epoch avg. loss: 0.023 | test avg. loss: 5.716


 83%|████████▎ | 41699/50000 [22:12<03:04, 44.91it/s]

Epochs: 41699 | epoch avg. loss: 0.034 | test avg. loss: 5.477


 84%|████████▎ | 41796/50000 [22:16<04:09, 32.82it/s]

Epochs: 41799 | epoch avg. loss: 0.016 | test avg. loss: 5.400


 84%|████████▍ | 41897/50000 [22:20<02:57, 45.56it/s]

Epochs: 41899 | epoch avg. loss: 0.190 | test avg. loss: 5.717


 84%|████████▍ | 41997/50000 [22:23<02:54, 45.90it/s]

Epochs: 41999 | epoch avg. loss: 0.044 | test avg. loss: 5.466


 84%|████████▍ | 42097/50000 [22:26<02:49, 46.59it/s]

Epochs: 42099 | epoch avg. loss: 0.010 | test avg. loss: 5.544


 84%|████████▍ | 42199/50000 [22:29<03:55, 33.19it/s]

Epochs: 42199 | epoch avg. loss: 0.006 | test avg. loss: 5.445


 85%|████████▍ | 42296/50000 [22:33<02:47, 45.95it/s]

Epochs: 42299 | epoch avg. loss: 0.072 | test avg. loss: 5.428


 85%|████████▍ | 42398/50000 [22:36<02:41, 47.04it/s]

Epochs: 42399 | epoch avg. loss: 0.018 | test avg. loss: 5.486


 85%|████████▍ | 42498/50000 [22:39<02:40, 46.61it/s]

Epochs: 42499 | epoch avg. loss: 0.005 | test avg. loss: 5.335


 85%|████████▌ | 42598/50000 [22:42<03:09, 38.97it/s]

Epochs: 42599 | epoch avg. loss: 0.025 | test avg. loss: 5.326


 85%|████████▌ | 42699/50000 [22:46<02:52, 42.29it/s]

Epochs: 42699 | epoch avg. loss: 0.004 | test avg. loss: 5.386


 86%|████████▌ | 42799/50000 [22:49<02:36, 46.13it/s]

Epochs: 42799 | epoch avg. loss: 0.009 | test avg. loss: 5.367


 86%|████████▌ | 42899/50000 [22:52<02:34, 45.97it/s]

Epochs: 42899 | epoch avg. loss: 0.016 | test avg. loss: 5.439


 86%|████████▌ | 42999/50000 [22:55<02:29, 46.82it/s]

Epochs: 42999 | epoch avg. loss: 0.021 | test avg. loss: 5.521


 86%|████████▌ | 43095/50000 [22:59<03:16, 35.12it/s]

Epochs: 43099 | epoch avg. loss: 0.032 | test avg. loss: 5.191


 86%|████████▋ | 43195/50000 [23:02<02:25, 46.79it/s]

Epochs: 43199 | epoch avg. loss: 0.102 | test avg. loss: 5.384


 87%|████████▋ | 43295/50000 [23:06<02:21, 47.23it/s]

Epochs: 43299 | epoch avg. loss: 0.013 | test avg. loss: 5.292


 87%|████████▋ | 43395/50000 [23:09<02:20, 47.15it/s]

Epochs: 43399 | epoch avg. loss: 0.082 | test avg. loss: 5.601


 87%|████████▋ | 43496/50000 [23:13<03:18, 32.70it/s]

Epochs: 43499 | epoch avg. loss: 0.005 | test avg. loss: 5.344


 87%|████████▋ | 43595/50000 [23:16<02:15, 47.28it/s]

Epochs: 43599 | epoch avg. loss: 0.009 | test avg. loss: 5.205


 87%|████████▋ | 43696/50000 [23:19<02:15, 46.58it/s]

Epochs: 43699 | epoch avg. loss: 0.008 | test avg. loss: 5.293


 88%|████████▊ | 43795/50000 [23:22<02:14, 45.98it/s]

Epochs: 43799 | epoch avg. loss: 0.020 | test avg. loss: 5.297


 88%|████████▊ | 43898/50000 [23:26<03:05, 32.84it/s]

Epochs: 43899 | epoch avg. loss: 0.109 | test avg. loss: 5.352


 88%|████████▊ | 43997/50000 [23:29<02:12, 45.17it/s]

Epochs: 43999 | epoch avg. loss: 0.013 | test avg. loss: 5.331


 88%|████████▊ | 44098/50000 [23:32<02:10, 45.13it/s]

Epochs: 44099 | epoch avg. loss: 0.021 | test avg. loss: 5.208


 88%|████████▊ | 44198/50000 [23:35<02:05, 46.18it/s]

Epochs: 44199 | epoch avg. loss: 0.013 | test avg. loss: 5.255


 89%|████████▊ | 44299/50000 [23:39<02:47, 34.11it/s]

Epochs: 44299 | epoch avg. loss: 0.006 | test avg. loss: 5.188


 89%|████████▉ | 44398/50000 [23:42<02:01, 46.23it/s]

Epochs: 44399 | epoch avg. loss: 0.005 | test avg. loss: 5.218


 89%|████████▉ | 44499/50000 [23:45<01:58, 46.50it/s]

Epochs: 44499 | epoch avg. loss: 0.008 | test avg. loss: 5.227


 89%|████████▉ | 44599/50000 [23:48<01:56, 46.47it/s]

Epochs: 44599 | epoch avg. loss: 0.041 | test avg. loss: 5.158


 89%|████████▉ | 44697/50000 [23:52<02:36, 33.78it/s]

Epochs: 44699 | epoch avg. loss: 0.053 | test avg. loss: 5.150


 90%|████████▉ | 44799/50000 [23:56<01:50, 47.04it/s]

Epochs: 44799 | epoch avg. loss: 0.044 | test avg. loss: 5.057


 90%|████████▉ | 44899/50000 [23:59<01:47, 47.40it/s]

Epochs: 44899 | epoch avg. loss: 0.017 | test avg. loss: 5.219


 90%|████████▉ | 44995/50000 [24:02<01:43, 48.38it/s]

Epochs: 44999 | epoch avg. loss: 0.169 | test avg. loss: 5.089


 90%|█████████ | 45099/50000 [24:05<02:16, 36.01it/s]

Epochs: 45099 | epoch avg. loss: 0.018 | test avg. loss: 5.173


 90%|█████████ | 45198/50000 [24:09<01:48, 44.38it/s]

Epochs: 45199 | epoch avg. loss: 0.008 | test avg. loss: 5.186


 91%|█████████ | 45298/50000 [24:12<01:42, 46.06it/s]

Epochs: 45299 | epoch avg. loss: 0.032 | test avg. loss: 5.089


 91%|█████████ | 45398/50000 [24:15<01:40, 45.63it/s]

Epochs: 45399 | epoch avg. loss: 0.005 | test avg. loss: 5.137


 91%|█████████ | 45498/50000 [24:18<01:37, 46.13it/s]

Epochs: 45499 | epoch avg. loss: 0.006 | test avg. loss: 5.090


 91%|█████████ | 45595/50000 [24:22<01:56, 37.69it/s]

Epochs: 45599 | epoch avg. loss: 0.006 | test avg. loss: 5.168


 91%|█████████▏| 45696/50000 [24:25<01:31, 46.94it/s]

Epochs: 45699 | epoch avg. loss: 0.076 | test avg. loss: 5.264


 92%|█████████▏| 45796/50000 [24:29<01:29, 47.05it/s]

Epochs: 45799 | epoch avg. loss: 0.030 | test avg. loss: 5.326


 92%|█████████▏| 45896/50000 [24:32<01:25, 47.86it/s]

Epochs: 45899 | epoch avg. loss: 0.067 | test avg. loss: 4.977


 92%|█████████▏| 45997/50000 [24:36<02:14, 29.83it/s]

Epochs: 45999 | epoch avg. loss: 0.010 | test avg. loss: 5.006


 92%|█████████▏| 46095/50000 [24:39<01:22, 47.38it/s]

Epochs: 46099 | epoch avg. loss: 0.400 | test avg. loss: 5.556


 92%|█████████▏| 46195/50000 [24:42<01:32, 41.35it/s]

Epochs: 46199 | epoch avg. loss: 0.030 | test avg. loss: 4.991


 93%|█████████▎| 46295/50000 [24:45<01:25, 43.16it/s]

Epochs: 46299 | epoch avg. loss: 0.004 | test avg. loss: 4.999


 93%|█████████▎| 46396/50000 [24:49<01:58, 30.44it/s]

Epochs: 46399 | epoch avg. loss: 0.100 | test avg. loss: 4.928


 93%|█████████▎| 46495/50000 [24:52<01:16, 46.06it/s]

Epochs: 46499 | epoch avg. loss: 0.026 | test avg. loss: 4.997


 93%|█████████▎| 46595/50000 [24:55<01:12, 47.25it/s]

Epochs: 46599 | epoch avg. loss: 0.015 | test avg. loss: 5.046


 93%|█████████▎| 46695/50000 [24:58<01:10, 47.18it/s]

Epochs: 46699 | epoch avg. loss: 0.039 | test avg. loss: 4.790


 94%|█████████▎| 46796/50000 [25:02<01:43, 30.89it/s]

Epochs: 46799 | epoch avg. loss: 0.009 | test avg. loss: 4.968


 94%|█████████▍| 46895/50000 [25:06<01:08, 45.45it/s]

Epochs: 46899 | epoch avg. loss: 0.016 | test avg. loss: 5.070


 94%|█████████▍| 46996/50000 [25:09<01:04, 46.80it/s]

Epochs: 46999 | epoch avg. loss: 0.016 | test avg. loss: 4.865


 94%|█████████▍| 47096/50000 [25:12<01:05, 44.13it/s]

Epochs: 47099 | epoch avg. loss: 0.018 | test avg. loss: 4.821


 94%|█████████▍| 47199/50000 [25:15<01:19, 35.28it/s]

Epochs: 47199 | epoch avg. loss: 0.004 | test avg. loss: 4.901


 95%|█████████▍| 47298/50000 [25:19<00:58, 45.94it/s]

Epochs: 47299 | epoch avg. loss: 0.012 | test avg. loss: 4.741


 95%|█████████▍| 47398/50000 [25:22<00:57, 44.97it/s]

Epochs: 47399 | epoch avg. loss: 0.009 | test avg. loss: 4.853


 95%|█████████▍| 47498/50000 [25:25<00:55, 45.20it/s]

Epochs: 47499 | epoch avg. loss: 0.057 | test avg. loss: 4.717


 95%|█████████▌| 47597/50000 [25:29<01:11, 33.62it/s]

Epochs: 47599 | epoch avg. loss: 0.012 | test avg. loss: 4.710


 95%|█████████▌| 47698/50000 [25:32<00:49, 46.61it/s]

Epochs: 47699 | epoch avg. loss: 0.018 | test avg. loss: 4.835


 96%|█████████▌| 47798/50000 [25:35<00:47, 46.07it/s]

Epochs: 47799 | epoch avg. loss: 0.055 | test avg. loss: 4.597


 96%|█████████▌| 47899/50000 [25:38<00:44, 47.16it/s]

Epochs: 47899 | epoch avg. loss: 0.052 | test avg. loss: 4.730


 96%|█████████▌| 47996/50000 [25:42<00:59, 33.76it/s]

Epochs: 47999 | epoch avg. loss: 0.010 | test avg. loss: 4.664


 96%|█████████▌| 48096/50000 [25:46<00:42, 45.15it/s]

Epochs: 48099 | epoch avg. loss: 0.020 | test avg. loss: 4.642


 96%|█████████▋| 48197/50000 [25:49<00:39, 46.14it/s]

Epochs: 48199 | epoch avg. loss: 0.006 | test avg. loss: 4.669


 97%|█████████▋| 48297/50000 [25:52<00:37, 45.01it/s]

Epochs: 48299 | epoch avg. loss: 0.014 | test avg. loss: 4.576


 97%|█████████▋| 48396/50000 [25:55<00:45, 35.54it/s]

Epochs: 48399 | epoch avg. loss: 0.019 | test avg. loss: 4.542


 97%|█████████▋| 48496/50000 [25:59<00:33, 44.39it/s]

Epochs: 48499 | epoch avg. loss: 0.009 | test avg. loss: 4.692


 97%|█████████▋| 48596/50000 [26:02<00:31, 44.53it/s]

Epochs: 48599 | epoch avg. loss: 0.017 | test avg. loss: 4.887


 97%|█████████▋| 48696/50000 [26:05<00:27, 46.87it/s]

Epochs: 48699 | epoch avg. loss: 0.022 | test avg. loss: 4.491


 98%|█████████▊| 48796/50000 [26:08<00:29, 40.49it/s]

Epochs: 48799 | epoch avg. loss: 0.012 | test avg. loss: 4.584


 98%|█████████▊| 48898/50000 [26:12<00:24, 44.16it/s]

Epochs: 48899 | epoch avg. loss: 0.114 | test avg. loss: 4.415


 98%|█████████▊| 48998/50000 [26:16<00:21, 45.72it/s]

Epochs: 48999 | epoch avg. loss: 0.039 | test avg. loss: 4.604


 98%|█████████▊| 49099/50000 [26:19<00:19, 46.89it/s]

Epochs: 49099 | epoch avg. loss: 0.008 | test avg. loss: 4.540


 98%|█████████▊| 49199/50000 [26:22<00:18, 43.89it/s]

Epochs: 49199 | epoch avg. loss: 0.010 | test avg. loss: 4.506


 99%|█████████▊| 49297/50000 [26:26<00:19, 36.86it/s]

Epochs: 49299 | epoch avg. loss: 0.014 | test avg. loss: 4.522


 99%|█████████▉| 49396/50000 [26:29<00:13, 44.72it/s]

Epochs: 49399 | epoch avg. loss: 0.028 | test avg. loss: 4.575


 99%|█████████▉| 49497/50000 [26:32<00:11, 44.85it/s]

Epochs: 49499 | epoch avg. loss: 0.004 | test avg. loss: 4.465


 99%|█████████▉| 49597/50000 [26:40<00:09, 42.45it/s]

Epochs: 49599 | epoch avg. loss: 0.012 | test avg. loss: 4.403


 99%|█████████▉| 49697/50000 [26:43<00:06, 45.69it/s]

Epochs: 49699 | epoch avg. loss: 0.020 | test avg. loss: 4.434


100%|█████████▉| 49797/50000 [26:46<00:04, 43.62it/s]

Epochs: 49799 | epoch avg. loss: 0.010 | test avg. loss: 4.406


100%|█████████▉| 49897/50000 [26:49<00:02, 43.64it/s]

Epochs: 49899 | epoch avg. loss: 0.045 | test avg. loss: 4.430


100%|█████████▉| 49995/50000 [26:53<00:00, 43.63it/s]

Epochs: 49999 | epoch avg. loss: 0.011 | test avg. loss: 4.447


100%|██████████| 50000/50000 [26:54<00:00, 30.96it/s]


In [5]:
# Save the results to create the plots

qm9_none_results_filename = os.path.join(results_dir, "qm9_size100_alpha3.0000none.pt")

if os.path.exists(qm9_none_results_filename):
    qm9_none_data = torch.load(qm9_none_results_filename)
else:
    print(f"File {qm9_none_results_filename} not found!")

NameError: name 'os' is not defined

In [18]:
# for ma

In [19]:
if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--label", default="")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--size", type=int, default=100)
    parser.add_argument("--init_scale", type=float, default=3.0) # init_scale 1.0 no grokking / init_scale 3.0 grokking

    # Grokfast
    parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="ma")
    parser.add_argument("--alpha", type=float, default=0.99)
    parser.add_argument("--window_size", type=int, default=100)
    parser.add_argument("--lamb", type=float, default=5.0)
    args, unknown = parser.parse_known_args()

    filter_str = ('_' if args.label != '' else '') + args.filter
    window_size_str = f'_w{args.window_size}'
    alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
    lamb_str = f'_l{args.lamb:.2f}'.replace('.', '')

    model_suffix = f'size{args.size}_alpha{args.init_scale:.4f}'

    if args.filter == 'none':
        filter_suffix = ''
    elif args.filter == 'ma':
        filter_suffix = window_size_str + lamb_str
    elif args.filter == 'ema':
        filter_suffix = alpha_str + lamb_str
    else:
        raise ValueError(f"Unrecognized filter type {args.filter}")

    optim_suffix = ''
    if args.weight_decay != 0:
        optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
    if args.lr != 1e-3:
        optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

    args.label = args.label + model_suffix + filter_str + filter_suffix + optim_suffix
    print(f'Experiment results saved under name: {args.label}')

    main(args)

Experiment results saved under name: size100_alpha3.0000ma_w100_l500


  0%|          | 97/50000 [00:04<57:18, 14.51it/s]

Epochs: 99 | epoch avg. loss: 82.260 | test avg. loss: 114.463


  0%|          | 199/50000 [00:09<39:16, 21.13it/s]

Epochs: 199 | epoch avg. loss: 9.391 | test avg. loss: 114.326


  1%|          | 298/50000 [00:15<50:56, 16.26it/s]

Epochs: 299 | epoch avg. loss: 10.336 | test avg. loss: 100.006


  1%|          | 399/50000 [00:21<39:49, 20.76it/s]

Epochs: 399 | epoch avg. loss: 8.863 | test avg. loss: 69.784


  1%|          | 498/50000 [00:27<39:29, 20.89it/s]

Epochs: 499 | epoch avg. loss: 5.127 | test avg. loss: 28.776


  1%|          | 598/50000 [00:33<38:46, 21.24it/s]

Epochs: 599 | epoch avg. loss: 15.391 | test avg. loss: 33.702


  1%|▏         | 699/50000 [00:38<39:52, 20.60it/s]

Epochs: 699 | epoch avg. loss: 11.911 | test avg. loss: 14.730


  2%|▏         | 798/50000 [00:45<58:18, 14.06it/s]

Epochs: 799 | epoch avg. loss: 9.925 | test avg. loss: 15.538


  2%|▏         | 897/50000 [00:50<40:29, 20.21it/s]

Epochs: 899 | epoch avg. loss: 6.531 | test avg. loss: 16.211


  2%|▏         | 999/50000 [00:56<49:42, 16.43it/s]

Epochs: 999 | epoch avg. loss: 2.824 | test avg. loss: 6.733


  2%|▏         | 1097/50000 [01:03<38:14, 21.31it/s]

Epochs: 1099 | epoch avg. loss: 10.869 | test avg. loss: 13.659


  2%|▏         | 1198/50000 [01:08<38:18, 21.23it/s]

Epochs: 1199 | epoch avg. loss: 7.749 | test avg. loss: 14.113


  3%|▎         | 1297/50000 [01:14<38:24, 21.14it/s]

Epochs: 1299 | epoch avg. loss: 15.636 | test avg. loss: 18.792


  3%|▎         | 1399/50000 [01:20<38:01, 21.30it/s]

Epochs: 1399 | epoch avg. loss: 11.635 | test avg. loss: 11.494


  3%|▎         | 1498/50000 [01:26<57:01, 14.18it/s]

Epochs: 1499 | epoch avg. loss: 9.226 | test avg. loss: 10.864


  3%|▎         | 1597/50000 [01:32<38:32, 20.93it/s]

Epochs: 1599 | epoch avg. loss: 10.205 | test avg. loss: 14.243


  3%|▎         | 1609/50000 [01:33<46:42, 17.27it/s]


KeyboardInterrupt: 