In [1]:
# import the necessary packages
import torch
import os
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import cv2
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import time

In [2]:
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

In [3]:
# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3

# initialize learning rate, number of epochs to train for, and the
# batch size
INIT_LR = 0.001
NUM_EPOCHS = 200
BATCH_SIZE = 32

# define the input image dimensions
INPUT_IMAGE_WIDTH = 256
INPUT_IMAGE_HEIGHT = 256

# define threshold to filter weak predictions
THRESHOLD = 0.5

# define the path to the base output directory
BASE_OUTPUT = "output"

# define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = os.path.join(BASE_OUTPUT, "unet.pth")
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])

In [4]:
subdir = ["benign_image/", "benign_mask/", "malignant_image/",
          "malignant_mask/", "normal_image/", "normal_mask/"]
dataset = ['dataset/','train/', 'test/', 'validation/']

train_test_valid = [[[], [], []], [[], [], []], [[], [], []]]

for i in range(1, len(dataset)):
    for j in range(3):
    # for j in range(2):
        for k in range(len(os.listdir(dataset[0]+dataset[i]+subdir[j*2]))):
            train_test_valid[i-1][0].append(plt.imread(
                dataset[0]+dataset[i]+subdir[j*2]+str(k)+".jpeg"))
            train_test_valid[i-1][1].append(plt.imread(
                dataset[0]+dataset[i]+subdir[j*2+1]+str(k)+".jpeg"))
            train_test_valid[i-1][2].append(j)

X_train_npy = np.asarray(train_test_valid[0][0], dtype=np.float32)/255
y_train_npy = np.asarray(train_test_valid[0][1], dtype=np.float32)/255

X_test_npy = np.asarray(train_test_valid[1][0], dtype=np.float32)/255
y_test_npy = np.asarray(train_test_valid[1][1], dtype=np.float32)/255

X_valid_npy = np.asarray(train_test_valid[2][0], dtype=np.float32)/255
y_valid_npy = np.asarray(train_test_valid[2][1], dtype=np.float32)/255

In [5]:
class SegmentationDataset(Dataset):

	def __init__(self, X, y, transforms):
		self.X = X
		self.y = y
		self.transforms = transforms

	def __len__(self):
		return len(self.y)

	def __getitem__(self, idx):
		
		image = self.X[idx]
		mask = self.y[idx]
  
		if self.transforms is not None:
			image = self.transforms(image)
			mask = self.transforms(mask)
   
		# return a tuple of the image and its mask
		return (image, mask)

In [6]:
# import the necessary packages
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch.nn as nn

