In [1]:
!pip install segmentation_models_pytorch 
!pip install neptune
!pip install torchmetrics
!pip install albumentations 

[0m

In [2]:
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
import numpy as np
import torch.nn as nn
import time
import copy
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import segmentation_models_pytorch as smp
import neptune
from neptune.types import File

from torchmetrics.functional.classification import dice as calc_dice_score
from torchmetrics.classification import BinaryJaccardIndex

import albumentations as A
import cv2

In [3]:
# Define a custom dataset class
# Training params
BATCH_SIZE = 13
EPOCHS = 110
LEARNING_RATE = 0.0000367

# Model params
ENCODER_NAME = "resnet34"
ENCODER_WEIGHTS = "imagenet"

with open("/notebooks/Testing/NEPTUNE_API_TOKEN.txt", "r") as file:
    # Read the entire content of the file into a string
    token = file.read()

run = neptune.init_run(
    project="Kernel-bois/computer-vision",
    api_token=token,
)
run_id = run["sys/id"].fetch()

# Create the model
model = smp.Unet(
    encoder_name=ENCODER_NAME,           # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=ENCODER_WEIGHTS,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                        # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                            # model output channels (number of classes in your dataset)
    )


# Define optimizer
optimizer = optim.RMSprop(model.parameters(), lr=LEARNING_RATE)
#Scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

run_name = "MODEL-" + model.__class__.__name__ + ENCODER_NAME + str(run_id)

save_path = str(run_id) + "/"
os.makedirs(save_path)

# Proper directories
TRAIN_DATA_DIR = '/notebooks/image_segmentation/network/image_data_all/train'
VAL_DATA_DIR = '/notebooks/image_segmentation/network/image_data_all/val'

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE



# Define loss function
# criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=False)  # Binary dice Loss for binary segmentation
# criterion = smp.losses.SoftBCEWithLogitsLoss()  # Binary dice Loss for binary segmentation
criterion = DiceBCELoss()  # Binary dice Loss for binary segmentation
calc_iou = BinaryJaccardIndex().to(device)


params = {
    "MODEL": model.__class__.__name__,
    "BACKBONE": ENCODER_NAME,
    "ENCODER_WEIGHTS": ENCODER_WEIGHTS,
    "BATCH_SIZE": str(BATCH_SIZE),
    "EPOCHS": str(EPOCHS),
    "CRITERION": criterion.__class__.__name__,
    "OPTIMIZER": optimizer.__class__.__name__,
    "LEARNRATE": str(LEARNING_RATE),
    "MODEL_NAME": run_name,
}

run["params"] = params



  run = neptune.init_run(


https://app.neptune.ai/Kernel-bois/computer-vision/e/CV-151


In [4]:
class SegmentationDataset(Dataset):
    def __init__(self, root_dir, transform = None, target_size = (992, 416)):
        self.root_dir = root_dir
        self.transform = transform
        self.target_size = target_size

        self.image_folder = os.path.join(root_dir, 'images')
        self.mask_folder = os.path.join(root_dir, 'masks')

        self.images = os.listdir(self.image_folder)
        self.masks = os.listdir(self.mask_folder)

        assert len(self.images) == len(self.masks), "Number of images and masks should be the same."

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.images[idx])
        if "patient" in self.images[idx]:
            mask_path = os.path.join(self.mask_folder, "segmentation_" + self.images[idx][-7:])
        else:
            mask_path = os.path.join(self.mask_folder, "target_seg_" + self.images[idx][-7:])
        
        # Load images
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        # Convert to tensors 
        tensor_image = torch.from_numpy(image)
        tensor_image = tensor_image.permute(2, 0, 1)

        tensor_mask = torch.from_numpy(mask)
        tensor_mask = tensor_mask.permute(2, 0, 1) / 255
        tensor_mask = tensor_mask[2:, :, :]
        
        # add padding
        pad_height = max(self.target_size[0] - tensor_image.size(1), 0)
        pad_width = max(self.target_size[1] - tensor_image.size(2), 0)

        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_image = transforms.functional.pad(tensor_image, (pad_left, pad_bottom, pad_right, pad_top), fill=255)
        padded_mask = transforms.functional.pad(tensor_mask, (pad_left, pad_bottom, pad_right, pad_top), fill=0)

        return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)


In [5]:
model = model.to(device)

# Set up dataset and dataloader
transform = A.Compose([])

#transform = A.Compose([A.GaussNoise(p = 0.2), A.RandomGamma(p = 0.2), A.Sharpen(p=0.2),
#                       A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 0.2)])
""" 


"""
# ALL Transforms
# A.CLAHE(p=0.2)
# A.HorizontalFlip(p=0.5),
# A.RandomGamma(p=0.2)
# A.GridDistortion(p=0.2)
# A.RandomBrightnessContrast(p=0.2)
# A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 0.2)
# A.OneOf([ ],p=0.9) for more 
# A.Sharpen(p=0.2)
# A.Blur(p=0.2)
# A.RandomCrop(height = 200, width=200, p=0.2)
# GaussNoise(p = 0.2)

trainDataset = SegmentationDataset(root_dir=TRAIN_DATA_DIR, transform=transform)
valDataset = SegmentationDataset(root_dir=VAL_DATA_DIR)

train_loader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(valDataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)


In [None]:
# Train Loop
for epoch in range(EPOCHS):
    train_loss = torch.tensor(0.0)
    model.train()

    # Use tqdm to add a progress bar
    for images, masks in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{EPOCHS}', leave=False):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        # outputs = torch.argmax(outputs, dim=1).unsqueeze(1).float()

        loss = criterion(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.detach().cpu()
        optimizer.step()

    train_loss /= len(train_loader)
    run["loss/train_loss"].log(train_loss)

    # Validation
    model.eval()

    val_loss = torch.tensor(0.0)
    iou = torch.tensor(0.0)
    dice_score = torch.tensor(0.0)

    with torch.no_grad():
        for val_images, val_masks in tqdm(val_loader, desc=f'Validation', leave=False):
            val_images, val_masks = val_images.to(device), val_masks.to(device)

            model_outputs = model(val_images)

            val_loss += criterion(model_outputs, val_masks).cpu()

            val_masks_int = torch.tensor(val_masks, dtype=torch.int8)
            dice_score += calc_dice_score(torch.sigmoid(model_outputs), val_masks_int, ignore_index=0).cpu()

            iou += calc_iou(model_outputs, val_masks_int).cpu()


    val_loss /= len(val_loader)
    iou /= len(val_loader)
    dice_score /= len(val_loader)

    if torch.isnan(iou):
        iou = torch.tensor(0.0)

    run["loss/val_loss"].log(val_loss)
    run["val/iou"].log(iou)
    run["val/dice_score"].log(dice_score)

    torch.save(model.state_dict(), save_path + run_name + "_EPOCH_" + str(epoch) + '.pth')

    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train Loss: {train_loss}, Validation Loss: {val_loss}\n"
          f"IOU: {iou}, Dice Score: {dice_score}")

