In [1]:
# import the necessary packages
import torch
import os
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [2]:
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

In [3]:
# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3

# initialize learning rate, number of epochs to train for, and the
# batch size
INIT_LR = 0.001
NUM_EPOCHS = 200
BATCH_SIZE = 32

# define the input image dimensions
INPUT_IMAGE_WIDTH = 256
INPUT_IMAGE_HEIGHT = 256

# define threshold to filter weak predictions
THRESHOLD = 0.5

# define the path to the base output directory
BASE_OUTPUT = "output"

# define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = os.path.join(BASE_OUTPUT, "unet.pth")
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])
TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])

In [4]:
subdir = ["benign_image/", "benign_mask/", "malignant_image/",
          "malignant_mask/", "normal_image/", "normal_mask/"]
dataset = ['dataset/','train/', 'test/', 'validation/']

train_test_valid = [[[], [], []], [[], [], []], [[], [], []]]

for i in range(1, len(dataset)):
    for j in range(3):
    # for j in range(2):
        for k in range(len(os.listdir(dataset[0]+dataset[i]+subdir[j*2]))):
            train_test_valid[i-1][0].append(plt.imread(
                dataset[0]+dataset[i]+subdir[j*2]+str(k)+".jpeg"))
            train_test_valid[i-1][1].append(plt.imread(
                dataset[0]+dataset[i]+subdir[j*2+1]+str(k)+".jpeg"))
            train_test_valid[i-1][2].append(j)

X_train_npy = np.asarray(train_test_valid[0][0], dtype=np.float32)/255
y_train_npy = np.asarray(train_test_valid[0][1], dtype=np.float32)/255

X_test_npy = np.asarray(train_test_valid[1][0], dtype=np.float32)/255
y_test_npy = np.asarray(train_test_valid[1][1], dtype=np.float32)/255

X_valid_npy = np.asarray(train_test_valid[2][0], dtype=np.float32)/255
y_valid_npy = np.asarray(train_test_valid[2][1], dtype=np.float32)/255

In [5]:
class SegmentationDataset(Dataset):

	def __init__(self, X, y, transforms):
		self.X = X
		self.y = y
		self.transforms = transforms

	def __len__(self):
		return len(self.y)

	def __getitem__(self, idx):
		
		image = self.X[idx]
		mask = self.y[idx]
  
		if self.transforms is not None:
			image = self.transforms(image)
			mask = self.transforms(mask)
   
		# return a tuple of the image and its mask
		return (image, mask)

In [6]:
# import the necessary packages
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch.nn as nn