In [7]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        if not mid_channels:
            mid_channels = out_channels
            
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Encoder(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.max = nn.MaxPool2d(2)
        self.conv_block = DoubleConv(in_channels, out_channels)
        # self.dropout = nn.Dropout2d(p=0.3)

    def forward(self, x):
        
        conv = self.conv_block(x)
        pool = self.max(conv)
        # drop = self.dropout(pool)
        
        # return conv, drop
        return conv, pool

In [8]:
class DecoderB1(nn.Module):
    
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = DoubleConv(mid_channels, out_channels)
        # self.dropout = nn.Dropout2d(p=0.3)
    
    def forward(self, x, skip_features, skip_features_b2):

        x = self.conv_transpose(x)
        x = torch.cat([x, skip_features, skip_features_b2],dim=1)
        # x = self.dropout(x)
        x = self.conv_block(x)

        return x

class DecoderB2(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = DoubleConv(in_channels, out_channels)
        # self.dropout = nn.Dropout2d(p=0.3)
    
    def forward(self, x, skip_features):

        conv_transpose = self.conv_transpose(x)
        x = torch.cat([conv_transpose, skip_features],dim=1)
        # x = self.dropout(x)
        x = self.conv_block(x)

        return x, conv_transpose

In [9]:
class Branch1(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        
        self.conv_block = DoubleConv(in_channels, mid_channels)
        self.conv_2d = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv_block(x)
        x = self.conv_2d(x)
        x = self.sigmoid(x)
        
        return x

class Branch2(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        
        self.conv_block = DoubleConv(in_channels, mid_channels)
        self.conv_2d = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv_block(x)
        x = self.conv_2d(x)
        x = self.sigmoid(x)
        
        return x

In [10]:
class UNet(nn.Module):
    def __init__(self, input_shape=256, n_input_channels=3, n_output_channels_b1=1, n_output_channels_b2=3, n_features=64, latent_dim=128):
        super(UNet, self).__init__()
        self.input_shape = input_shape
        self.n_features = n_features
        
        self.down1 = Encoder(n_input_channels, n_features)
        self.down2 = Encoder(n_features, n_features*2)
        self.down3 = Encoder(n_features*2, n_features*4)
        self.down4 = Encoder(n_features*4, n_features*8)
        
        self.bridge = DoubleConv(n_features*8, n_features*16)
        
        self.upb1_1 = DecoderB1(n_features*16,n_features*24, n_features*8)
        self.upb1_2 = DecoderB1(n_features*8,n_features*12, n_features*4)
        self.upb1_3 = DecoderB1(n_features*4,n_features*6, n_features*2)
        self.upb1_4 = DecoderB1(n_features*2,n_features*3, n_features)
        
        flatten_size = n_features*16 * int(input_shape/16) * int(input_shape/16)
        
        self.intermediate1 = nn.Sequential(
            nn.Linear(flatten_size, latent_dim*2),
            nn.BatchNorm1d(latent_dim*2),
            nn.ReLU(inplace=True),
        )
        
        self.fc_mu = nn.Linear(latent_dim*2, latent_dim)
        self.fc_var = nn.Linear(latent_dim*2, latent_dim)
        
        self.decoder_input = nn.Sequential(
            nn.Linear(latent_dim, flatten_size),
            nn.ReLU(inplace=True),
        )
        
        self.upb2_1 = DecoderB2(n_features*16, n_features*8)
        self.upb2_2 = DecoderB2(n_features*8, n_features*4)
        self.upb2_3 = DecoderB2(n_features*4, n_features*2)
        self.upb2_4 = DecoderB2(n_features*2, n_features)
        
        self.outchannel1 = Branch1(n_features, n_features, n_output_channels_b1)
        self.outchannel2 = Branch2(n_features, n_features, n_output_channels_b2)
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        
        if(torch.isnan(std).any()):
            print(f'std : {torch.isnan(std).any()}')
        
        eps = torch.randn_like(std)
        
        return eps * std + mu
    
    def forward(self, x):
        
        conv1, pool1 = self.down1(x)
        conv2, pool2 = self.down2(pool1)
        conv3, pool3 = self.down3(pool2)
        conv4, pool4 = self.down4(pool3)
        
        bridge = self.bridge(pool4)
        
        #2nd branch vae
        
        flattened_bridge = torch.flatten(bridge, start_dim=1)
        
        intermediate1 = self.intermediate1(flattened_bridge)
        
        
        mu = self.fc_mu(intermediate1)
        log_var = self.fc_var(intermediate1)
        
        z = self.reparameterize(mu,log_var)
        
        result = self.decoder_input(z)
        reshaped_result = result.view(-1, self.n_features*16, int(self.input_shape/16),int(self.input_shape/16))
        
        # 2nd branch
        decoder_b2_1,decoder_b2_conv_transpose_1 = self.upb2_1(reshaped_result, conv4)
        decoder_b2_2,decoder_b2_conv_transpose_2 = self.upb2_2(decoder_b2_1, conv3)
        decoder_b2_3,decoder_b2_conv_transpose_3 = self.upb2_3(decoder_b2_2, conv2)
        decoder_b2_4,decoder_b2_conv_transpose_4 = self.upb2_4(decoder_b2_3, conv1)
                
        logits2 = self.outchannel2(decoder_b2_4)
        
        
        # 1st branch
        decoder_b1_1 = self.upb1_1(bridge, decoder_b2_conv_transpose_1, conv4)
        decoder_b1_2 = self.upb1_2(decoder_b1_1, decoder_b2_conv_transpose_2, conv3)
        decoder_b1_3 = self.upb1_3(decoder_b1_2, decoder_b2_conv_transpose_3, conv2)
        decoder_b1_4 = self.upb1_4(decoder_b1_3, decoder_b2_conv_transpose_4, conv1)
                
        logits1 = self.outchannel1(decoder_b1_4)
        
        
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        
        return logits1, logits2, kld_loss


In [11]:
#custom dice loss
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1e-6):
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [12]:
# define transformations
transforms_ = transforms.Compose([transforms.ToTensor()])

# create the train and test datasets
trainDS = SegmentationDataset(X=X_train_npy, y=y_train_npy, transforms=transforms_)
validDS = SegmentationDataset(X=X_valid_npy, y=y_valid_npy, transforms=transforms_)

print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(validDS)} examples in the validation set...")

# create the training and test data loaders
trainLoader = DataLoader(trainDS, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,num_workers=os.cpu_count(),shuffle=True,drop_last=True)
validLoader = DataLoader(validDS, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,num_workers=os.cpu_count(),shuffle=True,drop_last=True)

# initialize our UNet model
n_features = 64
input_shape= 256

unet = UNet(input_shape=input_shape, n_input_channels=3, n_output_channels_b1=1,
            n_output_channels_b2=3, n_features= n_features, latent_dim=128).to(DEVICE)

# initialize loss function and optimizer
opt = Adam(unet.parameters(), lr=INIT_LR, weight_decay=1e-5)

dice = DiceLoss()
mse = nn.MSELoss()

# calculate steps per epoch for training and valid set
trainSteps = len(trainDS) // BATCH_SIZE
validSteps = len(validDS) // BATCH_SIZE

# initialize a dictionary to store training history
H = {
    "train_loss": [], "valid_loss": [],
	"train_dice_loss": [], "valid_dice_loss": [],
	"train_mse_loss": [], "valid_mse_loss": [],
	"train_vae_loss":[],"valid_vae_loss":[]
}

# loop over epochs
print("[INFO] training the network...")
startTime = time.time()

for e in tqdm(range(NUM_EPOCHS)):
	# set the model in training mode
	unet.train()

	# initialize the total training and validation loss
	totalTrainLoss = 0
	totalValidLoss = 0
 
	totalTrainKLLoss = 0
	totalValidKLLoss = 0
 
	totalTrainDiceLoss = 0
	totalValidDiceLoss = 0
 
	totalTrainMSELoss = 0
	totalValidMSELoss = 0
 
	# loop over the training set
	for (i, (x, y)) in enumerate(trainLoader):
		# print(i)
		# send the input to the device
		(x, y) = (x.to(DEVICE), y.to(DEVICE))
  
		# perform a forward pass and calculate the training loss
		pred_b1, pred_b2, kl_loss = unet(x)

		loss1 = dice(pred_b1, y)
		loss2 = mse(pred_b2.view(-1), x.view(-1))
  
  
		loss = loss1 + loss2 + kl_loss

		# first, zero out any previously accumulated gradients, then
		# perform backpropagation, and then update model parameters
		opt.zero_grad()
		loss.backward()
		opt.step()
		
		# add the loss to the total training loss so far
		totalTrainDiceLoss += loss1
		totalTrainMSELoss += loss2
		totalTrainKLLoss += kl_loss
  
		totalTrainLoss += loss
	
 
	# switch off autograd
	with torch.no_grad():
    
		# set the model in evaluation mode
		unet.eval()
  
		# loop over the validation set
		for (x, y) in validLoader:
      
			# send the input to the device
			(x, y) = (x.to(DEVICE), y.to(DEVICE))

			# make the predictions and calculate the validation loss
			pred_b1, pred_b2, valid_kl_loss = unet(x)

			validloss1 = dice(pred_b1, y)
			validloss2 = mse(pred_b2.view(-1), x.view(-1))
			validloss = validloss1 + validloss2 + valid_kl_loss
   
			# add the loss to the total validation loss so far
			totalValidDiceLoss += validloss1
			totalValidMSELoss += validloss2
			totalValidKLLoss += valid_kl_loss
   
			totalValidLoss += validloss
   
	# calculate the average training and validation loss
	avgTrainDiceLoss = totalTrainDiceLoss / trainSteps
	avgValidDiceLoss = totalValidDiceLoss / validSteps
 
 	# calculate the average training and validation loss
	avgTrainMSELoss = totalTrainMSELoss / trainSteps
	avgValidMSELoss = totalValidMSELoss / validSteps
 
  	# calculate the average training and validation loss
	avgTrainKLLoss = totalTrainKLLoss / trainSteps
	avgValidKLLoss = totalValidKLLoss / validSteps

 	# calculate the average training and validation loss
	avgTrainLoss = totalTrainLoss / trainSteps
	avgValidLoss = totalValidLoss / validSteps
 
	# avgTrainKLLoss = totalTrainKLLoss / trainSteps
 
	# update our training history
	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["valid_loss"].append(avgValidLoss.cpu().detach().numpy())
 
	H["train_dice_loss"].append(avgTrainDiceLoss.cpu().detach().numpy())
	H["valid_dice_loss"].append(avgValidDiceLoss.cpu().detach().numpy())
 
	H["train_vae_loss"].append(avgTrainMSELoss.cpu().detach().numpy())
	H["valid_vae_loss"].append(avgValidMSELoss.cpu().detach().numpy())

	H["train_mse_loss"].append(avgTrainMSELoss.cpu().detach().numpy())
	H["valid_mse_loss"].append(avgValidMSELoss.cpu().detach().numpy())
 
	# print the model training and validation information
	print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
	print("Train Dice loss: {:.4f}, Vaild Dice loss: {:.4f}, Train MSE loss: {:.4f}, Vaild MSE loss: {:.4f},"
       " Train VAE loss: {:.4f}, Vaild VAE loss: {:.4f}, Train loss: {:.4f}, Vaild loss: {:.4f}"
       .format(avgTrainDiceLoss, avgValidDiceLoss, avgTrainMSELoss, avgValidMSELoss,
               avgTrainKLLoss, avgValidKLLoss, avgTrainLoss, avgValidLoss))

# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

[INFO] found 897 examples in the training set...
[INFO] found 303 examples in the validation set...
[INFO] training the network...


  0%|          | 1/200 [00:45<2:31:57, 45.82s/it]

[INFO] EPOCH: 1/200
Train Dice loss: 0.8253, Vaild Dice loss: 0.8555, Train MSE loss: 0.0127, Vaild MSE loss: 0.0137, Train VAE loss: 2.5584, Vaild VAE loss: 0.0885, Train loss: 3.3965, Vaild loss: 0.9577


  1%|          | 2/200 [01:31<2:31:27, 45.90s/it]

[INFO] EPOCH: 2/200
Train Dice loss: 0.7711, Vaild Dice loss: 0.8220, Train MSE loss: 0.0013, Vaild MSE loss: 0.0024, Train VAE loss: 0.2419, Vaild VAE loss: 0.0297, Train loss: 1.0143, Vaild loss: 0.8541


  2%|▏         | 3/200 [02:18<2:31:19, 46.09s/it]

[INFO] EPOCH: 3/200
Train Dice loss: 0.7366, Vaild Dice loss: 0.7325, Train MSE loss: 0.0009, Vaild MSE loss: 0.0009, Train VAE loss: 0.1092, Vaild VAE loss: 0.0479, Train loss: 0.8468, Vaild loss: 0.7814


  2%|▏         | 4/200 [03:04<2:30:52, 46.19s/it]

[INFO] EPOCH: 4/200
Train Dice loss: 0.7086, Vaild Dice loss: 0.6940, Train MSE loss: 0.0009, Vaild MSE loss: 0.0005, Train VAE loss: 0.0711, Vaild VAE loss: 0.0372, Train loss: 0.7807, Vaild loss: 0.7317


  2%|▎         | 5/200 [03:50<2:30:08, 46.20s/it]

[INFO] EPOCH: 5/200
Train Dice loss: 0.6828, Vaild Dice loss: 0.7201, Train MSE loss: 0.0007, Vaild MSE loss: 0.0004, Train VAE loss: 0.0373, Vaild VAE loss: 0.0236, Train loss: 0.7208, Vaild loss: 0.7442


  3%|▎         | 6/200 [04:37<2:29:38, 46.28s/it]

[INFO] EPOCH: 6/200
Train Dice loss: 0.6402, Vaild Dice loss: 0.6399, Train MSE loss: 0.0005, Vaild MSE loss: 0.0004, Train VAE loss: 0.0259, Vaild VAE loss: 0.0224, Train loss: 0.6666, Vaild loss: 0.6627


  4%|▎         | 7/200 [05:23<2:28:46, 46.25s/it]

[INFO] EPOCH: 7/200
Train Dice loss: 0.6067, Vaild Dice loss: 0.6254, Train MSE loss: 0.0005, Vaild MSE loss: 0.0006, Train VAE loss: 0.0202, Vaild VAE loss: 0.0147, Train loss: 0.6273, Vaild loss: 0.6408


  4%|▎         | 7/200 [06:00<2:45:41, 51.51s/it]


KeyboardInterrupt: 

In [25]:
t1 = Encoder(3, 16)(next(iter(trainLoader))[0])[0]
t2 = nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2)(t1)

torch.cat([t2,t2,t2],dim=1).shape

torch.Size([32, 24, 512, 512])

In [19]:
# torch.Size([32, 16, 256, 256])

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["valid_loss"], label="valid_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")

