In [1]:
!pip install segmentation_models_pytorch 
!pip install neptune
!pip install torchmetrics
!pip install albumentations 

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Collecting timm==0.9.2
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m64.5 MB/s[0m eta [36m0:00:00[0m
Collecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting munch
  Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Collecting safetensors
  Downloading safetensors-0.4.0-cp39-cp39-manylinux_2_17_x86_64.many

In [2]:
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
import numpy as np
import torch.nn as nn
import time
import copy
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import segmentation_models_pytorch as smp
import neptune
from neptune.types import File

from torchmetrics.functional.classification import dice as calc_dice_score
from torchmetrics.classification import BinaryJaccardIndex

import albumentations as A
import cv2

In [3]:
def identify_axis(shape):
    # Three dimensional
    if len(shape) == 5 : return [2,3,4]

    # Two dimensional
    elif len(shape) == 4 : return [2,3]
    
    # Exception - Unknown
    else : raise ValueError('Metric: Shape of tensor is neither 2D or 3D.')

class AsymmetricFocalLoss(nn.Module):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.25
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    epsilon : float, optional
        clip values to prevent division by zero error
    """
    def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
        super(AsymmetricFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, y_pred_raw, y_true):
        # Apply sigmoid function to raw output
        y_pred = torch.sigmoid(y_pred_raw)
        
        # Rest of the code remains the same
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = -y_true * torch.log(y_pred) - (1 - y_true) * torch.log(1 - y_pred)
        
        # Calculate losses separately for each class, only suppressing background class
        back_ce = torch.pow(1 - y_pred, self.gamma) * cross_entropy
        back_ce =  (1 - self.delta) * back_ce

        fore_ce = cross_entropy
        fore_ce = self.delta * fore_ce

        loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1))

        return loss


class AsymmetricFocalTverskyLoss(nn.Module):
    """This is the implementation for binary segmentation.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    smooth : float, optional
        smooithing constant to prevent division by 0 errors, by default 0.000001
    epsilon : float, optional
        clip values to prevent division by zero error
    """
    def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07):
        super(AsymmetricFocalTverskyLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        # Clip values to prevent division by zero error
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        axis = identify_axis(y_true.size())

        # Calculate true positives (tp), false negatives (fn) and false positives (fp)     
        tp = torch.sum(y_true * y_pred, axis=axis)
        fn = torch.sum(y_true * (1-y_pred), axis=axis)
        fp = torch.sum((1-y_true) * y_pred, axis=axis)
        dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon)

        # Calculate losses separately for each class, only enhancing foreground class
        back_dice = (1-dice_class[:,0]) 
        fore_dice = (1-dice_class[:,0]) * torch.pow(1-dice_class[:,0], -self.gamma)

        # Average class scores
        loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1))
        return loss

class AsymmetricUnifiedFocalLoss(nn.Module):
    """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework.
    Parameters
    ----------
    weight : float, optional
        represents lambda parameter and controls weight given to asymmetric Focal Tversky loss and asymmetric Focal loss, by default 0.5
    delta : float, optional
        controls weight given to each class, by default 0.6
    gamma : float, optional
        focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5
    epsilon : float, optional
        clip values to prevent division by zero error
    """
    def __init__(self, weight=0.5, delta=0.7, gamma=0.5):
        super(AsymmetricUnifiedFocalLoss, self).__init__()
        self.weight = weight
        self.delta = delta
        self.gamma = gamma

    def forward(self, y_pred, y_true):
      # Obtain Asymmetric Focal Tversky loss
      asymmetric_ftl = AsymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true)

      # Obtain Asymmetric Focal loss
      asymmetric_fl = AsymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true)

      # Return weighted sum of Asymmetrical Focal loss and Asymmetric Focal Tversky loss
      if self.weight is not None:
        return (self.weight * asymmetric_ftl) + ((1-self.weight) * asymmetric_fl)  
      else:
        return asymmetric_ftl + asymmetric_fl

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [4]:
# Define a custom dataset class
# Training params
BATCH_SIZE = 8
EPOCHS = 100
LEARNING_RATE = 0.000045

# Model params
ENCODER_NAME = "resnet34"
ENCODER_WEIGHTS = "imagenet"

with open("/notebooks/NEPTUNE_API_TOKEN.txt", "r") as file:
    # Read the entire content of the file into a string
    token = file.read()

run = neptune.init_run(
    project="Kernel-bois/computer-vision",
    api_token=token,
)
run_id = run["sys/id"].fetch()

# Create the model
model = smp.Unet(
    encoder_name=ENCODER_NAME,           # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=ENCODER_WEIGHTS,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                        # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                            # model output channels (number of classes in your dataset)
    )


# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold = 0.001, patience = 5)

run_name = "MODEL-" + model.__class__.__name__ + ENCODER_NAME + str(run_id)

save_path = str(run_id) + "/"
os.makedirs(save_path)

# Proper directories
TRAIN_DATA_DIR = '/notebooks/image_segmentation/network/image_data_all3/train'
VAL_DATA_DIR = '/notebooks/image_segmentation/network/image_data_all3/val'

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define loss function
# criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=False)  # Binary dice Loss for binary segmentation
# criterion = smp.losses.SoftBCEWithLogitsLoss()  # Binary dice Loss for binary segmentation
criterion = AsymmetricUnifiedFocalLoss()  # Binary dice Loss for binary segmentation
calc_iou = BinaryJaccardIndex().to(device)


params = {
    "MODEL": model.__class__.__name__,
    "BACKBONE": ENCODER_NAME,
    "ENCODER_WEIGHTS": ENCODER_WEIGHTS,
    "BATCH_SIZE": str(BATCH_SIZE),
    "EPOCHS": str(EPOCHS),
    "CRITERION": criterion.__class__.__name__,
    "OPTIMIZER": optimizer.__class__.__name__,
    "LEARNRATE": str(LEARNING_RATE),
    "MODEL_NAME": run_name,
}

run["params"] = params



  run = neptune.init_run(


https://app.neptune.ai/Kernel-bois/computer-vision/e/CV-265


In [5]:
class SegmentationDataset(Dataset):
    def __init__(self, root_dir, transform = None, target_size = (992, 416)):
        self.root_dir = root_dir
        self.transform = transform
        self.target_size = target_size

        self.image_folder = os.path.join(root_dir, 'images')
        self.mask_folder = os.path.join(root_dir, 'masks')

        self.images = os.listdir(self.image_folder)
        self.masks = os.listdir(self.mask_folder)
        
        assert len(self.images) == len(self.masks), "Number of images and masks should be the same."

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.images[idx])
        if "patient" in self.images[idx]:
            mask_path = os.path.join(self.mask_folder, "segmentation_" + self.images[idx][-7:])
        else:
            mask_path = os.path.join(self.mask_folder, "target_seg_" + self.images[idx][-7:])
        
        # Load images
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        # Convert to tensors 
        tensor_image = torch.from_numpy(image)
        tensor_image = tensor_image.permute(2, 0, 1)

        tensor_mask = torch.from_numpy(mask)
        tensor_mask = tensor_mask.permute(2, 0, 1) / 255
        tensor_mask = tensor_mask[2:, :, :]
        
        # add padding
        pad_height = max(self.target_size[0] - tensor_image.size(1), 0)
        pad_width = max(self.target_size[1] - tensor_image.size(2), 0)

        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_image = transforms.functional.pad(tensor_image, (pad_left, pad_bottom, pad_right, pad_top), fill=255)
        padded_mask = transforms.functional.pad(tensor_mask, (pad_left, pad_bottom, pad_right, pad_top), fill=0)

        return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)


In [6]:
model = model.to(device)

# Set up dataset and dataloader
transform = A.Compose([])
#transform = A.Compose([A.OneOf([A.GaussNoise(p = 1), A.RandomGamma(p = 1), A.Sharpen(p=1), A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 1) ],p=0.7), A.OneOf([A.GaussNoise(p = 1), A.RandomGamma(p = 1), A.Sharpen(p=1),A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 1) ],p=0.3)])
#
#transform = A.Compose([A.GaussNoise(p = 0.2), A.RandomGamma(p = 0.2), A.Sharpen(p=0.2),
#                       A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 0.2)])
""" 

 A.Compose([A.HorizontalFlip(p=0.5), A.OneOf([A.GaussNoise(p = 1), A.RandomGamma(p = 1), A.Sharpen(p=1), 
                       A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 1)],p=0.5)])
"""
# ALL Transforms
# A.CLAHE(p=0.2)
# A.HorizontalFlip(p=0.5),
# A.RandomGamma(p=0.2)
# A.GridDistortion(p=0.2)
# A.RandomBrightnessContrast(p=0.2)
# A.Resize(width = np.random.randint(200, 416), height = np.random.randint(200,992), p = 0.2)
# A.OneOf([ ],p=0.9) for more 
# A.Sharpen(p=0.2)
# A.Blur(p=0.2)
# A.RandomCrop(height = 200, width=200, p=0.2)
# GaussNoise(p = 0.2)

trainDataset = SegmentationDataset(root_dir=TRAIN_DATA_DIR, transform=transform)
valDataset = SegmentationDataset(root_dir=VAL_DATA_DIR)

train_loader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(valDataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)


In [None]:
# Train Loop
for epoch in range(EPOCHS):
    train_loss = torch.tensor(0.0)
    model.train()

    # Use tqdm to add a progress bar
    for images, masks in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{EPOCHS}', leave=False):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        # outputs = torch.argmax(outputs, dim=1).unsqueeze(1).float()

        loss = criterion(outputs, masks)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.detach().cpu()
        optimizer.step()

    train_loss /= len(train_loader)
    run["loss/train_loss"].log(train_loss)

    # Validation
    model.eval()

    val_loss = torch.tensor(0.0)
    iou = torch.tensor(0.0)
    dice_score = torch.tensor(0.0)

    with torch.no_grad():
        for val_images, val_masks in tqdm(val_loader, desc=f'Validation', leave=False):
            val_images, val_masks = val_images.to(device), val_masks.to(device)

            model_outputs = model(val_images)

            val_loss += criterion(model_outputs, val_masks).cpu()

            val_masks_int = torch.tensor(val_masks, dtype=torch.int8)
            dice_score += calc_dice_score(torch.sigmoid(model_outputs), val_masks_int, ignore_index=0).cpu()

            iou += calc_iou(model_outputs, val_masks_int).cpu()


    val_loss /= len(val_loader)
    iou /= len(val_loader)
    dice_score /= len(val_loader)
    
    scheduler.step(val_loss)

    if torch.isnan(iou):
        iou = torch.tensor(0.0)

    run["loss/val_loss"].log(val_loss)
    run["val/iou"].log(iou)
    run["val/dice_score"].log(dice_score)

    torch.save(model.state_dict(), save_path + run_name + "_EPOCH_" + str(epoch) + '.pth')

    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train Loss: {train_loss}, Validation Loss: {val_loss}\n"
          f"IOU: {iou}, Dice Score: {dice_score}")

# Save the trained model
torch.save(model.state_dict(), save_path + run_name + "_FINAL" + '.pth')
run[f"network/network_weights"].upload(File(run_name + '.pth'))

run.stop()


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  val_masks_int = torch.tensor(val_masks, dtype=torch.int8)
                                                           

Epoch [1/100], Train Loss: 0.9927887320518494, Validation Loss: 0.9802963733673096
IOU: 0.01371900737285614, Dice Score: 0.026982637122273445


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [2/100], Train Loss: 0.9901798367500305, Validation Loss: 0.9748908877372742
IOU: 0.02892202138900757, Dice Score: 0.055961623787879944


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [3/100], Train Loss: 0.988048255443573, Validation Loss: 0.9761552214622498
IOU: 0.05618658289313316, Dice Score: 0.10538797825574875


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [4/100], Train Loss: 0.9858904480934143, Validation Loss: 0.9624328017234802
IOU: 0.10557689517736435, Dice Score: 0.18986183404922485


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [5/100], Train Loss: 0.9828876852989197, Validation Loss: 0.9575397372245789
IOU: 0.14457498490810394, Dice Score: 0.25073641538619995


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [6/100], Train Loss: 0.9793909788131714, Validation Loss: 0.9497098922729492
IOU: 0.12249438464641571, Dice Score: 0.2161562591791153


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [7/100], Train Loss: 0.9741501212120056, Validation Loss: 0.9348689913749695
IOU: 0.20168690383434296, Dice Score: 0.33247584104537964


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [8/100], Train Loss: 0.9689496755599976, Validation Loss: 0.9216614365577698
IOU: 0.1912747472524643, Dice Score: 0.31848862767219543


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [9/100], Train Loss: 0.960737407207489, Validation Loss: 0.8958742022514343
IOU: 0.4160194396972656, Dice Score: 0.5826720595359802


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [10/100], Train Loss: 0.9480918049812317, Validation Loss: 0.8724563717842102
IOU: 0.45275115966796875, Dice Score: 0.6161086559295654


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [11/100], Train Loss: 0.935570240020752, Validation Loss: 0.8647222518920898
IOU: 0.17224180698394775, Dice Score: 0.2906143069267273


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [12/100], Train Loss: 0.9098482728004456, Validation Loss: 0.8019484877586365
IOU: 0.5125585794448853, Dice Score: 0.6714134812355042


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [13/100], Train Loss: 0.8810474872589111, Validation Loss: 0.7373595833778381
IOU: 0.4305485785007477, Dice Score: 0.599831223487854


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [14/100], Train Loss: 0.834848165512085, Validation Loss: 0.6705369353294373
IOU: 0.46636712551116943, Dice Score: 0.6334206461906433


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [15/100], Train Loss: 0.787290632724762, Validation Loss: 0.7360660433769226
IOU: 0.2146247774362564, Dice Score: 0.3404843211174011


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [16/100], Train Loss: 0.7204568982124329, Validation Loss: 0.5025706887245178
IOU: 0.5841635465621948, Dice Score: 0.7342085242271423


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [17/100], Train Loss: 0.6366013288497925, Validation Loss: 0.49566027522087097
IOU: 0.4676094055175781, Dice Score: 0.6223363876342773


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [18/100], Train Loss: 0.5652099847793579, Validation Loss: 0.39448773860931396
IOU: 0.5829125642776489, Dice Score: 0.7272955179214478


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [19/100], Train Loss: 0.49485230445861816, Validation Loss: 0.3484908640384674
IOU: 0.5897874236106873, Dice Score: 0.7325799465179443


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [20/100], Train Loss: 0.4393582046031952, Validation Loss: 0.3337964117527008
IOU: 0.6017965078353882, Dice Score: 0.7406086325645447


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [21/100], Train Loss: 0.3905489146709442, Validation Loss: 0.2907753586769104
IOU: 0.6320322751998901, Dice Score: 0.770823061466217


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [22/100], Train Loss: 0.3762131631374359, Validation Loss: 0.2845476269721985
IOU: 0.615077018737793, Dice Score: 0.7542890310287476


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [23/100], Train Loss: 0.3595917224884033, Validation Loss: 0.23303738236427307
IOU: 0.6740050911903381, Dice Score: 0.8027042746543884


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [24/100], Train Loss: 0.30445563793182373, Validation Loss: 0.22547802329063416
IOU: 0.6762659549713135, Dice Score: 0.8050290942192078


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [25/100], Train Loss: 0.2860942482948303, Validation Loss: 0.39270612597465515
IOU: 0.469266414642334, Dice Score: 0.6222317218780518


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [26/100], Train Loss: 0.27139729261398315, Validation Loss: 0.2895395755767822
IOU: 0.5811893939971924, Dice Score: 0.7281773090362549


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [27/100], Train Loss: 0.28564658761024475, Validation Loss: 0.23072881996631622
IOU: 0.6532652378082275, Dice Score: 0.7866107821464539


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [28/100], Train Loss: 0.245477557182312, Validation Loss: 0.2374807894229889
IOU: 0.6424047946929932, Dice Score: 0.7781686186790466


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [29/100], Train Loss: 0.23535723984241486, Validation Loss: 0.21684755384922028
IOU: 0.6694793701171875, Dice Score: 0.7967092990875244


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [30/100], Train Loss: 0.23030538856983185, Validation Loss: 0.21404482424259186
IOU: 0.6696136593818665, Dice Score: 0.7984399795532227


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [31/100], Train Loss: 0.22832220792770386, Validation Loss: 0.20061159133911133
IOU: 0.6849263310432434, Dice Score: 0.8101657629013062


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [32/100], Train Loss: 0.21331787109375, Validation Loss: 0.22272950410842896
IOU: 0.6522389054298401, Dice Score: 0.7865728735923767


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [33/100], Train Loss: 0.2037600725889206, Validation Loss: 0.2118222713470459
IOU: 0.6668439507484436, Dice Score: 0.7972186803817749


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [34/100], Train Loss: 0.23204964399337769, Validation Loss: 0.20859995484352112
IOU: 0.6724777817726135, Dice Score: 0.7994716167449951


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [35/100], Train Loss: 0.19261759519577026, Validation Loss: 0.2080453485250473
IOU: 0.6729398369789124, Dice Score: 0.7996243834495544


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [36/100], Train Loss: 0.14927054941654205, Validation Loss: 0.21430836617946625
IOU: 0.661419689655304, Dice Score: 0.7929033637046814


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
                                                           

Epoch [37/100], Train Loss: 0.16637198626995087, Validation Loss: 0.18912415206432343
IOU: 0.6964561343193054, Dice Score: 0.8174368143081665


  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
  return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)
Epoch 38/100:  54%|█████▍    | 41/76 [00:24<00:19,  1.75it/s]

In [None]:
!nvidia-smi

In [None]:
"""
# Declare an augmentation pipeline
transform = A.Compose([
    A.RandomCrop(height = 200, width=200, p=1)
])

image_path = "/notebooks/image_segmentation/network/image_data/train/images/patient_116.png"
mask_path = "/notebooks/image_segmentation/network/image_data/train/masks/segmentation_116.png" 

image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path)

transformed = transform(image=image, mask=mask)
transformed_image = transformed['image']
transformed_mask = transformed['mask']

image_torch, mask_torch = convert_to_torch(image, mask)
plot_image_and_mask(image_torch, mask_torch)
    
image_torch, mask_torch = convert_to_torch(transformed_image, transformed_mask)
plot_image_and_mask(image_torch, mask_torch)
"""

In [7]:
def plot_image_and_mask(image, mask): 
    image = image/ 255.0
    mask = mask

    # Plot side by side with the mask and mask overlain
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot the original image
    axes[0].imshow(image.permute(1, 2, 0))
    axes[0].set_title('Original Image')

    # Plot the mask
    axes[1].imshow(mask.permute(1, 2, 0), cmap='viridis')
    axes[1].set_title('Mask')

    # Overlay the mask on the image
    axes[2].imshow(image.permute(1, 2, 0))
    axes[2].imshow(mask.permute(1, 2, 0), cmap='viridis', alpha=0.6)  # Set alpha to less than 1
    axes[2].set_title('Mask Overlain on Image')

    # Display the plots
    plt.show()

def convert_to_torch(image, mask):
    
    tensor_image = torch.from_numpy(image)
    tensor_image = tensor_image.permute(2,0, 1)

    tensor_mask = torch.from_numpy(mask)
    tensor_mask = tensor_mask.permute(2,0, 1) / 255
    
    tensor_mask = tensor_mask[2:, :, :]

    # add padding
    pad_height = max(992 - tensor_image.size(1), 0)
    pad_width = max(416 - tensor_image.size(2), 0)

    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left

    padded_image = transforms.functional.pad(tensor_image, (pad_left, pad_bottom, pad_right, pad_top), fill=255)
    padded_mask = transforms.functional.pad(tensor_mask, (pad_left, pad_bottom, pad_right, pad_top), fill=0)


    return torch.tensor(padded_image, dtype=torch.float32), torch.tensor(padded_mask, dtype=torch.float32)

In [8]:
image_path = "/notebooks/image_segmentation/network/image_data/train/images/patient_116.png"
mask_path = "/notebooks/image_segmentation/network/image_data/train/masks/segmentation_116.png" 
model_path = "/notebooks/Testing/CV-131/MODEL-UnetPlusPlusresnet34CV-131_FINAL.pth" 

model = smp.UnetPlusPlus(
    encoder_name=ENCODER_NAME,           # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    in_channels=3,                        # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                            # model output channels (number of classes in your dataset)
)
model.load_state_dict(torch.load(model_path))
model.eval()

def predict_image(model, image, mask): 

    if type(image) == str: 
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
    
        image, mask = convert_to_torch(image, mask)
    
    out = model(image.unsqueeze(0)) 

    segment_map = torch.sigmoid(out).squeeze(0).detach()
    segment_map = (segment_map > 0.5) * 1
    
    plot_image_and_mask(image, mask)
    plot_image_and_mask(image, segment_map)
    
#predict_image(model, image_path, mask_path)

RuntimeError: Error(s) in loading state_dict for UnetPlusPlus:
	Missing key(s) in state_dict: "encoder.stem.conv.weight", "encoder.stem.bn.weight", "encoder.stem.bn.bias", "encoder.stem.bn.running_mean", "encoder.stem.bn.running_var", "encoder.s1.b1.conv1.conv.weight", "encoder.s1.b1.conv1.bn.weight", "encoder.s1.b1.conv1.bn.bias", "encoder.s1.b1.conv1.bn.running_mean", "encoder.s1.b1.conv1.bn.running_var", "encoder.s1.b1.conv2.conv.weight", "encoder.s1.b1.conv2.bn.weight", "encoder.s1.b1.conv2.bn.bias", "encoder.s1.b1.conv2.bn.running_mean", "encoder.s1.b1.conv2.bn.running_var", "encoder.s1.b1.conv3.conv.weight", "encoder.s1.b1.conv3.bn.weight", "encoder.s1.b1.conv3.bn.bias", "encoder.s1.b1.conv3.bn.running_mean", "encoder.s1.b1.conv3.bn.running_var", "encoder.s1.b1.downsample.conv.weight", "encoder.s1.b1.downsample.bn.weight", "encoder.s1.b1.downsample.bn.bias", "encoder.s1.b1.downsample.bn.running_mean", "encoder.s1.b1.downsample.bn.running_var", "encoder.s1.b2.conv1.conv.weight", "encoder.s1.b2.conv1.bn.weight", "encoder.s1.b2.conv1.bn.bias", "encoder.s1.b2.conv1.bn.running_mean", "encoder.s1.b2.conv1.bn.running_var", "encoder.s1.b2.conv2.conv.weight", "encoder.s1.b2.conv2.bn.weight", "encoder.s1.b2.conv2.bn.bias", "encoder.s1.b2.conv2.bn.running_mean", "encoder.s1.b2.conv2.bn.running_var", "encoder.s1.b2.conv3.conv.weight", "encoder.s1.b2.conv3.bn.weight", "encoder.s1.b2.conv3.bn.bias", "encoder.s1.b2.conv3.bn.running_mean", "encoder.s1.b2.conv3.bn.running_var", "encoder.s2.b1.conv1.conv.weight", "encoder.s2.b1.conv1.bn.weight", "encoder.s2.b1.conv1.bn.bias", "encoder.s2.b1.conv1.bn.running_mean", "encoder.s2.b1.conv1.bn.running_var", "encoder.s2.b1.conv2.conv.weight", "encoder.s2.b1.conv2.bn.weight", "encoder.s2.b1.conv2.bn.bias", "encoder.s2.b1.conv2.bn.running_mean", "encoder.s2.b1.conv2.bn.running_var", "encoder.s2.b1.conv3.conv.weight", "encoder.s2.b1.conv3.bn.weight", "encoder.s2.b1.conv3.bn.bias", "encoder.s2.b1.conv3.bn.running_mean", "encoder.s2.b1.conv3.bn.running_var", "encoder.s2.b1.downsample.conv.weight", "encoder.s2.b1.downsample.bn.weight", "encoder.s2.b1.downsample.bn.bias", "encoder.s2.b1.downsample.bn.running_mean", "encoder.s2.b1.downsample.bn.running_var", "encoder.s2.b2.conv1.conv.weight", "encoder.s2.b2.conv1.bn.weight", "encoder.s2.b2.conv1.bn.bias", "encoder.s2.b2.conv1.bn.running_mean", "encoder.s2.b2.conv1.bn.running_var", "encoder.s2.b2.conv2.conv.weight", "encoder.s2.b2.conv2.bn.weight", "encoder.s2.b2.conv2.bn.bias", "encoder.s2.b2.conv2.bn.running_mean", "encoder.s2.b2.conv2.bn.running_var", "encoder.s2.b2.conv3.conv.weight", "encoder.s2.b2.conv3.bn.weight", "encoder.s2.b2.conv3.bn.bias", "encoder.s2.b2.conv3.bn.running_mean", "encoder.s2.b2.conv3.bn.running_var", "encoder.s2.b3.conv1.conv.weight", "encoder.s2.b3.conv1.bn.weight", "encoder.s2.b3.conv1.bn.bias", "encoder.s2.b3.conv1.bn.running_mean", "encoder.s2.b3.conv1.bn.running_var", "encoder.s2.b3.conv2.conv.weight", "encoder.s2.b3.conv2.bn.weight", "encoder.s2.b3.conv2.bn.bias", "encoder.s2.b3.conv2.bn.running_mean", "encoder.s2.b3.conv2.bn.running_var", "encoder.s2.b3.conv3.conv.weight", "encoder.s2.b3.conv3.bn.weight", "encoder.s2.b3.conv3.bn.bias", "encoder.s2.b3.conv3.bn.running_mean", "encoder.s2.b3.conv3.bn.running_var", "encoder.s2.b4.conv1.conv.weight", "encoder.s2.b4.conv1.bn.weight", "encoder.s2.b4.conv1.bn.bias", "encoder.s2.b4.conv1.bn.running_mean", "encoder.s2.b4.conv1.bn.running_var", "encoder.s2.b4.conv2.conv.weight", "encoder.s2.b4.conv2.bn.weight", "encoder.s2.b4.conv2.bn.bias", "encoder.s2.b4.conv2.bn.running_mean", "encoder.s2.b4.conv2.bn.running_var", "encoder.s2.b4.conv3.conv.weight", "encoder.s2.b4.conv3.bn.weight", "encoder.s2.b4.conv3.bn.bias", "encoder.s2.b4.conv3.bn.running_mean", "encoder.s2.b4.conv3.bn.running_var", "encoder.s2.b5.conv1.conv.weight", "encoder.s2.b5.conv1.bn.weight", "encoder.s2.b5.conv1.bn.bias", "encoder.s2.b5.conv1.bn.running_mean", "encoder.s2.b5.conv1.bn.running_var", "encoder.s2.b5.conv2.conv.weight", "encoder.s2.b5.conv2.bn.weight", "encoder.s2.b5.conv2.bn.bias", "encoder.s2.b5.conv2.bn.running_mean", "encoder.s2.b5.conv2.bn.running_var", "encoder.s2.b5.conv3.conv.weight", "encoder.s2.b5.conv3.bn.weight", "encoder.s2.b5.conv3.bn.bias", "encoder.s2.b5.conv3.bn.running_mean", "encoder.s2.b5.conv3.bn.running_var", "encoder.s3.b1.conv1.conv.weight", "encoder.s3.b1.conv1.bn.weight", "encoder.s3.b1.conv1.bn.bias", "encoder.s3.b1.conv1.bn.running_mean", "encoder.s3.b1.conv1.bn.running_var", "encoder.s3.b1.conv2.conv.weight", "encoder.s3.b1.conv2.bn.weight", "encoder.s3.b1.conv2.bn.bias", "encoder.s3.b1.conv2.bn.running_mean", "encoder.s3.b1.conv2.bn.running_var", "encoder.s3.b1.conv3.conv.weight", "encoder.s3.b1.conv3.bn.weight", "encoder.s3.b1.conv3.bn.bias", "encoder.s3.b1.conv3.bn.running_mean", "encoder.s3.b1.conv3.bn.running_var", "encoder.s3.b1.downsample.conv.weight", "encoder.s3.b1.downsample.bn.weight", "encoder.s3.b1.downsample.bn.bias", "encoder.s3.b1.downsample.bn.running_mean", "encoder.s3.b1.downsample.bn.running_var", "encoder.s3.b2.conv1.conv.weight", "encoder.s3.b2.conv1.bn.weight", "encoder.s3.b2.conv1.bn.bias", "encoder.s3.b2.conv1.bn.running_mean", "encoder.s3.b2.conv1.bn.running_var", "encoder.s3.b2.conv2.conv.weight", "encoder.s3.b2.conv2.bn.weight", "encoder.s3.b2.conv2.bn.bias", "encoder.s3.b2.conv2.bn.running_mean", "encoder.s3.b2.conv2.bn.running_var", "encoder.s3.b2.conv3.conv.weight", "encoder.s3.b2.conv3.bn.weight", "encoder.s3.b2.conv3.bn.bias", "encoder.s3.b2.conv3.bn.running_mean", "encoder.s3.b2.conv3.bn.running_var", "encoder.s3.b3.conv1.conv.weight", "encoder.s3.b3.conv1.bn.weight", "encoder.s3.b3.conv1.bn.bias", "encoder.s3.b3.conv1.bn.running_mean", "encoder.s3.b3.conv1.bn.running_var", "encoder.s3.b3.conv2.conv.weight", "encoder.s3.b3.conv2.bn.weight", "encoder.s3.b3.conv2.bn.bias", "encoder.s3.b3.conv2.bn.running_mean", "encoder.s3.b3.conv2.bn.running_var", "encoder.s3.b3.conv3.conv.weight", "encoder.s3.b3.conv3.bn.weight", "encoder.s3.b3.conv3.bn.bias", "encoder.s3.b3.conv3.bn.running_mean", "encoder.s3.b3.conv3.bn.running_var", "encoder.s3.b4.conv1.conv.weight", "encoder.s3.b4.conv1.bn.weight", "encoder.s3.b4.conv1.bn.bias", "encoder.s3.b4.conv1.bn.running_mean", "encoder.s3.b4.conv1.bn.running_var", "encoder.s3.b4.conv2.conv.weight", "encoder.s3.b4.conv2.bn.weight", "encoder.s3.b4.conv2.bn.bias", "encoder.s3.b4.conv2.bn.running_mean", "encoder.s3.b4.conv2.bn.running_var", "encoder.s3.b4.conv3.conv.weight", "encoder.s3.b4.conv3.bn.weight", "encoder.s3.b4.conv3.bn.bias", "encoder.s3.b4.conv3.bn.running_mean", "encoder.s3.b4.conv3.bn.running_var", "encoder.s3.b5.conv1.conv.weight", "encoder.s3.b5.conv1.bn.weight", "encoder.s3.b5.conv1.bn.bias", "encoder.s3.b5.conv1.bn.running_mean", "encoder.s3.b5.conv1.bn.running_var", "encoder.s3.b5.conv2.conv.weight", "encoder.s3.b5.conv2.bn.weight", "encoder.s3.b5.conv2.bn.bias", "encoder.s3.b5.conv2.bn.running_mean", "encoder.s3.b5.conv2.bn.running_var", "encoder.s3.b5.conv3.conv.weight", "encoder.s3.b5.conv3.bn.weight", "encoder.s3.b5.conv3.bn.bias", "encoder.s3.b5.conv3.bn.running_mean", "encoder.s3.b5.conv3.bn.running_var", "encoder.s3.b6.conv1.conv.weight", "encoder.s3.b6.conv1.bn.weight", "encoder.s3.b6.conv1.bn.bias", "encoder.s3.b6.conv1.bn.running_mean", "encoder.s3.b6.conv1.bn.running_var", "encoder.s3.b6.conv2.conv.weight", "encoder.s3.b6.conv2.bn.weight", "encoder.s3.b6.conv2.bn.bias", "encoder.s3.b6.conv2.bn.running_mean", "encoder.s3.b6.conv2.bn.running_var", "encoder.s3.b6.conv3.conv.weight", "encoder.s3.b6.conv3.bn.weight", "encoder.s3.b6.conv3.bn.bias", "encoder.s3.b6.conv3.bn.running_mean", "encoder.s3.b6.conv3.bn.running_var", "encoder.s3.b7.conv1.conv.weight", "encoder.s3.b7.conv1.bn.weight", "encoder.s3.b7.conv1.bn.bias", "encoder.s3.b7.conv1.bn.running_mean", "encoder.s3.b7.conv1.bn.running_var", "encoder.s3.b7.conv2.conv.weight", "encoder.s3.b7.conv2.bn.weight", "encoder.s3.b7.conv2.bn.bias", "encoder.s3.b7.conv2.bn.running_mean", "encoder.s3.b7.conv2.bn.running_var", "encoder.s3.b7.conv3.conv.weight", "encoder.s3.b7.conv3.bn.weight", "encoder.s3.b7.conv3.bn.bias", "encoder.s3.b7.conv3.bn.running_mean", "encoder.s3.b7.conv3.bn.running_var", "encoder.s3.b8.conv1.conv.weight", "encoder.s3.b8.conv1.bn.weight", "encoder.s3.b8.conv1.bn.bias", "encoder.s3.b8.conv1.bn.running_mean", "encoder.s3.b8.conv1.bn.running_var", "encoder.s3.b8.conv2.conv.weight", "encoder.s3.b8.conv2.bn.weight", "encoder.s3.b8.conv2.bn.bias", "encoder.s3.b8.conv2.bn.running_mean", "encoder.s3.b8.conv2.bn.running_var", "encoder.s3.b8.conv3.conv.weight", "encoder.s3.b8.conv3.bn.weight", "encoder.s3.b8.conv3.bn.bias", "encoder.s3.b8.conv3.bn.running_mean", "encoder.s3.b8.conv3.bn.running_var", "encoder.s3.b9.conv1.conv.weight", "encoder.s3.b9.conv1.bn.weight", "encoder.s3.b9.conv1.bn.bias", "encoder.s3.b9.conv1.bn.running_mean", "encoder.s3.b9.conv1.bn.running_var", "encoder.s3.b9.conv2.conv.weight", "encoder.s3.b9.conv2.bn.weight", "encoder.s3.b9.conv2.bn.bias", "encoder.s3.b9.conv2.bn.running_mean", "encoder.s3.b9.conv2.bn.running_var", "encoder.s3.b9.conv3.conv.weight", "encoder.s3.b9.conv3.bn.weight", "encoder.s3.b9.conv3.bn.bias", "encoder.s3.b9.conv3.bn.running_mean", "encoder.s3.b9.conv3.bn.running_var", "encoder.s3.b10.conv1.conv.weight", "encoder.s3.b10.conv1.bn.weight", "encoder.s3.b10.conv1.bn.bias", "encoder.s3.b10.conv1.bn.running_mean", "encoder.s3.b10.conv1.bn.running_var", "encoder.s3.b10.conv2.conv.weight", "encoder.s3.b10.conv2.bn.weight", "encoder.s3.b10.conv2.bn.bias", "encoder.s3.b10.conv2.bn.running_mean", "encoder.s3.b10.conv2.bn.running_var", "encoder.s3.b10.conv3.conv.weight", "encoder.s3.b10.conv3.bn.weight", "encoder.s3.b10.conv3.bn.bias", "encoder.s3.b10.conv3.bn.running_mean", "encoder.s3.b10.conv3.bn.running_var", "encoder.s3.b11.conv1.conv.weight", "encoder.s3.b11.conv1.bn.weight", "encoder.s3.b11.conv1.bn.bias", "encoder.s3.b11.conv1.bn.running_mean", "encoder.s3.b11.conv1.bn.running_var", "encoder.s3.b11.conv2.conv.weight", "encoder.s3.b11.conv2.bn.weight", "encoder.s3.b11.conv2.bn.bias", "encoder.s3.b11.conv2.bn.running_mean", "encoder.s3.b11.conv2.bn.running_var", "encoder.s3.b11.conv3.conv.weight", "encoder.s3.b11.conv3.bn.weight", "encoder.s3.b11.conv3.bn.bias", "encoder.s3.b11.conv3.bn.running_mean", "encoder.s3.b11.conv3.bn.running_var", "encoder.s3.b12.conv1.conv.weight", "encoder.s3.b12.conv1.bn.weight", "encoder.s3.b12.conv1.bn.bias", "encoder.s3.b12.conv1.bn.running_mean", "encoder.s3.b12.conv1.bn.running_var", "encoder.s3.b12.conv2.conv.weight", "encoder.s3.b12.conv2.bn.weight", "encoder.s3.b12.conv2.bn.bias", "encoder.s3.b12.conv2.bn.running_mean", "encoder.s3.b12.conv2.bn.running_var", "encoder.s3.b12.conv3.conv.weight", "encoder.s3.b12.conv3.bn.weight", "encoder.s3.b12.conv3.bn.bias", "encoder.s3.b12.conv3.bn.running_mean", "encoder.s3.b12.conv3.bn.running_var", "encoder.s3.b13.conv1.conv.weight", "encoder.s3.b13.conv1.bn.weight", "encoder.s3.b13.conv1.bn.bias", "encoder.s3.b13.conv1.bn.running_mean", "encoder.s3.b13.conv1.bn.running_var", "encoder.s3.b13.conv2.conv.weight", "encoder.s3.b13.conv2.bn.weight", "encoder.s3.b13.conv2.bn.bias", "encoder.s3.b13.conv2.bn.running_mean", "encoder.s3.b13.conv2.bn.running_var", "encoder.s3.b13.conv3.conv.weight", "encoder.s3.b13.conv3.bn.weight", "encoder.s3.b13.conv3.bn.bias", "encoder.s3.b13.conv3.bn.running_mean", "encoder.s3.b13.conv3.bn.running_var", "encoder.s3.b14.conv1.conv.weight", "encoder.s3.b14.conv1.bn.weight", "encoder.s3.b14.conv1.bn.bias", "encoder.s3.b14.conv1.bn.running_mean", "encoder.s3.b14.conv1.bn.running_var", "encoder.s3.b14.conv2.conv.weight", "encoder.s3.b14.conv2.bn.weight", "encoder.s3.b14.conv2.bn.bias", "encoder.s3.b14.conv2.bn.running_mean", "encoder.s3.b14.conv2.bn.running_var", "encoder.s3.b14.conv3.conv.weight", "encoder.s3.b14.conv3.bn.weight", "encoder.s3.b14.conv3.bn.bias", "encoder.s3.b14.conv3.bn.running_mean", "encoder.s3.b14.conv3.bn.running_var", "encoder.s4.b1.conv1.conv.weight", "encoder.s4.b1.conv1.bn.weight", "encoder.s4.b1.conv1.bn.bias", "encoder.s4.b1.conv1.bn.running_mean", "encoder.s4.b1.conv1.bn.running_var", "encoder.s4.b1.conv2.conv.weight", "encoder.s4.b1.conv2.bn.weight", "encoder.s4.b1.conv2.bn.bias", "encoder.s4.b1.conv2.bn.running_mean", "encoder.s4.b1.conv2.bn.running_var", "encoder.s4.b1.conv3.conv.weight", "encoder.s4.b1.conv3.bn.weight", "encoder.s4.b1.conv3.bn.bias", "encoder.s4.b1.conv3.bn.running_mean", "encoder.s4.b1.conv3.bn.running_var", "encoder.s4.b1.downsample.conv.weight", "encoder.s4.b1.downsample.bn.weight", "encoder.s4.b1.downsample.bn.bias", "encoder.s4.b1.downsample.bn.running_mean", "encoder.s4.b1.downsample.bn.running_var", "encoder.s4.b2.conv1.conv.weight", "encoder.s4.b2.conv1.bn.weight", "encoder.s4.b2.conv1.bn.bias", "encoder.s4.b2.conv1.bn.running_mean", "encoder.s4.b2.conv1.bn.running_var", "encoder.s4.b2.conv2.conv.weight", "encoder.s4.b2.conv2.bn.weight", "encoder.s4.b2.conv2.bn.bias", "encoder.s4.b2.conv2.bn.running_mean", "encoder.s4.b2.conv2.bn.running_var", "encoder.s4.b2.conv3.conv.weight", "encoder.s4.b2.conv3.bn.weight", "encoder.s4.b2.conv3.bn.bias", "encoder.s4.b2.conv3.bn.running_mean", "encoder.s4.b2.conv3.bn.running_var". 
	Unexpected key(s) in state_dict: "encoder.conv1.weight", "encoder.bn1.weight", "encoder.bn1.bias", "encoder.bn1.running_mean", "encoder.bn1.running_var", "encoder.bn1.num_batches_tracked", "encoder.layer1.0.conv1.weight", "encoder.layer1.0.bn1.weight", "encoder.layer1.0.bn1.bias", "encoder.layer1.0.bn1.running_mean", "encoder.layer1.0.bn1.running_var", "encoder.layer1.0.bn1.num_batches_tracked", "encoder.layer1.0.conv2.weight", "encoder.layer1.0.bn2.weight", "encoder.layer1.0.bn2.bias", "encoder.layer1.0.bn2.running_mean", "encoder.layer1.0.bn2.running_var", "encoder.layer1.0.bn2.num_batches_tracked", "encoder.layer1.1.conv1.weight", "encoder.layer1.1.bn1.weight", "encoder.layer1.1.bn1.bias", "encoder.layer1.1.bn1.running_mean", "encoder.layer1.1.bn1.running_var", "encoder.layer1.1.bn1.num_batches_tracked", "encoder.layer1.1.conv2.weight", "encoder.layer1.1.bn2.weight", "encoder.layer1.1.bn2.bias", "encoder.layer1.1.bn2.running_mean", "encoder.layer1.1.bn2.running_var", "encoder.layer1.1.bn2.num_batches_tracked", "encoder.layer1.2.conv1.weight", "encoder.layer1.2.bn1.weight", "encoder.layer1.2.bn1.bias", "encoder.layer1.2.bn1.running_mean", "encoder.layer1.2.bn1.running_var", "encoder.layer1.2.bn1.num_batches_tracked", "encoder.layer1.2.conv2.weight", "encoder.layer1.2.bn2.weight", "encoder.layer1.2.bn2.bias", "encoder.layer1.2.bn2.running_mean", "encoder.layer1.2.bn2.running_var", "encoder.layer1.2.bn2.num_batches_tracked", "encoder.layer2.0.conv1.weight", "encoder.layer2.0.bn1.weight", "encoder.layer2.0.bn1.bias", "encoder.layer2.0.bn1.running_mean", "encoder.layer2.0.bn1.running_var", "encoder.layer2.0.bn1.num_batches_tracked", "encoder.layer2.0.conv2.weight", "encoder.layer2.0.bn2.weight", "encoder.layer2.0.bn2.bias", "encoder.layer2.0.bn2.running_mean", "encoder.layer2.0.bn2.running_var", "encoder.layer2.0.bn2.num_batches_tracked", "encoder.layer2.0.downsample.0.weight", "encoder.layer2.0.downsample.1.weight", "encoder.layer2.0.downsample.1.bias", "encoder.layer2.0.downsample.1.running_mean", "encoder.layer2.0.downsample.1.running_var", "encoder.layer2.0.downsample.1.num_batches_tracked", "encoder.layer2.1.conv1.weight", "encoder.layer2.1.bn1.weight", "encoder.layer2.1.bn1.bias", "encoder.layer2.1.bn1.running_mean", "encoder.layer2.1.bn1.running_var", "encoder.layer2.1.bn1.num_batches_tracked", "encoder.layer2.1.conv2.weight", "encoder.layer2.1.bn2.weight", "encoder.layer2.1.bn2.bias", "encoder.layer2.1.bn2.running_mean", "encoder.layer2.1.bn2.running_var", "encoder.layer2.1.bn2.num_batches_tracked", "encoder.layer2.2.conv1.weight", "encoder.layer2.2.bn1.weight", "encoder.layer2.2.bn1.bias", "encoder.layer2.2.bn1.running_mean", "encoder.layer2.2.bn1.running_var", "encoder.layer2.2.bn1.num_batches_tracked", "encoder.layer2.2.conv2.weight", "encoder.layer2.2.bn2.weight", "encoder.layer2.2.bn2.bias", "encoder.layer2.2.bn2.running_mean", "encoder.layer2.2.bn2.running_var", "encoder.layer2.2.bn2.num_batches_tracked", "encoder.layer2.3.conv1.weight", "encoder.layer2.3.bn1.weight", "encoder.layer2.3.bn1.bias", "encoder.layer2.3.bn1.running_mean", "encoder.layer2.3.bn1.running_var", "encoder.layer2.3.bn1.num_batches_tracked", "encoder.layer2.3.conv2.weight", "encoder.layer2.3.bn2.weight", "encoder.layer2.3.bn2.bias", "encoder.layer2.3.bn2.running_mean", "encoder.layer2.3.bn2.running_var", "encoder.layer2.3.bn2.num_batches_tracked", "encoder.layer3.0.conv1.weight", "encoder.layer3.0.bn1.weight", "encoder.layer3.0.bn1.bias", "encoder.layer3.0.bn1.running_mean", "encoder.layer3.0.bn1.running_var", "encoder.layer3.0.bn1.num_batches_tracked", "encoder.layer3.0.conv2.weight", "encoder.layer3.0.bn2.weight", "encoder.layer3.0.bn2.bias", "encoder.layer3.0.bn2.running_mean", "encoder.layer3.0.bn2.running_var", "encoder.layer3.0.bn2.num_batches_tracked", "encoder.layer3.0.downsample.0.weight", "encoder.layer3.0.downsample.1.weight", "encoder.layer3.0.downsample.1.bias", "encoder.layer3.0.downsample.1.running_mean", "encoder.layer3.0.downsample.1.running_var", "encoder.layer3.0.downsample.1.num_batches_tracked", "encoder.layer3.1.conv1.weight", "encoder.layer3.1.bn1.weight", "encoder.layer3.1.bn1.bias", "encoder.layer3.1.bn1.running_mean", "encoder.layer3.1.bn1.running_var", "encoder.layer3.1.bn1.num_batches_tracked", "encoder.layer3.1.conv2.weight", "encoder.layer3.1.bn2.weight", "encoder.layer3.1.bn2.bias", "encoder.layer3.1.bn2.running_mean", "encoder.layer3.1.bn2.running_var", "encoder.layer3.1.bn2.num_batches_tracked", "encoder.layer3.2.conv1.weight", "encoder.layer3.2.bn1.weight", "encoder.layer3.2.bn1.bias", "encoder.layer3.2.bn1.running_mean", "encoder.layer3.2.bn1.running_var", "encoder.layer3.2.bn1.num_batches_tracked", "encoder.layer3.2.conv2.weight", "encoder.layer3.2.bn2.weight", "encoder.layer3.2.bn2.bias", "encoder.layer3.2.bn2.running_mean", "encoder.layer3.2.bn2.running_var", "encoder.layer3.2.bn2.num_batches_tracked", "encoder.layer3.3.conv1.weight", "encoder.layer3.3.bn1.weight", "encoder.layer3.3.bn1.bias", "encoder.layer3.3.bn1.running_mean", "encoder.layer3.3.bn1.running_var", "encoder.layer3.3.bn1.num_batches_tracked", "encoder.layer3.3.conv2.weight", "encoder.layer3.3.bn2.weight", "encoder.layer3.3.bn2.bias", "encoder.layer3.3.bn2.running_mean", "encoder.layer3.3.bn2.running_var", "encoder.layer3.3.bn2.num_batches_tracked", "encoder.layer3.4.conv1.weight", "encoder.layer3.4.bn1.weight", "encoder.layer3.4.bn1.bias", "encoder.layer3.4.bn1.running_mean", "encoder.layer3.4.bn1.running_var", "encoder.layer3.4.bn1.num_batches_tracked", "encoder.layer3.4.conv2.weight", "encoder.layer3.4.bn2.weight", "encoder.layer3.4.bn2.bias", "encoder.layer3.4.bn2.running_mean", "encoder.layer3.4.bn2.running_var", "encoder.layer3.4.bn2.num_batches_tracked", "encoder.layer3.5.conv1.weight", "encoder.layer3.5.bn1.weight", "encoder.layer3.5.bn1.bias", "encoder.layer3.5.bn1.running_mean", "encoder.layer3.5.bn1.running_var", "encoder.layer3.5.bn1.num_batches_tracked", "encoder.layer3.5.conv2.weight", "encoder.layer3.5.bn2.weight", "encoder.layer3.5.bn2.bias", "encoder.layer3.5.bn2.running_mean", "encoder.layer3.5.bn2.running_var", "encoder.layer3.5.bn2.num_batches_tracked", "encoder.layer4.0.conv1.weight", "encoder.layer4.0.bn1.weight", "encoder.layer4.0.bn1.bias", "encoder.layer4.0.bn1.running_mean", "encoder.layer4.0.bn1.running_var", "encoder.layer4.0.bn1.num_batches_tracked", "encoder.layer4.0.conv2.weight", "encoder.layer4.0.bn2.weight", "encoder.layer4.0.bn2.bias", "encoder.layer4.0.bn2.running_mean", "encoder.layer4.0.bn2.running_var", "encoder.layer4.0.bn2.num_batches_tracked", "encoder.layer4.0.downsample.0.weight", "encoder.layer4.0.downsample.1.weight", "encoder.layer4.0.downsample.1.bias", "encoder.layer4.0.downsample.1.running_mean", "encoder.layer4.0.downsample.1.running_var", "encoder.layer4.0.downsample.1.num_batches_tracked", "encoder.layer4.1.conv1.weight", "encoder.layer4.1.bn1.weight", "encoder.layer4.1.bn1.bias", "encoder.layer4.1.bn1.running_mean", "encoder.layer4.1.bn1.running_var", "encoder.layer4.1.bn1.num_batches_tracked", "encoder.layer4.1.conv2.weight", "encoder.layer4.1.bn2.weight", "encoder.layer4.1.bn2.bias", "encoder.layer4.1.bn2.running_mean", "encoder.layer4.1.bn2.running_var", "encoder.layer4.1.bn2.num_batches_tracked", "encoder.layer4.2.conv1.weight", "encoder.layer4.2.bn1.weight", "encoder.layer4.2.bn1.bias", "encoder.layer4.2.bn1.running_mean", "encoder.layer4.2.bn1.running_var", "encoder.layer4.2.bn1.num_batches_tracked", "encoder.layer4.2.conv2.weight", "encoder.layer4.2.bn2.weight", "encoder.layer4.2.bn2.bias", "encoder.layer4.2.bn2.running_mean", "encoder.layer4.2.bn2.running_var", "encoder.layer4.2.bn2.num_batches_tracked". 
	size mismatch for decoder.blocks.x_0_0.conv1.0.weight: copying a param with shape torch.Size([256, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1920, 3, 3]).
	size mismatch for decoder.blocks.x_0_1.conv1.0.weight: copying a param with shape torch.Size([128, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 736, 3, 3]).
	size mismatch for decoder.blocks.x_1_1.conv1.0.weight: copying a param with shape torch.Size([128, 384, 3, 3]) from checkpoint, the shape in current model is torch.Size([240, 800, 3, 3]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv1.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.0.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([240, 240, 3, 3]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_1_1.conv2.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([240]).
	size mismatch for decoder.blocks.x_0_2.conv1.0.weight: copying a param with shape torch.Size([64, 320, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 368, 3, 3]).
	size mismatch for decoder.blocks.x_1_2.conv1.0.weight: copying a param with shape torch.Size([64, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 400, 3, 3]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_1_2.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.0.weight: copying a param with shape torch.Size([64, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 320, 3, 3]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_2_2.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for decoder.blocks.x_0_3.conv1.0.weight: copying a param with shape torch.Size([32, 320, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 192, 3, 3]).
	size mismatch for decoder.blocks.x_1_3.conv1.0.weight: copying a param with shape torch.Size([64, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 176, 3, 3]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_1_3.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.0.weight: copying a param with shape torch.Size([64, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 144, 3, 3]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_2_3.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.0.weight: copying a param with shape torch.Size([64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 112, 3, 3]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.0.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for decoder.blocks.x_3_3.conv2.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).

In [None]:
count = 0 

for image, mask in DataLoader(valDataset, batch_size=1, shuffle=False, num_workers=0): 
    predict_image(model, image.squeeze(0), mask.squeeze(0))
    print("-----------------------")
    count += 1
    if count == 50:
        break