In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import uproot

import numpy as np
import pandas as pd
import awkward as ak
import time
import matplotlib.pyplot as plt

from utils.datasets import *
from utils.training import *
from models.models import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Data Preprocessing

root = "/Users/markmatthewman/Projects/PhD/EgammaDNN/data/GenSim/TICLv4_Mustache/electron"
fname = "HLTAnalyzerTree_IDEAL_Flat_train.root"
tree = "egRegDataEcalHLTV1"
batches = 256

dataset = RegressionDataset(root,fname,tree)
dataloader = DataLoader(dataset,batch_size=batches,shuffle=True)

In [3]:
# Set Trainer 

lr = 0.001
device = torch.device("cpu")

# Select Model
model = RegressionDNN(7)
model.to(device)

# Select Loss function
loss_fn = nn.MSELoss()

# Select optimiser
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

# Initialize Trainer
trainer = Trainer(model, loss_fn, dataloader, optimizer)


In [14]:

for batch, data_batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    #data_batch_device = {key: val.to(self.device) for key, val in data_batch.items()}
    inpt = data_batch["features"].to(device)
    tgt = data_batch["targets"].to(device)
    #print(tgt.shape)
    d = model(inpt)
    logprob = model.log_prob(tgt,inpt)
    loss = -logprob.mean()

    # Backpropagation
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 10 == 0:
        tqdm.write(f"loss: {loss.item():>7f}")

  0%|                                         | 11/3802 [00:00<01:12, 52.16it/s]

loss: -1.270175
loss: -1.230481


  1%|▎                                        | 32/3802 [00:00<01:00, 62.20it/s]

loss: -1.190065
loss: -1.290922


  1%|▌                                        | 53/3802 [00:00<00:59, 63.54it/s]

loss: -1.383662
loss: -1.484747


  2%|▊                                        | 74/3802 [00:01<00:57, 64.91it/s]

loss: -1.463675
loss: -1.290262


  2%|▉                                        | 88/3802 [00:01<00:55, 67.04it/s]

loss: -1.422099
loss: -1.345874


  3%|█▏                                      | 109/3802 [00:01<00:54, 67.56it/s]

loss: -1.389878
loss: -1.490925


  3%|█▎                                      | 130/3802 [00:02<00:54, 66.82it/s]

loss: -1.379824
loss: -1.188957


  4%|█▌                                      | 151/3802 [00:02<00:55, 66.34it/s]

loss: -1.307247
loss: -1.550916


  5%|█▊                                      | 172/3802 [00:02<00:54, 66.03it/s]

loss: -1.441195
loss: -1.523895


  5%|██                                      | 193/3802 [00:02<00:53, 66.88it/s]

loss: -1.282200
loss: -1.499165


  6%|██▎                                     | 214/3802 [00:03<00:53, 66.79it/s]

loss: -1.389077
loss: -1.318171


  6%|██▍                                     | 228/3802 [00:03<00:55, 64.10it/s]

loss: -1.409667
loss: -1.495718


  7%|██▌                                     | 249/3802 [00:03<00:55, 64.14it/s]

loss: -1.582566
loss: -1.399485


  7%|██▊                                     | 270/3802 [00:04<00:57, 61.24it/s]

loss: -1.554600
loss: -1.539766


  8%|███                                     | 291/3802 [00:04<00:57, 61.51it/s]

loss: -1.472245
loss: -1.511321


  8%|███▎                                    | 312/3802 [00:04<00:54, 63.53it/s]

loss: -1.469832
loss: -1.494045


  9%|███▌                                    | 333/3802 [00:05<00:54, 63.51it/s]

loss: -1.387250
loss: -1.363161


  9%|███▋                                    | 354/3802 [00:05<00:52, 65.99it/s]

loss: -1.489191
loss: -1.521156


 10%|███▊                                    | 368/3802 [00:05<00:51, 67.02it/s]

loss: -1.403368
loss: -1.551369


 10%|████                                    | 389/3802 [00:06<00:51, 66.12it/s]

loss: -1.357093
loss: -1.394794


 11%|████▎                                   | 411/3802 [00:06<00:50, 67.24it/s]

loss: -1.435943
loss: -1.479480


 11%|████▌                                   | 432/3802 [00:06<00:51, 65.25it/s]

loss: -1.425736
loss: -1.490411


 12%|████▊                                   | 453/3802 [00:07<00:50, 65.92it/s]

loss: -1.532652
loss: -1.433146


 12%|████▉                                   | 474/3802 [00:07<00:50, 66.45it/s]

loss: -1.480434
loss: -1.377119


 13%|█████▏                                  | 488/3802 [00:07<00:50, 66.14it/s]

loss: -1.300243
loss: -1.456789


 13%|█████▎                                  | 510/3802 [00:07<00:49, 66.97it/s]

loss: -1.421783
loss: -1.271810


 14%|█████▌                                  | 532/3802 [00:08<00:47, 68.89it/s]

loss: -1.346556
loss: -1.426232


 15%|█████▊                                  | 554/3802 [00:08<00:46, 69.68it/s]

loss: -1.566311
loss: -1.312644


 15%|█████▉                                  | 569/3802 [00:08<00:46, 70.04it/s]

loss: -1.223027
loss: -1.243328


 16%|██████▏                                 | 593/3802 [00:09<00:45, 70.51it/s]

loss: -1.390381
loss: -1.475520


 16%|██████▍                                 | 609/3802 [00:09<00:45, 70.56it/s]

loss: -1.420006
loss: -1.514978


 17%|██████▋                                 | 633/3802 [00:09<00:45, 70.14it/s]

loss: -1.421415
loss: -1.365024


 17%|██████▊                                 | 649/3802 [00:09<00:45, 69.90it/s]

loss: -1.368657
loss: -1.335135


 18%|███████                                 | 671/3802 [00:10<00:44, 69.91it/s]

