In [14]:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler

from tqdm.notebook import tqdm
import h5py
from matplotlib import pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

In [2]:
# Seed function
def set_seed(seed_val):
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed(seed_val)
    
set_seed(42)

In [3]:
batch_size = 64
epochs = 100
lr = 1e-4
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Build Model

In [4]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=2, padding_mode='replicate')
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2, padding_mode='replicate')
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2, padding_mode='replicate')
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        
        return x

In [5]:
model = SRCNN()
model.to(device)
summary(model, (1, 33, 33))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 29, 29]           5,248
              ReLU-2           [-1, 64, 29, 29]               0
            Conv2d-3           [-1, 32, 33, 33]           2,080
              ReLU-4           [-1, 32, 33, 33]               0
            Conv2d-5            [-1, 1, 33, 33]             801
Total params: 8,129
Trainable params: 8,129
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.36
Params size (MB): 0.03
Estimated Total Size (MB): 1.40
----------------------------------------------------------------


# Load and build dataloader

In [6]:
with h5py.File('input/train_mscale.h5') as files:
    images = files['data'][:]
    labels = files['label'][:]

images = images.astype(np.float32)
labels = labels.astype(np.float32)

In [7]:
(x_train, x_val, y_train, y_val) = train_test_split(images, labels, test_size=0.2, random_state=2021)

In [8]:
class ResDataset(Dataset):
    def __init__(self, image_data, labels):
        super(ResDataset, self).__init__()
        self.image_data = image_data
        self.labels = labels
            
    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, index):
        image = self.image_data[index]
        label = self.labels[index]

        return {
            'image': torch.tensor(image, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.float32)
        }

In [10]:
train_dataset = ResDataset(x_train, y_train)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

val_dataset = ResDataset(x_val, y_val)
val_sampler = SequentialSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, sampler=val_sampler, batch_size=batch_size)

# Define training functions

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