In [7]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        if not mid_channels:
            mid_channels = out_channels
            
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Encoder(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.max = nn.MaxPool2d(2)
        self.conv_block = DoubleConv(in_channels, out_channels)
        # self.dropout = nn.Dropout2d(p=0.3)

    def forward(self, x):
        
        conv = self.conv_block(x)
        pool = self.max(conv)
        # drop = self.dropout(pool)
        
        # return conv, drop
        return conv, pool

In [8]:
class DecoderB1(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = DoubleConv(in_channels, out_channels)
        # self.dropout = nn.Dropout2d(p=0.3)
    
    def forward(self, x, skip_features):

        x = self.conv_transpose(x)
        x = torch.cat([x, skip_features],dim=1)
        # x = self.dropout(x)
        x = self.conv_block(x)

        return x

class DecoderB2(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = DoubleConv(in_channels, out_channels)
        # self.dropout = nn.Dropout2d(p=0.3)
    
    def forward(self, x, skip_features):

        x = self.conv_transpose(x)
        x = torch.cat([x, skip_features],dim=1)
        # x = self.dropout(x)
        x = self.conv_block(x)

        return x

In [9]:
class Branch1(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        
        self.conv_block = DoubleConv(in_channels, mid_channels)
        self.conv_2d = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv_block(x)
        x = self.conv_2d(x)
        x = self.sigmoid(x)
        
        return x

class Branch2(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        
        self.conv_block = DoubleConv(in_channels, mid_channels)
        self.conv_2d = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv_block(x)
        x = self.conv_2d(x)
        x = self.sigmoid(x)
        
        return x

In [10]:
# n_input_channels=3
# n_features=64
# down1 = Encoder(n_input_channels, n_features)
# down2 = Encoder(n_features, n_features*2)
# down3 = Encoder(n_features*2, n_features*4)
# down4 = Encoder(n_features*4, n_features*8)
# bridge = DoubleConv(n_features*8, n_features*16)

# conv1, pool1 = down1(val1)
# conv2, pool2 = down2(pool1)
# conv3, pool3 = down3(pool2)
# conv4, pool4 = down4(pool3)

# bridge = bridge(pool4)

# bridge.shape

In [18]:
class UNet(nn.Module):
    def __init__(self, input_shape=256, n_input_channels=3, n_output_channels_b1=1, n_output_channels_b2=3, n_features=64, latent_dim=128):
        super(UNet, self).__init__()
        self.input_shape = input_shape
        self.n_features = n_features
        
        self.down1 = Encoder(n_input_channels, n_features)
        self.down2 = Encoder(n_features, n_features*2)
        self.down3 = Encoder(n_features*2, n_features*4)
        self.down4 = Encoder(n_features*4, n_features*8)
        
        self.bridge = DoubleConv(n_features*8, n_features*16)
        #1024,16,16
        
        self.upb1_1 = DecoderB1(n_features*16, n_features*8)
        self.upb1_2 = DecoderB1(n_features*8, n_features*4)
        self.upb1_3 = DecoderB1(n_features*4, n_features*2)
        self.upb1_4 = DecoderB1(n_features*2, n_features)
        
        flatten_size = n_features*16 * int(input_shape/16) * int(input_shape/16) # 64x16x16x16
        
        self.intermediate1 = nn.Sequential(
            nn.Linear(flatten_size, latent_dim*2),
            nn.BatchNorm1d(latent_dim*2),
            nn.ReLU(inplace=True),
        )
        
        
        self.fc_mu = nn.Linear(latent_dim*2, latent_dim)
        self.fc_var = nn.Linear(latent_dim*2, latent_dim)
        
        self.decoder_input = nn.Sequential(
            nn.Linear(latent_dim, flatten_size),
            nn.ReLU(inplace=True),
        )
        
        # self.fc_mu = nn.Linear(flatten_size, latent_dim)
        # self.fc_var = nn.Linear(flatten_size, latent_dim)
        
        # self.decoder_input = nn.Linear(latent_dim, flatten_size)
        
        self.upb2_1 = DecoderB2(n_features*16, n_features*8)
        self.upb2_2 = DecoderB2(n_features*8, n_features*4)
        self.upb2_3 = DecoderB2(n_features*4, n_features*2)
        self.upb2_4 = DecoderB2(n_features*2, n_features)
        
        self.outchannel1 = Branch1(n_features, n_features, n_output_channels_b1)
        self.outchannel2 = Branch2(n_features, n_features, n_output_channels_b2)
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        
        if(torch.isnan(std).any()):
            print(f'std : {torch.isnan(std).any()}')
        
        eps = torch.randn_like(std)
        
        return eps * std + mu
    
    def forward(self, x):
        
        conv1, pool1 = self.down1(x)
        conv2, pool2 = self.down2(pool1)
        conv3, pool3 = self.down3(pool2)
        conv4, pool4 = self.down4(pool3)
        
        bridge = self.bridge(pool4)
        
        # 1st branch
        decoder_b1_1 = self.upb1_1(bridge, conv4)
        decoder_b1_2 = self.upb1_2(decoder_b1_1, conv3)
        decoder_b1_3 = self.upb1_3(decoder_b1_2, conv2)
        decoder_b1_4 = self.upb1_4(decoder_b1_3, conv1)
        
        logits1 = self.outchannel1(decoder_b1_4)
        
        #2nd branch vae
        
        flattened_bridge = torch.flatten(bridge, start_dim=1)
        
        intermediate1 = self.intermediate1(flattened_bridge)
        
        
        mu = self.fc_mu(intermediate1)
        log_var = self.fc_var(intermediate1)
        # log_var = torch.clip(log_var,min=-8,max=80)
        
        z = self.reparameterize(mu,log_var)
        
        result = self.decoder_input(z)
        reshaped_result = result.view(-1, self.n_features*16, int(self.input_shape/16),int(self.input_shape/16))
        
        # 2nd branch
        decoder_b2_1 = self.upb2_1(reshaped_result, conv4)
        decoder_b2_2 = self.upb2_2(decoder_b2_1, conv3)
        decoder_b2_3 = self.upb2_3(decoder_b2_2, conv2)
        decoder_b2_4 = self.upb2_4(decoder_b2_3, conv1)
                
        logits2 = self.outchannel2(decoder_b2_4)
        
        # -------------------------------------------------------------------------------------------    
        
        # if(torch.isnan(conv1).any()):
        #     print(f'conv1: {torch.isnan(conv1).any()}')
            
        # if(torch.isnan(flattened_bridge).any()):
        #     print(f'flattened_bridge: {torch.isnan(flattened_bridge).any()}')
            
        # if(torch.isnan(mu).any()):
        #     print(f'mu: {torch.isnan(flattened_bridge).any()}')
            
        # if(torch.isnan(log_var).any()):
        #     print(f'Log_var: {torch.isnan(log_var).any()}')
        
        # if(torch.isnan(torch.isnan(z).any())):
        #     print(f'z: {torch.isnan(z).any()}')
        
        # if(torch.isnan(result).any()):
        #     # print("mu : ")
        #     # print(mu)
        #     print("z : ")
        #     print(z)
        #     print("result : ")
        #     print(result)
        #     print(f'result: {torch.isnan(result).any()}')
            
        # -------------------------------------------------------------------------------------------    
        
        # if(torch.isnan(reshaped_result).any()):
        #     print(f'reshaped_result: {torch.isnan(reshaped_result).any()}')
        
        # if(torch.isnan(decoder_b2_1).any()):
        #     print(f'decoder_b2_1: {torch.isnan(decoder_b2_1).any()}')
            
        # if(torch.isnan(decoder_b2_4).any()):
        #     print(f'decoder_b2_4: {torch.isnan(decoder_b2_4).any()}')
            
        # if(torch.isnan(logits2).any()):
        #     print(f'logits2: {torch.isnan(logits2).any()}')
        
        # if(torch.isnan(logits1).any()):
        #     print(f'logits1: {torch.isnan(logits1).any()}')
        
        # var_x = torch.clip(torch.exp(log_var), min=1e-5)
        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - var_x, dim = 1), dim = 0)
        
        # -------------------------------------------------------------------------------------------    
        
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        # print(kld_loss)
        
        return logits1, logits2, kld_loss


In [12]:
#custom dice loss
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1e-6):
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [13]:
loss = DiceLoss()
loss(torch.from_numpy(y_train_npy[:32]),torch.from_numpy(y_test_npy[:32]))

tensor(0.8170)

In [14]:
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import time

In [22]:
# define transformations
transforms_ = transforms.Compose([transforms.ToTensor()])

# create the train and test datasets
trainDS = SegmentationDataset(X=X_train_npy, y=y_train_npy, transforms=transforms_)
validDS = SegmentationDataset(X=X_valid_npy, y=y_valid_npy, transforms=transforms_)

print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(validDS)} examples in the validation set...")

# create the training and test data loaders
trainLoader = DataLoader(trainDS, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,num_workers=os.cpu_count(),shuffle=True,drop_last=True)
validLoader = DataLoader(validDS, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,num_workers=os.cpu_count(),shuffle=True,drop_last=True)

# initialize our UNet model
n_features = 64
input_shape= 256

unet = UNet(input_shape=input_shape, n_input_channels=3, n_output_channels_b1=1,
            n_output_channels_b2=3, n_features= n_features, latent_dim=128).to(DEVICE)

# initialize loss function and optimizer
opt = Adam(unet.parameters(), lr=INIT_LR, weight_decay=1e-5)

dice = DiceLoss()
mse = nn.MSELoss()

# calculate steps per epoch for training and valid set
trainSteps = len(trainDS) // BATCH_SIZE
validSteps = len(validDS) // BATCH_SIZE

# initialize a dictionary to store training history
H = {
    "train_loss": [], "valid_loss": [],
	"train_dice_loss": [], "valid_dice_loss": [],
	"train_mse_loss": [], "valid_mse_loss": [],
	"train_vae_loss":[],"valid_vae_loss":[]
}

# loop over epochs
print("[INFO] training the network...")
startTime = time.time()

for e in tqdm(range(NUM_EPOCHS)):
	# set the model in training mode
	unet.train()

	# initialize the total training and validation loss
	totalTrainLoss = 0
	totalValidLoss = 0
 
	totalTrainKLLoss = 0
	totalValidKLLoss = 0
 
	totalTrainDiceLoss = 0
	totalValidDiceLoss = 0
 
	totalTrainMSELoss = 0
	totalValidMSELoss = 0
 
	# loop over the training set
	for (i, (x, y)) in enumerate(trainLoader):
		# print(i)
		# send the input to the device
		(x, y) = (x.to(DEVICE), y.to(DEVICE))
  
		# perform a forward pass and calculate the training loss
		pred_b1, pred_b2, kl_loss = unet(x)

		loss1 = dice(pred_b1, y)
		loss2 = mse(pred_b2.view(-1), x.view(-1))
  
  
		loss = loss1 + loss2 + kl_loss

		# first, zero out any previously accumulated gradients, then
		# perform backpropagation, and then update model parameters
		opt.zero_grad()
		loss.backward()
		opt.step()
		
		# add the loss to the total training loss so far
		totalTrainDiceLoss += loss1
		totalTrainMSELoss += loss2
		totalTrainKLLoss += kl_loss
  
		totalTrainLoss += loss
	
 
	# switch off autograd
	with torch.no_grad():
		# set the model in evaluation mode
		unet.eval()
  
		# loop over the validation set
		for (x, y) in validLoader:
      
			# send the input to the device
			(x, y) = (x.to(DEVICE), y.to(DEVICE))

			# make the predictions and calculate the validation loss
			pred_b1, pred_b2, valid_kl_loss = unet(x)

			validloss1 = dice(pred_b1, y)
			validloss2 = mse(pred_b2.view(-1), x.view(-1))
			validloss = validloss1 + validloss2 + valid_kl_loss
   
			# add the loss to the total validation loss so far
			totalValidDiceLoss += validloss1
			totalValidMSELoss += validloss2
			totalValidKLLoss += valid_kl_loss
   
			totalValidLoss += validloss
   
	# calculate the average training and validation loss
	avgTrainDiceLoss = totalTrainDiceLoss / trainSteps
	avgValidDiceLoss = totalValidDiceLoss / validSteps
 
 	# calculate the average training and validation loss
	avgTrainMSELoss = totalTrainMSELoss / trainSteps
	avgValidMSELoss = totalValidMSELoss / validSteps
 
  	# calculate the average training and validation loss
	avgTrainKLLoss = totalTrainKLLoss / trainSteps
	avgValidKLLoss = totalValidKLLoss / validSteps

 	# calculate the average training and validation loss
	avgTrainLoss = totalTrainLoss / trainSteps
	avgValidLoss = totalValidLoss / validSteps
 
	# avgTrainKLLoss = totalTrainKLLoss / trainSteps
 
	# update our training history
	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["valid_loss"].append(avgValidLoss.cpu().detach().numpy())
 
	H["train_dice_loss"].append(avgTrainDiceLoss.cpu().detach().numpy())
	H["valid_dice_loss"].append(avgValidDiceLoss.cpu().detach().numpy())
 
	H["train_vae_loss"].append(avgTrainMSELoss.cpu().detach().numpy())
	H["valid_vae_loss"].append(avgValidMSELoss.cpu().detach().numpy())

	H["train_mse_loss"].append(avgTrainMSELoss.cpu().detach().numpy())
	H["valid_mse_loss"].append(avgValidMSELoss.cpu().detach().numpy())
 
	# print the model training and validation information
	print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
	print("Train Dice loss: {:.4f}, Vaild Dice loss: {:.4f}, Train MSE loss: {:.4f}, Vaild MSE loss: {:.4f},"
       " Train VAE loss: {:.4f}, Vaild VAE loss: {:.4f}, Train loss: {:.4f}, Vaild loss: {:.4f}"
       .format(avgTrainDiceLoss, avgValidDiceLoss, avgTrainMSELoss, avgValidMSELoss,
               avgTrainKLLoss, avgValidKLLoss, avgTrainLoss, avgValidLoss))

# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

[INFO] found 897 examples in the training set...
[INFO] found 303 examples in the validation set...
[INFO] training the network...


  0%|          | 1/200 [00:40<2:15:07, 40.74s/it]

[INFO] EPOCH: 1/200
Train Dice loss: 0.8376, Vaild Dice loss: 0.8677, Train MSE loss: 0.0238, Vaild MSE loss: 0.0331, Train VAE loss: 2.5020, Vaild VAE loss: 0.1514, Train loss: 3.3634, Vaild loss: 1.0523


  1%|          | 2/200 [01:22<2:16:36, 41.40s/it]

[INFO] EPOCH: 2/200
Train Dice loss: 0.7975, Vaild Dice loss: 0.8119, Train MSE loss: 0.0046, Vaild MSE loss: 0.0044, Train VAE loss: 0.2519, Vaild VAE loss: 0.0378, Train loss: 1.0539, Vaild loss: 0.8541


  2%|▏         | 3/200 [02:04<2:16:29, 41.57s/it]

[INFO] EPOCH: 3/200
Train Dice loss: 0.7603, Vaild Dice loss: 0.7679, Train MSE loss: 0.0015, Vaild MSE loss: 0.0010, Train VAE loss: 0.1075, Vaild VAE loss: 0.0455, Train loss: 0.8693, Vaild loss: 0.8144


  2%|▏         | 4/200 [02:45<2:15:47, 41.57s/it]

[INFO] EPOCH: 4/200
Train Dice loss: 0.7168, Vaild Dice loss: 0.7151, Train MSE loss: 0.0008, Vaild MSE loss: 0.0012, Train VAE loss: 0.0789, Vaild VAE loss: 0.0510, Train loss: 0.7966, Vaild loss: 0.7673


  2%|▎         | 5/200 [03:27<2:15:22, 41.66s/it]

[INFO] EPOCH: 5/200
Train Dice loss: 0.6645, Vaild Dice loss: 0.6475, Train MSE loss: 0.0010, Vaild MSE loss: 0.0010, Train VAE loss: 0.0487, Vaild VAE loss: 0.0245, Train loss: 0.7142, Vaild loss: 0.6730


  3%|▎         | 6/200 [04:09<2:15:10, 41.81s/it]

[INFO] EPOCH: 6/200
Train Dice loss: 0.6086, Vaild Dice loss: 0.6397, Train MSE loss: 0.0008, Vaild MSE loss: 0.0003, Train VAE loss: 0.0288, Vaild VAE loss: 0.0189, Train loss: 0.6382, Vaild loss: 0.6589


  4%|▎         | 7/200 [04:51<2:14:31, 41.82s/it]

[INFO] EPOCH: 7/200
Train Dice loss: 0.6087, Vaild Dice loss: 0.5662, Train MSE loss: 0.0008, Vaild MSE loss: 0.0007, Train VAE loss: 0.0202, Vaild VAE loss: 0.0143, Train loss: 0.6297, Vaild loss: 0.5812


  4%|▍         | 8/200 [05:33<2:13:37, 41.76s/it]

[INFO] EPOCH: 8/200
Train Dice loss: 0.5833, Vaild Dice loss: 0.6012, Train MSE loss: 0.0005, Vaild MSE loss: 0.0004, Train VAE loss: 0.0161, Vaild VAE loss: 0.0117, Train loss: 0.5999, Vaild loss: 0.6134


  4%|▍         | 9/200 [06:15<2:12:57, 41.77s/it]

[INFO] EPOCH: 9/200
Train Dice loss: 0.5332, Vaild Dice loss: 0.5972, Train MSE loss: 0.0005, Vaild MSE loss: 0.0004, Train VAE loss: 0.0150, Vaild VAE loss: 0.0080, Train loss: 0.5487, Vaild loss: 0.6056


  5%|▌         | 10/200 [06:56<2:12:16, 41.77s/it]

[INFO] EPOCH: 10/200
Train Dice loss: 0.5088, Vaild Dice loss: 0.6063, Train MSE loss: 0.0005, Vaild MSE loss: 0.0003, Train VAE loss: 0.0122, Vaild VAE loss: 0.0112, Train loss: 0.5216, Vaild loss: 0.6178


  6%|▌         | 11/200 [07:38<2:11:41, 41.81s/it]

[INFO] EPOCH: 11/200
Train Dice loss: 0.4906, Vaild Dice loss: 0.4774, Train MSE loss: 0.0003, Vaild MSE loss: 0.0005, Train VAE loss: 0.0109, Vaild VAE loss: 0.0122, Train loss: 0.5018, Vaild loss: 0.4901


  6%|▌         | 12/200 [08:20<2:11:00, 41.81s/it]

[INFO] EPOCH: 12/200
Train Dice loss: 0.4714, Vaild Dice loss: 0.4993, Train MSE loss: 0.0005, Vaild MSE loss: 0.0003, Train VAE loss: 0.0109, Vaild VAE loss: 0.0049, Train loss: 0.4828, Vaild loss: 0.5045


  6%|▋         | 13/200 [09:03<2:11:01, 42.04s/it]

[INFO] EPOCH: 13/200
Train Dice loss: 0.4799, Vaild Dice loss: 0.5002, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0087, Vaild VAE loss: 0.0058, Train loss: 0.4890, Vaild loss: 0.5063


  7%|▋         | 14/200 [09:44<2:10:05, 41.97s/it]

[INFO] EPOCH: 14/200
Train Dice loss: 0.4799, Vaild Dice loss: 0.8001, Train MSE loss: 0.0005, Vaild MSE loss: 0.0003, Train VAE loss: 0.0056, Vaild VAE loss: 0.0078, Train loss: 0.4860, Vaild loss: 0.8081


  8%|▊         | 15/200 [10:26<2:09:02, 41.85s/it]

[INFO] EPOCH: 15/200
Train Dice loss: 0.4641, Vaild Dice loss: 0.4881, Train MSE loss: 0.0005, Vaild MSE loss: 0.0003, Train VAE loss: 0.0065, Vaild VAE loss: 0.0031, Train loss: 0.4711, Vaild loss: 0.4914


  8%|▊         | 16/200 [11:08<2:08:14, 41.82s/it]

[INFO] EPOCH: 16/200
Train Dice loss: 0.4503, Vaild Dice loss: 0.5106, Train MSE loss: 0.0003, Vaild MSE loss: 0.0003, Train VAE loss: 0.0053, Vaild VAE loss: 0.0023, Train loss: 0.4560, Vaild loss: 0.5131


  8%|▊         | 17/200 [11:51<2:08:28, 42.12s/it]

[INFO] EPOCH: 17/200
Train Dice loss: 0.4289, Vaild Dice loss: 0.4694, Train MSE loss: 0.0004, Vaild MSE loss: 0.0002, Train VAE loss: 0.0046, Vaild VAE loss: 0.0041, Train loss: 0.4339, Vaild loss: 0.4737


  9%|▉         | 18/200 [12:32<2:07:17, 41.96s/it]

[INFO] EPOCH: 18/200
Train Dice loss: 0.4180, Vaild Dice loss: 0.5799, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0041, Vaild VAE loss: 0.0020, Train loss: 0.4225, Vaild loss: 0.5823


 10%|▉         | 19/200 [13:14<2:06:22, 41.89s/it]

[INFO] EPOCH: 19/200
Train Dice loss: 0.4045, Vaild Dice loss: 0.4760, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0027, Vaild VAE loss: 0.0035, Train loss: 0.4075, Vaild loss: 0.4797


 10%|█         | 20/200 [13:56<2:05:56, 41.98s/it]

[INFO] EPOCH: 20/200
Train Dice loss: 0.3996, Vaild Dice loss: 0.7695, Train MSE loss: 0.0004, Vaild MSE loss: 0.0002, Train VAE loss: 0.0037, Vaild VAE loss: 0.0098, Train loss: 0.4037, Vaild loss: 0.7794


 10%|█         | 21/200 [14:38<2:04:58, 41.89s/it]

[INFO] EPOCH: 21/200
Train Dice loss: 0.4022, Vaild Dice loss: 0.4103, Train MSE loss: 0.0005, Vaild MSE loss: 0.0002, Train VAE loss: 0.0060, Vaild VAE loss: 0.0038, Train loss: 0.4087, Vaild loss: 0.4142


 11%|█         | 22/200 [15:20<2:04:52, 42.09s/it]

[INFO] EPOCH: 22/200
Train Dice loss: 0.3708, Vaild Dice loss: 0.5169, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0036, Vaild VAE loss: 0.0043, Train loss: 0.3748, Vaild loss: 0.5213


 12%|█▏        | 23/200 [16:02<2:03:47, 41.96s/it]

[INFO] EPOCH: 23/200
Train Dice loss: 0.3973, Vaild Dice loss: 0.4785, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0046, Vaild VAE loss: 0.0050, Train loss: 0.4022, Vaild loss: 0.4837


 12%|█▏        | 24/200 [16:44<2:03:03, 41.95s/it]

[INFO] EPOCH: 24/200
Train Dice loss: 0.3867, Vaild Dice loss: 0.6007, Train MSE loss: 0.0003, Vaild MSE loss: 0.0008, Train VAE loss: 0.0053, Vaild VAE loss: 0.0047, Train loss: 0.3923, Vaild loss: 0.6062


 12%|█▎        | 25/200 [17:27<2:03:17, 42.27s/it]

[INFO] EPOCH: 25/200
Train Dice loss: 0.3738, Vaild Dice loss: 0.4128, Train MSE loss: 0.0005, Vaild MSE loss: 0.0018, Train VAE loss: 0.0045, Vaild VAE loss: 0.0040, Train loss: 0.3788, Vaild loss: 0.4186


 13%|█▎        | 26/200 [18:09<2:01:58, 42.06s/it]

[INFO] EPOCH: 26/200
Train Dice loss: 0.3739, Vaild Dice loss: 0.3678, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0034, Vaild VAE loss: 0.0030, Train loss: 0.3777, Vaild loss: 0.3711


 14%|█▎        | 27/200 [18:51<2:01:18, 42.07s/it]

[INFO] EPOCH: 27/200
Train Dice loss: 0.3774, Vaild Dice loss: 0.3928, Train MSE loss: 0.0003, Vaild MSE loss: 0.0006, Train VAE loss: 0.0031, Vaild VAE loss: 0.0031, Train loss: 0.3808, Vaild loss: 0.3965


 14%|█▍        | 28/200 [19:32<2:00:15, 41.95s/it]

[INFO] EPOCH: 28/200
Train Dice loss: 0.3631, Vaild Dice loss: 0.3509, Train MSE loss: 0.0005, Vaild MSE loss: 0.0004, Train VAE loss: 0.0051, Vaild VAE loss: 0.0043, Train loss: 0.3687, Vaild loss: 0.3556


 14%|█▍        | 29/200 [20:14<1:59:27, 41.91s/it]

[INFO] EPOCH: 29/200
Train Dice loss: 0.3570, Vaild Dice loss: 0.3577, Train MSE loss: 0.0005, Vaild MSE loss: 0.0003, Train VAE loss: 0.0058, Vaild VAE loss: 0.0062, Train loss: 0.3633, Vaild loss: 0.3642


 15%|█▌        | 30/200 [20:56<1:58:50, 41.95s/it]

[INFO] EPOCH: 30/200
Train Dice loss: 0.3320, Vaild Dice loss: 0.3646, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0052, Vaild VAE loss: 0.0027, Train loss: 0.3376, Vaild loss: 0.3674


 16%|█▌        | 31/200 [21:38<1:57:52, 41.85s/it]

[INFO] EPOCH: 31/200
Train Dice loss: 0.3219, Vaild Dice loss: 0.3423, Train MSE loss: 0.0005, Vaild MSE loss: 0.0008, Train VAE loss: 0.0026, Vaild VAE loss: 0.0021, Train loss: 0.3250, Vaild loss: 0.3452


 16%|█▌        | 32/200 [22:19<1:57:01, 41.80s/it]

[INFO] EPOCH: 32/200
Train Dice loss: 0.3243, Vaild Dice loss: 0.3647, Train MSE loss: 0.0003, Vaild MSE loss: 0.0007, Train VAE loss: 0.0022, Vaild VAE loss: 0.0012, Train loss: 0.3268, Vaild loss: 0.3666


 16%|█▋        | 33/200 [23:01<1:56:19, 41.79s/it]

[INFO] EPOCH: 33/200
Train Dice loss: 0.3453, Vaild Dice loss: 0.4477, Train MSE loss: 0.0004, Vaild MSE loss: 0.0004, Train VAE loss: 0.0018, Vaild VAE loss: 0.0023, Train loss: 0.3474, Vaild loss: 0.4504


 17%|█▋        | 34/200 [23:43<1:55:46, 41.84s/it]

[INFO] EPOCH: 34/200
Train Dice loss: 0.3156, Vaild Dice loss: 0.6474, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0024, Vaild VAE loss: 0.0023, Train loss: 0.3183, Vaild loss: 0.6500


 18%|█▊        | 35/200 [24:25<1:54:47, 41.75s/it]

[INFO] EPOCH: 35/200
Train Dice loss: 0.3262, Vaild Dice loss: 0.3748, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0020, Vaild VAE loss: 0.0022, Train loss: 0.3285, Vaild loss: 0.3772


 18%|█▊        | 36/200 [25:07<1:54:31, 41.90s/it]

[INFO] EPOCH: 36/200
Train Dice loss: 0.3068, Vaild Dice loss: 0.3417, Train MSE loss: 0.0004, Vaild MSE loss: 0.0004, Train VAE loss: 0.0018, Vaild VAE loss: 0.0023, Train loss: 0.3089, Vaild loss: 0.3443


 18%|█▊        | 37/200 [25:50<1:54:26, 42.12s/it]

[INFO] EPOCH: 37/200
Train Dice loss: 0.3089, Vaild Dice loss: 0.4891, Train MSE loss: 0.0004, Vaild MSE loss: 0.0004, Train VAE loss: 0.0022, Vaild VAE loss: 0.0023, Train loss: 0.3114, Vaild loss: 0.4918


 19%|█▉        | 38/200 [26:31<1:53:18, 41.97s/it]

[INFO] EPOCH: 38/200
Train Dice loss: 0.2988, Vaild Dice loss: 0.3239, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0056, Vaild VAE loss: 0.0023, Train loss: 0.3048, Vaild loss: 0.3263


 20%|█▉        | 39/200 [27:13<1:52:28, 41.91s/it]

[INFO] EPOCH: 39/200
Train Dice loss: 0.3101, Vaild Dice loss: 0.4322, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0059, Vaild VAE loss: 0.0060, Train loss: 0.3164, Vaild loss: 0.4384


 20%|██        | 40/200 [27:55<1:51:40, 41.88s/it]

[INFO] EPOCH: 40/200
Train Dice loss: 0.3134, Vaild Dice loss: 0.3926, Train MSE loss: 0.0003, Vaild MSE loss: 0.0004, Train VAE loss: 0.0094, Vaild VAE loss: 0.0063, Train loss: 0.3231, Vaild loss: 0.3994


 20%|██        | 41/200 [28:37<1:50:55, 41.86s/it]

[INFO] EPOCH: 41/200
Train Dice loss: 0.3020, Vaild Dice loss: 0.4290, Train MSE loss: 0.0003, Vaild MSE loss: 0.0004, Train VAE loss: 0.0041, Vaild VAE loss: 0.0026, Train loss: 0.3064, Vaild loss: 0.4320


 21%|██        | 42/200 [29:19<1:50:28, 41.95s/it]

[INFO] EPOCH: 42/200
Train Dice loss: 0.2654, Vaild Dice loss: 0.3641, Train MSE loss: 0.0004, Vaild MSE loss: 0.0009, Train VAE loss: 0.0033, Vaild VAE loss: 0.0033, Train loss: 0.2691, Vaild loss: 0.3683


 22%|██▏       | 43/200 [30:00<1:49:28, 41.84s/it]

[INFO] EPOCH: 43/200
Train Dice loss: 0.2728, Vaild Dice loss: 0.3858, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0045, Vaild VAE loss: 0.0025, Train loss: 0.2777, Vaild loss: 0.3888


 22%|██▏       | 44/200 [30:42<1:48:39, 41.79s/it]

[INFO] EPOCH: 44/200
Train Dice loss: 0.2760, Vaild Dice loss: 0.3081, Train MSE loss: 0.0004, Vaild MSE loss: 0.0002, Train VAE loss: 0.0033, Vaild VAE loss: 0.0016, Train loss: 0.2797, Vaild loss: 0.3099


 22%|██▎       | 45/200 [31:24<1:47:56, 41.79s/it]

[INFO] EPOCH: 45/200
Train Dice loss: 0.2661, Vaild Dice loss: 0.3418, Train MSE loss: 0.0005, Vaild MSE loss: 0.0002, Train VAE loss: 0.0052, Vaild VAE loss: 0.0069, Train loss: 0.2717, Vaild loss: 0.3489


 23%|██▎       | 46/200 [32:06<1:47:32, 41.90s/it]

[INFO] EPOCH: 46/200
Train Dice loss: 0.2704, Vaild Dice loss: 0.3632, Train MSE loss: 0.0004, Vaild MSE loss: 0.0002, Train VAE loss: 0.0105, Vaild VAE loss: 0.0030, Train loss: 0.2812, Vaild loss: 0.3665


 24%|██▎       | 47/200 [32:48<1:46:41, 41.84s/it]

[INFO] EPOCH: 47/200
Train Dice loss: 0.2878, Vaild Dice loss: 0.3188, Train MSE loss: 0.0003, Vaild MSE loss: 0.0003, Train VAE loss: 0.0134, Vaild VAE loss: 0.0057, Train loss: 0.3015, Vaild loss: 0.3248


 24%|██▍       | 48/200 [33:29<1:45:52, 41.80s/it]

[INFO] EPOCH: 48/200
Train Dice loss: 0.2532, Vaild Dice loss: 0.2900, Train MSE loss: 0.0004, Vaild MSE loss: 0.0004, Train VAE loss: 0.0089, Vaild VAE loss: 0.0028, Train loss: 0.2625, Vaild loss: 0.2932


 24%|██▍       | 49/200 [34:11<1:45:00, 41.73s/it]

[INFO] EPOCH: 49/200
Train Dice loss: 0.2568, Vaild Dice loss: 0.2619, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0044, Vaild VAE loss: 0.0062, Train loss: 0.2616, Vaild loss: 0.2683


 25%|██▌       | 50/200 [34:53<1:44:17, 41.72s/it]

[INFO] EPOCH: 50/200
Train Dice loss: 0.2708, Vaild Dice loss: 0.4164, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0082, Vaild VAE loss: 0.0063, Train loss: 0.2794, Vaild loss: 0.4232


 26%|██▌       | 51/200 [35:34<1:43:38, 41.73s/it]

[INFO] EPOCH: 51/200
Train Dice loss: 0.2515, Vaild Dice loss: 0.2888, Train MSE loss: 0.0004, Vaild MSE loss: 0.0006, Train VAE loss: 0.0023, Vaild VAE loss: 0.0008, Train loss: 0.2542, Vaild loss: 0.2902


 26%|██▌       | 52/200 [36:16<1:43:08, 41.81s/it]

[INFO] EPOCH: 52/200
Train Dice loss: 0.2400, Vaild Dice loss: 0.2616, Train MSE loss: 0.0004, Vaild MSE loss: 0.0004, Train VAE loss: 0.0021, Vaild VAE loss: 0.0011, Train loss: 0.2426, Vaild loss: 0.2631


 26%|██▋       | 53/200 [36:58<1:42:16, 41.74s/it]

[INFO] EPOCH: 53/200
Train Dice loss: 0.2354, Vaild Dice loss: 0.3857, Train MSE loss: 0.0002, Vaild MSE loss: 0.0002, Train VAE loss: 0.0009, Vaild VAE loss: 0.0008, Train loss: 0.2365, Vaild loss: 0.3867


 27%|██▋       | 54/200 [37:40<1:41:26, 41.69s/it]

[INFO] EPOCH: 54/200
Train Dice loss: 0.2454, Vaild Dice loss: 0.2769, Train MSE loss: 0.0003, Vaild MSE loss: 0.0001, Train VAE loss: 0.0011, Vaild VAE loss: 0.0007, Train loss: 0.2468, Vaild loss: 0.2777


 28%|██▊       | 55/200 [38:21<1:40:53, 41.75s/it]

[INFO] EPOCH: 55/200
Train Dice loss: 0.2237, Vaild Dice loss: 0.4382, Train MSE loss: 0.0003, Vaild MSE loss: 0.0004, Train VAE loss: 0.0012, Vaild VAE loss: 0.0009, Train loss: 0.2252, Vaild loss: 0.4395


 28%|██▊       | 56/200 [39:04<1:40:28, 41.87s/it]

[INFO] EPOCH: 56/200
Train Dice loss: 0.2495, Vaild Dice loss: 0.2889, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0012, Vaild VAE loss: 0.0012, Train loss: 0.2511, Vaild loss: 0.2902


 28%|██▊       | 57/200 [39:45<1:39:44, 41.85s/it]

[INFO] EPOCH: 57/200
Train Dice loss: 0.2380, Vaild Dice loss: 0.3379, Train MSE loss: 0.0003, Vaild MSE loss: 0.0005, Train VAE loss: 0.0026, Vaild VAE loss: 0.0064, Train loss: 0.2409, Vaild loss: 0.3448


 29%|██▉       | 58/200 [40:27<1:38:58, 41.82s/it]

[INFO] EPOCH: 58/200
Train Dice loss: 0.2278, Vaild Dice loss: 0.2651, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0037, Vaild VAE loss: 0.0049, Train loss: 0.2318, Vaild loss: 0.2703


 30%|██▉       | 59/200 [41:09<1:38:15, 41.81s/it]

[INFO] EPOCH: 59/200
Train Dice loss: 0.2242, Vaild Dice loss: 0.2752, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0049, Vaild VAE loss: 0.0031, Train loss: 0.2296, Vaild loss: 0.2785


 30%|███       | 60/200 [41:51<1:37:27, 41.77s/it]

[INFO] EPOCH: 60/200
Train Dice loss: 0.2222, Vaild Dice loss: 0.3065, Train MSE loss: 0.0003, Vaild MSE loss: 0.0007, Train VAE loss: 0.0029, Vaild VAE loss: 0.0025, Train loss: 0.2254, Vaild loss: 0.3096


 30%|███       | 61/200 [42:32<1:36:47, 41.78s/it]

[INFO] EPOCH: 61/200
Train Dice loss: 0.2044, Vaild Dice loss: 0.2708, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0021, Vaild VAE loss: 0.0014, Train loss: 0.2069, Vaild loss: 0.2728


 31%|███       | 62/200 [43:14<1:36:05, 41.78s/it]

[INFO] EPOCH: 62/200
Train Dice loss: 0.2148, Vaild Dice loss: 0.2910, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0012, Vaild VAE loss: 0.0012, Train loss: 0.2162, Vaild loss: 0.2924


 32%|███▏      | 63/200 [43:56<1:35:17, 41.74s/it]

[INFO] EPOCH: 63/200
Train Dice loss: 0.1977, Vaild Dice loss: 0.2454, Train MSE loss: 0.0004, Vaild MSE loss: 0.0008, Train VAE loss: 0.0011, Vaild VAE loss: 0.0014, Train loss: 0.1992, Vaild loss: 0.2476


 32%|███▏      | 64/200 [44:37<1:34:30, 41.70s/it]

[INFO] EPOCH: 64/200
Train Dice loss: 0.2222, Vaild Dice loss: 0.3122, Train MSE loss: 0.0005, Vaild MSE loss: 0.0004, Train VAE loss: 0.0023, Vaild VAE loss: 0.0014, Train loss: 0.2250, Vaild loss: 0.3140


 32%|███▎      | 65/200 [45:19<1:33:49, 41.70s/it]

[INFO] EPOCH: 65/200
Train Dice loss: 0.2224, Vaild Dice loss: 0.2660, Train MSE loss: 0.0004, Vaild MSE loss: 0.0013, Train VAE loss: 0.0015, Vaild VAE loss: 0.0009, Train loss: 0.2243, Vaild loss: 0.2681


 33%|███▎      | 66/200 [46:01<1:33:04, 41.68s/it]

[INFO] EPOCH: 66/200
Train Dice loss: 0.1993, Vaild Dice loss: 0.3349, Train MSE loss: 0.0004, Vaild MSE loss: 0.0008, Train VAE loss: 0.0023, Vaild VAE loss: 0.0011, Train loss: 0.2021, Vaild loss: 0.3368


 34%|███▎      | 67/200 [46:42<1:32:21, 41.66s/it]

[INFO] EPOCH: 67/200
Train Dice loss: 0.1989, Vaild Dice loss: 0.3160, Train MSE loss: 0.0005, Vaild MSE loss: 0.0007, Train VAE loss: 0.0061, Vaild VAE loss: 0.0068, Train loss: 0.2055, Vaild loss: 0.3235


 34%|███▍      | 68/200 [47:24<1:31:40, 41.67s/it]

[INFO] EPOCH: 68/200
Train Dice loss: 0.1965, Vaild Dice loss: 0.3004, Train MSE loss: 0.0005, Vaild MSE loss: 0.0003, Train VAE loss: 0.0049, Vaild VAE loss: 0.0040, Train loss: 0.2018, Vaild loss: 0.3048


 34%|███▍      | 69/200 [48:06<1:31:00, 41.68s/it]

[INFO] EPOCH: 69/200
Train Dice loss: 0.1905, Vaild Dice loss: 0.2658, Train MSE loss: 0.0004, Vaild MSE loss: 0.0007, Train VAE loss: 0.0056, Vaild VAE loss: 0.0049, Train loss: 0.1965, Vaild loss: 0.2714


 35%|███▌      | 70/200 [48:47<1:30:18, 41.68s/it]

[INFO] EPOCH: 70/200
Train Dice loss: 0.1907, Vaild Dice loss: 0.4836, Train MSE loss: 0.0003, Vaild MSE loss: 0.0006, Train VAE loss: 0.0044, Vaild VAE loss: 0.0031, Train loss: 0.1954, Vaild loss: 0.4873


 36%|███▌      | 71/200 [49:29<1:29:44, 41.74s/it]

[INFO] EPOCH: 71/200
Train Dice loss: 0.1776, Vaild Dice loss: 0.2316, Train MSE loss: 0.0003, Vaild MSE loss: 0.0010, Train VAE loss: 0.0077, Vaild VAE loss: 0.0197, Train loss: 0.1856, Vaild loss: 0.2523


 36%|███▌      | 72/200 [50:11<1:29:02, 41.74s/it]

[INFO] EPOCH: 72/200
Train Dice loss: 0.1843, Vaild Dice loss: 0.2515, Train MSE loss: 0.0004, Vaild MSE loss: 0.0002, Train VAE loss: 0.0194, Vaild VAE loss: 0.0103, Train loss: 0.2041, Vaild loss: 0.2620


 36%|███▋      | 73/200 [50:53<1:28:25, 41.77s/it]

[INFO] EPOCH: 73/200
Train Dice loss: 0.2011, Vaild Dice loss: 0.2577, Train MSE loss: 0.0004, Vaild MSE loss: 0.0027, Train VAE loss: 0.0081, Vaild VAE loss: 0.0029, Train loss: 0.2097, Vaild loss: 0.2633


 37%|███▋      | 74/200 [51:35<1:27:46, 41.80s/it]

[INFO] EPOCH: 74/200
Train Dice loss: 0.1861, Vaild Dice loss: 0.3601, Train MSE loss: 0.0005, Vaild MSE loss: 0.0001, Train VAE loss: 0.0082, Vaild VAE loss: 0.0024, Train loss: 0.1948, Vaild loss: 0.3626


 38%|███▊      | 75/200 [52:16<1:26:55, 41.73s/it]

[INFO] EPOCH: 75/200
Train Dice loss: 0.1888, Vaild Dice loss: 0.2576, Train MSE loss: 0.0003, Vaild MSE loss: 0.0007, Train VAE loss: 0.0058, Vaild VAE loss: 0.0065, Train loss: 0.1950, Vaild loss: 0.2649


 38%|███▊      | 76/200 [52:58<1:26:12, 41.71s/it]

[INFO] EPOCH: 76/200
Train Dice loss: 0.1848, Vaild Dice loss: 0.2417, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0080, Vaild VAE loss: 0.0044, Train loss: 0.1933, Vaild loss: 0.2466


 38%|███▊      | 77/200 [53:40<1:25:29, 41.70s/it]

[INFO] EPOCH: 77/200
Train Dice loss: 0.1761, Vaild Dice loss: 0.3383, Train MSE loss: 0.0004, Vaild MSE loss: 0.0004, Train VAE loss: 0.0095, Vaild VAE loss: 0.0029, Train loss: 0.1859, Vaild loss: 0.3416


 39%|███▉      | 78/200 [54:22<1:25:07, 41.86s/it]

[INFO] EPOCH: 78/200
Train Dice loss: 0.1691, Vaild Dice loss: 0.2821, Train MSE loss: 0.0002, Vaild MSE loss: 0.0003, Train VAE loss: 0.0039, Vaild VAE loss: 0.0025, Train loss: 0.1732, Vaild loss: 0.2849


 40%|███▉      | 79/200 [55:04<1:24:24, 41.86s/it]

[INFO] EPOCH: 79/200
Train Dice loss: 0.1728, Vaild Dice loss: 0.3033, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0022, Vaild VAE loss: 0.0010, Train loss: 0.1753, Vaild loss: 0.3046


 40%|████      | 80/200 [55:45<1:23:34, 41.78s/it]

[INFO] EPOCH: 80/200
Train Dice loss: 0.1893, Vaild Dice loss: 0.3520, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0026, Vaild VAE loss: 0.0007, Train loss: 0.1924, Vaild loss: 0.3532


 40%|████      | 81/200 [56:27<1:22:44, 41.72s/it]

[INFO] EPOCH: 81/200
Train Dice loss: 0.1803, Vaild Dice loss: 0.2633, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0017, Vaild VAE loss: 0.0015, Train loss: 0.1822, Vaild loss: 0.2650


 41%|████      | 82/200 [57:09<1:21:57, 41.68s/it]

[INFO] EPOCH: 82/200
Train Dice loss: 0.1888, Vaild Dice loss: 0.2648, Train MSE loss: 0.0005, Vaild MSE loss: 0.0005, Train VAE loss: 0.0055, Vaild VAE loss: 0.0045, Train loss: 0.1947, Vaild loss: 0.2698


 42%|████▏     | 83/200 [57:50<1:21:23, 41.74s/it]

[INFO] EPOCH: 83/200
Train Dice loss: 0.1946, Vaild Dice loss: 0.2689, Train MSE loss: 0.0003, Vaild MSE loss: 0.0005, Train VAE loss: 0.0036, Vaild VAE loss: 0.0013, Train loss: 0.1985, Vaild loss: 0.2707


 42%|████▏     | 84/200 [58:32<1:20:49, 41.81s/it]

[INFO] EPOCH: 84/200
Train Dice loss: 0.1688, Vaild Dice loss: 0.3005, Train MSE loss: 0.0005, Vaild MSE loss: 0.0002, Train VAE loss: 0.0016, Vaild VAE loss: 0.0008, Train loss: 0.1709, Vaild loss: 0.3015


 42%|████▎     | 85/200 [59:15<1:20:21, 41.92s/it]

[INFO] EPOCH: 85/200
Train Dice loss: 0.1575, Vaild Dice loss: 0.2780, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0020, Vaild VAE loss: 0.0024, Train loss: 0.1598, Vaild loss: 0.2806


 43%|████▎     | 86/200 [59:56<1:19:29, 41.84s/it]

[INFO] EPOCH: 86/200
Train Dice loss: 0.1742, Vaild Dice loss: 0.3651, Train MSE loss: 0.0003, Vaild MSE loss: 0.0003, Train VAE loss: 0.0023, Vaild VAE loss: 0.0030, Train loss: 0.1768, Vaild loss: 0.3684


 44%|████▎     | 87/200 [1:00:39<1:19:32, 42.23s/it]

[INFO] EPOCH: 87/200
Train Dice loss: 0.1547, Vaild Dice loss: 0.2441, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0027, Vaild VAE loss: 0.0032, Train loss: 0.1577, Vaild loss: 0.2474


 44%|████▍     | 88/200 [1:01:21<1:18:34, 42.09s/it]

[INFO] EPOCH: 88/200
Train Dice loss: 0.1484, Vaild Dice loss: 0.2708, Train MSE loss: 0.0002, Vaild MSE loss: 0.0002, Train VAE loss: 0.0060, Vaild VAE loss: 0.0019, Train loss: 0.1546, Vaild loss: 0.2729


 44%|████▍     | 89/200 [1:02:03<1:17:38, 41.97s/it]

[INFO] EPOCH: 89/200
Train Dice loss: 0.1435, Vaild Dice loss: 0.2723, Train MSE loss: 0.0002, Vaild MSE loss: 0.0003, Train VAE loss: 0.0059, Vaild VAE loss: 0.0488, Train loss: 0.1497, Vaild loss: 0.3214


 45%|████▌     | 90/200 [1:02:45<1:16:58, 41.99s/it]

[INFO] EPOCH: 90/200
Train Dice loss: 0.1493, Vaild Dice loss: 0.3112, Train MSE loss: 0.0003, Vaild MSE loss: 0.0005, Train VAE loss: 0.0191, Vaild VAE loss: 0.0083, Train loss: 0.1688, Vaild loss: 0.3200


 46%|████▌     | 91/200 [1:03:27<1:16:23, 42.05s/it]

[INFO] EPOCH: 91/200
Train Dice loss: 0.1530, Vaild Dice loss: 0.2621, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0059, Vaild VAE loss: 0.0049, Train loss: 0.1593, Vaild loss: 0.2673


 46%|████▌     | 92/200 [1:04:09<1:15:26, 41.91s/it]

[INFO] EPOCH: 92/200
Train Dice loss: 0.1414, Vaild Dice loss: 0.2451, Train MSE loss: 0.0003, Vaild MSE loss: 0.0005, Train VAE loss: 0.0030, Vaild VAE loss: 0.0010, Train loss: 0.1447, Vaild loss: 0.2466


 46%|████▋     | 93/200 [1:04:50<1:14:38, 41.85s/it]

[INFO] EPOCH: 93/200
Train Dice loss: 0.1450, Vaild Dice loss: 0.2720, Train MSE loss: 0.0003, Vaild MSE loss: 0.0004, Train VAE loss: 0.0032, Vaild VAE loss: 0.0019, Train loss: 0.1486, Vaild loss: 0.2743


 47%|████▋     | 94/200 [1:05:32<1:13:47, 41.77s/it]

[INFO] EPOCH: 94/200
Train Dice loss: 0.1518, Vaild Dice loss: 0.2788, Train MSE loss: 0.0004, Vaild MSE loss: 0.0005, Train VAE loss: 0.0015, Vaild VAE loss: 0.0021, Train loss: 0.1537, Vaild loss: 0.2814


 48%|████▊     | 95/200 [1:06:14<1:13:05, 41.76s/it]

[INFO] EPOCH: 95/200
Train Dice loss: 0.1451, Vaild Dice loss: 0.2682, Train MSE loss: 0.0003, Vaild MSE loss: 0.0003, Train VAE loss: 0.0016, Vaild VAE loss: 0.0029, Train loss: 0.1470, Vaild loss: 0.2714


 48%|████▊     | 96/200 [1:06:55<1:12:21, 41.75s/it]

[INFO] EPOCH: 96/200
Train Dice loss: 0.1329, Vaild Dice loss: 0.2442, Train MSE loss: 0.0004, Vaild MSE loss: 0.0006, Train VAE loss: 0.0027, Vaild VAE loss: 0.0017, Train loss: 0.1360, Vaild loss: 0.2465


 48%|████▊     | 97/200 [1:07:37<1:11:31, 41.67s/it]

[INFO] EPOCH: 97/200
Train Dice loss: 0.1408, Vaild Dice loss: 0.2264, Train MSE loss: 0.0003, Vaild MSE loss: 0.0010, Train VAE loss: 0.0035, Vaild VAE loss: 0.0033, Train loss: 0.1445, Vaild loss: 0.2307


 49%|████▉     | 98/200 [1:08:19<1:11:06, 41.83s/it]

[INFO] EPOCH: 98/200
Train Dice loss: 0.1304, Vaild Dice loss: 0.2296, Train MSE loss: 0.0003, Vaild MSE loss: 0.0018, Train VAE loss: 0.0027, Vaild VAE loss: 0.0020, Train loss: 0.1333, Vaild loss: 0.2334


 50%|████▉     | 99/200 [1:09:01<1:10:17, 41.76s/it]

[INFO] EPOCH: 99/200
Train Dice loss: 0.1236, Vaild Dice loss: 0.2289, Train MSE loss: 0.0003, Vaild MSE loss: 0.0008, Train VAE loss: 0.0015, Vaild VAE loss: 0.0011, Train loss: 0.1254, Vaild loss: 0.2308


 50%|█████     | 100/200 [1:09:42<1:09:34, 41.74s/it]

[INFO] EPOCH: 100/200
Train Dice loss: 0.1191, Vaild Dice loss: 0.2335, Train MSE loss: 0.0003, Vaild MSE loss: 0.0006, Train VAE loss: 0.0011, Vaild VAE loss: 0.0010, Train loss: 0.1205, Vaild loss: 0.2351


 50%|█████     | 101/200 [1:10:24<1:08:45, 41.67s/it]

[INFO] EPOCH: 101/200
Train Dice loss: 0.1165, Vaild Dice loss: 0.2443, Train MSE loss: 0.0004, Vaild MSE loss: 0.0001, Train VAE loss: 0.0021, Vaild VAE loss: 0.0009, Train loss: 0.1189, Vaild loss: 0.2453


 51%|█████     | 102/200 [1:11:06<1:08:03, 41.67s/it]

[INFO] EPOCH: 102/200
Train Dice loss: 0.1288, Vaild Dice loss: 0.2228, Train MSE loss: 0.0002, Vaild MSE loss: 0.0001, Train VAE loss: 0.0016, Vaild VAE loss: 0.0015, Train loss: 0.1306, Vaild loss: 0.2244


 52%|█████▏    | 103/200 [1:11:48<1:07:30, 41.75s/it]

[INFO] EPOCH: 103/200
Train Dice loss: 0.1400, Vaild Dice loss: 0.3182, Train MSE loss: 0.0003, Vaild MSE loss: 0.0008, Train VAE loss: 0.0028, Vaild VAE loss: 0.0023, Train loss: 0.1431, Vaild loss: 0.3213


 52%|█████▏    | 104/200 [1:12:29<1:06:41, 41.68s/it]

[INFO] EPOCH: 104/200
Train Dice loss: 0.1558, Vaild Dice loss: 0.3045, Train MSE loss: 0.0003, Vaild MSE loss: 0.0003, Train VAE loss: 0.0018, Vaild VAE loss: 0.0011, Train loss: 0.1580, Vaild loss: 0.3059


 52%|█████▎    | 105/200 [1:13:11<1:05:58, 41.67s/it]

[INFO] EPOCH: 105/200
Train Dice loss: 0.1264, Vaild Dice loss: 0.2219, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0055, Vaild VAE loss: 0.0119, Train loss: 0.1321, Vaild loss: 0.2340


 53%|█████▎    | 106/200 [1:13:52<1:05:11, 41.61s/it]

[INFO] EPOCH: 106/200
Train Dice loss: 0.1351, Vaild Dice loss: 0.2369, Train MSE loss: 0.0004, Vaild MSE loss: 0.0030, Train VAE loss: 0.0034, Vaild VAE loss: 0.0007, Train loss: 0.1389, Vaild loss: 0.2406


 54%|█████▎    | 107/200 [1:14:34<1:04:30, 41.61s/it]

[INFO] EPOCH: 107/200
Train Dice loss: 0.1402, Vaild Dice loss: 0.3001, Train MSE loss: 0.0004, Vaild MSE loss: 0.0008, Train VAE loss: 0.0021, Vaild VAE loss: 0.0006, Train loss: 0.1427, Vaild loss: 0.3015


 54%|█████▍    | 108/200 [1:15:15<1:03:45, 41.58s/it]

[INFO] EPOCH: 108/200
Train Dice loss: 0.1251, Vaild Dice loss: 0.2811, Train MSE loss: 0.0004, Vaild MSE loss: 0.0003, Train VAE loss: 0.0007, Vaild VAE loss: 0.0006, Train loss: 0.1261, Vaild loss: 0.2821


 55%|█████▍    | 109/200 [1:15:57<1:03:00, 41.55s/it]

[INFO] EPOCH: 109/200
Train Dice loss: 0.1190, Vaild Dice loss: 0.2526, Train MSE loss: 0.0003, Vaild MSE loss: 0.0006, Train VAE loss: 0.0006, Vaild VAE loss: 0.0006, Train loss: 0.1199, Vaild loss: 0.2537


 55%|█████▌    | 110/200 [1:16:38<1:02:19, 41.55s/it]

[INFO] EPOCH: 110/200
Train Dice loss: 0.1245, Vaild Dice loss: 0.2322, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0016, Vaild VAE loss: 0.0020, Train loss: 0.1264, Vaild loss: 0.2345


 56%|█████▌    | 111/200 [1:17:20<1:01:39, 41.57s/it]

[INFO] EPOCH: 111/200
Train Dice loss: 0.1170, Vaild Dice loss: 0.2343, Train MSE loss: 0.0003, Vaild MSE loss: 0.0005, Train VAE loss: 0.0010, Vaild VAE loss: 0.0008, Train loss: 0.1183, Vaild loss: 0.2356


 56%|█████▌    | 112/200 [1:18:01<1:00:54, 41.53s/it]

[INFO] EPOCH: 112/200
Train Dice loss: 0.1026, Vaild Dice loss: 0.2255, Train MSE loss: 0.0002, Vaild MSE loss: 0.0005, Train VAE loss: 0.0008, Vaild VAE loss: 0.0031, Train loss: 0.1036, Vaild loss: 0.2291


 56%|█████▋    | 113/200 [1:18:43<1:00:17, 41.58s/it]

[INFO] EPOCH: 113/200
Train Dice loss: 0.0946, Vaild Dice loss: 0.2070, Train MSE loss: 0.0003, Vaild MSE loss: 0.0002, Train VAE loss: 0.0013, Vaild VAE loss: 0.0013, Train loss: 0.0962, Vaild loss: 0.2085


 57%|█████▋    | 114/200 [1:19:25<59:43, 41.67s/it]  

[INFO] EPOCH: 114/200
Train Dice loss: 0.0982, Vaild Dice loss: 0.2135, Train MSE loss: 0.0002, Vaild MSE loss: 0.0006, Train VAE loss: 0.0013, Vaild VAE loss: 0.0008, Train loss: 0.0997, Vaild loss: 0.2149


 57%|█████▊    | 115/200 [1:20:07<59:04, 41.69s/it]

[INFO] EPOCH: 115/200
Train Dice loss: 0.0956, Vaild Dice loss: 0.1936, Train MSE loss: 0.0002, Vaild MSE loss: 0.0002, Train VAE loss: 0.0010, Vaild VAE loss: 0.0027, Train loss: 0.0968, Vaild loss: 0.1966


 58%|█████▊    | 116/200 [1:20:49<58:27, 41.75s/it]

[INFO] EPOCH: 116/200
Train Dice loss: 0.0939, Vaild Dice loss: 0.2161, Train MSE loss: 0.0003, Vaild MSE loss: 0.0001, Train VAE loss: 0.0028, Vaild VAE loss: 0.0100, Train loss: 0.0971, Vaild loss: 0.2262


 58%|█████▊    | 117/200 [1:21:30<57:39, 41.69s/it]

[INFO] EPOCH: 117/200
Train Dice loss: 0.0945, Vaild Dice loss: 0.2173, Train MSE loss: 0.0002, Vaild MSE loss: 0.0001, Train VAE loss: 0.0030, Vaild VAE loss: 0.0048, Train loss: 0.0977, Vaild loss: 0.2222


 59%|█████▉    | 118/200 [1:22:12<56:55, 41.65s/it]

[INFO] EPOCH: 118/200
Train Dice loss: 0.0933, Vaild Dice loss: 0.2442, Train MSE loss: 0.0002, Vaild MSE loss: 0.0005, Train VAE loss: 0.0011, Vaild VAE loss: 0.0009, Train loss: 0.0946, Vaild loss: 0.2456


 60%|█████▉    | 119/200 [1:22:53<56:09, 41.60s/it]

[INFO] EPOCH: 119/200
Train Dice loss: 0.0924, Vaild Dice loss: 0.2400, Train MSE loss: 0.0002, Vaild MSE loss: 0.0002, Train VAE loss: 0.0014, Vaild VAE loss: 0.0013, Train loss: 0.0941, Vaild loss: 0.2414


 60%|██████    | 120/200 [1:23:35<55:35, 41.70s/it]

[INFO] EPOCH: 120/200
Train Dice loss: 0.1284, Vaild Dice loss: 0.3241, Train MSE loss: 0.0003, Vaild MSE loss: 0.0007, Train VAE loss: 0.0015, Vaild VAE loss: 0.0008, Train loss: 0.1302, Vaild loss: 0.3256


 60%|██████    | 120/200 [1:23:46<55:50, 41.88s/it]


KeyboardInterrupt: 

In [None]:
import random
random.random()

In [None]:
q = torch.from_numpy(np.asarray([[[[random.random() for i in range(32)] for i in range(32)] for i in range(3)] for i in range(12)],
                                dtype=np.float32))
q.shape

In [None]:
max1 = nn.MaxPool2d(2)
conv1 = nn.Conv2d(3, 64, kernel_size= 3,padding=1)
max2 = nn.MaxPool2d(2)
conv2 = nn.Conv2d(64, 128, kernel_size= 3,padding=1)
max3 = nn.MaxPool2d(2)
conv3 = nn.Conv2d(128, 256, kernel_size= 3,padding=1)
max4 = nn.MaxPool2d(2)
conv4 = nn.Conv2d(256, 512, kernel_size= 3,padding=1)

conv5 = nn.Conv2d(512,1024, kernel_size=3, padding=1)

linear1 = nn.Linear(1024*16*16, 512)
# conv1(val1).shape
op = conv5(max4(conv4(max3(conv3(max2(conv2(max1(conv1(val1)))))))))

_conv1 = nn.Conv2d(3, 8, 3, stride=2, padding=1)
_conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1)
_batch2 = nn.BatchNorm2d(16)
_conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=0)  
_linear1 = nn.Linear(3*3*32, 128)