loss: -1.186678
loss: -1.456948


 18%|███████▏                                | 687/3802 [00:10<00:44, 70.01it/s]

loss: -1.583101
loss: -1.333737


 19%|███████▍                                | 711/3802 [00:10<00:43, 70.57it/s]

loss: -1.432779
loss: -1.417841


 19%|███████▋                                | 735/3802 [00:11<00:43, 71.11it/s]

loss: -1.312701
loss: -1.505918


 20%|███████▉                                | 751/3802 [00:11<00:43, 70.74it/s]

loss: -1.291131
loss: -1.415388


 20%|████████▏                               | 775/3802 [00:11<00:42, 70.56it/s]

loss: -1.251537
loss: -1.389075


 21%|████████▎                               | 791/3802 [00:11<00:42, 70.41it/s]

loss: -1.370810
loss: -1.359791


 21%|████████▍                               | 807/3802 [00:12<00:42, 70.65it/s]

loss: -1.625158
loss: -1.443351


 22%|████████▋                               | 831/3802 [00:12<00:42, 70.31it/s]

loss: -1.305503
loss: -1.555517


 22%|████████▉                               | 854/3802 [00:12<00:42, 69.67it/s]

loss: -1.516043
loss: -1.241970


 23%|█████████▏                              | 869/3802 [00:13<00:41, 70.12it/s]

loss: -1.331554
loss: -1.362522


 23%|█████████▎                              | 891/3802 [00:13<00:41, 69.74it/s]

loss: -1.521779
loss: -1.312100


 24%|█████████▋                              | 915/3802 [00:13<00:41, 70.18it/s]

loss: -1.440286
loss: -1.448449


 24%|█████████▊                              | 931/3802 [00:13<00:41, 70.00it/s]

loss: -1.313336
loss: -1.280041


 25%|██████████                              | 954/3802 [00:14<00:41, 68.96it/s]

loss: -1.304480
loss: -1.383428


 25%|██████████▏                             | 968/3802 [00:14<00:40, 69.35it/s]

loss: -1.466916
loss: -1.397693


 26%|██████████▍                             | 990/3802 [00:14<00:40, 69.90it/s]

loss: -1.458210
loss: -1.421335


 27%|██████████▎                            | 1011/3802 [00:15<00:39, 69.79it/s]

loss: -1.255343
loss: -1.371920


 27%|██████████▌                            | 1032/3802 [00:15<00:39, 69.65it/s]

loss: -1.325181
loss: -1.424194


 28%|██████████▊                            | 1048/3802 [00:15<00:39, 69.78it/s]

loss: -1.305066
loss: -1.506339


 28%|██████████▉                            | 1070/3802 [00:15<00:39, 69.23it/s]

loss: -1.349962
loss: -1.374468


 29%|███████████▏                           | 1091/3802 [00:16<00:39, 69.38it/s]

loss: -1.522952
loss: -1.403215


 29%|███████████▍                           | 1114/3802 [00:16<00:38, 70.01it/s]

loss: -1.509225
loss: -1.621358


 30%|███████████▌                           | 1130/3802 [00:16<00:37, 70.65it/s]

loss: -1.422669
loss: -1.490735


 30%|███████████▊                           | 1154/3802 [00:17<00:37, 70.20it/s]

loss: -1.485520
loss: -1.338782


 31%|████████████                           | 1170/3802 [00:17<00:37, 70.25it/s]

loss: -1.368229
loss: -1.387755


 31%|████████████▏                          | 1194/3802 [00:17<00:37, 70.09it/s]

loss: -1.519755
loss: -1.161150


 32%|████████████▍                          | 1210/3802 [00:17<00:36, 70.70it/s]

loss: -1.449285
loss: -1.468556


 32%|████████████▋                          | 1234/3802 [00:18<00:36, 70.21it/s]

loss: -1.431342
loss: -1.464530


 33%|████████████▊                          | 1250/3802 [00:18<00:36, 70.53it/s]

loss: -1.441528
loss: -1.466814


 34%|█████████████                          | 1274/3802 [00:18<00:35, 70.33it/s]

loss: -1.306199
loss: -1.546292


 34%|█████████████▏                         | 1290/3802 [00:19<00:35, 70.50it/s]

loss: -1.251554
loss: -1.399756


 35%|█████████████▍                         | 1312/3802 [00:19<00:35, 69.79it/s]

loss: -1.277295
loss: -1.436903


 35%|█████████████▌                         | 1328/3802 [00:19<00:35, 70.29it/s]

loss: -1.453049
loss: -1.194836


 36%|█████████████▊                         | 1352/3802 [00:19<00:35, 69.65it/s]

loss: -1.454751
loss: -1.482011


 36%|██████████████                         | 1368/3802 [00:20<00:34, 69.93it/s]

loss: -1.319474
loss: -1.418439


 37%|██████████████▎                        | 1392/3802 [00:20<00:34, 70.14it/s]

loss: -1.416274
loss: -1.560800


 37%|██████████████▍                        | 1408/3802 [00:20<00:34, 70.18it/s]

loss: -1.398128
loss: -1.313587


 38%|██████████████▋                        | 1432/3802 [00:21<00:33, 70.01it/s]

loss: -1.430360
loss: -1.446876


 38%|██████████████▊                        | 1448/3802 [00:21<00:33, 70.03it/s]

loss: -1.344194
loss: -1.422024


 39%|███████████████                        | 1471/3802 [00:21<00:33, 69.68it/s]

loss: -1.413548
loss: -1.315737


 39%|███████████████▎                       | 1493/3802 [00:21<00:33, 69.53it/s]

loss: -1.242683
loss: -1.152575


 40%|███████████████▍                       | 1508/3802 [00:22<00:32, 69.81it/s]

loss: -1.435422
loss: -1.412110


 40%|███████████████▋                       | 1530/3802 [00:22<00:33, 68.12it/s]

