In [None]:
# !apt-get install openslide-tools
# !pip install openslide-python
%matplotlib inline
from openslide import open_slide, __library_version__ as openslide_version
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from skimage.color import rgb2gray
import os
from tqdm import tqdm
import random
import time
import copy
import itertools

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torchsummary import summary
import numpy as np
import torchvision
from torchvision.transforms.functional import hflip, vflip, rotate, adjust_hue, adjust_contrast, adjust_brightness, adjust_saturation
from torchvision.transforms import ToTensor

cudnn.benchmark = True

In [None]:
# Key Parameters 
download_data = 0 # download the slide image and masks from google bucket
gen_data = 0 # generate the np array data for our model

ngf = 8 # number of channels that the generator starts with
ndf = 8 # number of channels that the discriminator starts with

batch = 4 # batch size of a dataset
nc_i = 9 # number of channels in the image (3 images, each 3 channels)
nc_m = 1 # number of channels in the mask
height = 128 # height of mask/image
width = 128 # width of mask/image

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# download the image the google bucket that I set up 
# credit: datasset is by Joshua Gordon for the COMS 4995 Applied Deep Leaning in Fall 2022
if download_data == 1: 
    image_url = 'https://storage.googleapis.com/acv_project/adl_slides.zip'
    !curl -O $image_url
    !unzip adl_slides
    !rm adl_slides.zip

In [None]:
def read_slide(slide, x, y, level, width, height, as_float=False):
    im = slide.read_region((x,y), level, (width, height))
    im = im.convert('RGB') # drop the alpha channel
    if as_float:
        im = np.asarray(im, dtype=np.float32)
    else:
        im = np.asarray(im)
    assert im.shape == (height, width, 3)
    return im

In [None]:
row = 250
col = 400
slide = open_slide("./acv_slides/tumor_075.tif")
mask = open_slide("./acv_slides/tumor_075_mask.tif")

# notice that the tumor is very small and it can only be viewed at a level 0 to level 4 region
lev_7 = read_slide(slide, 
                x = 0, 
                y = 0, 
                level = 7, 
                width = slide.level_dimensions[7][0], 
                height = slide.level_dimensions[7][1])
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_7)

lev_7_mask = read_slide(mask, 
                x = 0, 
                y = 0, 
                level = 7, 
                width = slide.level_dimensions[7][0], 
                height = slide.level_dimensions[7][1])
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_7_mask[:,:,0])

# note the way we can alter the zoom levels
# we keep the center point fixed and expand the window around it
lev_4 = read_slide(slide, 
                    x = (row-1)*128 - 64*7, 
                    y = (col-1)*128 - 64*7, 
                    level = 4, 
                    width = 128, 
                    height = 128)
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_4)

lev_4_mask = read_slide(mask, 
                        x = (row-1)*128 - 64*7, 
                        y = (col-1)*128 - 64*7, 
                        level = 4, 
                        width = 128, 
                        height = 128)
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_4_mask[:,:,0])

lev_2 = read_slide(slide, 
                    x = (row-1)*128 - 64*3, 
                    y = (col-1)*128 - 64*3, 
                    level = 2, 
                    width = 128, 
                    height = 128)
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_2)

lev_2_mask = read_slide(mask, 
                        x = (row-1)*128 - 64*3, 
                        y = (col-1)*128 - 64*3, 
                        level = 2, 
                        width = 128, 
                        height = 128)
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_2_mask[:,:,0])

lev_0 = read_slide(slide, 
                    x = (row-1)*128, 
                    y = (col-1)*128, 
                    level = 0, 
                    width = 128, 
                    height = 128)
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_0)


lev_0_mask = read_slide(mask, 
                        x = (row-1)*128, 
                        y = (col-1)*128, 
                        level = 0, 
                        width = 128, 
                        height = 128)
plt.figure(figsize=(5,5), dpi=100)
plt.imshow(lev_0_mask[:,:,0])

In [None]:
# some simple helper functions

def find_tissue_pixels(image, intensity=0.8):
    im_gray = rgb2gray(image)
    assert im_gray.shape == (image.shape[0], image.shape[1])
    indices = np.where(im_gray <= intensity)
    return list(zip(indices[0], indices[1]))

def have_cell(image, mask, intensity=0.8, percentage_thres = 0.15): # originally at 5%
    if np.sum(mask)>0:
        return True
    im_gray = rgb2gray(image)
    assert im_gray.shape == (image.shape[0], image.shape[1])
    cell_percentage = np.mean(im_gray <= intensity)
    if cell_percentage > percentage_thres:
        return True
    else:
        return False