In [None]:
_conv3(_conv2(_conv1(q))).shape

In [None]:
op.shape

In [None]:
torch.flatten(op, start_dim=1).shape

In [None]:
y = linear1(torch.flatten(op, start_dim=1))
y.shape

In [None]:
z = nn.Linear(128, 31 * 31 * 32)(y)
z.shape

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["valid_loss"], label="valid_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")

plt.savefig(PLOT_PATH)
# serialize the model to disk
torch.save(unet, MODEL_PATH)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2
import os
def prepare_plot(origImage, origMask, predMask):
	# initialize our figure
	figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
	# plot the original image, its mask, and the predicted mask
	ax[0].imshow(origImage)
	ax[1].imshow(origMask)
	ax[2].imshow(predMask)
	# set the titles of the subplots
	ax[0].set_title("Image")
	ax[1].set_title("Original Mask")
	ax[2].set_title("Predicted Mask")
	# set the layout of the figure and display it
	figure.tight_layout()
	figure.show()

In [None]:
def make_predictions(model, X, y):
	# set model to evaluation mode
	model.eval()
	# turn off gradient tracking
	with torch.no_grad():
		# load the image from disk, swap its color channels, cast it
		# to float data type, and scale its pixel values
		orig = X.copy()
		# find the filename and generate the path to ground truth
  
		# load the ground-truth segmentation mask in grayscale mode
		# and resize it
		gtMask = y.copy()
  
  
  		# make the channel axis to be the leading one, add a batch
		# dimension, create a PyTorch tensor, and flash it to the
		# current device
		image = np.transpose(X, (2, 0, 1))
		image = np.expand_dims(image, 0)
		image = torch.from_numpy(image).to(DEVICE)
  
		# make the prediction, pass the results through the sigmoid
		# function, and convert the result to a NumPy array
		predMask = model(image)[0].squeeze()
		predMask = predMask.cpu().numpy()
  
		# filter out the weak predictions and convert them to integers
		# predMask = (predMask > THRESHOLD) * 255
		predMask = predMask.astype(np.uint8)

		# prepare a plot for visualization
		prepare_plot(orig, gtMask, predMask)

In [None]:
# load our model from disk and flash it to the current device
print("[INFO] load up model...")
unet = torch.load(MODEL_PATH).to(DEVICE)

# iterate over the randomly selected test image paths
for i in range(10):
	# make predictions and visualize the results
	make_predictions(unet, X_train_npy[i], y_train_npy[i])

In [None]:
val1 = next(iter(trainLoader))[0]
val1.shape

In [None]:
hidden_dims = [32, 64, 128, 256, 512]
in_channels = 3
modules = []
for h_dim in hidden_dims:
    modules.append(
        nn.Sequential(
            nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding  = 1),
            nn.BatchNorm2d(h_dim),
            nn.LeakyReLU())
    )
    in_channels = h_dim

encoder = nn.Sequential(*modules)

In [None]:
q = torch.from_numpy(np.asarray([[[[random.random() for i in range(64)] for i in range(64)] for i in range(3)] for i in range(12)],
                                dtype=np.float32))
q.shape

In [None]:
result = encoder(q)
result.shape

In [None]:
x1 = torch.flatten(result, start_dim=1)
x1.shape

In [None]:
512*5

In [None]:
fc_mu = nn.Linear(hidden_dims[-1]*4, 128)

In [None]:
fc_mu(x1)