# Save the trained model
torch.save(model.state_dict(), save_path + run_name + "_FINAL" + '.pth')
run[f"network/network_weights"].upload(File(run_name + '.pth'))

run.stop()


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  val_masks_int = torch.tensor(val_masks, dtype=torch.int8)
                                                          

Epoch [1/110], Train Loss: 1.8247190713882446, Validation Loss: 1.7104558944702148
IOU: 0.01065882109105587, Dice Score: 0.02095317840576172


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [2/110], Train Loss: 1.6140100955963135, Validation Loss: 1.5243184566497803
IOU: 0.020627550780773163, Dice Score: 0.03998257964849472


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [3/110], Train Loss: 1.5249232053756714, Validation Loss: 1.5278584957122803
IOU: 0.06291592121124268, Dice Score: 0.11481727659702301


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [4/110], Train Loss: 1.4748064279556274, Validation Loss: 1.4663931131362915
IOU: 0.10031600296497345, Dice Score: 0.1770683079957962


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [5/110], Train Loss: 1.4168775081634521, Validation Loss: 1.4175342321395874
IOU: 0.12665539979934692, Dice Score: 0.21781054139137268


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [6/110], Train Loss: 1.3917344808578491, Validation Loss: 1.3803331851959229
IOU: 0.18144741654396057, Dice Score: 0.29516810178756714


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [7/110], Train Loss: 1.3679026365280151, Validation Loss: 1.3558037281036377
IOU: 0.20357978343963623, Dice Score: 0.3252891004085541


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [8/110], Train Loss: 1.345407485961914, Validation Loss: 1.3571960926055908
IOU: 0.1812358796596527, Dice Score: 0.2951108515262604


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [9/110], Train Loss: 1.3215934038162231, Validation Loss: 1.3237907886505127
IOU: 0.17658288776874542, Dice Score: 0.2894582450389862


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [10/110], Train Loss: 1.3005708456039429, Validation Loss: 1.4328480958938599
IOU: 0.2024664580821991, Dice Score: 0.3116248846054077


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [11/110], Train Loss: 1.287680983543396, Validation Loss: 1.2779290676116943
IOU: 0.24960562586784363, Dice Score: 0.3782554566860199


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [12/110], Train Loss: 1.2711519002914429, Validation Loss: 1.2647461891174316
IOU: 0.3075675368309021, Dice Score: 0.4430972635746002


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [13/110], Train Loss: 1.2578096389770508, Validation Loss: 1.252870798110962
IOU: 0.31852924823760986, Dice Score: 0.45545274019241333


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [14/110], Train Loss: 1.2405598163604736, Validation Loss: 1.2370259761810303
IOU: 0.3917141854763031, Dice Score: 0.5323654413223267


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [15/110], Train Loss: 1.2232125997543335, Validation Loss: 1.245605230331421
IOU: 0.35608261823654175, Dice Score: 0.4933839738368988


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [16/110], Train Loss: 1.2091789245605469, Validation Loss: 1.207174301147461
IOU: 0.38523241877555847, Dice Score: 0.524076521396637


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [17/110], Train Loss: 1.2031294107437134, Validation Loss: 1.4368808269500732
IOU: 0.13454392552375793, Dice Score: 0.22971530258655548


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [18/110], Train Loss: 1.1892164945602417, Validation Loss: 1.1883124113082886
IOU: 0.3551744520664215, Dice Score: 0.49328580498695374


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [19/110], Train Loss: 1.1733943223953247, Validation Loss: 1.2414228916168213
IOU: 0.060035042464733124, Dice Score: 0.11144056171178818


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [20/110], Train Loss: 1.1680433750152588, Validation Loss: 1.1720243692398071
IOU: 0.2918093800544739, Dice Score: 0.42608633637428284


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [21/110], Train Loss: 1.1535311937332153, Validation Loss: 1.1602048873901367
IOU: 0.375556617975235, Dice Score: 0.5147019624710083


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [22/110], Train Loss: 1.1413488388061523, Validation Loss: 1.1476166248321533
IOU: 0.4241347312927246, Dice Score: 0.5607422590255737


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [23/110], Train Loss: 1.159265160560608, Validation Loss: 1.246962308883667
IOU: 0.13673552870750427, Dice Score: 0.2323029488325119


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [24/110], Train Loss: 1.1286687850952148, Validation Loss: 1.130663514137268
IOU: 0.39009517431259155, Dice Score: 0.527160108089447


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [25/110], Train Loss: 1.115174651145935, Validation Loss: 1.1244146823883057
IOU: 0.2980736196041107, Dice Score: 0.4368125796318054


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [26/110], Train Loss: 1.1068310737609863, Validation Loss: 1.1153748035430908
IOU: 0.32781094312667847, Dice Score: 0.4676799774169922


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [27/110], Train Loss: 1.0984923839569092, Validation Loss: 1.1077708005905151
IOU: 0.4008568823337555, Dice Score: 0.539330780506134


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [28/110], Train Loss: 1.0895898342132568, Validation Loss: 1.0999451875686646
IOU: 0.3959546387195587, Dice Score: 0.5321620106697083


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [29/110], Train Loss: 1.0808212757110596, Validation Loss: 1.0923852920532227
IOU: 0.41688767075538635, Dice Score: 0.5533701181411743


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [30/110], Train Loss: 1.0742923021316528, Validation Loss: 1.0865488052368164
IOU: 0.38199684023857117, Dice Score: 0.5224844217300415


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [31/110], Train Loss: 1.065133810043335, Validation Loss: 1.0805078744888306
IOU: 0.4304068982601166, Dice Score: 0.5634545087814331


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [32/110], Train Loss: 1.055924654006958, Validation Loss: 1.072073221206665
IOU: 0.39567962288856506, Dice Score: 0.5327767133712769


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [33/110], Train Loss: 1.0493476390838623, Validation Loss: 1.069801688194275
IOU: 0.4019598066806793, Dice Score: 0.539448082447052


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [34/110], Train Loss: 1.042913794517517, Validation Loss: 1.066338300704956
IOU: 0.0, Dice Score: 0.5250110030174255


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [35/110], Train Loss: 1.0356889963150024, Validation Loss: 1.0545575618743896
IOU: 0.0, Dice Score: 0.5510304570198059


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [36/110], Train Loss: 1.0297703742980957, Validation Loss: 1.2611525058746338
IOU: 0.0, Dice Score: 0.33567577600479126


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [37/110], Train Loss: 1.0263772010803223, Validation Loss: 1.0448919534683228
IOU: 0.41463595628738403, Dice Score: 0.5536563992500305


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [38/110], Train Loss: 1.0139355659484863, Validation Loss: 1.035209059715271
IOU: 0.3800840973854065, Dice Score: 0.5212694406509399


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [39/110], Train Loss: 1.003511667251587, Validation Loss: 1.0274012088775635
IOU: 0.39761000871658325, Dice Score: 0.5369656682014465


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [40/110], Train Loss: 0.9978570342063904, Validation Loss: 1.0214585065841675
IOU: 0.45416775345802307, Dice Score: 0.5872234106063843


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [41/110], Train Loss: 0.9909974932670593, Validation Loss: 1.0141972303390503
IOU: 0.43080735206604004, Dice Score: 0.5674703121185303


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [42/110], Train Loss: 0.9822422862052917, Validation Loss: 1.00824773311615
IOU: 0.4160084128379822, Dice Score: 0.5549716353416443


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                          