loss: -1.464130
loss: -1.421698


 41%|███████████████▉                       | 1552/3802 [00:22<00:32, 69.14it/s]

loss: -1.465618
loss: -1.396948


 41%|████████████████▏                      | 1574/3802 [00:23<00:32, 69.26it/s]

loss: -1.343415
loss: -1.418956


 42%|████████████████▎                      | 1589/3802 [00:23<00:31, 69.69it/s]

loss: -1.424214
loss: -1.416392


 42%|████████████████▌                      | 1612/3802 [00:23<00:31, 69.80it/s]

loss: -1.493028
loss: -1.263265


 43%|████████████████▋                      | 1628/3802 [00:23<00:30, 70.17it/s]

loss: -1.285800
loss: -1.246157


 43%|████████████████▉                      | 1650/3802 [00:24<00:31, 68.83it/s]

loss: -1.318126
loss: -1.462534


 44%|█████████████████▏                     | 1671/3802 [00:24<00:30, 68.98it/s]

loss: -1.500891
loss: -1.306796


 45%|█████████████████▎                     | 1693/3802 [00:24<00:30, 69.49it/s]

loss: -1.272650
loss: -1.376595


 45%|█████████████████▌                     | 1709/3802 [00:25<00:29, 70.21it/s]

loss: -1.530321
loss: -1.506292


 46%|█████████████████▊                     | 1732/3802 [00:25<00:29, 69.83it/s]

loss: -1.339517
loss: -1.157130


 46%|█████████████████▉                     | 1748/3802 [00:25<00:29, 70.28it/s]

loss: -1.291795
loss: -1.503025


 47%|██████████████████▏                    | 1772/3802 [00:25<00:28, 70.55it/s]

loss: -1.469760
loss: -1.351255


 47%|██████████████████▎                    | 1788/3802 [00:26<00:28, 70.35it/s]

loss: -1.494430
loss: -1.454517


 48%|██████████████████▌                    | 1812/3802 [00:26<00:28, 70.40it/s]

loss: -1.220127
loss: -1.416148


 48%|██████████████████▊                    | 1828/3802 [00:26<00:28, 70.35it/s]

loss: -1.525632
loss: -1.463606


 49%|██████████████████▉                    | 1852/3802 [00:27<00:27, 69.97it/s]

loss: -1.503128
loss: -1.483820


 49%|███████████████████▏                   | 1873/3802 [00:27<00:27, 69.52it/s]

loss: -1.563280
loss: -1.407609


 50%|███████████████████▍                   | 1889/3802 [00:27<00:27, 69.87it/s]

loss: -1.542128
loss: -1.365997


 50%|███████████████████▌                   | 1911/3802 [00:27<00:27, 69.30it/s]

loss: -1.404619
loss: -1.513947


 51%|███████████████████▊                   | 1933/3802 [00:28<00:26, 69.23it/s]

loss: -1.216488
loss: -1.377867


 51%|███████████████████▉                   | 1948/3802 [00:28<00:26, 69.62it/s]

loss: -1.380049
loss: -1.426147


 52%|████████████████████▏                  | 1970/3802 [00:28<00:26, 69.97it/s]

loss: -1.420817
loss: -1.415710


 52%|████████████████████▍                  | 1992/3802 [00:29<00:26, 69.61it/s]

loss: -1.274779
loss: -1.423167


 53%|████████████████████▋                  | 2014/3802 [00:29<00:25, 69.04it/s]

loss: -1.469211
loss: -1.308621


 53%|████████████████████▊                  | 2030/3802 [00:29<00:25, 69.68it/s]

loss: -1.473648
loss: -1.408223


 54%|█████████████████████                  | 2051/3802 [00:29<00:25, 69.10it/s]

loss: -1.172691
loss: -1.501368


 55%|█████████████████████▎                 | 2073/3802 [00:30<00:25, 68.75it/s]

loss: -1.371759
loss: -1.213959


 55%|█████████████████████▍                 | 2089/3802 [00:30<00:24, 69.78it/s]

loss: -1.414478
loss: -1.310123


 56%|█████████████████████▋                 | 2111/3802 [00:30<00:24, 69.84it/s]

loss: -1.326071
loss: -1.461606


 56%|█████████████████████▉                 | 2133/3802 [00:31<00:24, 68.98it/s]

loss: -1.478070
loss: -1.426969


 57%|██████████████████████                 | 2149/3802 [00:31<00:23, 69.81it/s]

loss: -1.461978
loss: -1.434147


 57%|██████████████████████▎                | 2172/3802 [00:31<00:23, 69.73it/s]

loss: -1.556980
loss: -1.295818


 58%|██████████████████████▍                | 2188/3802 [00:31<00:22, 70.24it/s]

loss: -1.375540
loss: -1.321253


 58%|██████████████████████▋                | 2212/3802 [00:32<00:22, 70.37it/s]

loss: -1.373929
loss: -1.532448


 59%|██████████████████████▊                | 2228/3802 [00:32<00:22, 70.21it/s]

loss: -1.204849
loss: -1.346397


 59%|███████████████████████                | 2252/3802 [00:32<00:22, 70.19it/s]

loss: -1.277015
loss: -1.431075


 60%|███████████████████████▎               | 2274/3802 [00:33<00:22, 68.00it/s]

loss: -1.369928
loss: -1.471885


 60%|███████████████████████▍               | 2288/3802 [00:33<00:21, 68.88it/s]

loss: -1.257602
loss: -1.411980


 61%|███████████████████████▋               | 2310/3802 [00:33<00:21, 70.01it/s]

loss: -1.430625
loss: -1.401792


 61%|███████████████████████▉               | 2333/3802 [00:33<00:20, 70.08it/s]

loss: -1.272278
loss: -1.328473


 62%|████████████████████████               | 2349/3802 [00:34<00:20, 70.41it/s]

