In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as T
import os

from DatasetCH import UpscaleDataset
#from models import *
import Network

In [2]:
# Make dirs
mdir="./Model_dif/Test_1"
rdir="./Results_dif/Test_1"
os.makedirs(mdir, exist_ok=True)
os.makedirs(rdir, exist_ok=True)

# Define the tensorboard writer
writer = SummaryWriter("./Runs_dif/Test_1") # was runs_unet

In [3]:
import sys
sys.path.append('/home/mpyrina/Notebooks/ANEMOI/ClimateDiffuse/src/')
from DatasetCH import *
from TrainDiffusion import *
#from TrainUnet import *

In [4]:
import Evaluation

### TRAIN DIFFUSION

In [7]:
batch_size = 16
learning_rate = 1e-5
num_epochs = 10
accum = 4

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# a tensor of shape [B, C, H, W] mean that c=8, image resol=(H, W) 

network = Network.EDMPrecond(
        img_resolution=(256, 128),
        in_channels=2,
        out_channels=1,
        label_dim=0
    ).to(device)

# define the datasets
ifs_dir = '/s2s/mpyrina/ECMWF_MCH/Europe_eval/s2s_hind_2022/all/'
obs_dir = '/net/cfc/s2s_nobackup/mpyrina/TABSD_ifs_like/'

dataset_train = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2002, year_end=2012, month=815,  
constant_variables=None, constant_variables_filename=None)

dataset_test = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2012, year_end=2015, month=815,  
constant_variables=None, constant_variables_filename=None)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

Test - new upscale
Loaded coarse data shape: (460, 11, 16, 32)
Loaded high-resolution data shape: (460, 128, 256)
Final coarse shape: torch.Size([5060, 1, 16, 32])
Final fine shape: torch.Size([5060, 1, 128, 256])
Input shape (should be [N, 1, H, W]): torch.Size([5060, 1, 128, 256])
Dataset ready.
Test - new upscale
Loaded coarse data shape: (138, 11, 16, 32)
Loaded high-resolution data shape: (138, 128, 256)
Final coarse shape: torch.Size([1518, 1, 16, 32])
Final fine shape: torch.Size([1518, 1, 128, 256])
Input shape (should be [N, 1, H, W]): torch.Size([1518, 1, 128, 256])
Dataset ready.


In [None]:
# Train
scaler = torch.cuda.amp.GradScaler()
optimiser = torch.optim.AdamW(network.parameters(), lr=learning_rate)
writer = SummaryWriter("./runs_dif")

loss_fn = EDMLoss()
losses = []