In [None]:
def generate_data(slide_id, output_path, root_path = './acv_slides', perc_neg = 0.8, cap = 500): 
    '''
    create the data for the model to train and evaluate on

    slide_id: the id of the slide that we will be working on
    output_path: where to save the output
    root_path: where the slide will be resciding in
    perc_neg: the approximate proportion of samples that will have no tumor at all 
    cap: the maximum number of samples from this slide
    '''

    image_path = os.path.join(root_path, slide_id + '.tif')
    mask_path = os.path.join(root_path, slide_id + '_mask.tif')
    
    # create output file
    if not os.path.exists(output_path):
        os.mkdir(output_path) # save all data to this path

    slide = open_slide(image_path)
    tumor_mask = open_slide(mask_path)
    dim_x, dim_y = slide.level_dimensions[7]

    slide_image = read_slide(slide, 
                             x=0, 
                             y=0, 
                             level=7, 
                             width=dim_x, 
                             height=dim_y)

    mask_image = read_slide(tumor_mask, 
                            x=0, 
                            y=0, 
                            level=7, 
                            width=dim_x, 
                            height=dim_y)

    indices = find_tissue_pixels(slide_image)
    random.shuffle(indices) # shuffle it so that every indice is equally likely to appear as a sample

    # compute the number of positive cells
    perc_pos = np.sum(mask_image[:,:,0])/len(indices)

    id = 0
    pos_count = 0
    neg_count = 0

    for row, col in indices:
        # save image if it is a positive tumor in the middle and randomly select 80 other images
        c = np.random.uniform()
        if (mask_image[row][col][0] == 1) or (c <= perc_pos * perc_neg):

            # count the number of slides added
            id += 1
            if mask_image[row][col][0] == 1:
                pos_count += 1
            else: 
                neg_count += 1

            lev_2_slide = read_slide(slide, 
                                    x = col*128 - 64*3, 
                                    y = row*128 - 64*3, 
                                    level = 2, width = 128, height = 128)
            
            lev_3_slide = read_slide(slide, 
                                    x = col*128 - 64*5, 
                                    y = row*128 - 64*5, 
                                    level = 3, width = 128, height = 128)

            lev_4_slide = read_slide(slide, 
                                    x = col*128 - 64*7, 
                                    y = row*128 - 64*7, 
                                    level = 4, width = 128, height = 128)
            
            mask = read_slide(tumor_mask, 
                            x = col*128 - 64*5, 
                            y = row*128 - 64*5, 
                            level = 3, width = 128, height = 128)            
            np.save(os.path.join(output_path, slide_id + str(id)), np.array([lev_2_slide, lev_3_slide, lev_4_slide, mask]))
            
            # prevent the overdominance of a single slide
            if pos_count > cap: 
                break

    print(slide_id + ": {} positive slides and {} negative slides added to '{}'".format(int(pos_count), int(neg_count), output_path))
    return None

In [None]:
# Save the 3 layers of the slides and the mask image

if gen_data == 1: 
    model_path = './data'
    if not os.path.exists(model_path):
        os.mkdir(model_path) # save all models to this path
    else: 
        !rm -rf './data/'
        os.mkdir(model_path)

    train_slides_names = [
                          'tumor_001',
                          'tumor_002',
                          'tumor_005',
                          'tumor_012',
                          'tumor_016',
                          'tumor_031',
                          'tumor_035',
                          'tumor_059',
                          'tumor_064',
                          'tumor_075',
                          'tumor_078',
                          'tumor_081',
                          'tumor_084',
                          'tumor_091',
                          'tumor_096',
                          'tumor_110']

    for slide_id in train_slides_names:
        generate_data(slide_id, output_path = './data/train')

    valid_slides_names = ['tumor_057',
                          'tumor_019']

    for slide_id in valid_slides_names:
        generate_data(slide_id, output_path = './data/val')

In [None]:
# data loader (adapted from the lecture slides)