In [34]:
for epoch in tqdm(range(epochs)):
    model.train()
    curr_loss = 0
    curr_psnr = 0
    
    for step, data in tqdm(enumerate(train_dataloader),  total=int(len(train_dataset)/train_dataloader.batch_size)):
        batch_image = data['image'].to(device)
        batch_label = data['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(batch_image)
        loss = criterion(logits, batch_label)
        
        loss.backward()
        optimizer.step()
        
        curr_loss += loss.item()
#         curr_psnr += psnr(batch_label, logits)
    
    final_loss = curr_loss / len(train_dataloader.dataset)
#     final_psnr = curr_psnr / int(len(train_dataset)/train_dataloader.batch_size)
    
    print('Loss: {} at epoch: {}'.format(final_loss, epoch))
#     print('PSNR: {} at epoch: {}'.format(final_psnr, epoch))
    
    # Validation data
    model.eval()
    cur_val_loss = 0
    for data in val_dataloader:
        batch_image = data['image'].to(device)
        batch_label = data['label'].to(device)

        with torch.no_grad():
            logits = model(batch_image)

        loss = criterion(logits, batch_label)
        cur_val_loss += loss.item()

    final_loss = cur_val_loss / len(val_dataloader.dataset)
    print('Validation Loss: {} at epoch: {}'.format(final_loss, epoch))

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.7121828863490806e-05 at epoch: 0
Loss: 2.7694796982252244e-05 at epoch: 0


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.703065868996046e-05 at epoch: 1
Loss: 2.7683172598993246e-05 at epoch: 1


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.703059808518011e-05 at epoch: 2
Loss: 2.7656705328876757e-05 at epoch: 2


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.7033520717206035e-05 at epoch: 3
Loss: 2.7844425208678837e-05 at epoch: 3


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.7018432627371825e-05 at epoch: 4
Loss: 2.7862734631779266e-05 at epoch: 4


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.694589510650506e-05 at epoch: 5
Loss: 2.7574904936174824e-05 at epoch: 5


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6953293363885858e-05 at epoch: 6
Loss: 2.7573831168185803e-05 at epoch: 6


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6926861184716542e-05 at epoch: 7
Loss: 2.753678399894519e-05 at epoch: 7


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.693828190185842e-05 at epoch: 8
Loss: 2.75529491931695e-05 at epoch: 8


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6879823873655646e-05 at epoch: 9
Loss: 2.7502969521837137e-05 at epoch: 9


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6870019121736833e-05 at epoch: 10
Loss: 2.7587567322715724e-05 at epoch: 10


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.679147648662576e-05 at epoch: 11
Loss: 2.7484391183343815e-05 at epoch: 11


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6811377452181267e-05 at epoch: 12
Loss: 2.745180078734135e-05 at epoch: 12


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6818447287529802e-05 at epoch: 13
Loss: 2.7437953551524967e-05 at epoch: 13


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6791815243943484e-05 at epoch: 14
Loss: 2.745662920949351e-05 at epoch: 14


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6781453533493164e-05 at epoch: 15
Loss: 2.7426012326565327e-05 at epoch: 15


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.682471297099035e-05 at epoch: 16
Loss: 2.7378622409412883e-05 at epoch: 16


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6686431152358914e-05 at epoch: 17
Loss: 2.7330946374603494e-05 at epoch: 17


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6654275308772124e-05 at epoch: 18
Loss: 2.7650564814035495e-05 at epoch: 18


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6711690003243645e-05 at epoch: 19
Loss: 2.742632554112192e-05 at epoch: 19


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6649817423434647e-05 at epoch: 20
Loss: 2.7451998333341067e-05 at epoch: 20


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.664392429367527e-05 at epoch: 21
Loss: 2.7319180879438933e-05 at epoch: 21


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6658362457418288e-05 at epoch: 22
Loss: 2.7417774762390746e-05 at epoch: 22


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6605315897998905e-05 at epoch: 23
Loss: 2.722644918136307e-05 at epoch: 23


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6610561348758185e-05 at epoch: 24
Loss: 2.718602716794743e-05 at epoch: 24


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6578808491287297e-05 at epoch: 25
Loss: 2.745162700446918e-05 at epoch: 25


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6532235314558633e-05 at epoch: 26
Loss: 2.7151711531546723e-05 at epoch: 26


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6567670057463014e-05 at epoch: 27
Loss: 2.716241662048085e-05 at epoch: 27


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.655301990701334e-05 at epoch: 28
Loss: 2.771502932510152e-05 at epoch: 28


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6532360784924108e-05 at epoch: 29
Loss: 2.7120072060698837e-05 at epoch: 29


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6461175456125675e-05 at epoch: 30
Loss: 2.7171057517739504e-05 at epoch: 30


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6497306970636082e-05 at epoch: 31
Loss: 2.7592830042027254e-05 at epoch: 31


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6480326596721114e-05 at epoch: 32
Loss: 2.713448982493779e-05 at epoch: 32


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.645216116350859e-05 at epoch: 33
Loss: 2.704305845800305e-05 at epoch: 33


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.64620358553043e-05 at epoch: 34
Loss: 2.7017284212431275e-05 at epoch: 34


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6382441181484448e-05 at epoch: 35
Loss: 2.7084828035442964e-05 at epoch: 35


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.636778442645397e-05 at epoch: 36
Loss: 2.713849483205076e-05 at epoch: 36


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6405518373631416e-05 at epoch: 37
Loss: 2.6969055091494957e-05 at epoch: 37


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6332997189183605e-05 at epoch: 38
Loss: 2.6973230040974633e-05 at epoch: 38


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6375751084527463e-05 at epoch: 39
Loss: 2.6922630688083814e-05 at epoch: 39


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6326075201760544e-05 at epoch: 40
Loss: 2.6953224768010815e-05 at epoch: 40


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6384722632404263e-05 at epoch: 41
Loss: 2.6998586231524727e-05 at epoch: 41


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6274803404224356e-05 at epoch: 42
Loss: 2.706630407744018e-05 at epoch: 42


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6249514701034467e-05 at epoch: 43
Loss: 2.6911252843939356e-05 at epoch: 43


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.627874347509051e-05 at epoch: 44
Loss: 2.6972220574781932e-05 at epoch: 44


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.630109178289934e-05 at epoch: 45
Loss: 2.6892096722735367e-05 at epoch: 45


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.622432312619015e-05 at epoch: 46
Loss: 2.682765077411241e-05 at epoch: 46


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.626214328298418e-05 at epoch: 47
Loss: 2.693889589547121e-05 at epoch: 47


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.620816271714992e-05 at epoch: 48
Loss: 2.7039781386696256e-05 at epoch: 48


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6339507952268056e-05 at epoch: 49
Loss: 2.679825762557751e-05 at epoch: 49


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6178322487602332e-05 at epoch: 50
Loss: 2.695202594356268e-05 at epoch: 50


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6202421879488765e-05 at epoch: 51
Loss: 2.67791468703443e-05 at epoch: 51


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.618569865414395e-05 at epoch: 52
Loss: 2.6795652882758216e-05 at epoch: 52


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6167150797496613e-05 at epoch: 53
Loss: 2.687834181907114e-05 at epoch: 53


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6164304689924606e-05 at epoch: 54
Loss: 2.676890464232891e-05 at epoch: 54


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6118600453987902e-05 at epoch: 55
Loss: 2.676941148236923e-05 at epoch: 55


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6098286383738633e-05 at epoch: 56
Loss: 2.716838807307124e-05 at epoch: 56


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.614369013291948e-05 at epoch: 57
Loss: 2.6727047693389728e-05 at epoch: 57


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6120838340012526e-05 at epoch: 58
Loss: 2.7015245944167672e-05 at epoch: 58


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6103883647614943e-05 at epoch: 59
Loss: 2.6703712728860246e-05 at epoch: 59


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.609867473042278e-05 at epoch: 60
Loss: 2.6662049992148682e-05 at epoch: 60


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.614296498261996e-05 at epoch: 61
Loss: 2.666135040674047e-05 at epoch: 61


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6028460021621448e-05 at epoch: 62
Loss: 2.692049275336943e-05 at epoch: 62


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6098654913346403e-05 at epoch: 63
Loss: 2.6800608414973114e-05 at epoch: 63


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6073447405492665e-05 at epoch: 64
Loss: 2.6728920953269336e-05 at epoch: 64


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.6128621586776677e-05 at epoch: 65
Loss: 2.6827476191134934e-05 at epoch: 65


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5953004869673336e-05 at epoch: 66
Loss: 2.665471686701048e-05 at epoch: 66


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5995153590712042e-05 at epoch: 67
Loss: 2.6600412759857353e-05 at epoch: 67


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.600426891241147e-05 at epoch: 68
Loss: 2.659769781586743e-05 at epoch: 68


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.60245704802993e-05 at epoch: 69
Loss: 2.6591435631679506e-05 at epoch: 69


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.604120954641481e-05 at epoch: 70
Loss: 2.6614042100329102e-05 at epoch: 70


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.599227691403282e-05 at epoch: 71
Loss: 2.6857018345947734e-05 at epoch: 71


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.594534684322743e-05 at epoch: 72
Loss: 2.793451989302145e-05 at epoch: 72


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.602406519152716e-05 at epoch: 73
Loss: 2.6673888910310035e-05 at epoch: 73


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5940537854907055e-05 at epoch: 74
Loss: 2.651879385770869e-05 at epoch: 74


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5941914488258513e-05 at epoch: 75
Loss: 2.655582054425351e-05 at epoch: 75


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5962192028274632e-05 at epoch: 76
Loss: 2.670432408932353e-05 at epoch: 76


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5898521912280834e-05 at epoch: 77
Loss: 2.6520940726932313e-05 at epoch: 77


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.592593771056919e-05 at epoch: 78
Loss: 2.6504455917308967e-05 at epoch: 78


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5935153577926318e-05 at epoch: 79
Loss: 2.6524783766061443e-05 at epoch: 79


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5909527544354892e-05 at epoch: 80
Loss: 2.6496913617973953e-05 at epoch: 80


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5994347458450004e-05 at epoch: 81
Loss: 2.653189019469666e-05 at epoch: 81


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.583198951566027e-05 at epoch: 82
Loss: 2.6705582601627395e-05 at epoch: 82


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.589810247305727e-05 at epoch: 83
Loss: 2.648120664405761e-05 at epoch: 83


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.586461190736854e-05 at epoch: 84
Loss: 2.665225390285117e-05 at epoch: 84


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5933182785683627e-05 at epoch: 85
Loss: 2.6651284335242992e-05 at epoch: 85


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5881852090356175e-05 at epoch: 86
Loss: 2.661279153573793e-05 at epoch: 86


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.589680808857429e-05 at epoch: 87
Loss: 2.6423999568167967e-05 at epoch: 87


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5867832222287392e-05 at epoch: 88
Loss: 2.6499915213024356e-05 at epoch: 88


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5810364751180772e-05 at epoch: 89
Loss: 2.6425480083023698e-05 at epoch: 89


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5796852148924214e-05 at epoch: 90
Loss: 2.640341411217757e-05 at epoch: 90


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5928071036844596e-05 at epoch: 91
Loss: 2.6434845315614654e-05 at epoch: 91


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5811122751018237e-05 at epoch: 92
Loss: 2.64136167083102e-05 at epoch: 92


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5852253553126028e-05 at epoch: 93
Loss: 2.6408194581351806e-05 at epoch: 93


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5720330918948918e-05 at epoch: 94
Loss: 2.6382158327962936e-05 at epoch: 94


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5818049439329196e-05 at epoch: 95
Loss: 2.637952443464037e-05 at epoch: 95


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.57633755858803e-05 at epoch: 96
Loss: 2.647910135363966e-05 at epoch: 96


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5791792716633385e-05 at epoch: 97
Loss: 2.6506915147643524e-05 at epoch: 97


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5816786807510317e-05 at epoch: 98
Loss: 2.6341431501096397e-05 at epoch: 98


HBox(children=(FloatProgress(value=0.0, max=272.0), HTML(value='')))


Loss: 2.5821615212831503e-05 at epoch: 99
Loss: 2.836267341713486e-05 at epoch: 99