loss: -1.405688
loss: -1.554468


 62%|████████████████████████▎              | 2373/3802 [00:34<00:20, 70.03it/s]

loss: -1.379054
loss: -1.538694


 63%|████████████████████████▌              | 2389/3802 [00:34<00:20, 70.46it/s]

loss: -1.324794
loss: -1.415174


 63%|████████████████████████▊              | 2413/3802 [00:35<00:19, 69.82it/s]

loss: -1.489518
loss: -1.371073


 64%|████████████████████████▉              | 2429/3802 [00:35<00:19, 70.09it/s]

loss: -1.348253
loss: -1.502922


 64%|█████████████████████████▏             | 2452/3802 [00:35<00:19, 70.01it/s]

loss: -1.340810
loss: -1.451247


 65%|█████████████████████████▎             | 2468/3802 [00:35<00:18, 70.47it/s]

loss: -1.487764
loss: -1.282921


 66%|█████████████████████████▌             | 2492/3802 [00:36<00:18, 69.85it/s]

loss: -1.433269
loss: -1.388826


 66%|█████████████████████████▋             | 2508/3802 [00:36<00:18, 70.17it/s]

loss: -1.345581
loss: -1.414495


 67%|█████████████████████████▉             | 2532/3802 [00:36<00:18, 70.30it/s]

loss: -1.466302
loss: -1.154019


 67%|██████████████████████████▏            | 2548/3802 [00:37<00:17, 70.37it/s]

loss: -1.303498
loss: -1.428768


 68%|██████████████████████████▎            | 2571/3802 [00:37<00:17, 70.03it/s]

loss: -1.401102
loss: -1.348457


 68%|██████████████████████████▌            | 2595/3802 [00:37<00:17, 70.28it/s]

loss: -1.402654
loss: -1.403637


 69%|██████████████████████████▊            | 2611/3802 [00:37<00:16, 70.40it/s]

loss: -1.260435
loss: -1.490254


 69%|███████████████████████████            | 2634/3802 [00:38<00:18, 62.27it/s]

loss: -1.286798
loss: -1.451900


 70%|███████████████████████████▏           | 2650/3802 [00:38<00:17, 66.34it/s]

loss: -1.209965
loss: -1.380406


 70%|███████████████████████████▍           | 2672/3802 [00:38<00:16, 68.51it/s]

loss: -1.455294
loss: -1.417257


 71%|███████████████████████████▋           | 2695/3802 [00:39<00:15, 69.83it/s]

loss: -1.271640
loss: -1.358944


 71%|███████████████████████████▊           | 2711/3802 [00:39<00:15, 70.26it/s]

loss: -1.242097
loss: -1.587803


 72%|███████████████████████████▉           | 2727/3802 [00:39<00:15, 70.51it/s]

loss: -1.330218
loss: -1.289332


 72%|████████████████████████████▏          | 2750/3802 [00:40<00:15, 68.87it/s]

loss: -1.288755
loss: -1.331039


 73%|████████████████████████████▍          | 2772/3802 [00:40<00:14, 69.49it/s]

loss: -1.392509
loss: -1.444806


 73%|████████████████████████████▌          | 2788/3802 [00:40<00:14, 70.30it/s]

loss: -1.430267
loss: -1.317489


 74%|████████████████████████████▊          | 2812/3802 [00:40<00:14, 70.23it/s]

loss: -1.425355
loss: -1.249767


 74%|█████████████████████████████          | 2828/3802 [00:41<00:13, 70.60it/s]

loss: -1.462294
loss: -1.574286


 75%|█████████████████████████████▎         | 2852/3802 [00:41<00:13, 70.44it/s]

loss: -1.433344
loss: -1.517176


 75%|█████████████████████████████▍         | 2868/3802 [00:41<00:13, 70.62it/s]

loss: -1.316634
loss: -1.397420


 76%|█████████████████████████████▋         | 2892/3802 [00:42<00:12, 70.16it/s]

loss: -1.314995
loss: -1.404154


 76%|█████████████████████████████▊         | 2908/3802 [00:42<00:12, 69.92it/s]

loss: -1.378261
loss: -1.578752


 77%|██████████████████████████████         | 2931/3802 [00:42<00:12, 70.28it/s]

loss: -1.453761
loss: -1.436568


 78%|██████████████████████████████▎        | 2955/3802 [00:42<00:12, 69.92it/s]

loss: -1.589622
loss: -1.292894


 78%|██████████████████████████████▍        | 2970/3802 [00:43<00:11, 69.87it/s]

loss: -1.335556
loss: -1.221991


 79%|██████████████████████████████▋        | 2992/3802 [00:43<00:11, 69.56it/s]

loss: -1.388036
loss: -1.425080


 79%|██████████████████████████████▉        | 3014/3802 [00:43<00:11, 69.76it/s]

loss: -1.368166
loss: -1.305811


 80%|███████████████████████████████        | 3029/3802 [00:44<00:11, 69.98it/s]

loss: -1.546268
loss: -1.358555


 80%|███████████████████████████████▎       | 3050/3802 [00:44<00:10, 68.95it/s]

loss: -1.095402
loss: -1.312176


 81%|███████████████████████████████▌       | 3072/3802 [00:44<00:10, 69.37it/s]

loss: -1.332928
loss: -1.362772


 81%|███████████████████████████████▋       | 3094/3802 [00:44<00:10, 69.75it/s]

loss: -1.381974
loss: -1.391397


 82%|███████████████████████████████▉       | 3110/3802 [00:45<00:09, 70.16it/s]

loss: -1.391315
loss: -1.431217


 82%|████████████████████████████████▏      | 3132/3802 [00:45<00:09, 69.63it/s]

loss: -1.496845
loss: -1.426082


 83%|████████████████████████████████▎      | 3148/3802 [00:45<00:09, 70.05it/s]