class MyImageDataset(torch.utils.data.Dataset):
    """
        There are only 2 files, one for training and one for validation
        The tumor mask is stored along with the cell image
    """
    def __init__(self, images_dir, image_transform = True):

        self.images_dir = images_dir
        # note that only flip and rotation applies to tumor mask
        self.image_transform = image_transform
        
        # Next, let's collect all image files underneath each class name directory as a single list of image files. 
        # note that we use a np array to prevent memory leak
        self.image_files = [os.path.join(images_dir, img_id) for img_id in os.listdir(images_dir)]
     
        # How many total images do we need to iterate in this entire dataset?
        self.num_images = len(self.image_files)
        
    def __len__(self):
        return self.num_images
    
    def __getitem__(self, idx):  
        # Retrieve the images from the list, load it, transform it, 
        # concat the transformed images together along the channel axis
        # and return it along with its ground truth label.  
        # need the image and mask to be in tensor form 

        sample = np.load(self.image_files[idx])
        lev_2_image = Image.fromarray(sample[0])
        lev_3_image = Image.fromarray(sample[1])
        lev_4_image = Image.fromarray(sample[2])
        mask = Image.fromarray(sample[3])
        
        # Apply the image transformations if needed
        if self.image_transform: 
            # flipping
            c = np.random.randint(0,3)
            if c == 1: 
                lev_2_image = hflip(lev_2_image)
                lev_3_image = hflip(lev_3_image)
                lev_4_image = hflip(lev_4_image)
                mask = hflip(mask)
            elif c == 2: 
                lev_2_image = vflip(lev_2_image)
                lev_3_image = vflip(lev_3_image)
                lev_4_image = vflip(lev_4_image)
                mask = vflip(mask)
            elif c == 3: 
                lev_2_image = hflip(lev_2_image)
                lev_3_image = hflip(lev_3_image)
                lev_4_image = hflip(lev_4_image)
                mask = hflip(mask)
                lev_2_image = vflip(lev_2_image)
                lev_3_image = vflip(lev_3_image)
                lev_4_image = vflip(lev_4_image)
                mask = vflip(mask)

            # rotation
            c = np.random.randint(0,3)
            lev_2_image  = rotate(lev_2_image, 90*c)
            lev_3_image  = rotate(lev_3_image, 90*c)
            lev_4_image  = rotate(lev_4_image, 90*c)
            mask = rotate(mask, 90*c)

            # color gittering by 50% (only for the image)
            # adjust brightness
            c = np.random.uniform(0.5,1.5)
            lev_2_image  = adjust_brightness(lev_2_image, c)
            lev_3_image  = adjust_brightness(lev_3_image, c)
            lev_4_image  = adjust_brightness(lev_4_image, c)

            # adjust contrast
            c = np.random.uniform(0.5,1.5)
            lev_2_image  = adjust_contrast(lev_2_image, c)
            lev_3_image  = adjust_contrast(lev_3_image, c)
            lev_4_image  = adjust_contrast(lev_4_image, c)

            # adjust saturation
            c = np.random.uniform(0.5,1.5)
            lev_2_image  = adjust_saturation(lev_2_image, c)
            lev_3_image  = adjust_saturation(lev_3_image, c)
            lev_4_image  = adjust_saturation(lev_4_image, c)

            # adjust hue
            c = np.random.uniform(-0.2,0.2)
            lev_2_image  = adjust_hue(lev_2_image, c)
            lev_3_image  = adjust_hue(lev_3_image, c)
            lev_4_image  = adjust_hue(lev_4_image, c)

        lev_2_image = ToTensor()(lev_2_image)
        lev_3_image = ToTensor()(lev_3_image)
        lev_4_image = ToTensor()(lev_4_image)
        image = torch.cat([lev_2_image, lev_3_image, lev_4_image], 0)
        mask = ToTensor()(mask)
        return image, mask # note the implicit permute here

In [None]:
def collate_fn(batch):
    # Now collate into mini-batches
    images = torch.stack([b[0] for b in batch]) 
    masks = torch.stack([b[1] for b in batch])

    return images, masks[:,0:1,:,:]

In [None]:
data_dir = './data'

data_transforms = {'train': ['flip','rotation','gitter'],
                    'val': []}

# implement custom image_dataset and wrap it with the dataloader
image_datasets = {x: MyImageDataset(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, 
                                              shuffle=True, num_workers=0, collate_fn = collate_fn)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [None]:
# U-Net Generator
# implementation from https://github.com/usuyama/pytorch-unet

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   

class Generator_Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dconv_down1 = double_conv(nc_i, 64)  
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, nc_m)
        self.conv_last = nn.Sequential(nn.Conv2d(nc_m, 1, 1),
                                       nn.Sigmoid())
        
    def forward(self, x):
        # input shape = (9, height, width)
        # downconv block 1 (output shape = (64, height/2, width/2))
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        # downconv block 2 (output shape = (128, height/4, width/4))
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        # downconv block 3 (output shape = (256, height/8, width/8))
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        # upconv layer (output shape = (512, height/4, width/4))
        x = self.dconv_down4(x)
        x = self.upsample(x)        
        # upconv layer (output shape = (256, height/2, width/2))
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        # upconv layer (output shape = (128, height, width))
        x = torch.cat([x, conv2], dim=1)       
        x = self.dconv_up2(x)
        x = self.upsample(x)        
        # last layer (output shape = (1, height, width))
        x = torch.cat([x, conv1], dim=1)   
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out