Epoch [43/110], Train Loss: 0.9704117774963379, Validation Loss: 1.0015987157821655
IOU: 0.0, Dice Score: 0.5896446704864502


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
Epoch 44/110:  83%|████████▎ | 35/42 [00:21<00:04,  1.72it/s]

In [None]:
"""
# Declare an augmentation pipeline
transform = A.Compose([
    A.RandomCrop(height = 200, width=200, p=1)
])

image_path = "/notebooks/image_segmentation/network/image_data/train/images/patient_116.png"
mask_path = "/notebooks/image_segmentation/network/image_data/train/masks/segmentation_116.png" 

image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path)

transformed = transform(image=image, mask=mask)
transformed_image = transformed['image']
transformed_mask = transformed['mask']

image_torch, mask_torch = convert_to_torch(image, mask)
plot_image_and_mask(image_torch, mask_torch)
    
image_torch, mask_torch = convert_to_torch(transformed_image, transformed_mask)
plot_image_and_mask(image_torch, mask_torch)
"""

In [7]:
def plot_image_and_mask(image, mask): 
    image = image/ 255.0
    mask = mask

    # Plot side by side with the mask and mask overlain
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot the original image
    axes[0].imshow(image.permute(1, 2, 0))
    axes[0].set_title('Original Image')

    # Plot the mask
    axes[1].imshow(mask.permute(1, 2, 0), cmap='viridis')
    axes[1].set_title('Mask')

    # Overlay the mask on the image
    axes[2].imshow(image.permute(1, 2, 0))
    axes[2].imshow(mask.permute(1, 2, 0), cmap='viridis', alpha=0.6)  # Set alpha to less than 1
    axes[2].set_title('Mask Overlain on Image')

    # Display the plots
    plt.show()

def convert_to_torch(image, mask):
    
    tensor_image = torch.from_numpy(image)
    tensor_image = tensor_image.permute(2,0, 1)

    tensor_mask = torch.from_numpy(mask)
    tensor_mask = tensor_mask.permute(2,0, 1) / 255
    
    tensor_mask = tensor_mask[2:, :, :]

    # add padding
    pad_height = max(992 - tensor_image.size(1), 0)
    pad_width = max(416 - tensor_image.size(2), 0)

    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left

    padded_image = transforms.functional.pad(tensor_image, (pad_left, pad_bottom, pad_right, pad_top), fill=255)
    padded_mask = transforms.functional.pad(tensor_mask, (pad_left, pad_bottom, pad_right, pad_top), fill=0)


    return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)

In [8]:
image_path = "/notebooks/image_segmentation/network/image_data/train/images/patient_116.png"
mask_path = "/notebooks/image_segmentation/network/image_data/train/masks/segmentation_116.png" 
model_path = "/notebooks/Testing/CV-131/MODEL-UnetPlusPlusresnet34CV-131_FINAL.pth" 

model = smp.UnetPlusPlus(
    encoder_name=ENCODER_NAME,           # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    in_channels=3,                        # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                            # model output channels (number of classes in your dataset)
)
model.load_state_dict(torch.load(model_path))
model.eval()

def predict_image(model, image, mask): 

    if type(image) == str: 
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
    
        image, mask = convert_to_torch(image, mask)
    
    out = model(image.unsqueeze(0)) 

    segment_map = torch.sigmoid(out).squeeze(0).detach()
    segment_map = (segment_map > 0.5) * 1
    
    plot_image_and_mask(image, mask)
    plot_image_and_mask(image, segment_map)
    
#predict_image(model, image_path, mask_path)