loss: -1.474412
loss: -1.290651


 83%|████████████████████████████████▌      | 3171/3802 [00:46<00:09, 69.58it/s]

loss: -1.541066
loss: -1.404800


 84%|████████████████████████████████▊      | 3193/3802 [00:46<00:08, 69.53it/s]

loss: -1.239700
loss: -1.190517


 84%|████████████████████████████████▉      | 3209/3802 [00:46<00:08, 70.25it/s]

loss: -1.393227
loss: -1.565876


 85%|█████████████████████████████████▏     | 3233/3802 [00:46<00:08, 70.07it/s]

loss: -1.406929
loss: -1.385128


 85%|█████████████████████████████████▎     | 3249/3802 [00:47<00:07, 70.56it/s]

loss: -1.429109
loss: -1.304243


 86%|█████████████████████████████████▌     | 3273/3802 [00:47<00:07, 69.32it/s]

loss: -1.470135
loss: -1.522667


 87%|█████████████████████████████████▊     | 3294/3802 [00:47<00:07, 67.17it/s]

loss: -1.307768
loss: -1.274128


 87%|█████████████████████████████████▉     | 3309/3802 [00:48<00:07, 67.65it/s]

loss: -1.416498
loss: -1.535036


 88%|██████████████████████████████████▏    | 3331/3802 [00:48<00:06, 68.39it/s]

loss: -1.486983
loss: -1.358352


 88%|██████████████████████████████████▍    | 3353/3802 [00:48<00:06, 69.45it/s]

loss: -1.267838
loss: -1.369902


 89%|██████████████████████████████████▌    | 3369/3802 [00:48<00:06, 70.35it/s]

loss: -1.397210
loss: -1.526139


 89%|██████████████████████████████████▊    | 3393/3802 [00:49<00:05, 70.07it/s]

loss: -1.520859
loss: -1.244836


 90%|██████████████████████████████████▉    | 3409/3802 [00:49<00:05, 70.33it/s]

loss: -1.449934
loss: -1.058563


 90%|███████████████████████████████████▏   | 3433/3802 [00:49<00:05, 69.55it/s]

loss: -1.423887
loss: -1.434176


 91%|███████████████████████████████████▎   | 3448/3802 [00:50<00:05, 69.82it/s]

loss: -1.559925
loss: -1.528120


 91%|███████████████████████████████████▌   | 3470/3802 [00:50<00:04, 70.23it/s]

loss: -1.445184
loss: -1.350678


 92%|███████████████████████████████████▊   | 3493/3802 [00:50<00:04, 69.18it/s]

loss: -1.506295
loss: -1.552165


 92%|███████████████████████████████████▉   | 3509/3802 [00:50<00:04, 70.12it/s]

loss: -1.289361
loss: -1.441376


 93%|████████████████████████████████████▏  | 3531/3802 [00:51<00:03, 69.70it/s]

loss: -1.357123
loss: -1.416317


 93%|████████████████████████████████████▍  | 3552/3802 [00:51<00:03, 69.48it/s]

loss: -1.394237
loss: -1.438125


 94%|████████████████████████████████████▋  | 3573/3802 [00:51<00:03, 69.43it/s]

loss: -1.305345
loss: -1.365613


 94%|████████████████████████████████████▊  | 3589/3802 [00:52<00:03, 70.16it/s]

loss: -1.428353
loss: -1.256166


 95%|█████████████████████████████████████  | 3612/3802 [00:52<00:02, 69.71it/s]

loss: -1.576139
loss: -1.460261


 95%|█████████████████████████████████████▏ | 3628/3802 [00:52<00:02, 70.12it/s]

loss: -1.385097
loss: -1.423247


 96%|█████████████████████████████████████▍ | 3651/3802 [00:52<00:02, 70.15it/s]

loss: -1.453527
loss: -1.388152


 96%|█████████████████████████████████████▌ | 3667/3802 [00:53<00:01, 70.17it/s]

loss: -1.362866
loss: -1.397547


 97%|█████████████████████████████████████▊ | 3689/3802 [00:53<00:01, 69.74it/s]

loss: -1.575396
loss: -1.380143


 98%|██████████████████████████████████████ | 3711/3802 [00:53<00:01, 69.64it/s]

loss: -1.569782
loss: -1.495477


 98%|██████████████████████████████████████▎| 3733/3802 [00:54<00:00, 69.68it/s]

loss: -1.456215
loss: -1.332924


 99%|██████████████████████████████████████▍| 3748/3802 [00:54<00:00, 68.59it/s]

loss: -1.322665
loss: -1.228875


 99%|██████████████████████████████████████▋| 3770/3802 [00:54<00:00, 69.74it/s]

loss: -1.464279
loss: -1.518256


100%|██████████████████████████████████████▉| 3792/3802 [00:54<00:00, 69.60it/s]

loss: -1.456773
loss: -1.559768


100%|███████████████████████████████████████| 3802/3802 [00:55<00:00, 68.95it/s]

loss: -1.409008





In [8]:
data_batch = next(iter(dataloader))

In [9]:
inpt = data_batch["features"]
tgt = data_batch["targets"]

In [10]:
d = model(inpt)

In [11]:
d.mu