In [None]:
# build the generator and disriminator network separately
# adapted from the model given in https://towardsdatascience.com/generative-adversarial-network-gan-for-dummies-a-step-by-step-tutorial-fdefff170391#:~:text=GAN%20Training&text=Step%201%20%E2%80%94%20Select%20a%20number,both%20fake%20and%20real%20images.

class Discriminator(nn.Module):
    '''
    takes in a mask object and determine if it is real or generated
    '''
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input shape = (nc_m + nc_i, height, width)
            # discriminator block 1 (output shape = (ndf, height/2, width/2))
            nn.Conv2d(nc_m + nc_i, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # discriminator block 2 (output shape = (ndf * 2, height/4, width/4))
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # discriminator block 3 (output shape = (ndf * 4, height/8, width/8))
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # discriminator block 4 (output shape = (ndf * 8, height/16, width/16))
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # average and return a binary output
            nn.Flatten(),
            nn.Linear(4096,1), # ndf*8*height/16*width/16
            nn.Sigmoid())

    def forward(self, image, mask):
        input = torch.cat((image, mask), 1) # merge on channels
        output = self.main(input)
        return output

In [None]:
# create a place to save memory
model_path = './model'
if not os.path.exists(model_path):
    os.mkdir(model_path) # save all models to this path

# parameters for training the U-Net
model_g = Generator_Unet().to(device)
optimizer_g = optim.Adam(model_g.parameters(), lr=0.01) 
scheduler_g = lr_scheduler.StepLR(optimizer_g, step_size=7, gamma=0.1)
criterion_g = nn.BCELoss()
num_epochs = 15 # we just want to warm start the generator here

In [None]:
# train the U-Net Model
train_loss_list = []
val_loss_list = []
best_loss = 100.0