RuntimeError: Error(s) in loading state_dict for UnetPlusPlus:
	Missing key(s) in state_dict: "encoder.stem.conv.weight", "encoder.stem.bn.weight", "encoder.stem.bn.bias", "encoder.stem.bn.running_mean", "encoder.stem.bn.running_var", "encoder.s1.b1.conv1.conv.weight", "encoder.s1.b1.conv1.bn.weight", "encoder.s1.b1.conv1.bn.bias", "encoder.s1.b1.conv1.bn.running_mean", "encoder.s1.b1.conv1.bn.running_var", "encoder.s1.b1.conv2.conv.weight", "encoder.s1.b1.conv2.bn.weight", "encoder.s1.b1.conv2.bn.bias", "encoder.s1.b1.conv2.bn.running_mean", "encoder.s1.b1.conv2.bn.running_var", "encoder.s1.b1.conv3.conv.weight", "encoder.s1.b1.conv3.bn.weight", "encoder.s1.b1.conv3.bn.bias", "encoder.s1.b1.conv3.bn.running_mean", "encoder.s1.b1.conv3.bn.running_var", "encoder.s1.b1.downsample.conv.weight", "encoder.s1.b1.downsample.bn.weight", "encoder.s1.b1.downsample.bn.bias", "encoder.s1.b1.downsample.bn.running_mean", "encoder.s1.b1.downsample.bn.running_var", "encoder.s1.b2.conv1.conv.weight", "encoder.s1.b2.conv1.bn.weight", "encoder.s1.b2.conv1.bn.bias", "encoder.s1.b2.conv1.bn.running_mean", "encoder.s1.b2.conv1.bn.running_var", "encoder.s1.b2.conv2.conv.weight", "encoder.s1.b2.conv2.bn.weight", "encoder.s1.b2.conv2.bn.bias", "encoder.s1.b2.conv2.bn.running_mean", "encoder.s1.b2.conv2.bn.running_var", "encoder.s1.b2.conv3.conv.weight", "encoder.s1.b2.conv3.bn.weight", "encoder.s1.b2.conv3.bn.bias", "encoder.s1.b2.conv3.bn.running_mean", "encoder.s1.b2.conv3.bn.running_var", "encoder.s2.b1.conv1.conv.weight", "encoder.s2.b1.conv1.bn.weight", "encoder.s2.b1.conv1.bn.bias", "encoder.s2.b1.conv1.bn.running_mean", "encoder.s2.b1.conv1.bn.running_var", "encoder.s2.b1.conv2.conv.weight", "encoder.s2.b1.conv2.bn.weight", "encoder.s2.b1.conv2.bn.bias", "encoder.s2.b1.conv2.bn.running_mean", "encoder.s2.b1.conv2.bn.running_var", "encoder.s2.b1.conv3.conv.weight", "encoder.s2.b1.conv3.bn.weight", "encoder.s2.b1.conv3.bn.bias", "encoder.s2.b1.conv3.bn.running_mean", "encoder.s2.b1.conv3.bn.running_var", "encoder.s2.b1.downsample.conv.weight", "encoder.s2.b1.downsample.bn.weight", "encoder.s2.b1.downsample.bn.bias", "encoder.s2.b1.downsample.bn.running_mean", "encoder.s2.b1.downsample.bn.running_var", "encoder.s2.b2.conv1.conv.weight", "encoder.s2.b2.conv1.bn.weight", "encoder.s2.b2.conv1.bn.bias", "encoder.s2.b2.conv1.bn.running_mean", "encoder.s2.b2.conv1.bn.running_var", "encoder.s2.b2.conv2.conv.weight", "encoder.s2.b2.conv2.bn.weight", "encoder.s2.b2.conv2.bn.bias", "encoder.s2.b2.conv2.bn.running_mean", "encoder.s2.b2.conv2.bn.running_var", "encoder.s2.b2.conv3.conv.weight", "encoder.s2.b2.conv3.bn.weight", "encoder.s2.b2.conv3.bn.bias", "encoder.s2.b2.conv3.bn.running_mean", "encoder.s2.b2.conv3.bn.running_var", "encoder.s2.b3.conv1.conv.weight", "encoder.s2.b3.conv1.bn.weight", "encoder.s2.b3.conv1.bn.bias", "encoder.s2.b3.conv1.bn.running_mean", "encoder.s2.b3.conv1.bn.running_var", "encoder.s2.b3.conv2.conv.weight", "encoder.s2.b3.conv2.bn.weight", "encoder.s2.b3.conv2.bn.bias", "encoder.s2.b3.conv2.bn.running_mean", "encoder.s2.b3.conv2.bn.running_var", "encoder.s2.b3.conv3.conv.weight", "encoder.s2.b3.conv3.bn.weight", "encoder.s2.b3.conv3.bn.bias", "encoder.s2.b3.conv3.bn.running_mean", "encoder.s2.b3.conv3.bn.running_var", "encoder.s2.b4.conv1.conv.weight", "encoder.s2.b4.conv1.bn.weight", "encoder.s2.b4.conv1.bn.bias", "encoder.s2.b4.conv1.bn.running_mean", "encoder.s2.b4.conv1.bn.running_var", "encoder.s2.b4.conv2.conv.weight", "encoder.s2.b4.conv2.bn.weight", "encoder.s2.b4.conv2.bn.bias", "encoder.s2.b4.conv2.bn.running_mean", "encoder.s2.b4.conv2.bn.running_var", "encoder.s2.b4.conv3.conv.weight", "encoder.s2.b4.conv3.bn.weight", "encoder.s2.b4.conv3.bn.bias", "encoder.s2.b4.conv3.bn.running_mean", "encoder.s2.b4.conv3.bn.running_var", "encoder.s2.b5.conv1.conv.weight", "encoder.s2.b5.conv1.bn.weight", "encoder.s2.b5.conv1.bn.bias", "encoder.s2.b5.conv1.bn.running_mean", "encoder.s2.b5.conv1.bn.running_var", "encoder.s2.b5.conv2.conv.weight", "encoder.s2.b5.conv2.bn.weight", "encoder.s2.b5.conv2.bn.bias", "encoder.s2.b5.conv2.bn.running_mean", "encoder.s2.b5.conv2.bn.running_var", "encoder.s2.b5.conv3.conv.weight", "encoder.s2.b5.conv3.bn.weight", "encoder.s2.b5.conv3.bn.bias", "encoder.s2.b5.conv3.bn.running_mean", "encoder.s2.b5.conv3.bn.running_var", "encoder.s3.b1.conv1.conv.weight", "encoder.s3.b1.conv1.bn.weight", "encoder.s3.b1.conv1.bn.bias", "encoder.s3.b1.conv1.bn.running_mean", "encoder.s3.b1.conv1.bn.running_var", "encoder.s3.b1.conv2.conv.weight", "encoder.s3.b1.conv2.bn.weight", "encoder.s3.b1.conv2.bn.bias", "encoder.s3.b1.conv2.bn.running_mean", "encoder.s3.b1.conv2.bn.running_var", "encoder.s3.b1.conv3.conv.weight", "encoder.s3.b1.conv3.bn.weight", "encoder.s3.b1.conv3.bn.bias", "encoder.s3.b1.conv3.bn.running_mean", "encoder.s3.b1.conv3.bn.running_var", "encoder.s3.b1.downsample.conv.weight", "encoder.s3.b1.downsample.bn.weight", "encoder.s3.b1.downsample.bn.bias", "encoder.s3.b1.downsample.bn.running_mean", "encoder.s3.b1.downsample.bn.running_var", "encoder.s3.b2.conv1.conv.weight", "encoder.s3.b2.conv1.bn.weight", "encoder.s3.b2.conv1.bn.bias", "encoder.s3.b2.conv1.bn.running_mean", "encoder.s3.b2.conv1.bn.running_var", "encoder.s3.b2.conv2.conv.weight", "encoder.s3.b2.conv2.bn.weight", "encoder.s3.b2.conv2.bn.bias", "encoder.s3.b2.conv2.bn.running_mean", "encoder.s3.b2.conv2.bn.running_var", "encoder.s3.b2.conv3.conv.weight", "encoder.s3.b2.conv3.bn.weight", "encoder.s3.b2.conv3.bn.bias", "encoder.s3.b2.conv3.bn.running_mean", "encoder.s3.b2.conv3.bn.running_var", "encoder.s3.b3.conv1.conv.weight", "encoder.s3.b3.conv1.bn.weight", "encoder.s3.b3.conv1.bn.bias", "encoder.s3.b3.conv1.bn.running_mean", "encoder.s3.b3.conv1.bn.running_var", "encoder.s3.b3.conv2.conv.weight", "encoder.s3.b3.conv2.bn.weight", "encoder.s3.b3.conv2.bn.bias", "encoder.s3.b3.conv2.bn.running_mean", "encoder.s3.b3.conv2.bn.running_var", "encoder.s3.b3.conv3.conv.weight", "encoder.s3.b3.conv3.bn.weight", "encoder.s3.b3.conv3.bn.bias", "encoder.s3.b3.conv3.bn.running_mean", "encoder.s3.b3.conv3.bn.running_var", "encoder.s3.b4.conv1.conv.weight", "encoder.s3.b4.conv1.bn.weight", "encoder.s3.b4.conv1.bn.bias", "encoder.s3.b4.conv1.bn.running_mean", "encoder.s3.b4.conv1.bn.running_var", "encoder.s3.b4.conv2.conv.weight", "encoder.s3.b4.conv2.bn.weight", "encoder.s3.b4.conv2.bn.bias", "encoder.s3.b4.conv2.bn.running_mean", "encoder.s3.b4.conv2.bn.running_var", "encoder.s3.b4.conv3.conv.weight", "encoder.s3.b4.conv3.bn.weight", "encoder.s3.b4.conv3.bn.bias", "encoder.s3.b4.conv3.bn.running_mean", "encoder.s3.b4.conv3.bn.running_var", "encoder.s3.b5.conv1.conv.weight", "encoder.s3.b5.conv1.bn.weight", "encoder.s3.b5.conv1.bn.bias", "encoder.s3.b5.conv1.bn.running_mean", "encoder.s3.b5.conv1.bn.running_var", "encoder.s3.b5.conv2.conv.weight", "encoder.s3.b5.conv2.bn.weight", "encoder.s3.b5.conv2.bn.bias", "encoder.s3.b5.conv2.bn.running_mean", "encoder.s3.b5.conv2.bn.running_var", "encoder.s3.b5.conv3.conv.weight", "encoder.s3.b5.conv3.bn.weight", "encoder.s3.b5.conv3.bn.bias", "encoder.s3.b5.conv3.bn.running_mean", "encoder.s3.b5.conv3.bn.running_var", "encoder.s3.b6.conv1.conv.weight", "encoder.s3.b6.conv1.bn.weight", "encoder.s3.b6.conv1.bn.bias", "encoder.s3.b6.conv1.bn.running_mean", "encoder.s3.b6.conv1.bn.running_var", "encoder.s3.b6.conv2.conv.weight", "encoder.s3.b6.conv2.bn.weight", "encoder.s3.b6.conv2.bn.bias", "encoder.s3.b6.conv2.bn.running_mean", "encoder.s3.b6.conv2.bn.running_var", "encoder.s3.b6.conv3.conv.weight", "encoder.s3.b6.conv3.bn.weight", "encoder.s3.b6.conv3.bn.bias", "encoder.s3.b6.conv3.bn.running_mean", "encoder.s3.b6.conv3.bn.running_var", "encoder.s3.b7.conv1.conv.weight", "encoder.s3.b7.conv1.bn.weight", "encoder.s3.b7.conv1.bn.bias", "encoder.s3.b7.conv1.bn.running_mean", "encoder.s3.b7.conv1.bn.running_var", "encoder.s3.b7.conv2.conv.weight", "encoder.s3.b7.conv2.bn.weight", "encoder.s3.b7.conv2.bn.bias", "encoder.s3.b7.conv2.bn.running_mean", "encoder.s3.b7.conv2.bn.running_var", "encoder.s3.b7.conv3.conv.weight", "encoder.s3.b7.conv3.bn.weight", "encoder.s3.b7.conv3.bn.bias", "encoder.s3.b7.conv3.bn.running_mean", "encoder.s3.b7.conv3.bn.running_var", "encoder.s3.b8.conv1.conv.weight", "encoder.s3.b8.conv1.bn.weight", "encoder.s3.b8.conv1.bn.bias", "encoder.s3.b8.conv1.bn.running_mean", "encoder.s3.b8.conv1.bn.running_var", "encoder.s3.b8.conv2.conv.weight", "encoder.s3.b8.conv2.bn.weight", "encoder.s3.b8.conv2.bn.bias", "encoder.s3.b8.conv2.bn.running_mean", "encoder.s3.b8.conv2.bn.running_var", "encoder.s3.b8.conv3.conv.weight", "encoder.s3.b8.conv3.bn.weight", "encoder.s3.b8.conv3.bn.bias", "encoder.s3.b8.conv3.bn.running_mean", "encoder.s3.b8.conv3.bn.running_var", "encoder.s3.b9.conv1.conv.weight", "encoder.s3.b9.conv1.bn.weight", "encoder.s3.b9.conv1.bn.bias", "encoder.s3.b9.conv1.bn.running_mean", "encoder.s3.b9.conv1.bn.running_var", "encoder.s3.b9.conv2.conv.weight", "encoder.s3.b9.conv2.bn.weight", "encoder.s3.b9.conv2.bn.bias", "encoder.s3.b9.conv2.bn.running_mean", "encoder.s3.b9.conv2.bn.running_var", "encoder.s3.b9.conv3.conv.weight", "encoder.s3.b9.conv3.bn.weight", "encoder.s3.b9.conv3.bn.bias", "encoder.s3.b9.conv3.bn.running_mean", "encoder.s3.b9.conv3.bn.running_var", "encoder.s3.b10.conv1.conv.weight", "encoder.s3.b10.conv1.bn.weight", "encoder.s3.b10.conv1.bn.bias", "encoder.s3.b10.conv1.bn.running_mean", "encoder.s3.b10.conv1.bn.running_var", "encoder.s3.b10.conv2.conv.weight", "encoder.s3.b10.conv2.bn.weight", "encoder.s3.b10.conv2.bn.bias", "encoder.s3.b10.conv2.bn.running_mean", "encoder.s3.b10.conv2.bn.running_var", "encoder.s3.b10.conv3.conv.weight", "encoder.s3.b10.conv3.bn.weight", "encoder.s3.b10.conv3.bn.bias", "encoder.s3.b10.conv3.bn.running_mean", "encoder.s3.b10.conv3.bn.running_var", "encoder.s3.b11.conv1.conv.weight", "encoder.s3.b11.conv1.bn.weight", "encoder.s3.b11.conv1.bn.bias", "encoder.s3.b11.conv1.bn.running_mean", "encoder.s3.b11.conv1.bn.running_var", "encoder.s3.b11.conv2.conv.weight", "encoder.s3.b11.conv2.bn.weight", "encoder.s3.b11.conv2.bn.bias", "encoder.s3.b11.conv2.bn.running_mean", "encoder.s3.b11.conv2.bn.running_var", "encoder.s3.b11.conv3.conv.weight", "encoder.s3.b11.conv3.bn.weight", "encoder.s3.b11.conv3.bn.bias", "encoder.s3.b11.conv3.bn.running_mean", "encoder.s3.b11.conv3.bn.running_var", "encoder.s3.b12.conv1.conv.weight", "encoder.s3.b12.conv1.bn.weight", "encoder.s3.b12.conv1.bn.bias", "encoder.s3.b12.conv1.bn.running_mean", "encoder.s3.b12.conv1.bn.running_var", "encoder.s3.b12.conv2.conv.weight", "encoder.s3.b12.conv2.bn.weight", "encoder.s3.b12.conv2.bn.bias", "encoder.s3.b12.conv2.bn.running_mean", "encoder.s3.b12.conv2.bn.running_var", "encoder.s3.b12.conv3.conv.weight", "encoder.s3.b12.conv3.bn.weight", "encoder.s3.b12.conv3.bn.bias", "encoder.s3.b12.conv3.bn.running_mean", "encoder.s3.b12.conv3.bn.running_var", "encoder.s3.b13.conv1.conv.weight", "encoder.s3.b13.conv1.bn.weight", "encoder.s3.b13.conv1.bn.bias", "encoder.s3.b13.conv1.bn.running_mean", "encoder.s3.b13.conv1.bn.running_var", "encoder.s3.b13.conv2.conv.weight", "encoder.s3.b13.conv2.bn.weight", "encoder.s3.b13.conv2.bn.bias", "encoder.s3.b13.conv2.bn.running_mean", "encoder.s3.b13.conv2.bn.running_var", "encoder.s3.b13.conv3.conv.weight", "encoder.s3.b13.conv3.bn.weight", "encoder.s3.b13.conv3.bn.bias", "encoder.s3.b13.conv3.bn.running_mean", "encoder.s3.b13.conv3.bn.running_var", "encoder.s3.b14.conv1.conv.weight", "encoder.s3.b14.conv1.bn.weight", "encoder.s3.b14.conv1.bn.bias", "encoder.s3.b14.conv1.bn.running_mean", "encoder.s3.b14.conv1.bn.running_var", "encoder.s3.b14.conv2.conv.weight", "encoder.s3.b14.conv2.bn.weight", "encoder.s3.b14.conv2.bn.bias", "encoder.s3.b14.conv2.bn.running_mean", "encoder.s3.b14.conv2.bn.running_var", "encoder.s3.b14.conv3.conv.weight", "encoder.s3.b14.conv3.bn.weight", "encoder.s3.b14.conv3.bn.bias", "encoder.s3.b14.conv3.bn.running_mean", "encoder.s3.b14.conv3.bn.running_var", "encoder.s4.b1.conv1.conv.weight", "encoder.s4.b1.conv1.bn.weight", "encoder.s4.b1.conv1.bn.bias", "encoder.s4.b1.conv1.bn.running_mean", "encoder.s4.b1.conv1.bn.running_var", "encoder.s4.b1.conv2.conv.weight", "encoder.s4.b1.conv2.bn.weight", "encoder.s4.b1.conv2.bn.bias", "encoder.s4.b1.conv2.bn.running_mean", "encoder.s4.b1.conv2.bn.running_var", "encoder.s4.b1.conv3.conv.weight", "encoder.s4.b1.conv3.bn.weight", "encoder.s4.b1.conv3.bn.bias", "encoder.s4.b1.conv3.bn.running_mean", "encoder.s4.b1.conv3.bn.running_var", "encoder.s4.b1.downsample.conv.weight", "encoder.s4.b1.downsample.bn.weight", "encoder.s4.b1.downsample.bn.bias", "encoder.s4.b1.downsample.bn.running_mean", "encoder.s4.b1.downsample.bn.running_var", "encoder.s4.b2.conv1.conv.weight", "encoder.s4.b2.conv1.bn.weight", "encoder.s4.b2.conv1.bn.bias", "encoder.s4.b2.conv1.bn.running_mean", "encoder.s4.b2.conv1.bn.running_var", "encoder.s4.b2.conv2.conv.weight", "encoder.s4.b2.conv2.bn.weight", "encoder.s4.b2.conv2.bn.bias", "encoder.s4.b2.conv2.bn.running_mean", "encoder.s4.b2.conv2.bn.running_var", "encoder.s4.b2.conv3.conv.weight", "encoder.s4.b2.conv3.bn.weight", "encoder.s4.b2.conv3.bn.bias", "encoder.s4.b2.conv3.bn.running_mean", "encoder.s4.b2.conv3.bn.running_var". 
	Unexpected key(s) in state_dict: "encoder.conv1.weight", "encoder.bn1.weight", "encoder.bn1.bias", "encoder.bn1.running_mean", "encoder.bn1.running_var", "encoder.bn1.num_batches_tracked", "encoder.layer1.0.conv1.weight", "encoder.layer1.0.bn1.weight", "encoder.layer1.0.bn1.bias", "encoder.layer1.0.bn1.running_mean", "encoder.layer1.0.bn1.running_var", "encoder.layer1.0.bn1.num_batches_tracked", "encoder.layer1.0.conv2.weight", "encoder.layer1.0.bn2.weight", "encoder.layer1.0.bn2.bias", "encoder.layer1.0.bn2.running_mean", "encoder.layer1.0.bn2.running_var", "encoder.layer1.0.bn2.num_batches_tracked", "encoder.layer1.1.conv1.weight", "encoder.layer1.1.bn1.weight", "encoder.layer1.1.bn1.bias", "encoder.layer1.1.bn1.running_mean", "encoder.layer1.1.bn1.running_var", "encoder.layer1.1.bn1.num_batches_tracked", "encoder.layer1.1.conv2.weight", "encoder.layer1.1.bn2.weight", "encoder.layer1.1.bn2.bias", "encoder.layer1.1.bn2.running_mean", "encoder.layer1.1.bn2.running_var", "encoder.layer1.1.bn2.num_batches_tracked", "encoder.layer1.2.conv1.weight", "encoder.layer1.2.bn1.weight", "encoder.layer1.2.bn1.bias", "encoder.layer1.2.bn1.running_mean", "encoder.layer1.2.bn1.running_var", "encoder.layer1.2.bn1.num_batches_tracked", "encoder.layer1.2.conv2.weight", "encoder.layer1.2.bn2.weight", "encoder.layer1.2.bn2.bias", "encoder.layer1.2.bn2.running_mean", "encoder.layer1.2.bn2.running_var", "encoder.layer1.2.bn2.num_batches_tracked", "encoder.layer2.0.conv1.weight", "encoder.layer2.0.bn1.weight", "encoder.layer2.0.bn1.bias", "encoder.layer2.0.bn1.running_mean", "encoder.layer2.0.bn1.running_var", "encoder.layer2.0.bn1.num_batches_tracked", "encoder.layer2.0.conv2.weight", "encoder.layer2.0.bn2.weight", "encoder.layer2.0.bn2.bias", "encoder.layer2.0.bn2.running_mean", "encoder.layer2.0.bn2.running_var", "encoder.layer2.0.bn2.num_batches_tracked", "encoder.layer2.0.downsample.0.weight", "encoder.layer2.0.downsample.1.weight", "encoder.layer2.0.downsample.1.bias", "encoder.layer2.0.downsample.1.running_mean", "encoder.layer2.0.downsample.1.running_var", "encoder.layer2.0.downsample.1.num_batches_tracked", "encoder.layer2.1.conv1.weight", "encoder.layer2.1.bn1.weight", "encoder.layer2.1.bn1.bias", "encoder.layer2.1.bn1.running_mean", "encoder.layer2.1.bn1.running_var", "encoder.layer2.1.bn1.num_batches_tracked", "encoder.layer2.1.conv2.weight", "encoder.layer2.1.bn2.weight", "encoder.layer2.1.bn2.bias", "encoder.layer2.1.bn2.running_mean", "encoder.layer2.1.bn2.running_var", "encoder.layer2.1.bn2.num_batches_tracked", "encoder.layer2.2.conv1.weight", "encoder.layer2.2.bn1.weight", "encoder.layer2.2.bn1.bias", "encoder.layer2.2.bn1.running_mean", "encoder.layer2.2.bn1.running_var", "encoder.layer2.2.bn1.num_batches_tracked", "encoder.layer2.2.conv2.weight", "encoder.layer2.2.bn2.weight", "encoder.layer2.2.bn2.bias", "encoder.layer2.2.bn2.running_mean", "encoder.layer2.2.bn2.running_var", "encoder.layer2.2.bn2.num_batches_tracked", "encoder.layer2.3.conv1.weight", "encoder.layer2.3.bn1.weight", "encoder.layer2.3.bn1.bias", "encoder.layer2.3.bn1.running_mean", "encoder.layer2.3.bn1.running_var", "encoder.layer2.3.bn1.num_batches_tracked", "encoder.layer2.3.conv2.weight", "encoder.layer2.3.bn2.weight", "encoder.layer2.3.bn2.bias", "encoder.layer2.3.bn2.running_mean", "encoder.layer2.3.bn2.running_var", "encoder.layer2.3.bn2.num_batches_tracked", "encoder.layer3.0.conv1.weight", "encoder.layer3.0.bn1.weight", "encoder.layer3.0.bn1.bias", "encoder.layer3.0.bn1.running_mean", "encoder.layer3.0.bn1.running_var", "encoder.layer3.0.bn1.num_batches_tracked", "encoder.layer3.0.conv2.weight", "encoder.layer3.0.bn2.weight", "encoder.layer3.0.bn2.bias", "encoder.layer3.0.bn2.running_mean", "encoder.layer3.0.bn2.running_var", "encoder.layer3.0.bn2.num_batches_tracked", "encoder.layer3.0.downsample.0.weight", "encoder.layer3.0.downsample.1.weight", "encoder.layer3.0.downsample.1.bias", "encoder.layer3.0.downsample.1.running_mean", "encoder.layer3.0.downsample.1.running_var", "encoder.layer3.0.downsample.1.num_batches_tracked", "encoder.layer3.1.conv1.weight", "encoder.layer3.1.bn1.weight", "encoder.layer3.1.bn1.bias", "encoder.layer3.1.bn1.running_mean", "encoder.layer3.1.bn1.running_var", "encoder.layer3.1.bn1.num_batches_tracked", "encoder.layer3.1.conv2.weight", "encoder.layer3.1.bn2.weight", "encoder.layer3.1.bn2.bias", "encoder.layer3.1.bn2.running_mean", "encoder.layer3.1.bn2.running_var", "encoder.layer3.1.bn2.num_batches_tracked", "encoder.layer3.2.conv1.weight", "encoder.layer3.2.bn1.weight", "encoder.layer3.2.bn1.bias", "encoder.layer3.2.bn1.running_mean", "encoder.layer3.2.bn1.running_var", "encoder.layer3.2.bn1.num_batches_tracked", "encoder.layer3.2.conv2.weight", "encoder.layer3.2.bn2.weight", "encoder.layer3.2.bn2.bias", "encoder.layer3.2.bn2.running_mean", "encoder.layer3.2.bn2.running_var", "encoder.layer3.2.bn2.num_batches_tracked", "encoder.layer3.3.conv1.weight", "encoder.layer3.3.bn1.weight", "encoder.layer3.3.bn1.bias", "encoder.layer3.3.bn1.running_mean", "encoder.layer3.3.bn1.running_var", "encoder.layer3.3.bn1.num_batches_tracked", "encoder.layer3.3.conv2.weight", "encoder.layer3.3.bn2.weight", "encoder.layer3.3.bn2.bias", "encoder.layer3.3.bn2.running_mean", "encoder.layer3.3.bn2.running_var", "encoder.layer3.3.bn2.num_batches_tracked", "encoder.layer3.4.conv1.weight", "encoder.layer3.4.bn1.weight", "encoder.layer3.4.bn1.bias", "encoder.layer3.4.bn1.running_mean", "encoder.layer3.4.bn1.running_var", "encoder.layer3.4.bn1.num_batches_tracked", "encoder.layer3.4.conv2.weight", "encoder.layer3.4.bn2.weight", "encoder.layer3.4.bn2.bias", "encoder.layer3.4.bn2.running_mean", "encoder.layer3.4.bn2.running_var", "encoder.layer3.4.bn2.num_batches_tracked", "encoder.layer3.5.conv1.weight", "encoder.layer3.5.bn1.weight", "encoder.layer3.5.bn1.bias", "encoder.layer3.5.bn1.running_mean", "encoder.layer3.5.bn1.running_var", "encoder.layer3.5.bn1.num_batches_tracked", "encoder.layer3.5.conv2.weight", "encoder.layer3.5.bn2.weight", "encoder.layer3.5.bn2.bias", "encoder.layer3.5.bn2.running_mean", "encoder.layer3.5.bn2.running_var", "encoder.layer3.5.bn2.num_batches_tracked", "encoder.layer4.0.conv1.weight", "encoder.layer4.0.bn1.weight", "encoder.layer4.0.bn1.bias", "encoder.layer4.0.bn1.running_mean", "encoder.layer4.0.bn1.running_var", "encoder.layer4.0.bn1.num_batches_tracked", "encoder.layer4.0.conv2.weight", "encoder.layer4.0.bn2.weight", "encoder.layer4.0.bn2.bias", "encoder.layer4.0.bn2.running_mean", "encoder.layer4.0.bn2.running_var", "encoder.layer4.0.bn2.num_batches_tracked", "encoder.layer4.0.downsample.0.weight", "encoder.layer4.0.downsample.1.weight", "encoder.layer4.0.downsample.1.bias", "encoder.layer4.0.downsample.1.running_mean", "encoder.layer4.0.downsample.1.running_var", "encoder.layer4.0.downsample.1.num_batches_tracked", "encoder.layer4.1.conv1.weight", "encoder.layer4.1.bn1.weight", "encoder.layer4.1.bn1.bias", "encoder.layer4.1.bn1.running_mean", "encoder.layer4.1.bn1.running_var", "encoder.layer4.1.bn1.num_batches_tracked", "encoder.layer4.1.conv2.weight", "encoder.layer4.1.bn2.weight", "encoder.layer4.1.bn2.bias", "encoder.layer4.1.bn2.running_mean", "encoder.layer4.1.bn2.running_var", "encoder.layer4.1.bn2.num_batches_tracked", "encoder.layer4.2.conv1.weight", "encoder.layer4.2.bn1.weight", "encoder.layer4.2.bn1.bias", "encoder.layer4.2.bn1.running_mean", "encoder.layer4.2.bn1.running_var", "encoder.layer4.2.bn1.num_batches_tracked", "encoder.layer4.2.conv2.weight", "encoder.layer4.2.bn2.weight", "encoder.layer4.2.bn2.bias", "encoder.layer4.2.bn2.running_mean", "encoder.layer4.2.bn2.running_var", "encoder.layer4.2.bn2.num_batches_tracked". 
	size mismatch for decoder.blocks.x_0_0.conv1.0.weight: copying a param with shape torch.Size([256, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1920, 3, 3]).
	size mismatch for decoder.blocks.x_0_1.conv1.0.weight: copying a param with shape torch.Size([128, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 736, 3, 3]).
	size mismatch for decoder.blocks.x_1_1.conv1.0.weight: copying a param with shape torch.Size([128, 384, 3, 3]) from checkpoint, the shape in current model is torch.Size([240, 800, 3, 3]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.0.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([240, 240, 3, 3]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_0_2.conv1.0.weight: copying a param with shape torch.Size([64, 320, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 368, 3, 3]).
	size mismatch for decoder.blocks.x_1_2.conv1.0.weight: copying a param with shape torch.Size([64, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 400, 3, 3]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.0.weight: copying a param with shape torch.Size([64, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 320, 3, 3]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_0_3.conv1.0.weight: copying a param with shape torch.Size([32, 320, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 192, 3, 3]).
	size mismatch for decoder.blocks.x_1_3.conv1.0.weight: copying a param with shape torch.Size([64, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 176, 3, 3]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.0.weight: copying a param with shape torch.Size([64, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 144, 3, 3]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.0.weight: copying a param with shape torch.Size([64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 112, 3, 3]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).

In [None]:
count = 0 

for image, mask in DataLoader(valDataset, batch_size=1, shuffle=False, num_workers=0): 
    predict_image(model, image.squeeze(0), mask.squeeze(0))
    print("-----------------------")
    count += 1
    if count == 50:
        break