for step in range(num_epochs):
    # model_save
    model_save_path = f"{mdir}/dif_model_epoch_{step}.pt"
    # fig_save
    fig_save_path = f"{rdir}/{step}.png"
    # best modes
    mbest = f"{mdir}/best_dif_model_epoch_{step}.pt"

    epoch_loss = training_step(network, loss_fn, optimiser,
                                   dataloader_train, scaler, step,
                                   accum, writer, device=device)
    losses.append(epoch_loss)

    torch.save(network.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    if step % 5 == 0:
        (fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model_dif(network, dataloader_test, device=device)
        plt.show()
        fig.savefig(fig_save_path, dpi=300)
        plt.close(fig)
        writer.add_scalar("Error/base", base_error, step)
        writer.add_scalar("Error/pred", pred_error, step)

    if losses[-1] == min(losses):
        torch.save(network.state_dict(), mbest)


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   1%|          | 2/317 [01:08<3:29:27, 39.90s/it, Loss: 14.7726]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   1%|          | 3/317 [02:15<4:33:27, 52.25s/it, Loss: 17.7021]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   1%|▏         | 4/317 [03:23<5:05:09, 58.50s/it, Loss: 9.0066] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   2%|▏         | 5/317 [04:31<5:22:55, 62.10s/it, Loss: 13.6548]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   2%|▏         | 6/317 [05:39<5:32:29, 64.15s/it, Loss: 14.0677]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   2%|▏         | 7/317 [06:46<5:36:38, 65.16s/it, Loss: 13.4379]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   3%|▎         | 8/317 [07:57<5:44:10, 66.83s/it, Loss: 21.1464]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   3%|▎         | 9/317 [09:08<5:50:45, 68.33s/it, Loss: 10.9940]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   3%|▎         | 10/317 [10:18<5:51:06, 68.62s/it, Loss: 15.7088]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   3%|▎         | 11/317 [11:26<5:49:18, 68.49s/it, Loss: 11.9244]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   4%|▍         | 12/317 [12:36<5:49:53, 68.83s/it, Loss: 15.3637]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   4%|▍         | 13/317 [13:43<5:46:20, 68.36s/it, Loss: 13.0069]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   4%|▍         | 14/317 [14:50<5:43:44, 68.07s/it, Loss: 14.4735]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   5%|▍         | 15/317 [15:59<5:43:51, 68.32s/it, Loss: 19.2428]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   5%|▌         | 16/317 [17:08<5:44:14, 68.62s/it, Loss: 12.8029]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   5%|▌         | 17/317 [18:16<5:40:54, 68.18s/it, Loss: 20.0510]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   6%|▌         | 18/317 [19:21<5:34:56, 67.21s/it, Loss: 16.0737]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   6%|▌         | 19/317 [20:30<5:37:08, 67.88s/it, Loss: 12.1173]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   6%|▋         | 20/317 [21:39<5:37:03, 68.09s/it, Loss: 12.3336]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   7%|▋         | 21/317 [22:49<5:39:21, 68.79s/it, Loss: 14.0126]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   7%|▋         | 22/317 [23:57<5:37:40, 68.68s/it, Loss: 9.6742] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   7%|▋         | 23/317 [25:05<5:35:19, 68.43s/it, Loss: 21.9942]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   8%|▊         | 24/317 [26:12<5:31:12, 67.82s/it, Loss: 22.6765]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   8%|▊         | 25/317 [27:25<5:37:53, 69.43s/it, Loss: 19.1685]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   8%|▊         | 26/317 [28:33<5:34:44, 69.02s/it, Loss: 13.0740]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   9%|▊         | 27/317 [29:40<5:31:00, 68.48s/it, Loss: 18.6605]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   9%|▉         | 28/317 [30:45<5:24:13, 67.31s/it, Loss: 15.8955]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   9%|▉         | 29/317 [31:54<5:26:37, 68.05s/it, Loss: 10.2418]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   9%|▉         | 30/317 [33:03<5:26:03, 68.17s/it, Loss: 18.0066]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  10%|▉         | 31/317 [34:10<5:22:57, 67.75s/it, Loss: 7.5847] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  10%|█         | 32/317 [35:17<5:21:47, 67.75s/it, Loss: 23.9806]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  10%|█         | 33/317 [36:24<5:18:30, 67.29s/it, Loss: 21.3802]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  11%|█         | 34/317 [37:31<5:17:09, 67.24s/it, Loss: 23.5859]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  11%|█         | 35/317 [38:40<5:18:19, 67.73s/it, Loss: 20.4009]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  11%|█▏        | 36/317 [39:48<5:17:27, 67.78s/it, Loss: 17.9800]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  12%|█▏        | 37/317 [40:56<5:17:16, 67.99s/it, Loss: 25.8669]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  12%|█▏        | 38/317 [42:00<5:10:47, 66.84s/it, Loss: 22.4715]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  12%|█▏        | 39/317 [43:08<5:10:18, 66.97s/it, Loss: 4.4828] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  13%|█▎        | 40/317 [44:15<5:10:17, 67.21s/it, Loss: 17.3503]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  13%|█▎        | 41/317 [45:25<5:12:27, 67.93s/it, Loss: 12.3058]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  13%|█▎        | 42/317 [46:32<5:10:41, 67.79s/it, Loss: 14.6911]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  14%|█▎        | 43/317 [47:37<5:04:44, 66.73s/it, Loss: 7.0037] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  14%|█▍        | 44/317 [48:42<5:02:00, 66.38s/it, Loss: 19.7555]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  14%|█▍        | 45/317 [49:51<5:03:40, 66.99s/it, Loss: 7.1802] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  15%|█▍        | 46/317 [50:57<5:02:04, 66.88s/it, Loss: 13.7736]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  15%|█▍        | 47/317 [52:03<5:00:07, 66.70s/it, Loss: 9.1068] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  15%|█▌        | 48/317 [53:09<4:57:15, 66.30s/it, Loss: 17.1751]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  15%|█▌        | 49/317 [54:17<4:58:01, 66.72s/it, Loss: 11.3448]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  16%|█▌        | 50/317 [55:25<4:58:58, 67.18s/it, Loss: 17.1650]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  16%|█▌        | 51/317 [56:31<4:56:47, 66.94s/it, Loss: 10.0578]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  16%|█▋        | 52/317 [57:37<4:54:45, 66.74s/it, Loss: 19.0407]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  17%|█▋        | 53/317 [58:44<4:52:53, 66.57s/it, Loss: 9.1244] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  17%|█▋        | 54/317 [59:48<4:48:45, 65.88s/it, Loss: 13.5276]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  17%|█▋        | 55/317 [1:00:57<4:51:26, 66.74s/it, Loss: 12.0854]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  18%|█▊        | 56/317 [1:02:03<4:50:10, 66.71s/it, Loss: 18.0892]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  18%|█▊        | 57/317 [1:03:10<4:48:53, 66.67s/it, Loss: 9.2298] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  18%|█▊        | 58/317 [1:04:17<4:48:21, 66.80s/it, Loss: 11.2091]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  19%|█▊        | 59/317 [1:05:25<4:49:08, 67.24s/it, Loss: 13.4422]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  19%|█▉        | 60/317 [1:06:32<4:47:10, 67.04s/it, Loss: 12.4556]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  19%|█▉        | 61/317 [1:07:41<4:48:25, 67.60s/it, Loss: 20.6386]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  20%|█▉        | 62/317 [1:08:50<4:49:52, 68.21s/it, Loss: 24.3729]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  20%|█▉        | 63/317 [1:09:58<4:47:48, 67.99s/it, Loss: 15.7266]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  20%|██        | 64/317 [1:11:03<4:42:40, 67.04s/it, Loss: 6.4368] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  21%|██        | 65/317 [1:12:11<4:42:50, 67.34s/it, Loss: 11.6163]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  21%|██        | 66/317 [1:13:18<4:42:10, 67.45s/it, Loss: 19.2666]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  21%|██        | 67/317 [1:14:27<4:42:29, 67.80s/it, Loss: 25.8656]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  21%|██▏       | 68/317 [1:15:33<4:38:46, 67.17s/it, Loss: 22.2962]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  22%|██▏       | 69/317 [1:16:39<4:36:47, 66.96s/it, Loss: 9.8405] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  22%|██▏       | 70/317 [1:17:43<4:32:21, 66.16s/it, Loss: 17.8335]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  22%|██▏       | 71/317 [1:18:52<4:34:20, 66.91s/it, Loss: 19.3842]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  23%|██▎       | 72/317 [1:19:59<4:32:58, 66.85s/it, Loss: 14.5910]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  23%|██▎       | 73/317 [1:21:06<4:32:44, 67.07s/it, Loss: 11.8605]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  23%|██▎       | 74/317 [1:22:14<4:32:30, 67.28s/it, Loss: 8.5332] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  24%|██▎       | 75/317 [1:23:21<4:31:19, 67.27s/it, Loss: 9.3108]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  24%|██▍       | 76/317 [1:24:26<4:26:49, 66.43s/it, Loss: 10.1472]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  24%|██▍       | 77/317 [1:25:31<4:24:21, 66.09s/it, Loss: 26.5868]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  25%|██▍       | 78/317 [1:26:40<4:26:34, 66.92s/it, Loss: 17.7342]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  25%|██▍       | 79/317 [1:27:50<4:28:49, 67.77s/it, Loss: 20.3350]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  25%|██▌       | 80/317 [1:28:55<4:25:07, 67.12s/it, Loss: 24.3841]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  26%|██▌       | 81/317 [1:30:03<4:23:59, 67.12s/it, Loss: 20.7659]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  26%|██▌       | 82/317 [1:31:10<4:23:30, 67.28s/it, Loss: 19.7803]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  26%|██▌       | 83/317 [1:32:15<4:19:54, 66.64s/it, Loss: 11.4795]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  26%|██▋       | 84/317 [1:33:24<4:21:13, 67.27s/it, Loss: 8.9900] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  27%|██▋       | 85/317 [1:34:30<4:18:11, 66.78s/it, Loss: 8.8945]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  27%|██▋       | 86/317 [1:35:36<4:16:46, 66.69s/it, Loss: 20.7767]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  27%|██▋       | 87/317 [1:36:45<4:18:08, 67.34s/it, Loss: 17.9931]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  28%|██▊       | 88/317 [1:37:51<4:15:44, 67.01s/it, Loss: 10.5584]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  28%|██▊       | 89/317 [1:38:57<4:13:17, 66.66s/it, Loss: 12.0679]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  28%|██▊       | 90/317 [1:40:02<4:10:42, 66.27s/it, Loss: 15.4826]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  29%|██▊       | 91/317 [1:41:10<4:11:15, 66.71s/it, Loss: 16.7835]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  29%|██▉       | 92/317 [1:42:21<4:14:21, 67.83s/it, Loss: 13.9905]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  29%|██▉       | 93/317 [1:43:30<4:15:07, 68.34s/it, Loss: 16.7528]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  30%|██▉       | 94/317 [1:44:38<4:13:38, 68.25s/it, Loss: 25.4948]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  30%|██▉       | 95/317 [1:45:44<4:09:52, 67.53s/it, Loss: 18.6702]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  30%|███       | 96/317 [1:46:50<4:07:28, 67.19s/it, Loss: 15.7103]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  31%|███       | 97/317 [1:48:05<4:14:39, 69.45s/it, Loss: 23.7323]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  31%|███       | 98/317 [1:49:32<4:31:59, 74.52s/it, Loss: 15.9470]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  31%|███       | 99/317 [1:51:01<4:46:56, 78.98s/it, Loss: 12.0914]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  32%|███▏      | 100/317 [1:52:27<4:52:55, 80.99s/it, Loss: 12.0065]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  32%|███▏      | 101/317 [1:53:51<4:55:15, 82.02s/it, Loss: 10.6123]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  32%|███▏      | 102/317 [1:55:17<4:57:44, 83.09s/it, Loss: 16.1263]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  32%|███▏      | 103/317 [1:56:45<5:02:15, 84.75s/it, Loss: 15.3483]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  33%|███▎      | 104/317 [1:58:13<5:04:13, 85.70s/it, Loss: 10.3199]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  33%|███▎      | 105/317 [1:59:42<5:05:57, 86.59s/it, Loss: 16.0138]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  33%|███▎      | 106/317 [2:01:07<5:03:25, 86.28s/it, Loss: 11.2391]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  34%|███▍      | 107/317 [2:02:36<5:04:46, 87.08s/it, Loss: 17.4685]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  34%|███▍      | 108/317 [2:04:03<5:02:48, 86.93s/it, Loss: 15.8208]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  34%|███▍      | 109/317 [2:05:32<5:03:15, 87.48s/it, Loss: 22.1653]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  35%|███▍      | 110/317 [2:06:57<4:59:37, 86.85s/it, Loss: 10.6398]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  35%|███▌      | 111/317 [2:08:25<4:59:24, 87.20s/it, Loss: 30.0802]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  35%|███▌      | 112/317 [2:09:55<5:00:25, 87.93s/it, Loss: 15.5429]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  36%|███▌      | 113/317 [2:11:26<5:02:10, 88.88s/it, Loss: 22.0721]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  36%|███▌      | 114/317 [2:12:50<4:55:37, 87.38s/it, Loss: 14.9921]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  36%|███▋      | 115/317 [2:14:11<4:48:11, 85.60s/it, Loss: 13.5146]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  37%|███▋      | 116/317 [2:15:35<4:44:50, 85.03s/it, Loss: 25.2906]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  37%|███▋      | 117/317 [2:17:06<4:49:46, 86.93s/it, Loss: 11.1349]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  37%|███▋      | 118/317 [2:18:31<4:46:03, 86.25s/it, Loss: 14.7020]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  38%|███▊      | 119/317 [2:19:56<4:43:41, 85.96s/it, Loss: 19.1101]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  38%|███▊      | 120/317 [2:21:20<4:40:03, 85.30s/it, Loss: 20.2528]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  38%|███▊      | 121/317 [2:22:50<4:43:18, 86.73s/it, Loss: 14.6734]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  38%|███▊      | 122/317 [2:24:16<4:41:38, 86.66s/it, Loss: 9.0016] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  39%|███▉      | 123/317 [2:25:43<4:40:12, 86.66s/it, Loss: 7.3073]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  39%|███▉      | 124/317 [2:27:09<4:38:16, 86.51s/it, Loss: 14.9673]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  39%|███▉      | 125/317 [2:28:35<4:36:30, 86.41s/it, Loss: 10.7241]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  40%|███▉      | 126/317 [2:30:04<4:36:39, 86.91s/it, Loss: 10.4752]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  40%|████      | 127/317 [2:31:29<4:34:16, 86.61s/it, Loss: 16.5180]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  40%|████      | 128/317 [2:32:56<4:32:35, 86.54s/it, Loss: 9.7697] 

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  41%|████      | 129/317 [2:34:22<4:31:00, 86.49s/it, Loss: 6.9371]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  41%|████      | 130/317 [2:35:49<4:30:02, 86.64s/it, Loss: 14.4831]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  41%|████▏     | 131/317 [2:37:17<4:29:51, 87.05s/it, Loss: 13.8962]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  42%|████▏     | 132/317 [2:38:46<4:29:59, 87.56s/it, Loss: 16.9237]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  42%|████▏     | 133/317 [2:40:12<4:27:08, 87.11s/it, Loss: 12.7106]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:  42%|████▏     | 134/317 [2:41:36<4:23:09, 86.28s/it, Loss: 13.7865]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])