tensor([0.7064, 2.2544, 2.2564, 2.2541, 2.2627, 2.2435, 2.2705, 0.5148, 2.2448,
        0.1772, 2.2395, 2.2401, 2.2464, 2.2442, 2.2435, 0.7989, 2.2387, 2.2446,
        2.2517, 2.2487, 0.9486, 0.7595, 2.2322, 2.2416, 2.2510, 2.2430, 0.6918,
        2.2359, 2.2512, 2.2553, 2.2435, 2.2503, 2.2428, 0.9796, 2.2392, 2.2483,
        0.5863, 2.2623, 2.2413, 0.6864, 2.2427, 0.8264, 0.7166, 0.4156, 2.2610,
        0.7752, 2.2420, 2.2393, 2.2407, 2.2468, 0.4041, 2.2654, 2.2444, 2.2573,
        2.2388, 2.2430, 2.2565, 2.2513, 0.5737, 2.2416, 2.2476, 0.8724, 2.2397,
        2.2695, 2.2613, 2.2442, 2.2552, 2.2473, 2.2567, 2.2641, 2.2338, 2.2373,
        2.2402, 2.2471, 2.2434, 0.7695, 2.2495, 2.2525, 0.8694, 2.2558, 2.2365,
        2.2533, 0.2950, 2.2398, 2.2404, 2.2522, 2.2449, 2.2491, 2.2422, 2.2593,
        2.2404, 2.2415, 2.2525, 0.7305, 2.2520, 0.6263, 2.2403, 0.6599, 2.2572,
        2.2516, 0.3333, 2.2397, 0.5167, 0.5013, 1.0876, 2.2510, 2.2388, 0.3587,
        2.2505, 0.7153, 2.2483, 2.2584, 

In [12]:
d.log_prob(tgt)

tensor([[-1.4748, -2.4889, -2.4778,  ..., -1.4939, -1.4503, -2.4499],
        [-1.4488, -2.5473, -2.5356,  ..., -1.4689, -1.4237, -2.5082],
        [-1.4333, -2.5792, -2.5671,  ..., -1.4540, -1.4079, -2.5400],
        ...,
        [-1.4611, -2.5204, -2.5090,  ..., -1.4808, -1.4363, -2.4813],
        [-1.4396, -2.5663, -2.5543,  ..., -1.4602, -1.4144, -2.5271],
        [-1.4492, -2.5463, -2.5346,  ..., -1.4694, -1.4242, -2.5072]],
       grad_fn=<IndexPutBackward0>)

In [3]:
# Training
epochs = 1
trainer.full_train(epochs)

########## Epoch 0


  0%|                                          | 3/3802 [00:00<02:10, 29.02it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  0%|▏                                        | 13/3802 [00:00<00:54, 69.39it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  1%|▎                                        | 24/3802 [00:00<00:43, 86.30it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


  1%|▍                                        | 35/3802 [00:00<00:39, 95.10it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  1%|▍                                       | 46/3802 [00:00<00:37, 100.12it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  2%|▌                                       | 58/3802 [00:00<00:35, 104.33it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  2%|▋                                       | 69/3802 [00:00<00:35, 104.11it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  2%|▊                                       | 80/3802 [00:00<00:35, 105.74it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


  2%|▉                                       | 92/3802 [00:00<00:34, 107.88it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  3%|█                                      | 103/3802 [00:01<00:34, 108.02it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  3%|█▏                                     | 114/3802 [00:01<00:34, 108.22it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


  3%|█▎                                     | 126/3802 [00:01<00:33, 109.31it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  4%|█▍                                     | 137/3802 [00:01<00:33, 109.34it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  4%|█▌                                     | 149/3802 [00:01<00:32, 110.91it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


  4%|█▋                                     | 161/3802 [00:01<00:32, 111.11it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  5%|█▊                                     | 173/3802 [00:01<00:32, 110.81it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  5%|█▉                                     | 185/3802 [00:01<00:32, 110.22it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


  5%|██                                     | 197/3802 [00:01<00:32, 111.11it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  5%|██▏                                    | 209/3802 [00:02<00:32, 110.45it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  6%|██▍                                    | 233/3802 [00:02<00:32, 109.86it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  6%|██▍                                    | 233/3802 [00:02<00:32, 109.86it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


  7%|██▋                                    | 256/3802 [00:02<00:32, 109.22it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  7%|██▋                                    | 256/3802 [00:02<00:32, 109.22it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  7%|██▊                                    | 279/3802 [00:02<00:32, 109.61it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  7%|██▊                                    | 279/3802 [00:02<00:32, 109.61it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  8%|███                                    | 301/3802 [00:02<00:32, 106.74it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  9%|███▎                                   | 325/3802 [00:03<00:31, 109.89it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  9%|███▎                                   | 325/3802 [00:03<00:31, 109.89it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


  9%|███▌                                   | 349/3802 [00:03<00:30, 114.50it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  9%|███▌                                   | 349/3802 [00:03<00:30, 114.50it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


  9%|███▋                                   | 361/3802 [00:03<00:30, 111.80it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 10%|███▊                                   | 373/3802 [00:03<00:31, 109.99it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 10%|███▉                                   | 386/3802 [00:03<00:30, 113.23it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 10%|████                                   | 399/3802 [00:03<00:29, 115.70it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 11%|████▏                                  | 412/3802 [00:03<00:28, 116.98it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 11%|████▎                                  | 424/3802 [00:03<00:28, 117.61it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 11%|████▍                                  | 436/3802 [00:04<00:29, 116.06it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 12%|████▌                                  | 448/3802 [00:04<00:29, 114.82it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 12%|████▋                                  | 460/3802 [00:04<00:28, 115.79it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 12%|████▊                                  | 472/3802 [00:04<00:29, 112.34it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 13%|████▉                                  | 484/3802 [00:04<00:30, 110.55it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 13%|████▉                                  | 484/3802 [00:04<00:30, 110.55it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 13%|█████▏                                 | 509/3802 [00:04<00:29, 113.23it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 14%|█████▍                                 | 534/3802 [00:04<00:28, 116.60it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 14%|█████▍                                 | 534/3802 [00:04<00:28, 116.60it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 15%|█████▋                                 | 560/3802 [00:05<00:27, 119.44it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 15%|██████                                 | 585/3802 [00:05<00:27, 117.96it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 15%|██████                                 | 585/3802 [00:05<00:27, 117.96it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 16%|██████▎                                | 610/3802 [00:05<00:26, 118.99it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 17%|██████▌                                | 634/3802 [00:05<00:27, 115.57it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 17%|██████▋                                | 646/3802 [00:05<00:27, 115.03it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 17%|██████▋                                | 658/3802 [00:05<00:27, 114.27it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 18%|██████▊                                | 670/3802 [00:06<00:27, 114.81it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 18%|██████▉                                | 682/3802 [00:06<00:27, 115.01it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 18%|███████                                | 694/3802 [00:06<00:27, 113.77it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 19%|███████▏                               | 706/3802 [00:06<00:26, 115.16it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 19%|███████▍                               | 719/3802 [00:06<00:26, 116.79it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 19%|███████▍                               | 731/3802 [00:06<00:26, 117.50it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 20%|███████▌                               | 743/3802 [00:06<00:25, 117.70it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 20%|███████▋                               | 755/3802 [00:06<00:25, 117.90it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 20%|███████▊                               | 767/3802 [00:06<00:26, 116.36it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 20%|███████▉                               | 779/3802 [00:06<00:26, 116.11it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 21%|████████                               | 791/3802 [00:07<00:25, 116.23it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 21%|████████▏                              | 803/3802 [00:07<00:25, 117.19it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 21%|████████▎                              | 815/3802 [00:07<00:25, 117.37it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 22%|████████▍                              | 827/3802 [00:07<00:25, 114.89it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 22%|████████▌                              | 839/3802 [00:07<00:26, 112.04it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 22%|████████▋                              | 851/3802 [00:07<00:26, 111.44it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 23%|████████▊                              | 863/3802 [00:07<00:25, 113.42it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 23%|████████▉                              | 876/3802 [00:07<00:25, 115.71it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 23%|█████████                              | 888/3802 [00:07<00:25, 116.33it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 24%|█████████▏                             | 900/3802 [00:08<00:25, 115.66it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 24%|█████████▎                             | 912/3802 [00:08<00:25, 113.20it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 24%|█████████▎                             | 912/3802 [00:08<00:25, 113.20it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 25%|█████████▌                             | 936/3802 [00:08<00:25, 113.55it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 25%|█████████▊                             | 960/3802 [00:08<00:24, 115.67it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 26%|██████████▎                             | 985/3802 [00:08<00:28, 98.55it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 27%|██████████                            | 1009/3802 [00:09<00:26, 104.31it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 27%|██████████▎                           | 1031/3802 [00:09<00:26, 105.26it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 27%|██████████▍                           | 1043/3802 [00:09<00:25, 106.96it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 28%|██████████▋                           | 1066/3802 [00:09<00:25, 108.17it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 29%|██████████▉                           | 1089/3802 [00:09<00:25, 108.37it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 29%|███████████                           | 1113/3802 [00:10<00:24, 109.97it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 30%|███████████▋                           | 1136/3802 [00:10<00:26, 99.54it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 30%|███████████▌                          | 1159/3802 [00:10<00:24, 106.71it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 31%|███████████▊                          | 1183/3802 [00:10<00:23, 112.00it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 32%|████████████                          | 1207/3802 [00:10<00:22, 112.91it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 32%|████████████▎                         | 1231/3802 [00:11<00:22, 115.56it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 33%|████████████▍                         | 1243/3802 [00:11<00:23, 109.57it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 33%|████████████▋                         | 1267/3802 [00:11<00:24, 103.75it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 34%|████████████▉                         | 1291/3802 [00:11<00:22, 110.75it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 35%|█████████████▏                        | 1316/3802 [00:11<00:21, 114.97it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 35%|█████████████▍                        | 1340/3802 [00:12<00:22, 110.45it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 36%|█████████████▋                        | 1364/3802 [00:12<00:21, 112.09it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 37%|█████████████▊                        | 1388/3802 [00:12<00:20, 115.01it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 37%|██████████████                        | 1412/3802 [00:12<00:20, 114.74it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 38%|██████████████▎                       | 1437/3802 [00:12<00:20, 116.84it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 38%|██████████████▌                       | 1462/3802 [00:13<00:21, 110.78it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 39%|██████████████▋                       | 1474/3802 [00:13<00:22, 105.36it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 39%|██████████████▉                       | 1497/3802 [00:13<00:21, 107.30it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 40%|███████████████▏                      | 1519/3802 [00:13<00:21, 108.24it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 41%|███████████████▍                      | 1543/3802 [00:14<00:20, 109.19it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 41%|███████████████▋                      | 1568/3802 [00:14<00:19, 114.91it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 42%|███████████████▉                      | 1594/3802 [00:14<00:18, 117.78it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 43%|████████████████▏                     | 1618/3802 [00:14<00:18, 116.07it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 43%|████████████████▍                     | 1642/3802 [00:14<00:18, 117.23it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 44%|████████████████▋                     | 1667/3802 [00:15<00:18, 118.46it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 44%|████████████████▉                     | 1691/3802 [00:15<00:18, 116.81it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 45%|█████████████████▏                    | 1715/3802 [00:15<00:17, 115.98it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 46%|█████████████████▍                    | 1739/3802 [00:15<00:17, 116.26it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 46%|█████████████████▌                    | 1763/3802 [00:15<00:17, 117.25it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 47%|█████████████████▊                    | 1788/3802 [00:16<00:16, 118.64it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 48%|██████████████████                    | 1813/3802 [00:16<00:16, 117.25it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 48%|██████████████████▎                   | 1838/3802 [00:16<00:16, 118.60it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 49%|██████████████████▌                   | 1862/3802 [00:16<00:16, 115.04it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 50%|██████████████████▊                   | 1886/3802 [00:16<00:16, 116.31it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 50%|███████████████████                   | 1910/3802 [00:17<00:16, 115.71it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 51%|███████████████████▏                  | 1922/3802 [00:17<00:16, 112.72it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 51%|███████████████████▍                  | 1946/3802 [00:17<00:16, 114.12it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 52%|███████████████████▊                  | 1982/3802 [00:17<00:15, 117.19it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 52%|███████████████████▉                  | 1994/3802 [00:17<00:15, 117.82it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 53%|████████████████████▏                 | 2018/3802 [00:18<00:15, 117.22it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 54%|████████████████████▍                 | 2042/3802 [00:18<00:14, 117.48it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 54%|████████████████████▋                 | 2066/3802 [00:18<00:14, 116.97it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 55%|████████████████████▉                 | 2091/3802 [00:18<00:14, 117.73it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 56%|█████████████████████▏                | 2116/3802 [00:18<00:14, 118.18it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 56%|█████████████████████▍                | 2140/3802 [00:19<00:14, 117.98it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 57%|█████████████████████▋                | 2164/3802 [00:19<00:14, 117.00it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 58%|█████████████████████▊                | 2188/3802 [00:19<00:13, 117.35it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 58%|██████████████████████                | 2212/3802 [00:19<00:13, 115.53it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 59%|██████████████████████▎               | 2236/3802 [00:19<00:13, 116.45it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 59%|██████████████████████▌               | 2260/3802 [00:20<00:13, 116.22it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 60%|██████████████████████▊               | 2284/3802 [00:20<00:13, 114.70it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 61%|███████████████████████               | 2308/3802 [00:20<00:12, 115.25it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 61%|███████████████████████▎              | 2332/3802 [00:20<00:12, 113.14it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 62%|███████████████████████▌              | 2356/3802 [00:20<00:12, 114.73it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 63%|███████████████████████▊              | 2380/3802 [00:21<00:12, 116.14it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 63%|████████████████████████              | 2404/3802 [00:21<00:12, 115.44it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 64%|████████████████████████▎             | 2428/3802 [00:21<00:11, 115.30it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 64%|████████████████████████▌             | 2452/3802 [00:21<00:11, 116.20it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 65%|████████████████████████▊             | 2477/3802 [00:21<00:11, 118.65it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 66%|████████████████████████▉             | 2501/3802 [00:22<00:10, 118.34it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 66%|█████████████████████████▏            | 2525/3802 [00:22<00:10, 118.51it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 67%|█████████████████████████▍            | 2549/3802 [00:22<00:10, 118.39it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 68%|█████████████████████████▋            | 2573/3802 [00:22<00:10, 115.89it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 68%|█████████████████████████▉            | 2597/3802 [00:23<00:10, 116.72it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 69%|██████████████████████████▏           | 2622/3802 [00:23<00:10, 117.36it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 70%|██████████████████████████▍           | 2646/3802 [00:23<00:09, 117.47it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 70%|██████████████████████████▋           | 2671/3802 [00:23<00:09, 118.74it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 71%|██████████████████████████▉           | 2696/3802 [00:23<00:09, 119.24it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 72%|███████████████████████████▏          | 2720/3802 [00:24<00:09, 117.50it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 72%|███████████████████████████▍          | 2745/3802 [00:24<00:08, 118.26it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 73%|███████████████████████████▋          | 2770/3802 [00:24<00:08, 119.23it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 73%|███████████████████████████▉          | 2794/3802 [00:24<00:08, 113.18it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 74%|████████████████████████████▏         | 2818/3802 [00:24<00:08, 112.56it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 75%|████████████████████████████▍         | 2843/3802 [00:25<00:08, 115.65it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 75%|████████████████████████████▌         | 2855/3802 [00:25<00:08, 113.34it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


                                                                                

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 76%|█████████████████████████████         | 2904/3802 [00:25<00:07, 116.65it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 77%|█████████████████████████████▎        | 2928/3802 [00:25<00:07, 112.57it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 78%|█████████████████████████████▌        | 2952/3802 [00:26<00:07, 113.56it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 78%|█████████████████████████████▋        | 2976/3802 [00:26<00:07, 115.66it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 79%|█████████████████████████████▉        | 3000/3802 [00:26<00:06, 114.91it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 80%|██████████████████████████████▏       | 3024/3802 [00:26<00:06, 112.42it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 80%|██████████████████████████████▍       | 3049/3802 [00:26<00:06, 116.05it/s]

loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 81%|██████████████████████████████▋       | 3073/3802 [00:27<00:06, 115.38it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 81%|██████████████████████████████▉       | 3097/3802 [00:27<00:06, 112.66it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan


 82%|███████████████████████████████▏      | 3121/3802 [00:27<00:05, 115.30it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 83%|███████████████████████████████▍      | 3145/3802 [00:27<00:05, 117.11it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 83%|███████████████████████████████▋      | 3169/3802 [00:27<00:05, 116.39it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 84%|███████████████████████████████▊      | 3181/3802 [00:28<00:05, 110.56it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 84%|████████████████████████████████▉      | 3205/3802 [00:28<00:06, 95.47it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 85%|████████████████████████████████▎     | 3229/3802 [00:28<00:05, 103.67it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 86%|████████████████████████████████▌     | 3252/3802 [00:28<00:05, 107.91it/s]

torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])


 86%|████████████████████████████████▌     | 3264/3802 [00:28<00:04, 108.96it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 86%|████████████████████████████████▊     | 3288/3802 [00:29<00:04, 110.94it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 87%|█████████████████████████████████     | 3312/3802 [00:29<00:04, 110.40it/s]

torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])


 88%|█████████████████████████████████▎    | 3336/3802 [00:29<00:04, 110.81it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


 88%|█████████████████████████████████▌    | 3359/3802 [00:29<00:03, 112.98it/s]


torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
loss:     nan
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])



KeyboardInterrupt

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x10579b650>>
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniforge/base/envs/EgammaDNN/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 

KeyboardInterrupt



In [None]:
def create_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

outdir = "saved_models/v1"
create_dir(outdir)
fname = "v1"
trainer.save(os.path.join(outdir,fname))