plt.savefig(PLOT_PATH)
# serialize the model to disk
torch.save(unet, MODEL_PATH)

In [None]:
def prepare_plot(origImage, origMask, predMask):
	# initialize our figure
	figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
	# plot the original image, its mask, and the predicted mask
	ax[0].imshow(origImage)
	ax[1].imshow(origMask)
	ax[2].imshow(predMask)
	# set the titles of the subplots
	ax[0].set_title("Image")
	ax[1].set_title("Original Mask")
	ax[2].set_title("Predicted Mask")
	# set the layout of the figure and display it
	figure.tight_layout()
	figure.show()

In [None]:
def make_predictions(model, X, y):
	# set model to evaluation mode
	model.eval()
	# turn off gradient tracking
	with torch.no_grad():
		# load the image from disk, swap its color channels, cast it
		# to float data type, and scale its pixel values
		orig = X.copy()
		# find the filename and generate the path to ground truth
  
		# load the ground-truth segmentation mask in grayscale mode
		# and resize it
		gtMask = y.copy()
  
  
  		# make the channel axis to be the leading one, add a batch
		# dimension, create a PyTorch tensor, and flash it to the
		# current device
		image = np.transpose(X, (2, 0, 1))
		image = np.expand_dims(image, 0)
		image = torch.from_numpy(image).to(DEVICE)
  
		# make the prediction, pass the results through the sigmoid
		# function, and convert the result to a NumPy array
		predMask = model(image)[0].squeeze()
		predMask = predMask.cpu().numpy()
  
		# filter out the weak predictions and convert them to integers
		# predMask = (predMask > THRESHOLD) * 255
		predMask = predMask.astype(np.uint8)

		# prepare a plot for visualization
		prepare_plot(orig, gtMask, predMask)

In [None]:
# load our model from disk and flash it to the current device
print("[INFO] load up model...")
unet = torch.load(MODEL_PATH).to(DEVICE)

# iterate over the randomly selected test image paths
for i in range(10):
	# make predictions and visualize the results
	make_predictions(unet, X_train_npy[i], y_train_npy[i])