In [None]:
(fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model_dif(network, dataloader_test, device=device)
plt.show()

In [12]:
model_save_path = f"./Model_dif/dif_model_epoch_{step}.pt"
torch.save(network.state_dict(), model_save_path)
(f"Model saved to {model_save_path}")

'Model saved to ./Model_dif/dif_model_epoch_0.pt'

In [17]:
plt.show()

### unet only

In [None]:
# define the datasets
ifs_dir = '/s2s/mpyrina/ECMWF_MCH/Europe_eval/s2s_hind_2022/all/'
obs_dir = '/net/cfc/s2s_nobackup/mpyrina/TABSD_ifs_like/'

# Run training for small number of epochs 
num_epochs = 1
## Select hyperparameters of training
batch_size = 8
learning_rate = 1e-5
accum = 8

dataset_train = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2005, year_end=2008, month=815,  
constant_variables=None, constant_variables_filename=None)

dataset_test = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2009, year_end=2010, month=815,  
constant_variables=None, constant_variables_filename=None)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

# Define device
device =  'cuda' if torch.cuda.is_available() else 'cpu'

# define the ml model : 1, 1, : 1 input var, one output
unet_model = UNet((256, 128), 1, 1, label_dim=0, use_diffuse=True)
unet_model.to(device)


In [None]:
#
scaler = torch.cuda.amp.GradScaler()

# define the optimiser
optimiser = torch.optim.AdamW(unet_model.parameters(), lr=learning_rate)

# Define the tensorboard writer
writer = SummaryWriter("./runs_unet")

loss_fn = torch.nn.MSELoss()

# train the model
losses = []
for step in range(num_epochs):
    epoch_loss = train_step(
        unet_model, loss_fn, dataloader_train, optimiser,
        scaler, step, accum, writer, device=device)
    losses.append(epoch_loss)

    # Save the model weights
    model_save_path = f"./Model/dif_model_epoch_{step}.pt"
    torch.save(unet_model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    (fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model(
        unet_model, dataloader_test, device=device)
    plt.show()
    fig.savefig(f"./results_unet/{step}.png", dpi=300)
    plt.close(fig)


    writer.add_scalar("Error/base", base_error, step)
    writer.add_scalar("Error/pred", pred_error, step)

    # save the model
    if losses[-1] == min(losses):
        torch.save(unet_model.state_dict(), f"./Model/Models_dif/best_unet_model_epoch_{step}.pt")
