You should open this .ipynb on Google Colab and choose GPU. Then, you can run this .ipynb step by step.


**Requirements**

Pytorch

tensorboardX


In [1]:
!pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)
[K     |█                               | 10kB 25.0MB/s eta 0:00:01[K     |██▏                             | 20kB 29.6MB/s eta 0:00:01[K     |███▏                            | 30kB 17.0MB/s eta 0:00:01[K     |████▎                           | 40kB 11.5MB/s eta 0:00:01[K     |█████▎                          | 51kB 7.5MB/s eta 0:00:01[K     |██████▍                         | 61kB 7.9MB/s eta 0:00:01[K     |███████▍                        | 71kB 7.8MB/s eta 0:00:01[K     |████████▌                       | 81kB 8.6MB/s eta 0:00:01[K     |█████████▌                      | 92kB 9.2MB/s eta 0:00:01[K     |██████████▋                     | 102kB 9.4MB/s eta 0:00:01[K     |███████████▊                    | 112kB 9.4MB/s eta 0:00:01[K     |████████████▊                   | 122kB

In [2]:
#EDA
import sys
import os
import glob
import random
import time

import numpy as np
import pandas as pd

import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import cm

import argparse
import math

from torchvision import transforms
from torchvision.utils import save_image

from torch.utils.data import Dataset,DataLoader
from torchvision import datasets

import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch
import glob
from PIL import Image
from tensorboardX import SummaryWriter
from sklearn.metrics import roc_auc_score,roc_curve,auc,precision_recall_curve,average_precision_score
from sklearn import preprocessing 

**Dataset:** The dataset is saved in Google Drive: 

https://drive.google.com/drive/folders/1PKPgkOkTBqQERUCvd0pWXtQx_64q0I34?usp=sharing

You should just drag 'kaggle_3m' and 'all_data.csv' into your 'My Drive'.

In [4]:
from google.colab import drive
drive.mount('/content/gdrive')
#the path has kaggle_3m and all_data.csv
default_path = "/content/gdrive/My Drive"

#create a file for saving model
model_path = "/content/gdrive/My Drive/AnoAAE"
os.makedirs(model_path,exist_ok=True)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [5]:
#################################################################################
# reference: https://www.kaggle.com/bonhart/brain-mri-data-visualization-unet-fpn
#################################################################################
# Raw data

#if you have all_data.csv, please set read_df_csv=True.
read_df_csv=True

# Adding A/B column for diagnosis
def positiv_negativ_diagnosis(mask_path):
    value = np.max(cv2.imread(mask_path))
    if value > 0 : return 1
    else: return 0

if read_df_csv:
    df = pd.read_csv(default_path+"/all_data.csv")
else:
    data_map = []
    for sub_dir_path in glob.glob(default_path+"/kaggle_3m/*"):
        if os.path.isdir(sub_dir_path):
            dirname = sub_dir_path.split("/")[-1]
            for filename in os.listdir(sub_dir_path):
                image_path = sub_dir_path + "/" + filename
                data_map.extend([dirname, image_path])
        else:
            print("This is not a dir:", sub_dir_path)
    print("Generating data list. Please wait for 30 minutes...")
    df = pd.DataFrame({"dirname" : data_map[::2],
                      "path" : data_map[1::2]})
    df.head()
    # Masks/Not masks
    df_imgs = df[~df['path'].str.contains("mask")]
    df_masks = df[df['path'].str.contains("mask")]
    base_len = 75
    # Data sorting
    imgs = sorted(df_imgs["path"].values, key=lambda x : int(x[base_len:-4]))
    masks = sorted(df_masks["path"].values, key=lambda x : int(x[base_len:-9]))
    # Final dataframe
    df = pd.DataFrame({"patient": df_imgs.dirname.values,
                       "image_path": imgs,
                       "mask_path": masks})
    df["diagnosis"] = df["mask_path"].apply(lambda m: positiv_negativ_diagnosis(m))
    if not os.path.exists(default_path+"/all_data.csv"):
        df.to_csv(default_path+"/all_data.csv")


df_p = df[df['diagnosis']==0]
df_n = df[df['diagnosis']==1]

In [6]:
#CNN model
class Generator(nn.Module):
    def __init__(self, opt):
        super().__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim,
                                128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):  
    def __init__(self,opt):
        super().__init__()
        self.lin1 = nn.Linear(opt.latent_dim, 40)
        self.drop1 = nn.Dropout2d(0.25)
        self.act1 = nn.LeakyReLU(0.2, inplace=True)
        self.lin2 = nn.Linear(40, 40*2)
        self.drop2 = nn.Dropout2d(0.25)
        self.act2 = nn.LeakyReLU(0.2, inplace=True)
        self.lin3 = nn.Linear(40*2, 1)
        self.act3 = nn.Identity()
        #self.act4 = nn.Sigmoid()
    def forward(self, x):
        x = self.drop1(self.lin1(x))
        x = self.act1(x)
        x = self.drop2(self.lin2(x))
        x = self.act2(x)
        x = self.lin3(x)
        x = self.act3(x)
        output = x
        #output = self.act4(x)
        return output,x


class Encoder(nn.Module):
    def __init__(self, opt):
        super().__init__()

        def encoder_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *encoder_block(opt.channels, 16, bn=False),
            *encoder_block(16, 32),
            *encoder_block(32, 64),
            *encoder_block(64, 128),
        )



        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2,
                                                 opt.latent_dim),
                                       nn.Identity())
        
    def forward(self, img):
        features = self.model(img)
        features = features.view(features.shape[0], -1)
        validity = self.adv_layer(features)
        return validity

In [9]:
#load Dataset
class BrainMriDataset(Dataset):
    def __init__(self, df, transforms):
        
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, -3],cv2.IMREAD_GRAYSCALE)
        image = Image.fromarray(image)
        label = self.df.iloc[idx, -1]
        if self.transforms:
            image = self.transforms(image)
        mask = cv2.imread(self.df.iloc[idx, -2],cv2.IMREAD_GRAYSCALE)
        mask = Image.fromarray(mask)
        if self.transforms:
            mask = self.transforms(mask)
 
        return image,label,mask

# Train AAE

In [13]:
#gradient penalty computation of WGAN-GP
lambda_gp = 10
def compute_gradient_penalty(D, real_samples, fake_samples, device):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(*real_samples.shape[:2], 1, 1, device=device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    d_i_output,d_interpolates = D(interpolates)
    fake = torch.ones(*d_interpolates.shape, device=device)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates,
                              grad_outputs=fake, create_graph=True,
                              retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.shape[0], -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

#weight initialization
def truncated_normal_(tensor,mean=0,std=0.09):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size+(4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
    return tensor

def weight_init(net):
    for op in net.modules():
        if isinstance(op,nn.Conv2d):
            op.weight.data=truncated_normal_(op.weight.data,std=0.02)
            nn.init.constant_(op.bias.data, val=0)
        else:
            pass

#train aae
def train(opt,load_model=False):
    if type(opt.seed) is int:
        torch.manual_seed(opt.seed)
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    

    transform = transforms.Compose([transforms.Resize([opt.img_size]*2),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor()])
    train_df = df_p.reset_index(drop=True)
    train_dataset = BrainMriDataset(df=train_df, transforms=transform)
    train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size,shuffle=True)
    generator = Generator(opt)
    discriminator = Discriminator(opt)
    encoder = Encoder(opt)

    weight_init(generator)
    weight_init(discriminator)
    weight_init(encoder)


    generator.to(device)
    discriminator.to(device)
    encoder.to(device)

    # optimizers
    #optim_Gen = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optim_Dec = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optim_Enc = torch.optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optim_Dis = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    writer = SummaryWriter(model_path+'/tensorboard')

    print("Training..")

    for epoch in range(opt.n_epochs):
        begin_time = time.time()
        for i,(imgs,_lbl,_msk) in enumerate(train_dataloader):
            real_imgs = imgs.to(device)
            real_z = torch.randn(imgs.shape[0], opt.latent_dim, device=device)
            optim_Enc.zero_grad()
            optim_Dec.zero_grad()
            optim_Dis.zero_grad()

            # reconstruct
            encoder_output=encoder(real_imgs)
            output = generator(encoder_output)
            autoencoder_loss = torch.mean(torch.sqrt(torch.sum(torch.square(output - real_imgs),[1, 2, 3])))

            #l2 loss
            l2_loss = nn.MSELoss()
            reconstruct_loss_l2 = l2_loss(real_imgs,output)
            reconstruct_loss = autoencoder_loss+0.5*reconstruct_loss_l2
            reconstruct_loss.backward()

            optim_Enc.step()
            optim_Dec.step()

            # discriminator
            encoder.eval()
            
            with torch.no_grad():
                z_fake = encoder(real_imgs)
            real_d,real_logits = discriminator(real_z)
            fake_d,fake_logits = discriminator(z_fake)
            gradient_penalty = compute_gradient_penalty(discriminator,
                                                        real_z.data,
                                                        z_fake.data,
                                                        device)
            d_loss = -torch.mean(real_logits) + torch.mean(fake_logits) + lambda_gp * gradient_penalty
            d_loss.backward()
            optim_Dis.step()

            # encoder
            encoder.train()
            z_fake = encoder(real_imgs)
            e_fake,e_fake_logits = discriminator(z_fake)
            e_loss = -torch.mean(e_fake_logits)
            #g_loss = -torch.mean(torch.log(e_fake_logits + 1e-8))
            e_loss.backward()
            optim_Enc.step()

            #print(f"[Epoch {epoch}/{opt.n_epochs}] "
            #      f"[Batch {i}/{len(train_dataloader)}] "
            #      f"[R loss: {reconstruct_loss.item():3f}] "
            #      f"[D loss: {d_loss.item():3f}] "
            #      f"[G loss: {g_loss.item():3f}]")

            #visualization
            writer.add_scalar('Reconstruction Loss',reconstruct_loss,epoch*opt.batch_size+i)
            writer.add_scalar('Discriminator Loss',d_loss,epoch*opt.batch_size+i)
            writer.add_scalar('Encoder Loss',e_loss,epoch*opt.batch_size+i)
            writer.add_histogram('Encoder Distribution',encoder_output,epoch*opt.batch_size+i)
            writer.add_histogram('Real Distribution',real_z,epoch*opt.batch_size+i)
            #print(torch.max(real_imgs[0]))
            #print(torch.min(real_imgs[0]))
            #print(torch.max(output[0]))
            #print(torch.min(output[0]))
            input_image = real_imgs[0].reshape(opt.channels,opt.img_size,opt.img_size)
            generate_image = output[0].reshape(opt.channels,opt.img_size,opt.img_size)
            writer.add_image('input image',input_image,epoch*opt.batch_size+i)
            writer.add_image('generate image',generate_image,epoch*opt.batch_size+i)
        print(f"[Epoch {epoch}/{opt.n_epochs}] "
              f"[R loss: {reconstruct_loss.item():3f}] "
              f"[D loss: {d_loss.item():3f}] "
              f"[G loss: {e_loss.item():3f}]")
    torch.save(generator.state_dict(), model_path + "\aae_generator.pth")
    torch.save(discriminator.state_dict(), model_path+"\aae_discriminator.pth")
    torch.save(encoder.state_dict(), model_path+"\aae_encoder.pth")

In [None]:
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=300,
                        help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002,
                        help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--latent_dim", type=int, default=100,
                        help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=64,
                        help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=3,
                        help="number of image channels")
    parser.add_argument("--seed", type=int, default=None,
                        help="value of a random seed")
    opt = parser.parse_args(['--seed',str(1),'--n_epochs',str(3000),'--lr',str(1e-4),
                             '--img_size',str(64),'--latent_dim',str(128),
                             '--channels',str(1)])

    train(opt)

Tesla T4
Training..
[Epoch 0/3000] [R loss: 3.792454] [D loss: 67.356903] [G loss: 0.075394]
[Epoch 1/3000] [R loss: 3.996897] [D loss: 28.676685] [G loss: 0.081016]
[Epoch 2/3000] [R loss: 3.712724] [D loss: 17.830408] [G loss: 0.082219]
[Epoch 3/3000] [R loss: 3.556369] [D loss: 12.648266] [G loss: 0.092498]
[Epoch 4/3000] [R loss: 3.429359] [D loss: 7.290919] [G loss: 0.096107]
[Epoch 5/3000] [R loss: 3.252504] [D loss: 4.877563] [G loss: 0.100359]
[Epoch 6/3000] [R loss: 2.951897] [D loss: 3.842360] [G loss: 0.098423]


# Test & Evaluation

In [None]:
def BinaryConfusionMatrix(prediction, groundtruth):
    """Computes scores:
    TP = True Positives  
    FP = False Positives   
    FN = False Negatives   
    TN = True Negatives    
    return: TP, FP, FN, TN"""
 
    TP = np.float(np.sum((prediction == 1) & (groundtruth == 1)))
    FP = np.float(np.sum((prediction == 1) & (groundtruth == 0)))
    FN = np.float(np.sum((prediction == 0) & (groundtruth == 1)))
    TN = np.float(np.sum((prediction == 0) & (groundtruth == 0)))
 
    return TN, FP, FN,TP
    
def get_dice(prediction, groundtruth):
    TN, FP, FN,TP = BinaryConfusionMatrix(prediction, groundtruth)    
    dice = 2 * float(TP)/(float(FP + 2 * TP + FN) + 1e-6)
    return dice

def evaluation(ano_score,ano_mask,mask,plot_curve = False):
    #############
    # Dice sore #
    #############
    dice_score = get_dice(ano_mask.reshape(-1).astype(int), 
                          mask.reshape(-1).astype(int))
    #print("Dice Score:",dice_score)
    #############
    #   AUROC   #
    #############
    auroc = roc_auc_score(mask.reshape(-1).astype(int), ano_score.reshape(-1))
    #print("AUROC:",auroc)
    if plot_curve==True:
        fpr, tpr, _ = roc_curve(mask.reshape(-1).astype(int), ano_score.reshape(-1))
        roc_auc = auc(fpr, tpr)
        plt.figure()
        lw = 2
        plt.plot(fpr, tpr, color='darkorange',
                 lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic')
        plt.legend(loc="lower right")
        plt.show()
    #############
    #   AUPRC   #
    #############
    auprc = average_precision_score(mask.reshape(-1).astype(int),ano_score.reshape(-1))
    #print("AUPRC:",auprc)
    if plot_curve==True:
        precision,recall,_ = precision_recall_curve(mask.reshape(-1).astype(int),ano_score.reshape(-1))
        plt.figure()
        lw = 2
        plt.plot(recall, precision, color='darkorange',
                 lw=lw, label='PR curve (area = %0.2f)' % auprc)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall curve')
        plt.legend(loc="lower right")
        plt.show()
    return dice_score,auroc,auprc


In [None]:
def test(opt,load_model=True):
    if type(opt.seed) is int:
        torch.manual_seed(opt.seed)
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    

    transform = transforms.Compose([transforms.Resize([opt.img_size]*2),
                                    #transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor()
                                    ])
    #df_test = pd.concat([df_p[-1373:],df_n],axis=0,ignore_index=True)
    #print(df_test)

    test_df = df_n.reset_index(drop=True)
    test_dataset = BrainMriDataset(df=test_df, transforms=transform)
    test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size,shuffle=True)
    generator = Generator(opt).to(device)
    discriminator = Discriminator(opt).to(device)
    encoder = Encoder(opt).to(device)
    generator.load_state_dict(torch.load(model_path+ "\aae_generator.pth"))
    discriminator.load_state_dict(torch.load(model_path + "\aae_discriminator.pth"))
    encoder.load_state_dict(torch.load(model_path + "\aae_encoder.pth"))
    %matplotlib inline
    dice_score = 0
    auroc=0
    auprc=0
    count = 0
    for ind,(img, label,mask) in enumerate(test_dataloader):
        #Add FLAIR feature
        #img_flair = 1-np.multiply(1-img,1-mask)
        img_flair = np.multiply(img,1-mask)+np.multiply(img,mask)*1.6
        #Generate normal image and anomaly region
        img_z = encoder(img_flair.to(device))
        img_g = generator(img_z)
        img_ano = img_flair.to(device)-img_g
        #generate anomaly mask
        ano_mask = []
        for i in range(img.shape[0]):
            ano_img = img_ano[i].cpu().detach().numpy().squeeze()
            ret,thresh_img = cv2.threshold(ano_img,np.max(ano_img)/2.2,1,cv2.THRESH_BINARY)
            ano_mask.append(thresh_img)
        #normalize anomaly region image
        ano_score = img_ano.cpu().detach().numpy().reshape(32,-1)
        minmax_scaler = preprocessing.MinMaxScaler() 
        ano_score = minmax_scaler.fit_transform(ano_score)
        #evaluation
        dice,roc,prc = evaluation(ano_score,np.array(ano_mask),mask.numpy())
        dice_score+=dice
        auroc+=roc
        auprc+=prc
        print('Evaluating:',count)
        count+=1
    dice_score = dice_score/count
    auroc = auroc/count
    auprc = auprc/count
    print('Dice score:',dice_score)
    print('AUROC:',auroc)
    print('auprc:',auprc)


    #Visualization the first batch
    for ind,(img, label,mask) in enumerate(test_dataloader):
        plt.figure(figsize=(10, 10))
        for i in range(32):
            ax = plt.subplot(4, 8, i + 1)
            plt.imshow(img[i].squeeze(),cmap ='gray')
            plt.axis("off")
        plt.show()
        print("=================================================================")
        #img_flair = 1-np.multiply(1-img,1-mask)
        img_flair = np.multiply(img,1-mask)+np.multiply(img,mask)*1.6
        plt.figure(figsize=(10, 10))
        for i in range(32):
            ax = plt.subplot(4, 8, i + 1)
            plt.imshow(img_flair[i].squeeze(),cmap ='gray')
            plt.axis("off")
        #plt.savefig(path+"AnoAAE_results/"+"img_flair"+".jpg")
        plt.show()
        print("=================================================================")
        img_z = encoder(img_flair.to(device))
        img_g = generator(img_z)
        img_ano = img_flair.to(device)-img_g
        plt.figure(figsize=(10, 10))
        for i in range(32):
            ax = plt.subplot(4, 8, i + 1)
            plt.imshow(img_g[i].cpu().detach().numpy().squeeze(),cmap ='gray')
            plt.axis("off")
        #plt.savefig(path+"AnoAAE_results/"+"img_g"+".jpg")
        plt.show()
        print("=================================================================")
        plt.figure(figsize=(10, 10))
        for i in range(32):
            ax = plt.subplot(4, 8, i + 1)
            plt.imshow(mask[i].cpu().detach().numpy()[0].squeeze(),cmap ='gray')
            plt.axis("off")
        #plt.savefig(path+"AnoAAE_results/"+"mask"+".jpg")
        plt.show()
        print("=================================================================")
        plt.figure(figsize=(10, 10))
        ano_score = []
        ano_mask = []
        for i in range(32):
            ax = plt.subplot(4, 8, i + 1)
            ano_img = img_ano[i].cpu().detach().numpy().squeeze()
            ret,thresh_img = cv2.threshold(ano_img,np.max(ano_img)/2.2,1,cv2.THRESH_BINARY)
            ano_score.append(ano_img)
            ano_mask.append(thresh_img)
            plt.imshow(thresh_img,cmap ='gray')
            plt.axis("off")
        #plt.savefig(path+"AnoAAE_results/"+"ano_mask"+".jpg")
        plt.show()
        print("=================================================================")

        ano_score = np.array(ano_score).reshape(32,-1)
        minmax_scaler = preprocessing.MinMaxScaler() 
        ano_score = minmax_scaler.fit_transform(ano_score)

        dice_score,auroc,auprc=evaluation(ano_score,np.array(ano_mask),mask.numpy(),True)
        break

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=300,
                        help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002,
                        help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--latent_dim", type=int, default=100,
                        help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=64,
                        help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=3,
                        help="number of image channels")
    parser.add_argument("--seed", type=int, default=None,
                        help="value of a random seed")
    opt = parser.parse_args(['--seed',str(1),'--n_epochs',str(3000),'--lr',str(1e-4),
                             '--img_size',str(64),'--latent_dim',str(128),
                             '--channels',str(1)])
    test(opt,True)


# Visualize the training curve

In [None]:
%load_ext tensorboard
%tensorboard --logdir=model_path+'/tensorboard' --host=127.0.0.1