In [1]:
# Code credits: Adapted bits and pieces from https://github.com/webdataset/webdataset/blob/master/docs/gettingstarted.ipynb

import sys
sys.path.append('..')

import gc
import json
import os
from itertools import islice
from datetime import datetime
import pytz
from pytz import timezone
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
import skimage.transform as st
import tqdm

import torch
import torch.optim as optim
from torchvision import transforms
import webdataset as wds

from model.selfattn_3d_cnn import *
from model.baseline_3d_cnn import *
from model.resattn_3d_cnn import *
from utils.model_utils import *
from utils.model_run import *

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards_new')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

batch_size = parameters['batch_size']
shard_size = parameters['shard_size']
parameters

{'batch_size': 4, 'shard_size': 16}

In [3]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]

# Try to overfit on smaller data
# urls = urls[:round(len(urls)*0.3)]

# Another shard directory, continued; realize can't use because keys will collide cuz we refreshed...
# shards_dir2 = os.path.join(data_dir, 'shards_new_cont')
# urls += [os.path.join(shards_dir2, it) for it in os.listdir(shards_dir2) if it.endswith('.tar')]


# All the data
total_num_shards = round(len(urls)*0.75)
train_urls = urls[:round(total_num_shards*0.7)]
val_urls = urls[round(total_num_shards*0.7):round(total_num_shards*0.85)]
test_urls = urls[round(total_num_shards*0.85):]

# Smaller data just to run model once
# train_urls = urls[:2]
# val_urls = urls[2:3]
# test_urls = urls[3:]


print("Number of train shards:", len(train_urls))
print("Number of validation shards:", len(val_urls))
print("Number of test shards:", len(test_urls))

Number of train shards: 50
Number of validation shards: 11
Number of test shards: 35


In [4]:
# Create dataset objects
train_iternum = len(train_urls)*shard_size//batch_size
val_iternum = len(val_urls)*shard_size//batch_size
test_iternum = len(test_urls)*shard_size//batch_size

print("Number of iterations per train epoch:", train_iternum)

train_dataset = (
    wds
    .WebDataset(train_urls, length=train_iternum)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
#     .map_tuple(pre_transforms, identity, identity)
)
loader_train = torch.utils.data.DataLoader(train_dataset, num_workers=0, batch_size=None) #setting batch_size = None disables batching

val_dataset = (
    wds
    .WebDataset(val_urls, length=val_iternum)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
)
loader_val = torch.utils.data.DataLoader(val_dataset, num_workers=0, batch_size=None)

test_dataset = (
    wds
    .WebDataset(test_urls, length=test_iternum)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
)
loader_test = torch.utils.data.DataLoader(test_dataset, num_workers=0, batch_size=None)

# for image, target in islice(dataset, 0, 2):
#     print(image.shape)

Number of iterations per train epoch: 200


In [5]:
gc.collect()

66

In [6]:
USE_GPU = True
dtype = torch.float

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
#     dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')

print(device)
print(dtype)

cuda
torch.float32


In [410]:
# Make log directory and checkpoint directory (DIFFERENT DIRECTORY FROM BASELINE)
dir_nm = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
# dir_nm = "first_mini_c2fc2"
log_dir = os.path.join('../runs/baseline', dir_nm) # running from this notebook since the other one gives cuda memory errors
# log_dir = os.path.join('../runs/experiment', dir_nm)
os.mkdir(log_dir)
os.mkdir(os.path.join(log_dir, 'Checkpoints'))