for epoch in range(num_epochs):
    # training step
    model_g.train()
    train_running_loss = 0
    for inputs, labels in tqdm(dataloaders['train']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer_g.zero_grad()
        torch.set_grad_enabled(True)
        outputs = model_g(inputs)
        loss = criterion_g(outputs, labels)
        loss.backward()
        optimizer_g.step()
        train_running_loss += loss.item() * inputs.size(0)
    train_loss = train_running_loss/dataset_sizes['train']
    train_loss_list.append(train_loss)
    scheduler_g.step()

    # validation step
    model_g.eval()
    val_running_loss = 0
    for inputs, labels in tqdm(dataloaders['val']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        torch.set_grad_enabled(False)
        outputs = model_g(inputs)
        loss = criterion_g(outputs, labels)
        val_running_loss += loss.item() * inputs.size(0)
    val_loss = val_running_loss/dataset_sizes['val']
    val_loss_list.append(val_loss)

    # update the best model
    if val_loss < best_loss: 
        best_loss = val_loss
        torch.save(model_g.state_dict(), './model/u_net')

    print(f'epoch: {epoch}/{num_epochs}, Train Loss: {train_loss:.8f}, Val Loss: {val_loss:.8f}')

# load best model weight
model_g.load_state_dict(torch.load('./model/u_net'))

In [None]:
# plot the validation loss curve
plt.plot(val_loss_list)  
plt.title('validation loss for u-net')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

In [None]:
# parameters for the rest of the GAN model
model_d = Discriminator().to(device)
optimizer_d = optim.Adam(model_d.parameters(), lr=0.001)
criterion_gan = nn.BCELoss()
num_epochs = 30

In [None]:
# train the model
# idea is from https://medium.com/intel-student-ambassadors/segmentation-using-generative-adversarial-networks-80a161cf33c0

discriminator_loss_list = []
generator_loss_list = []
generator_val_loss_list = []
best_loss = 100.0

for epoch in range(num_epochs):
    # train step
    running_loss_d = 0
    running_loss_g = 0
    model_d.train()
    model_g.train()
    for inputs, labels in tqdm(dataloaders['train']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real (log(D(x)))
        optimizer_d.zero_grad()
        torch.set_grad_enabled(True)
        output = model_d(inputs, labels)
        loss_d_real = criterion_gan(output, torch.ones(output.shape, device = device))
        loss_d_real.backward()
        # train with fake (log(1 - D(G(z))))
        fake_labels = model_g(inputs)
        output = model_d(inputs, fake_labels)
        loss_d_fake = criterion_gan(output, torch.zeros(output.shape, device = device))
        loss_d_fake.backward()
        loss_d = loss_d_real + loss_d_fake
        optimizer_d.step()
        running_loss_d += loss_d.item()
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        optimizer_g.zero_grad()
        fake_labels = model_g(inputs)
        output = model_d(inputs, fake_labels)
        # note that fake labels are real for generator cost
        loss_g = criterion_gan(output, torch.ones(output.shape, device = device)) 
        loss_g.backward()
        optimizer_g.step()
        running_loss_g += loss_g.item()
    discriminator_loss = running_loss_d/dataset_sizes['train']
    discriminator_loss_list.append(discriminator_loss)
    generator_loss = running_loss_g/dataset_sizes['train']
    generator_loss_list.append(generator_loss)

    # validation step
    model_g.eval()
    val_running_loss = 0
    for inputs, labels in tqdm(dataloaders['val']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        torch.set_grad_enabled(False)
        outputs = model_g(inputs) 
        loss = criterion_g(outputs, labels)
        val_running_loss += loss.item() * inputs.size(0)
    val_loss = val_running_loss/dataset_sizes['val']
    generator_val_loss_list.append(val_loss)
    # update the best model
    if val_loss < best_loss: 
        # save the weights
        best_loss = val_loss
        torch.save(model_g.state_dict(), './model/gan_generator')
        torch.save(model_d.state_dict(), './model/gan_discriminator')

    # print the progress to determine if the progress has stagnated
    print(f'epoch: {epoch}/{num_epochs}, Generator Loss: {generator_loss:.4f}, Discriminator Loss: {discriminator_loss:.4f}, Generator Validation Loss: {val_loss:.8f}')

In [None]:
plt.plot(generator_val_loss_list)  
plt.title('validation loss for u-net trained using GAN')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

In [None]:
# helper functions to output the results

def slide_to_torch(slide): 
    ''' 
    convert the np array slide to a tensor for model
    '''
    slide = Image.fromarray(slide)
    slide = ToTensor()(slide)
    return slide.expand(size = (1, -1, -1, -1))

def model_predict(model_g, lev_2_slide, lev_3_slide, lev_4_slide, threshold):
    '''
    outputs the prediction of the model for a single sample
    '''
    lev_2_slide = slide_to_torch(lev_2_slide)
    lev_3_slide = slide_to_torch(lev_3_slide)
    lev_4_slide = slide_to_torch(lev_4_slide)    
    input = torch.cat([lev_2_slide, lev_3_slide, lev_4_slide],1)
    pred = model_g(input.to(device))
    if torch.mean(pred).item() >= 0.00125:
        print(torch.mean(pred), torch.max(pred))
    return 1 if torch.mean(pred) > threshold else 0  # we want to minimize the chances of missing the highlighted tumor region


def accuracy(model, slide_id, threshold = 0.1, root_path = './acv_slides'):
    '''
    output the accuracy on a level 7 scale 
    threshold: the average mask value for a positive prediction
    '''
    image_path = os.path.join(root_path, slide_id + '.tif')
    mask_path = os.path.join(root_path, slide_id + '_mask.tif')
        

    slide = open_slide(image_path)
    tumor_mask = open_slide(mask_path)

    # get the level 7 data
    slide_image = read_slide(slide, 
                                x=0, 
                                y=0, 
                                level=7, 
                                width=slide.level_dimensions[7][0], 
                                height=slide.level_dimensions[7][1])

    mask_image = read_slide(tumor_mask, 
                            x=0, 
                            y=0, 
                            level=7, 
                            width=slide.level_dimensions[7][0], 
                            height=slide.level_dimensions[7][1])

    indices = find_tissue_pixels(slide_image)
    acc = 0

    # predict on a level 7 image

    dim_x, dim_y = slide.level_dimensions[5]
    for row, col in tqdm(indices):
            lev_2_slide = read_slide(slide, 
                                    x = col*128 - 64*3, 
                                    y = row*128 - 64*3, 
                                    level = 2, width = 128, height = 128)
            
            lev_3_slide = read_slide(slide, 
                                    x = col*128 - 64*5, 
                                    y = row*128 - 64*5, 
                                    level = 3, width = 128, height = 128)

            lev_4_slide = read_slide(slide, 
                                    x = col*128 - 64*7, 
                                    y = row*128 - 64*7, 
                                    level = 4, width = 128, height = 128)

            acc += (model_predict(model_g, lev_2_slide, lev_3_slide, lev_4_slide, threshold) == mask_image[row][col])
            
    return acc/len(indices)

In [None]:
test_slides_names = ['tumor_094',
                     'tumor_101']

# load trained U-Net
model_g.load_state_dict(torch.load('./model/gan_generator'))
model_g.eval()
model_g.to(device)

accuracy(model_g, test_slides_names[0])[0]