# Model, optimizer, criterion
model = baseline_3DCNN(in_num_ch=1)
# model = selfattn_3DCNN(in_num_ch=1)
optimizer = optim.Adam(model.parameters(), lr = 1e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [7]:
gc.collect()

44

In [412]:
# Baseline model
train_loss_dict, val_loss_dict = train(model, optimizer, criterion, loader_train, loader_val, log_dir, device=device, epochs=10, val_every=5)

Epoch 1:   2%|▎         | 5/200 [02:46<2:56:10, 54.21s/batch, loss=0.704]

Total iteration 5, validation loss = 0.7028



Epoch 1:   5%|▌         | 10/200 [05:19<2:47:13, 52.81s/batch, loss=0.708]

Total iteration 10, validation loss = 0.7026



Epoch 1:   8%|▊         | 15/200 [07:52<2:41:49, 52.49s/batch, loss=0.628]

Total iteration 15, validation loss = 0.7018



Epoch 1:  10%|█         | 20/200 [10:24<2:36:39, 52.22s/batch, loss=0.66] 

Total iteration 20, validation loss = 0.6990



Epoch 1:  12%|█▎        | 25/200 [12:55<2:31:09, 51.82s/batch, loss=0.658]

Total iteration 25, validation loss = 0.6925



Epoch 1:  15%|█▌        | 30/200 [15:26<2:26:21, 51.66s/batch, loss=0.624]

Total iteration 30, validation loss = 0.6815



Epoch 1:  18%|█▊        | 35/200 [17:50<2:16:26, 49.62s/batch, loss=0.611]

Total iteration 35, validation loss = 0.6680



Epoch 1:  20%|██        | 40/200 [20:15<2:12:07, 49.54s/batch, loss=0.588]

Total iteration 40, validation loss = 0.6541



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 1:  22%|██▎       | 45/200 [22:44<2:11:47, 51.01s/batch, loss=0.622]

Total iteration 45, validation loss = 0.6430



Epoch 1:  25%|██▌       | 50/200 [25:12<2:06:51, 50.75s/batch, loss=0.578]

Total iteration 50, validation loss = 0.6333



Epoch 1:  28%|██▊       | 55/200 [27:41<2:02:40, 50.76s/batch, loss=0.714]

Total iteration 55, validation loss = 0.6286



Epoch 1:  30%|███       | 60/200 [30:08<1:58:15, 50.68s/batch, loss=0.602]

Total iteration 60, validation loss = 0.6248



Epoch 1:  32%|███▎      | 65/200 [32:37<1:54:41, 50.97s/batch, loss=0.684]

Total iteration 65, validation loss = 0.6216



Epoch 1:  35%|███▌      | 70/200 [35:06<1:50:05, 50.81s/batch, loss=0.657]

Total iteration 70, validation loss = 0.6189



Epoch 1:  38%|███▊      | 75/200 [37:32<1:44:36, 50.21s/batch, loss=0.596]

Total iteration 75, validation loss = 0.6123



Epoch 1:  40%|████      | 80/200 [39:58<1:39:56, 49.97s/batch, loss=0.658]

Total iteration 80, validation loss = 0.6099



Epoch 1:  42%|████▎     | 85/200 [42:18<1:32:35, 48.31s/batch, loss=0.775]

Total iteration 85, validation loss = 0.6074



Epoch 1:  45%|████▌     | 90/200 [44:40<1:28:28, 48.26s/batch, loss=0.684]

Total iteration 90, validation loss = 0.6073



Epoch 1:  48%|████▊     | 95/200 [47:01<1:24:21, 48.20s/batch, loss=0.675]

Total iteration 95, validation loss = 0.6061



Epoch 1:  50%|█████     | 100/200 [49:22<1:20:21, 48.21s/batch, loss=0.535]

Total iteration 100, validation loss = 0.6014



Epoch 1:  52%|█████▎    | 105/200 [51:43<1:16:20, 48.22s/batch, loss=0.768]

Total iteration 105, validation loss = 0.6005



Epoch 1:  55%|█████▌    | 110/200 [54:09<1:14:39, 49.77s/batch, loss=0.635]

Total iteration 110, validation loss = 0.5972



Epoch 1:  57%|█████▊    | 115/200 [56:29<1:07:56, 47.95s/batch, loss=0.676]

Total iteration 115, validation loss = 0.5941



Epoch 1:  60%|██████    | 120/200 [58:47<1:03:21, 47.52s/batch, loss=0.668]

Total iteration 120, validation loss = 0.5932



Epoch 1:  62%|██████▎   | 125/200 [1:01:07<59:31, 47.62s/batch, loss=0.716]

Total iteration 125, validation loss = 0.5939



Epoch 1:  65%|██████▌   | 130/200 [1:03:37<59:05, 50.64s/batch, loss=0.71] 

Total iteration 130, validation loss = 0.5937



Epoch 1:  68%|██████▊   | 135/200 [1:06:00<53:30, 49.40s/batch, loss=0.609]

Total iteration 135, validation loss = 0.5932



Epoch 1:  70%|███████   | 140/200 [1:08:27<50:03, 50.06s/batch, loss=0.661]

Total iteration 140, validation loss = 0.5936



Epoch 1:  72%|███████▎  | 145/200 [1:10:53<45:56, 50.12s/batch, loss=0.564]

Total iteration 145, validation loss = 0.5919



Epoch 1:  75%|███████▌  | 150/200 [1:13:22<42:13, 50.66s/batch, loss=0.662]

Total iteration 150, validation loss = 0.5915



Epoch 1:  78%|███████▊  | 155/200 [1:15:45<37:01, 49.36s/batch, loss=0.657]

Total iteration 155, validation loss = 0.5982



Epoch 1:  80%|████████  | 160/200 [1:18:13<33:36, 50.40s/batch, loss=0.576]

Total iteration 160, validation loss = 0.5972



Epoch 1:  82%|████████▎ | 165/200 [1:20:35<28:27, 48.80s/batch, loss=0.654]

Total iteration 165, validation loss = 0.5994



Epoch 1:  85%|████████▌ | 170/200 [1:22:57<24:18, 48.60s/batch, loss=0.529]

Total iteration 170, validation loss = 0.6014



Epoch 1:  88%|████████▊ | 175/200 [1:25:21<20:32, 49.29s/batch, loss=0.714]

Total iteration 175, validation loss = 0.5985



Epoch 1:  90%|█████████ | 180/200 [1:27:49<16:47, 50.35s/batch, loss=0.581]

Total iteration 180, validation loss = 0.5981



Epoch 1:  92%|█████████▎| 185/200 [1:30:15<12:30, 50.04s/batch, loss=0.547]

Total iteration 185, validation loss = 0.5991



Epoch 1:  95%|█████████▌| 190/200 [1:32:42<08:22, 50.22s/batch, loss=0.693]

Total iteration 190, validation loss = 0.5969



Epoch 1:  98%|█████████▊| 195/200 [1:35:11<04:14, 50.90s/batch, loss=0.649]

Total iteration 195, validation loss = 0.5989



Epoch 1: 100%|██████████| 200/200 [1:37:21<00:00, 29.21s/batch, loss=0.508]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 200, validation loss = 0.5942



Epoch 2:   2%|▎         | 5/200 [02:31<2:38:39, 48.82s/batch, loss=0.568]

Total iteration 206, validation loss = 0.5893



Epoch 2:   5%|▌         | 10/200 [04:50<2:30:30, 47.53s/batch, loss=0.635]

Total iteration 211, validation loss = 0.5868



Epoch 2:   8%|▊         | 15/200 [07:11<2:28:13, 48.07s/batch, loss=0.478]

Total iteration 216, validation loss = 0.5870



Epoch 2:  10%|█         | 20/200 [09:40<2:31:45, 50.58s/batch, loss=0.669]

Total iteration 221, validation loss = 0.5846



Epoch 2:  12%|█▎        | 25/200 [12:00<2:21:06, 48.38s/batch, loss=0.608]

Total iteration 226, validation loss = 0.5838



Epoch 2:  15%|█▌        | 30/200 [14:26<2:20:17, 49.51s/batch, loss=0.652]

Total iteration 231, validation loss = 0.5855



Epoch 2:  18%|█▊        | 35/200 [16:52<2:17:02, 49.83s/batch, loss=0.678]

Total iteration 236, validation loss = 0.5909



Epoch 2:  20%|██        | 40/200 [19:18<2:13:40, 50.13s/batch, loss=0.613]

Total iteration 241, validation loss = 0.5928



Epoch 2:  22%|██▎       | 45/200 [21:46<2:10:44, 50.61s/batch, loss=0.669]

Total iteration 246, validation loss = 0.5941



Epoch 2:  25%|██▌       | 50/200 [24:15<2:06:51, 50.74s/batch, loss=0.576]

Total iteration 251, validation loss = 0.5954



Epoch 2:  28%|██▊       | 55/200 [26:43<2:02:56, 50.87s/batch, loss=0.466]

Total iteration 256, validation loss = 0.5944



Epoch 2:  30%|███       | 60/200 [29:11<1:58:27, 50.77s/batch, loss=0.596]

Total iteration 261, validation loss = 0.5898



Epoch 2:  32%|███▎      | 65/200 [31:40<1:54:22, 50.83s/batch, loss=0.635]

Total iteration 266, validation loss = 0.5888



Epoch 2:  35%|███▌      | 70/200 [34:09<1:50:38, 51.07s/batch, loss=0.594]

Total iteration 271, validation loss = 0.5917



Epoch 2:  38%|███▊      | 75/200 [36:38<1:46:36, 51.17s/batch, loss=0.605]

Total iteration 276, validation loss = 0.5915



Epoch 2:  40%|████      | 80/200 [39:07<1:41:47, 50.90s/batch, loss=0.538]

Total iteration 281, validation loss = 0.5900



Epoch 2:  42%|████▎     | 85/200 [41:32<1:35:49, 49.99s/batch, loss=0.599]

Total iteration 286, validation loss = 0.5897



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 2:  45%|████▌     | 90/200 [43:56<1:30:05, 49.14s/batch, loss=0.492]

Total iteration 291, validation loss = 0.5910



Epoch 2:  48%|████▊     | 95/200 [46:25<1:28:47, 50.74s/batch, loss=0.667]

Total iteration 296, validation loss = 0.5882



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 2:  50%|█████     | 100/200 [48:46<1:21:18, 48.78s/batch, loss=0.59]

Total iteration 301, validation loss = 0.5858



Epoch 2:  52%|█████▎    | 105/200 [51:08<1:17:02, 48.66s/batch, loss=0.58] 

Total iteration 306, validation loss = 0.5811



Epoch 2:  55%|█████▌    | 110/200 [53:30<1:12:29, 48.33s/batch, loss=0.639]

Total iteration 311, validation loss = 0.5811



Epoch 2:  57%|█████▊    | 115/200 [55:49<1:07:26, 47.61s/batch, loss=0.65] 

Total iteration 316, validation loss = 0.5837



Epoch 2:  60%|██████    | 120/200 [58:14<1:05:49, 49.37s/batch, loss=0.566]

Total iteration 321, validation loss = 0.5835



Epoch 2:  62%|██████▎   | 125/200 [1:00:41<1:02:45, 50.21s/batch, loss=0.703]

Total iteration 326, validation loss = 0.5893



Epoch 2:  65%|██████▌   | 130/200 [1:03:04<57:07, 48.96s/batch, loss=0.628]  

Total iteration 331, validation loss = 0.5885



Epoch 2:  68%|██████▊   | 135/200 [1:05:22<51:31, 47.56s/batch, loss=0.62] 

Total iteration 336, validation loss = 0.5866



Epoch 2:  70%|███████   | 140/200 [1:07:41<47:25, 47.43s/batch, loss=0.526]

Total iteration 341, validation loss = 0.5829



Epoch 2:  72%|███████▎  | 145/200 [1:10:00<43:25, 47.38s/batch, loss=0.661]

Total iteration 346, validation loss = 0.5794



Epoch 2:  75%|███████▌  | 150/200 [1:12:20<39:41, 47.64s/batch, loss=0.516]

Total iteration 351, validation loss = 0.5788



Epoch 2:  78%|███████▊  | 155/200 [1:14:39<35:41, 47.60s/batch, loss=0.64] 

Total iteration 356, validation loss = 0.5812



Epoch 2:  80%|████████  | 160/200 [1:17:01<32:16, 48.40s/batch, loss=0.68] 

Total iteration 361, validation loss = 0.5842



Epoch 2:  82%|████████▎ | 165/200 [1:19:26<28:51, 49.47s/batch, loss=0.685]

Total iteration 366, validation loss = 0.5869



Epoch 2:  85%|████████▌ | 170/200 [1:21:53<24:57, 49.93s/batch, loss=0.572]

Total iteration 371, validation loss = 0.5906



Epoch 2:  88%|████████▊ | 175/200 [1:24:19<20:53, 50.14s/batch, loss=0.558]

Total iteration 376, validation loss = 0.5922



Epoch 2:  90%|█████████ | 180/200 [1:26:39<16:04, 48.24s/batch, loss=0.583]

Total iteration 381, validation loss = 0.5905



Epoch 2:  92%|█████████▎| 185/200 [1:28:59<11:57, 47.85s/batch, loss=0.592]

Total iteration 386, validation loss = 0.5844



Epoch 2:  95%|█████████▌| 190/200 [1:31:22<08:07, 48.75s/batch, loss=0.558]

Total iteration 391, validation loss = 0.5820



Epoch 2:  98%|█████████▊| 195/200 [1:33:44<04:01, 48.30s/batch, loss=0.521]

Total iteration 396, validation loss = 0.5829



Epoch 2: 100%|██████████| 200/200 [1:35:54<00:00, 28.77s/batch, loss=0.679]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 401, validation loss = 0.5806



Epoch 3:   2%|▎         | 5/200 [02:33<2:40:04, 49.26s/batch, loss=0.583]

Total iteration 407, validation loss = 0.5797



Epoch 3:   5%|▌         | 10/200 [04:53<2:32:28, 48.15s/batch, loss=0.64]

Total iteration 412, validation loss = 0.5802



Epoch 3:   8%|▊         | 15/200 [07:12<2:26:48, 47.62s/batch, loss=0.585]

Total iteration 417, validation loss = 0.5826



Epoch 3:  10%|█         | 20/200 [09:32<2:22:44, 47.58s/batch, loss=0.662]

Total iteration 422, validation loss = 0.5880



Epoch 3:  12%|█▎        | 25/200 [11:51<2:18:22, 47.44s/batch, loss=0.586]

Total iteration 427, validation loss = 0.5918



Epoch 3:  15%|█▌        | 30/200 [14:12<2:16:26, 48.15s/batch, loss=0.561]

Total iteration 432, validation loss = 0.5888



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 3:  18%|█▊        | 35/200 [16:36<2:15:06, 49.13s/batch, loss=0.517]

Total iteration 437, validation loss = 0.5897



Epoch 3:  20%|██        | 40/200 [19:02<2:12:29, 49.69s/batch, loss=0.625]

Total iteration 442, validation loss = 0.5915



Epoch 3:  22%|██▎       | 45/200 [21:25<2:06:29, 48.97s/batch, loss=0.713]

Total iteration 447, validation loss = 0.5873



Epoch 3:  25%|██▌       | 50/200 [23:44<1:59:21, 47.74s/batch, loss=0.65] 

Total iteration 452, validation loss = 0.5862



Epoch 3:  28%|██▊       | 55/200 [26:12<2:01:40, 50.35s/batch, loss=0.599]

Total iteration 457, validation loss = 0.5859



Epoch 3:  30%|███       | 60/200 [28:37<1:56:03, 49.74s/batch, loss=0.603]

Total iteration 462, validation loss = 0.5814



Epoch 3:  32%|███▎      | 65/200 [30:59<1:49:20, 48.59s/batch, loss=0.581]

Total iteration 467, validation loss = 0.5807



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 3:  35%|███▌      | 70/200 [33:21<1:45:13, 48.57s/batch, loss=0.524]

Total iteration 472, validation loss = 0.5776



Epoch 3:  38%|███▊      | 75/200 [35:44<1:41:37, 48.78s/batch, loss=0.529]

Total iteration 477, validation loss = 0.5773



Epoch 3:  40%|████      | 80/200 [38:11<1:39:56, 49.97s/batch, loss=0.525]

Total iteration 482, validation loss = 0.5777



Epoch 3:  42%|████▎     | 85/200 [40:40<1:37:32, 50.89s/batch, loss=0.583]

Total iteration 487, validation loss = 0.5787



Epoch 3:  45%|████▌     | 90/200 [43:01<1:29:27, 48.80s/batch, loss=0.618]

Total iteration 492, validation loss = 0.5799



Epoch 3:  48%|████▊     | 95/200 [45:26<1:26:21, 49.34s/batch, loss=0.654]

Total iteration 497, validation loss = 0.5783



Epoch 3:  50%|█████     | 100/200 [47:47<1:20:37, 48.38s/batch, loss=0.55]

Total iteration 502, validation loss = 0.5787



Epoch 3:  52%|█████▎    | 105/200 [50:06<1:15:31, 47.70s/batch, loss=0.515]

Total iteration 507, validation loss = 0.5782



Epoch 3:  55%|█████▌    | 110/200 [52:26<1:11:29, 47.67s/batch, loss=0.525]

Total iteration 512, validation loss = 0.5743



Epoch 3:  57%|█████▊    | 115/200 [54:46<1:07:49, 47.88s/batch, loss=0.503]

Total iteration 517, validation loss = 0.5745



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 3:  60%|██████    | 120/200 [57:06<1:03:51, 47.90s/batch, loss=0.49]

Total iteration 522, validation loss = 0.5731



Epoch 3:  62%|██████▎   | 125/200 [59:27<1:00:13, 48.18s/batch, loss=0.564]

Total iteration 527, validation loss = 0.5716



Epoch 3:  65%|██████▌   | 130/200 [1:01:47<55:39, 47.71s/batch, loss=0.571]

Total iteration 532, validation loss = 0.5717



Epoch 3:  68%|██████▊   | 135/200 [1:04:07<51:52, 47.88s/batch, loss=0.505]

Total iteration 537, validation loss = 0.5708



Epoch 3:  70%|███████   | 140/200 [1:06:29<48:33, 48.55s/batch, loss=0.542]

Total iteration 542, validation loss = 0.5724



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 3:  72%|███████▎  | 145/200 [1:08:50<44:07, 48.15s/batch, loss=0.592]

Total iteration 547, validation loss = 0.5719



Epoch 3:  75%|███████▌  | 150/200 [1:11:11<40:01, 48.03s/batch, loss=0.482]

Total iteration 552, validation loss = 0.5697



Epoch 3:  78%|███████▊  | 155/200 [1:13:33<36:24, 48.54s/batch, loss=0.658]

Total iteration 557, validation loss = 0.5697



Epoch 3:  80%|████████  | 160/200 [1:15:55<32:19, 48.49s/batch, loss=0.55] 

Total iteration 562, validation loss = 0.5707



Epoch 3:  82%|████████▎ | 165/200 [1:18:19<28:38, 49.10s/batch, loss=0.599]

Total iteration 567, validation loss = 0.5683



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 3:  85%|████████▌ | 170/200 [1:20:38<23:56, 47.88s/batch, loss=0.488]

Total iteration 572, validation loss = 0.5721



Epoch 3:  88%|████████▊ | 175/200 [1:23:00<20:10, 48.44s/batch, loss=0.585]

Total iteration 577, validation loss = 0.5712



Epoch 3:  90%|█████████ | 180/200 [1:25:25<16:28, 49.42s/batch, loss=0.637]

Total iteration 582, validation loss = 0.5703



Epoch 3:  92%|█████████▎| 185/200 [1:27:52<12:30, 50.01s/batch, loss=0.656]

Total iteration 587, validation loss = 0.5697



Epoch 3:  95%|█████████▌| 190/200 [1:30:11<07:59, 47.94s/batch, loss=0.63] 

Total iteration 592, validation loss = 0.5684



Epoch 3:  98%|█████████▊| 195/200 [1:32:29<03:56, 47.34s/batch, loss=0.596]

Total iteration 597, validation loss = 0.5703



Epoch 3: 100%|██████████| 200/200 [1:34:34<00:00, 28.37s/batch, loss=0.558]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 602, validation loss = 0.5748



Epoch 4:   2%|▎         | 5/200 [02:30<2:36:52, 48.27s/batch, loss=0.708]

Total iteration 608, validation loss = 0.5749



Epoch 4:   5%|▌         | 10/200 [04:57<2:38:44, 50.13s/batch, loss=0.601]

Total iteration 613, validation loss = 0.5796



Epoch 4:   8%|▊         | 15/200 [07:17<2:28:58, 48.32s/batch, loss=0.47] 

Total iteration 618, validation loss = 0.5811



Epoch 4:  10%|█         | 20/200 [09:34<2:20:24, 46.80s/batch, loss=0.535]

Total iteration 623, validation loss = 0.5830



Epoch 4:  12%|█▎        | 25/200 [11:50<2:15:20, 46.40s/batch, loss=0.543]

Total iteration 628, validation loss = 0.5837



Epoch 4:  15%|█▌        | 30/200 [14:05<2:11:03, 46.26s/batch, loss=0.5]  

Total iteration 633, validation loss = 0.5824



Epoch 4:  18%|█▊        | 35/200 [16:29<2:13:46, 48.65s/batch, loss=0.575]

Total iteration 638, validation loss = 0.5802



Epoch 4:  20%|██        | 40/200 [18:55<2:12:57, 49.86s/batch, loss=0.55] 

Total iteration 643, validation loss = 0.5795



Epoch 4:  22%|██▎       | 45/200 [21:14<2:03:19, 47.74s/batch, loss=0.538]

Total iteration 648, validation loss = 0.5797



Epoch 4:  25%|██▌       | 50/200 [23:37<2:02:24, 48.96s/batch, loss=0.585]

Total iteration 653, validation loss = 0.5803



Epoch 4:  28%|██▊       | 55/200 [26:00<1:57:34, 48.65s/batch, loss=0.614]

Total iteration 658, validation loss = 0.5800



Epoch 4:  30%|███       | 60/200 [28:22<1:53:48, 48.78s/batch, loss=0.564]

Total iteration 663, validation loss = 0.5776



Epoch 4:  32%|███▎      | 65/200 [30:42<1:47:39, 47.85s/batch, loss=0.584]

Total iteration 668, validation loss = 0.5725



Epoch 4:  35%|███▌      | 70/200 [33:07<1:46:47, 49.29s/batch, loss=0.52] 

Total iteration 673, validation loss = 0.5684



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 4:  38%|███▊      | 75/200 [35:22<1:37:20, 46.73s/batch, loss=0.589]

Total iteration 678, validation loss = 0.5648



Epoch 4:  40%|████      | 80/200 [37:47<1:38:21, 49.18s/batch, loss=0.546]

Total iteration 683, validation loss = 0.5648



Epoch 4:  42%|████▎     | 85/200 [40:05<1:30:41, 47.32s/batch, loss=0.728]

Total iteration 688, validation loss = 0.5625



Epoch 4:  45%|████▌     | 90/200 [42:23<1:26:19, 47.08s/batch, loss=0.513]

Total iteration 693, validation loss = 0.5638



Epoch 4:  48%|████▊     | 95/200 [44:42<1:22:46, 47.30s/batch, loss=0.707]

Total iteration 698, validation loss = 0.5646



Epoch 4:  50%|█████     | 100/200 [47:00<1:18:49, 47.29s/batch, loss=0.642]

Total iteration 703, validation loss = 0.5644



Epoch 4:  52%|█████▎    | 105/200 [49:19<1:15:12, 47.50s/batch, loss=0.634]

Total iteration 708, validation loss = 0.5633



Epoch 4:  55%|█████▍    | 109/200 [51:38<21:03, 13.89s/batch, loss=0.56]   

Total iteration 713, validation loss = 0.5612



Epoch 4:  57%|█████▊    | 115/200 [53:57<1:07:11, 47.43s/batch, loss=0.699]

Total iteration 718, validation loss = 0.5605



Epoch 4:  60%|██████    | 120/200 [56:16<1:03:08, 47.36s/batch, loss=0.573]

Total iteration 723, validation loss = 0.5656



Epoch 4:  62%|██████▎   | 125/200 [58:39<1:00:59, 48.79s/batch, loss=0.563]

Total iteration 728, validation loss = 0.5681



Epoch 4:  65%|██████▌   | 130/200 [1:01:00<56:27, 48.39s/batch, loss=0.626]

Total iteration 733, validation loss = 0.5669



Epoch 4:  68%|██████▊   | 135/200 [1:03:21<52:05, 48.08s/batch, loss=0.462]

Total iteration 738, validation loss = 0.5677



Epoch 4:  70%|███████   | 140/200 [1:05:42<48:04, 48.07s/batch, loss=0.509]

Total iteration 743, validation loss = 0.5655



Epoch 4:  72%|███████▎  | 145/200 [1:08:04<44:25, 48.46s/batch, loss=0.547]

Total iteration 748, validation loss = 0.5645



Epoch 4:  75%|███████▌  | 150/200 [1:10:31<41:45, 50.10s/batch, loss=0.554]

Total iteration 753, validation loss = 0.5657



Epoch 4:  78%|███████▊  | 155/200 [1:12:50<35:52, 47.83s/batch, loss=0.648]

Total iteration 758, validation loss = 0.5656



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 4:  80%|████████  | 160/200 [1:15:11<32:11, 48.28s/batch, loss=0.662]

Total iteration 763, validation loss = 0.5641



Epoch 4:  82%|████████▎ | 165/200 [1:17:28<27:22, 46.94s/batch, loss=0.551]

Total iteration 768, validation loss = 0.5605



Epoch 4:  85%|████████▌ | 170/200 [1:19:45<23:25, 46.84s/batch, loss=0.546]

Total iteration 773, validation loss = 0.5594



Epoch 4:  88%|████████▊ | 175/200 [1:22:02<19:26, 46.65s/batch, loss=0.551]

Total iteration 778, validation loss = 0.5584



Epoch 4:  90%|█████████ | 180/200 [1:24:20<15:41, 47.09s/batch, loss=0.436]

Total iteration 783, validation loss = 0.5573



Epoch 4:  92%|█████████▎| 185/200 [1:26:38<11:44, 46.97s/batch, loss=0.428]

Total iteration 788, validation loss = 0.5563



Epoch 4:  95%|█████████▌| 190/200 [1:28:57<07:52, 47.25s/batch, loss=0.489]

Total iteration 793, validation loss = 0.5544



Epoch 4:  98%|█████████▊| 195/200 [1:31:15<03:55, 47.19s/batch, loss=0.598]

Total iteration 798, validation loss = 0.5542



Epoch 4: 100%|██████████| 200/200 [1:33:18<00:00, 27.99s/batch, loss=0.541]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 803, validation loss = 0.5533



Epoch 5:   2%|▎         | 5/200 [02:29<2:36:00, 48.00s/batch, loss=0.561]

Total iteration 809, validation loss = 0.5536



Epoch 5:   5%|▌         | 10/200 [04:56<2:37:50, 49.84s/batch, loss=0.388]

Total iteration 814, validation loss = 0.5565



Epoch 5:   8%|▊         | 15/200 [07:18<2:30:17, 48.75s/batch, loss=0.521]

Total iteration 819, validation loss = 0.5567



Epoch 5:  10%|█         | 20/200 [09:39<2:25:31, 48.51s/batch, loss=0.497]

Total iteration 824, validation loss = 0.5597



Epoch 5:  12%|█▎        | 25/200 [12:02<2:22:08, 48.74s/batch, loss=0.584]

Total iteration 829, validation loss = 0.5615



Epoch 5:  15%|█▌        | 30/200 [14:28<2:20:30, 49.59s/batch, loss=0.654]

Total iteration 834, validation loss = 0.5634



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 5:  18%|█▊        | 35/200 [16:51<2:15:32, 49.29s/batch, loss=0.533]

Total iteration 839, validation loss = 0.5658



Epoch 5:  20%|██        | 40/200 [19:14<2:10:15, 48.85s/batch, loss=0.616]

Total iteration 844, validation loss = 0.5676



Epoch 5:  22%|██▎       | 45/200 [21:37<2:06:10, 48.84s/batch, loss=0.601]

Total iteration 849, validation loss = 0.5658



Epoch 5:  25%|██▌       | 50/200 [23:59<2:01:14, 48.50s/batch, loss=0.673]

Total iteration 854, validation loss = 0.5655



Epoch 5:  28%|██▊       | 55/200 [26:22<1:58:37, 49.09s/batch, loss=0.649]

Total iteration 859, validation loss = 0.5629



Epoch 5:  30%|███       | 60/200 [28:50<1:57:03, 50.17s/batch, loss=0.644]

Total iteration 864, validation loss = 0.5613



Epoch 5:  32%|███▎      | 65/200 [31:10<1:48:53, 48.40s/batch, loss=0.546]

Total iteration 869, validation loss = 0.5627



Epoch 5:  35%|███▌      | 70/200 [33:30<1:43:36, 47.82s/batch, loss=0.556]

Total iteration 874, validation loss = 0.5633



Epoch 5:  38%|███▊      | 75/200 [35:50<1:39:52, 47.94s/batch, loss=0.547]

Total iteration 879, validation loss = 0.5621



Epoch 5:  40%|████      | 80/200 [38:12<1:36:38, 48.32s/batch, loss=0.627]

Total iteration 884, validation loss = 0.5616



Epoch 5:  42%|████▎     | 85/200 [40:33<1:32:18, 48.16s/batch, loss=0.485]

Total iteration 889, validation loss = 0.5592



Epoch 5:  45%|████▌     | 90/200 [42:58<1:30:43, 49.49s/batch, loss=0.56] 

Total iteration 894, validation loss = 0.5579



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 5:  48%|████▊     | 95/200 [45:20<1:24:59, 48.57s/batch, loss=0.791]

Total iteration 899, validation loss = 0.5563



Epoch 5:  50%|█████     | 100/200 [47:44<1:22:06, 49.26s/batch, loss=0.453]

Total iteration 904, validation loss = 0.5559



Epoch 5:  52%|█████▎    | 105/200 [50:03<1:15:42, 47.82s/batch, loss=0.505]

Total iteration 909, validation loss = 0.5567



Epoch 5:  55%|█████▌    | 110/200 [52:25<1:12:18, 48.21s/batch, loss=0.795]

Total iteration 914, validation loss = 0.5578



Epoch 5:  57%|█████▊    | 115/200 [54:51<1:10:29, 49.76s/batch, loss=0.679]

Total iteration 919, validation loss = 0.5569



Epoch 5:  60%|██████    | 120/200 [57:13<1:04:53, 48.67s/batch, loss=0.554]

Total iteration 924, validation loss = 0.5561



Epoch 5:  62%|██████▎   | 125/200 [59:35<1:00:45, 48.60s/batch, loss=0.675]

Total iteration 929, validation loss = 0.5563



Epoch 5:  65%|██████▌   | 130/200 [1:01:59<57:07, 48.96s/batch, loss=0.458]

Total iteration 934, validation loss = 0.5596



Epoch 5:  68%|██████▊   | 135/200 [1:04:24<53:36, 49.48s/batch, loss=0.568]

Total iteration 939, validation loss = 0.5619



Epoch 5:  70%|███████   | 140/200 [1:06:47<49:10, 49.18s/batch, loss=0.613]

Total iteration 944, validation loss = 0.5623



Epoch 5:  72%|███████▎  | 145/200 [1:09:12<45:16, 49.39s/batch, loss=0.633]

Total iteration 949, validation loss = 0.5616



Epoch 5:  75%|███████▌  | 150/200 [1:11:37<41:22, 49.65s/batch, loss=0.55] 

Total iteration 954, validation loss = 0.5597



Epoch 5:  78%|███████▊  | 155/200 [1:14:02<37:08, 49.52s/batch, loss=0.474]

Total iteration 959, validation loss = 0.5577



Epoch 5:  80%|████████  | 160/200 [1:16:26<32:58, 49.47s/batch, loss=0.407]

Total iteration 964, validation loss = 0.5541



Epoch 5:  82%|████████▎ | 165/200 [1:18:52<29:10, 50.00s/batch, loss=0.471]

Total iteration 969, validation loss = 0.5527



Epoch 5:  85%|████████▌ | 170/200 [1:21:17<24:45, 49.51s/batch, loss=0.509]

Total iteration 974, validation loss = 0.5521



Epoch 5:  88%|████████▊ | 175/200 [1:23:42<20:36, 49.44s/batch, loss=0.445]

Total iteration 979, validation loss = 0.5523



Epoch 5:  90%|█████████ | 180/200 [1:25:58<15:43, 47.16s/batch, loss=0.499]

Total iteration 984, validation loss = 0.5525



Epoch 5:  92%|█████████▎| 185/200 [1:28:18<11:55, 47.70s/batch, loss=0.528]

Total iteration 989, validation loss = 0.5547



Epoch 5:  95%|█████████▌| 190/200 [1:30:45<08:16, 49.67s/batch, loss=0.665]

Total iteration 994, validation loss = 0.5557



Epoch 5:  98%|█████████▊| 195/200 [1:33:06<04:02, 48.47s/batch, loss=0.521]

Total iteration 999, validation loss = 0.5558



Epoch 5: 100%|██████████| 200/200 [1:35:14<00:00, 28.57s/batch, loss=0.508]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 1004, validation loss = 0.5554



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 6:   2%|▏         | 4/200 [02:28<16:24,  5.02s/batch, loss=0.513]

Total iteration 1010, validation loss = 0.5583



Epoch 6:   5%|▌         | 10/200 [04:43<2:25:56, 46.09s/batch, loss=0.499]

Total iteration 1015, validation loss = 0.5584



Epoch 6:   8%|▊         | 15/200 [06:58<2:21:54, 46.02s/batch, loss=0.557]

Total iteration 1020, validation loss = 0.5591



Epoch 6:  10%|█         | 20/200 [09:13<2:18:03, 46.02s/batch, loss=0.564]

Total iteration 1025, validation loss = 0.5605



Epoch 6:  12%|█▎        | 25/200 [11:27<2:13:54, 45.91s/batch, loss=0.469]

Total iteration 1030, validation loss = 0.5636



Epoch 6:  15%|█▌        | 30/200 [13:52<2:18:52, 49.01s/batch, loss=0.447]

Total iteration 1035, validation loss = 0.5645



Epoch 6:  18%|█▊        | 35/200 [16:11<2:11:12, 47.71s/batch, loss=0.548]

Total iteration 1040, validation loss = 0.5635



Epoch 6:  20%|██        | 40/200 [18:37<2:11:46, 49.42s/batch, loss=0.501]

Total iteration 1045, validation loss = 0.5621



Epoch 6:  22%|██▎       | 45/200 [20:55<2:03:14, 47.71s/batch, loss=0.505]

Total iteration 1050, validation loss = 0.5590



Epoch 6:  25%|██▌       | 50/200 [23:12<1:57:20, 46.94s/batch, loss=0.417]

Total iteration 1055, validation loss = 0.5556



Epoch 6:  28%|██▊       | 55/200 [25:29<1:52:51, 46.70s/batch, loss=0.494]

Total iteration 1060, validation loss = 0.5537



Epoch 6:  30%|███       | 60/200 [27:47<1:49:29, 46.93s/batch, loss=0.553]

Total iteration 1065, validation loss = 0.5529



Epoch 6:  32%|███▎      | 65/200 [30:06<1:46:17, 47.24s/batch, loss=0.694]

Total iteration 1070, validation loss = 0.5523



Epoch 6:  35%|███▌      | 70/200 [32:25<1:43:14, 47.65s/batch, loss=0.458]

Total iteration 1075, validation loss = 0.5536



Epoch 6:  38%|███▊      | 75/200 [34:46<1:39:58, 47.99s/batch, loss=0.457]

Total iteration 1080, validation loss = 0.5542



Epoch 6:  40%|████      | 80/200 [37:06<1:35:38, 47.82s/batch, loss=0.653]

Total iteration 1085, validation loss = 0.5564



Epoch 6:  42%|████▎     | 85/200 [39:29<1:33:31, 48.80s/batch, loss=0.479]

Total iteration 1090, validation loss = 0.5564



Epoch 6:  45%|████▌     | 90/200 [41:47<1:26:43, 47.30s/batch, loss=0.628]

Total iteration 1095, validation loss = 0.5560



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 6:  48%|████▊     | 95/200 [44:15<1:27:38, 50.08s/batch, loss=0.773]

Total iteration 1100, validation loss = 0.5552



Epoch 6:  50%|█████     | 100/200 [46:40<1:23:04, 49.85s/batch, loss=0.452]

Total iteration 1105, validation loss = 0.5573



Epoch 6:  52%|█████▎    | 105/200 [48:58<1:15:13, 47.51s/batch, loss=0.595]

Total iteration 1110, validation loss = 0.5597



Epoch 6:  55%|█████▌    | 110/200 [51:21<1:12:43, 48.49s/batch, loss=0.494]

Total iteration 1115, validation loss = 0.5604



Epoch 6:  57%|█████▊    | 115/200 [53:42<1:08:24, 48.29s/batch, loss=0.464]

Total iteration 1120, validation loss = 0.5603



Epoch 6:  60%|██████    | 120/200 [56:08<1:06:23, 49.79s/batch, loss=0.457]

Total iteration 1125, validation loss = 0.5614



Epoch 6:  62%|██████▎   | 125/200 [58:28<1:00:03, 48.05s/batch, loss=0.615]

Total iteration 1130, validation loss = 0.5595



Epoch 6:  65%|██████▌   | 130/200 [1:00:49<56:16, 48.24s/batch, loss=0.653]

Total iteration 1135, validation loss = 0.5578



Epoch 6:  68%|██████▊   | 135/200 [1:03:12<52:50, 48.78s/batch, loss=0.617]

Total iteration 1140, validation loss = 0.5577



Epoch 6:  70%|███████   | 140/200 [1:05:35<48:44, 48.74s/batch, loss=0.548]

Total iteration 1145, validation loss = 0.5575



Epoch 6:  72%|███████▎  | 145/200 [1:07:59<45:02, 49.13s/batch, loss=0.53] 

Total iteration 1150, validation loss = 0.5587



Epoch 6:  75%|███████▌  | 150/200 [1:10:22<40:48, 48.97s/batch, loss=0.534]

Total iteration 1155, validation loss = 0.5569



Epoch 6:  78%|███████▊  | 155/200 [1:12:45<36:35, 48.80s/batch, loss=0.517]

Total iteration 1160, validation loss = 0.5571



Epoch 6:  80%|████████  | 160/200 [1:15:10<33:09, 49.74s/batch, loss=0.518]

Total iteration 1165, validation loss = 0.5555



Epoch 6:  82%|████████▎ | 165/200 [1:17:29<27:45, 47.59s/batch, loss=0.584]

Total iteration 1170, validation loss = 0.5517



Epoch 6:  85%|████████▌ | 170/200 [1:19:47<23:34, 47.14s/batch, loss=0.598]

Total iteration 1175, validation loss = 0.5516



Epoch 6:  88%|████████▊ | 175/200 [1:22:05<19:37, 47.09s/batch, loss=0.607]

Total iteration 1180, validation loss = 0.5506



Epoch 6:  90%|█████████ | 180/200 [1:24:21<15:34, 46.71s/batch, loss=0.523]

Total iteration 1185, validation loss = 0.5494



Epoch 6:  92%|█████████▎| 185/200 [1:26:42<11:57, 47.83s/batch, loss=0.662]

Total iteration 1190, validation loss = 0.5499



Epoch 6:  95%|█████████▌| 190/200 [1:29:05<08:07, 48.78s/batch, loss=0.542]

Total iteration 1195, validation loss = 0.5511



Epoch 6:  98%|█████████▊| 195/200 [1:31:28<04:04, 48.83s/batch, loss=0.464]

Total iteration 1200, validation loss = 0.5499



Epoch 6: 100%|██████████| 200/200 [1:33:40<00:00, 28.10s/batch, loss=0.504]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 1205, validation loss = 0.5505



Epoch 7:   2%|▎         | 5/200 [02:32<2:39:28, 49.07s/batch, loss=0.447]

Total iteration 1211, validation loss = 0.5514



Epoch 7:   5%|▌         | 10/200 [04:52<2:31:50, 47.95s/batch, loss=0.743]

Total iteration 1216, validation loss = 0.5516



Epoch 7:   8%|▊         | 15/200 [07:18<2:32:49, 49.57s/batch, loss=0.39] 

Total iteration 1221, validation loss = 0.5515



Epoch 7:  10%|█         | 20/200 [09:37<2:23:24, 47.81s/batch, loss=0.553]

Total iteration 1226, validation loss = 0.5492



Epoch 7:  12%|█▎        | 25/200 [11:55<2:18:07, 47.36s/batch, loss=0.553]

Total iteration 1231, validation loss = 0.5534



Epoch 7:  15%|█▌        | 30/200 [14:18<2:17:45, 48.62s/batch, loss=0.525]

Total iteration 1236, validation loss = 0.5580



Epoch 7:  18%|█▊        | 35/200 [16:38<2:11:31, 47.83s/batch, loss=0.435]

Total iteration 1241, validation loss = 0.5580



Epoch 7:  20%|██        | 40/200 [18:59<2:08:43, 48.27s/batch, loss=0.433]

Total iteration 1246, validation loss = 0.5568



Epoch 7:  22%|██▎       | 45/200 [21:20<2:04:28, 48.18s/batch, loss=0.49] 

Total iteration 1251, validation loss = 0.5526



Epoch 7:  25%|██▌       | 50/200 [23:48<2:05:26, 50.18s/batch, loss=0.667]

Total iteration 1256, validation loss = 0.5500



Epoch 7:  28%|██▊       | 55/200 [26:10<1:57:59, 48.83s/batch, loss=0.453]

Total iteration 1261, validation loss = 0.5484



Epoch 7:  30%|███       | 60/200 [28:32<1:53:09, 48.50s/batch, loss=0.603]

Total iteration 1266, validation loss = 0.5472



Epoch 7:  32%|███▎      | 65/200 [30:54<1:49:15, 48.56s/batch, loss=0.379]

Total iteration 1271, validation loss = 0.5465



Epoch 7:  35%|███▌      | 70/200 [33:16<1:45:08, 48.52s/batch, loss=0.586]

Total iteration 1276, validation loss = 0.5461



Epoch 7:  38%|███▊      | 75/200 [35:39<1:41:42, 48.82s/batch, loss=0.576]

Total iteration 1281, validation loss = 0.5471



Epoch 7:  40%|████      | 80/200 [38:00<1:36:48, 48.40s/batch, loss=0.454]

Total iteration 1286, validation loss = 0.5498



Epoch 7:  42%|████▎     | 85/200 [40:22<1:32:34, 48.30s/batch, loss=0.48] 

Total iteration 1291, validation loss = 0.5536



Epoch 7:  45%|████▌     | 90/200 [42:42<1:27:45, 47.87s/batch, loss=0.452]

Total iteration 1296, validation loss = 0.5556



Epoch 7:  48%|████▊     | 95/200 [45:07<1:26:29, 49.42s/batch, loss=0.53] 

Total iteration 1301, validation loss = 0.5559



Epoch 7:  50%|█████     | 100/200 [47:28<1:20:52, 48.52s/batch, loss=0.592]

Total iteration 1306, validation loss = 0.5534



Epoch 7:  52%|█████▎    | 105/200 [49:51<1:16:53, 48.56s/batch, loss=0.504]

Total iteration 1311, validation loss = 0.5527



Epoch 7:  55%|█████▌    | 110/200 [52:20<1:16:05, 50.73s/batch, loss=0.481]

Total iteration 1316, validation loss = 0.5522



Epoch 7:  57%|█████▊    | 115/200 [54:40<1:08:27, 48.32s/batch, loss=0.577]

Total iteration 1321, validation loss = 0.5547



Epoch 7:  60%|██████    | 120/200 [57:07<1:06:39, 49.99s/batch, loss=0.492]

Total iteration 1326, validation loss = 0.5533



Epoch 7:  62%|██████▎   | 125/200 [59:27<1:00:24, 48.33s/batch, loss=0.609]

Total iteration 1331, validation loss = 0.5527



Epoch 7:  65%|██████▌   | 130/200 [1:01:49<56:17, 48.26s/batch, loss=0.444]

Total iteration 1336, validation loss = 0.5535



Epoch 7:  68%|██████▊   | 135/200 [1:04:09<51:57, 47.97s/batch, loss=0.474]

Total iteration 1341, validation loss = 0.5539



Epoch 7:  70%|███████   | 140/200 [1:06:34<49:24, 49.40s/batch, loss=0.476]

Total iteration 1346, validation loss = 0.5550



Epoch 7:  72%|███████▎  | 145/200 [1:08:52<43:32, 47.50s/batch, loss=0.371]

Total iteration 1351, validation loss = 0.5521



Epoch 7:  75%|███████▌  | 150/200 [1:11:22<42:13, 50.67s/batch, loss=0.4]  

Total iteration 1356, validation loss = 0.5495



Epoch 7:  78%|███████▊  | 155/200 [1:13:47<37:16, 49.69s/batch, loss=0.304]

Total iteration 1361, validation loss = 0.5482



Epoch 7:  80%|████████  | 160/200 [1:16:12<33:13, 49.83s/batch, loss=0.552]

Total iteration 1366, validation loss = 0.5494



Epoch 7:  82%|████████▎ | 165/200 [1:18:42<29:50, 51.16s/batch, loss=0.451]

Total iteration 1371, validation loss = 0.5508



Epoch 7:  85%|████████▌ | 170/200 [1:21:02<24:07, 48.25s/batch, loss=0.616]

Total iteration 1376, validation loss = 0.5529



Epoch 7:  88%|████████▊ | 175/200 [1:23:28<20:42, 49.72s/batch, loss=0.688]

Total iteration 1381, validation loss = 0.5537



Epoch 7:  90%|█████████ | 180/200 [1:25:54<16:35, 49.77s/batch, loss=0.524]

Total iteration 1386, validation loss = 0.5528



Epoch 7:  92%|█████████▎| 185/200 [1:28:15<12:09, 48.65s/batch, loss=0.65] 

Total iteration 1391, validation loss = 0.5529



Epoch 7:  95%|█████████▌| 190/200 [1:30:37<08:03, 48.40s/batch, loss=0.626]

Total iteration 1396, validation loss = 0.5524



Epoch 7:  98%|█████████▊| 195/200 [1:32:58<04:01, 48.21s/batch, loss=0.549]

Total iteration 1401, validation loss = 0.5495



Epoch 7: 100%|██████████| 200/200 [1:35:09<00:00, 28.55s/batch, loss=0.532]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 1406, validation loss = 0.5499



Epoch 8:   2%|▎         | 5/200 [02:35<2:43:10, 50.21s/batch, loss=0.585]

Total iteration 1412, validation loss = 0.5502



Epoch 8:   5%|▌         | 10/200 [04:58<2:34:55, 48.92s/batch, loss=0.6] 

Total iteration 1417, validation loss = 0.5505



Epoch 8:   8%|▊         | 15/200 [07:21<2:31:06, 49.01s/batch, loss=0.548]

Total iteration 1422, validation loss = 0.5490



Epoch 8:  10%|█         | 20/200 [09:51<2:32:23, 50.79s/batch, loss=0.541]

Total iteration 1427, validation loss = 0.5488



Epoch 8:  12%|█▎        | 25/200 [12:17<2:26:55, 50.37s/batch, loss=0.422]

Total iteration 1432, validation loss = 0.5492



Epoch 8:  15%|█▌        | 30/200 [14:42<2:20:30, 49.59s/batch, loss=0.509]

Total iteration 1437, validation loss = 0.5506



Epoch 8:  18%|█▊        | 35/200 [17:05<2:14:53, 49.05s/batch, loss=0.737]

Total iteration 1442, validation loss = 0.5538



Epoch 8:  20%|██        | 40/200 [19:29<2:10:44, 49.03s/batch, loss=0.472]

Total iteration 1447, validation loss = 0.5571



Epoch 8:  22%|██▎       | 45/200 [21:52<2:06:47, 49.08s/batch, loss=0.379]

Total iteration 1452, validation loss = 0.5562



Epoch 8:  25%|██▌       | 50/200 [24:15<2:02:14, 48.89s/batch, loss=0.812]

Total iteration 1457, validation loss = 0.5577



Epoch 8:  28%|██▊       | 55/200 [26:38<1:57:56, 48.81s/batch, loss=0.577]

Total iteration 1462, validation loss = 0.5597



Epoch 8:  30%|███       | 60/200 [29:06<1:57:46, 50.48s/batch, loss=0.384]

Total iteration 1467, validation loss = 0.5586



Epoch 8:  32%|███▎      | 65/200 [31:28<1:49:54, 48.85s/batch, loss=0.549]

Total iteration 1472, validation loss = 0.5542



Epoch 8:  35%|███▌      | 70/200 [33:52<1:46:36, 49.20s/batch, loss=0.416]

Total iteration 1477, validation loss = 0.5516



Epoch 8:  38%|███▊      | 75/200 [36:14<1:41:09, 48.56s/batch, loss=0.641]

Total iteration 1482, validation loss = 0.5508



Epoch 8:  40%|████      | 80/200 [38:36<1:37:08, 48.57s/batch, loss=0.573]

Total iteration 1487, validation loss = 0.5506



Epoch 8:  42%|████▎     | 85/200 [40:58<1:32:42, 48.37s/batch, loss=0.51] 

Total iteration 1492, validation loss = 0.5518



Epoch 8:  45%|████▌     | 90/200 [43:21<1:29:33, 48.85s/batch, loss=0.424]

Total iteration 1497, validation loss = 0.5548



Epoch 8:  48%|████▊     | 95/200 [45:44<1:25:34, 48.90s/batch, loss=0.439]

Total iteration 1502, validation loss = 0.5556



Epoch 8:  50%|█████     | 100/200 [48:07<1:21:21, 48.82s/batch, loss=0.459]

Total iteration 1507, validation loss = 0.5548



Epoch 8:  52%|█████▎    | 105/200 [50:33<1:19:01, 49.91s/batch, loss=0.532]

Total iteration 1512, validation loss = 0.5539



Epoch 8:  55%|█████▌    | 110/200 [52:56<1:13:37, 49.08s/batch, loss=0.495]

Total iteration 1517, validation loss = 0.5529



Epoch 8:  57%|█████▊    | 115/200 [55:18<1:09:02, 48.73s/batch, loss=0.418]

Total iteration 1522, validation loss = 0.5532



Epoch 8:  60%|██████    | 120/200 [57:48<1:07:47, 50.84s/batch, loss=0.57] 

Total iteration 1527, validation loss = 0.5547



Epoch 8:  62%|██████▎   | 125/200 [1:00:08<1:00:25, 48.34s/batch, loss=0.469]

Total iteration 1532, validation loss = 0.5587



Epoch 8:  65%|██████▌   | 130/200 [1:02:28<55:58, 47.98s/batch, loss=0.44]   

Total iteration 1537, validation loss = 0.5646



Epoch 8:  68%|██████▊   | 135/200 [1:04:49<52:10, 48.16s/batch, loss=0.521]

Total iteration 1542, validation loss = 0.5634



Epoch 8:  70%|███████   | 140/200 [1:07:10<48:03, 48.05s/batch, loss=0.413]

Total iteration 1547, validation loss = 0.5632



Epoch 8:  72%|███████▎  | 145/200 [1:09:30<43:53, 47.89s/batch, loss=0.352]

Total iteration 1552, validation loss = 0.5571



Epoch 8:  75%|███████▌  | 150/200 [1:11:55<41:08, 49.38s/batch, loss=0.301]

Total iteration 1557, validation loss = 0.5513



Epoch 8:  78%|███████▊  | 155/200 [1:14:14<35:40, 47.58s/batch, loss=0.453]

Total iteration 1562, validation loss = 0.5503



Epoch 8:  80%|████████  | 160/200 [1:16:32<31:35, 47.38s/batch, loss=0.432]

Total iteration 1567, validation loss = 0.5478



Epoch 8:  82%|████████▎ | 165/200 [1:18:51<27:31, 47.19s/batch, loss=0.859]

Total iteration 1572, validation loss = 0.5469



Epoch 8:  85%|████████▌ | 170/200 [1:21:08<23:26, 46.87s/batch, loss=0.351]

Total iteration 1577, validation loss = 0.5482



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 8:  88%|████████▊ | 175/200 [1:23:25<19:30, 46.81s/batch, loss=0.529]

Total iteration 1582, validation loss = 0.5502



Epoch 8:  90%|█████████ | 180/200 [1:25:45<15:52, 47.65s/batch, loss=0.553]

Total iteration 1587, validation loss = 0.5516



Epoch 8:  92%|█████████▎| 185/200 [1:28:10<12:18, 49.21s/batch, loss=0.492]

Total iteration 1592, validation loss = 0.5528



Epoch 8:  95%|█████████▌| 190/200 [1:30:32<08:07, 48.75s/batch, loss=0.653]

Total iteration 1597, validation loss = 0.5515



Epoch 8:  98%|█████████▊| 195/200 [1:32:55<04:04, 48.82s/batch, loss=0.469]

Total iteration 1602, validation loss = 0.5494



Epoch 8: 100%|██████████| 200/200 [1:35:04<00:00, 28.52s/batch, loss=0.539]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 1607, validation loss = 0.5513



Epoch 9:   2%|▎         | 5/200 [02:29<2:35:57, 47.99s/batch, loss=0.423]

Total iteration 1613, validation loss = 0.5556



Epoch 9:   5%|▌         | 10/200 [04:46<2:29:00, 47.06s/batch, loss=0.3] 

Total iteration 1618, validation loss = 0.5558



Epoch 9:   8%|▊         | 15/200 [07:04<2:25:14, 47.10s/batch, loss=0.347]

Total iteration 1623, validation loss = 0.5554



Epoch 9:  10%|█         | 20/200 [09:23<2:21:27, 47.15s/batch, loss=0.743]

Total iteration 1628, validation loss = 0.5561



Epoch 9:  12%|█▎        | 25/200 [11:47<2:22:49, 48.97s/batch, loss=0.487]

Total iteration 1633, validation loss = 0.5561



Epoch 9:  15%|█▌        | 30/200 [14:10<2:18:12, 48.78s/batch, loss=0.466]

Total iteration 1638, validation loss = 0.5593



Epoch 9:  18%|█▊        | 35/200 [16:34<2:15:27, 49.26s/batch, loss=0.603]

Total iteration 1643, validation loss = 0.5611



Epoch 9:  20%|██        | 40/200 [18:56<2:09:47, 48.67s/batch, loss=0.52] 

Total iteration 1648, validation loss = 0.5592



Epoch 9:  22%|██▎       | 45/200 [21:25<2:10:38, 50.57s/batch, loss=0.368]

Total iteration 1653, validation loss = 0.5591



Epoch 9:  25%|██▌       | 50/200 [23:47<2:02:15, 48.90s/batch, loss=0.46] 

Total iteration 1658, validation loss = 0.5604



Epoch 9:  28%|██▊       | 55/200 [26:13<2:00:06, 49.70s/batch, loss=0.589]

Total iteration 1663, validation loss = 0.5602



Epoch 9:  30%|███       | 60/200 [28:32<1:52:01, 48.01s/batch, loss=0.478]

Total iteration 1668, validation loss = 0.5599



Epoch 9:  32%|███▎      | 65/200 [30:52<1:47:29, 47.77s/batch, loss=0.346]

Total iteration 1673, validation loss = 0.5585



Epoch 9:  35%|███▌      | 70/200 [33:12<1:43:30, 47.77s/batch, loss=0.362]

Total iteration 1678, validation loss = 0.5580



Epoch 9:  38%|███▊      | 75/200 [35:33<1:40:24, 48.20s/batch, loss=0.402]

Total iteration 1683, validation loss = 0.5571



Epoch 9:  40%|████      | 80/200 [37:55<1:36:31, 48.27s/batch, loss=0.357]

Total iteration 1688, validation loss = 0.5552



Epoch 9:  42%|████▎     | 85/200 [40:24<1:37:08, 50.68s/batch, loss=0.408]

Total iteration 1693, validation loss = 0.5539



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 9:  45%|████▌     | 90/200 [42:46<1:29:47, 48.97s/batch, loss=0.513]

Total iteration 1698, validation loss = 0.5498



Epoch 9:  48%|████▊     | 95/200 [45:06<1:23:56, 47.97s/batch, loss=0.54] 

Total iteration 1703, validation loss = 0.5475



Epoch 9:  50%|█████     | 100/200 [47:26<1:19:35, 47.76s/batch, loss=0.662]

Total iteration 1708, validation loss = 0.5476



Epoch 9:  52%|█████▎    | 105/200 [49:45<1:15:20, 47.58s/batch, loss=0.609]

Total iteration 1713, validation loss = 0.5487



Epoch 9:  55%|█████▌    | 110/200 [52:04<1:11:20, 47.56s/batch, loss=0.435]

Total iteration 1718, validation loss = 0.5494



Epoch 9:  57%|█████▊    | 115/200 [54:24<1:07:24, 47.59s/batch, loss=0.557]

Total iteration 1723, validation loss = 0.5505



Epoch 9:  60%|██████    | 120/200 [56:43<1:03:21, 47.52s/batch, loss=0.51] 

Total iteration 1728, validation loss = 0.5541



Epoch 9:  62%|██████▎   | 125/200 [59:07<1:01:15, 49.00s/batch, loss=0.466]

Total iteration 1733, validation loss = 0.5541



Epoch 9:  65%|██████▌   | 130/200 [1:01:27<55:54, 47.92s/batch, loss=0.524]

Total iteration 1738, validation loss = 0.5537



Epoch 9:  68%|██████▊   | 135/200 [1:03:52<53:27, 49.34s/batch, loss=0.493]

Total iteration 1743, validation loss = 0.5556



Epoch 9:  70%|███████   | 140/200 [1:06:22<51:09, 51.17s/batch, loss=0.611]

Total iteration 1748, validation loss = 0.5614



Epoch 9:  72%|███████▎  | 145/200 [1:08:48<46:05, 50.28s/batch, loss=0.47] 

Total iteration 1753, validation loss = 0.5611



Epoch 9:  75%|███████▌  | 150/200 [1:11:08<40:12, 48.24s/batch, loss=0.483]

Total iteration 1758, validation loss = 0.5567



Epoch 9:  78%|███████▊  | 155/200 [1:13:28<35:55, 47.90s/batch, loss=0.374]

Total iteration 1763, validation loss = 0.5508



Epoch 9:  80%|████████  | 160/200 [1:15:56<33:27, 50.18s/batch, loss=0.578]

Total iteration 1768, validation loss = 0.5483



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 9:  82%|████████▎ | 165/200 [1:18:20<28:52, 49.49s/batch, loss=0.553]

Total iteration 1773, validation loss = 0.5481



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 9:  85%|████████▌ | 170/200 [1:20:47<25:01, 50.05s/batch, loss=0.454]

Total iteration 1778, validation loss = 0.5479



Epoch 9:  88%|████████▊ | 175/200 [1:23:04<19:37, 47.08s/batch, loss=0.376]

Total iteration 1783, validation loss = 0.5479



Epoch 9:  90%|█████████ | 180/200 [1:25:19<15:29, 46.49s/batch, loss=0.48] 

Total iteration 1788, validation loss = 0.5493



Epoch 9:  92%|█████████▎| 185/200 [1:27:36<11:40, 46.69s/batch, loss=0.416]

Total iteration 1793, validation loss = 0.5519



Epoch 9:  95%|█████████▌| 190/200 [1:29:55<07:51, 47.15s/batch, loss=0.59] 

Total iteration 1798, validation loss = 0.5525



Epoch 9:  98%|█████████▊| 195/200 [1:32:13<03:56, 47.20s/batch, loss=0.53] 

Total iteration 1803, validation loss = 0.5554



Epoch 9: 100%|██████████| 200/200 [1:34:22<00:00, 28.31s/batch, loss=0.575]
  0%|          | 0/200 [00:00<?, ?batch/s]

Total iteration 1808, validation loss = 0.5589



Epoch 10:   2%|▎         | 5/200 [02:29<2:35:44, 47.92s/batch, loss=0.458]

Total iteration 1814, validation loss = 0.5613



Epoch 10:   5%|▌         | 10/200 [04:50<2:32:52, 48.28s/batch, loss=0.392]

Total iteration 1819, validation loss = 0.5639



Epoch 10:   8%|▊         | 15/200 [07:11<2:28:45, 48.25s/batch, loss=0.289]

Total iteration 1824, validation loss = 0.5623



Epoch 10:  10%|█         | 20/200 [09:34<2:25:46, 48.59s/batch, loss=0.45] 

Total iteration 1829, validation loss = 0.5623



Epoch 10:  12%|█▎        | 25/200 [11:55<2:20:47, 48.27s/batch, loss=0.435]

Total iteration 1834, validation loss = 0.5622



Epoch 10:  15%|█▌        | 30/200 [14:17<2:17:18, 48.46s/batch, loss=0.339]

Total iteration 1839, validation loss = 0.5597



Epoch 10:  18%|█▊        | 35/200 [16:39<2:13:53, 48.69s/batch, loss=0.451]

Total iteration 1844, validation loss = 0.5609



Epoch 10:  20%|██        | 40/200 [19:03<2:10:43, 49.02s/batch, loss=0.448]

Total iteration 1849, validation loss = 0.5602



Epoch 10:  22%|██▎       | 45/200 [21:24<2:05:13, 48.47s/batch, loss=0.429]

Total iteration 1854, validation loss = 0.5603



Epoch 10:  25%|██▌       | 50/200 [23:46<2:00:42, 48.28s/batch, loss=0.602]

Total iteration 1859, validation loss = 0.5624



Epoch 10:  28%|██▊       | 55/200 [26:07<1:56:59, 48.41s/batch, loss=0.407]

Total iteration 1864, validation loss = 0.5652



Epoch 10:  30%|███       | 60/200 [28:29<1:53:06, 48.47s/batch, loss=0.294]

Total iteration 1869, validation loss = 0.5704



Epoch 10:  32%|███▎      | 65/200 [30:54<1:50:52, 49.28s/batch, loss=0.408]

Total iteration 1874, validation loss = 0.5684



Epoch 10:  35%|███▌      | 70/200 [33:20<1:48:13, 49.95s/batch, loss=0.361]

Total iteration 1879, validation loss = 0.5703



Epoch 10:  38%|███▊      | 75/200 [35:46<1:43:51, 49.85s/batch, loss=0.482]

Total iteration 1884, validation loss = 0.5720



Epoch 10:  40%|████      | 80/200 [38:10<1:38:41, 49.34s/batch, loss=0.384]

Total iteration 1889, validation loss = 0.5740



Epoch 10:  42%|████▎     | 85/200 [40:29<1:31:55, 47.96s/batch, loss=0.505]

Total iteration 1894, validation loss = 0.5752



Epoch 10:  45%|████▌     | 90/200 [42:54<1:30:10, 49.19s/batch, loss=0.278]

Total iteration 1899, validation loss = 0.5774



Epoch 10:  47%|████▋     | 94/200 [45:17<25:15, 14.30s/batch, loss=0.471]  

Total iteration 1904, validation loss = 0.5775



Epoch 10:  50%|█████     | 100/200 [47:36<1:19:35, 47.76s/batch, loss=0.501]

Total iteration 1909, validation loss = 0.5766



Epoch 10:  52%|█████▎    | 105/200 [49:56<1:15:40, 47.79s/batch, loss=0.378]

Total iteration 1914, validation loss = 0.5769



Epoch 10:  55%|█████▌    | 110/200 [52:15<1:11:32, 47.69s/batch, loss=0.38] 

Total iteration 1919, validation loss = 0.5772



Epoch 10:  57%|█████▊    | 115/200 [54:37<1:08:22, 48.26s/batch, loss=0.534]

Total iteration 1924, validation loss = 0.5730



Epoch 10:  60%|█████▉    | 119/200 [57:07<18:59, 14.07s/batch, loss=0.413]  

Total iteration 1929, validation loss = 0.5661



Epoch 10:  62%|██████▎   | 125/200 [59:31<1:01:59, 49.60s/batch, loss=0.322]

Total iteration 1934, validation loss = 0.5622



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 10:  65%|██████▌   | 130/200 [1:01:54<57:04, 48.92s/batch, loss=0.28] 

Total iteration 1939, validation loss = 0.5597



Epoch 10:  68%|██████▊   | 135/200 [1:04:11<50:54, 46.99s/batch, loss=0.422]

Total iteration 1944, validation loss = 0.5562



Epoch 10:  70%|███████   | 140/200 [1:06:33<48:23, 48.40s/batch, loss=0.432]

Total iteration 1949, validation loss = 0.5549



Epoch 10:  72%|███████▎  | 145/200 [1:08:52<43:44, 47.72s/batch, loss=0.64] 

Total iteration 1954, validation loss = 0.5570



Epoch 10:  75%|███████▌  | 150/200 [1:11:12<39:46, 47.73s/batch, loss=0.603]

Total iteration 1959, validation loss = 0.5576



Epoch 10:  78%|███████▊  | 155/200 [1:13:32<35:47, 47.72s/batch, loss=0.508]

Total iteration 1964, validation loss = 0.5589



  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
Epoch 10:  80%|████████  | 160/200 [1:15:52<31:46, 47.66s/batch, loss=0.276]

Total iteration 1969, validation loss = 0.5561



Epoch 10:  82%|████████▎ | 165/200 [1:18:10<27:40, 47.44s/batch, loss=0.308]

Total iteration 1974, validation loss = 0.5565



Epoch 10:  85%|████████▌ | 170/200 [1:20:28<23:35, 47.17s/batch, loss=0.323]

Total iteration 1979, validation loss = 0.5570



Epoch 10:  88%|████████▊ | 175/200 [1:22:50<20:03, 48.16s/batch, loss=0.227]

Total iteration 1984, validation loss = 0.5572



Epoch 10:  90%|█████████ | 180/200 [1:25:10<15:57, 47.88s/batch, loss=0.393]

Total iteration 1989, validation loss = 0.5548



Epoch 10:  92%|█████████▎| 185/200 [1:27:34<12:15, 49.01s/batch, loss=0.565]

Total iteration 1994, validation loss = 0.5572



Epoch 10:  95%|█████████▌| 190/200 [1:29:57<08:08, 48.81s/batch, loss=0.342]

Total iteration 1999, validation loss = 0.5581



Epoch 10:  98%|█████████▊| 195/200 [1:32:19<04:02, 48.57s/batch, loss=0.501]

Total iteration 2004, validation loss = 0.5571



Epoch 10: 100%|██████████| 200/200 [1:34:29<00:00, 28.35s/batch, loss=0.32] 

Total iteration 2009, validation loss = 0.5541






In [413]:
torch.save(model.state_dict(), '../runs/baseline/baseline_final_model.pt')

In [414]:
# Common errors and how to fix them:

# Error:
#   RuntimeError: running_mean should contain 1 elements not 8
# Fix: One of your batchnorm 3D parameter values is off

# Error:
#   RuntimeError: CUDA out of memory.
# Fix: Make the model / batch size smaller 
# First try to make batch size smaller. Will require longer training time possibly but does not decrease expressivity of model.
# If need to decrease complexity of model, 

# Error:
#   RuntimeError: Given groups=1, weight of size [1, 1, 1, 1, 1], expected input[8, 4, 5, 32, 32] to have 1 channels, but got 4 channels instead
# Fix: Wrong number of in_channels in self attention layer

# Experiment 1: 3D Self-Attention after 3D Conv Layers

In [9]:
# Make log directory and checkpoint directory (DIFFERENT DIRECTORY FROM BASELINE)
dir_nm = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
# dir_nm = "first_mini_c2fc2"
# log_dir = os.path.join('../runs/baseline', dir_nm) # running from this notebook since the other one gives cuda memory errors
log_dir = os.path.join('../runs/experiment_att', dir_nm)
os.mkdir(log_dir)
os.mkdir(os.path.join(log_dir, 'Checkpoints'))


# Model, optimizer, criterion
# model = baseline_3DCNN(in_num_ch=1)
model2 = selfattn_3DCNN(in_num_ch=1)
optimizer2 = optim.Adam(model2.parameters(), lr = 1e-4)
criterion2 = torch.nn.BCEWithLogitsLoss()

In [None]:
# Experimental model
train_loss_dict2, val_loss_dict2 = train(model2, optimizer2, criterion2, loader_train, loader_val, log_dir, device=device, epochs=10, val_every=5)

Epoch 1:   2%|▏         | 4/200 [00:28<16:18,  4.99s/batch, loss=0.669]

In [None]:
torch.save(model2.state_dict(), '../runs/experiment_att/experiment_final_model.pt')

In [None]:
# a = torch.randn(4, 40, 512, 512)
# print(a.size())
# transforms.Resize(size=(256, 256))(a).size()

# Experiment 2: Residual/Skip Connection after Self-Attention

In [None]:
# Make log directory and checkpoint directory (DIFFERENT DIRECTORY FROM BASELINE AND EXPERIMENT 1)
dir_nm = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
# dir_nm = "first_mini_c2fc2"
# log_dir = os.path.join('../runs/baseline', dir_nm) # running from this notebook since the other one gives cuda memory errors
log_dir = os.path.join('../runs/experiment_res', dir_nm)
os.mkdir(log_dir)
os.mkdir(os.path.join(log_dir, 'Checkpoints'))


# Model, optimizer, criterion
# model = baseline_3DCNN(in_num_ch=1)
model3 = resattn_3DCNN(in_num_ch=1)
optimizer3 = optim.Adam(model3.parameters(), lr = 1e-4)
criterion3 = torch.nn.BCEWithLogitsLoss()

In [None]:
# Experimental model
train_loss_dict3, val_loss_dict3 = train(model3, optimizer3, criterion3, loader_train, loader_val, log_dir, device=device, epochs=10, val_every=5)

In [None]:
torch.save(model3.state_dict(), '../runs/experiment_res/experimentres_final_model.pt')