In [None]:
import zipfile
with zipfile.ZipFile("PodocyteDataset.zip", 'r') as zip_ref:
    zip_ref.extractall("D:IS")

In [None]:
import os
imageDir = "PodocyteDataset/training/images"
annotationDir = "PodocyteDataset/training/masks"
# holdoutImageDir = "PodocyteDataset/test/images"
# holdoutAnnotationDir = "PodocyteDataset/test/masks"
NUM_CLASSES = 4
height, width = 256, 256

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
print(len(os.listdir(imageDir)))
# print(len(os.listdir(holdoutImageDir)))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from keras.optimizers import Adam

In [None]:
# from keras import backend as K
# K.clear_session()

## Create Generator

In [None]:
from generator import Generator

In [None]:
gen =  Generator((height, width, 3), NUM_CLASSES)
deepLab = gen.deepLab()
deepLab.summary()

## Create Discriminator

In [None]:
from discriminator import Discriminator

In [None]:
dis = Discriminator((height, width, 3), (height, width, 1))
conditionalFCD = dis.CFCDiscriminator()
conditionalFCD.summary()

In [None]:
conditionalFCD.compile(loss='binary_crossentropy', optimizer=Adam(5e-4, beta_1=0.9, beta_2=0.99), metrics=['accuracy'])

## Combine Model

In [None]:
from generative_adversarial_network import GAN

In [None]:
cGAN = GAN(deepLab, conditionalFCD, height, width)
combined_model = cGAN.supervisedCGAN()
combined_model.summary()

In [None]:
combined_model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], loss_weights=[0.001,1], optimizer=Adam(5e-5))

## Image Preprocessing

### Test Train Split

In [None]:
from preprocessing import Preprocessing

In [None]:
%%time
preProc = Preprocessing(height, width, NUM_CLASSES)
trainImages, valImages, testImages = preProc.get_test_train_filenames(imageDir, 0.10, 0.10)
print(len(trainImages), len(valImages), len(testImages))

### Data visualization

In [None]:
_, axs = plt.subplots(3, 2, figsize=(15, 15))
for n, d in enumerate(preProc.data_gen(trainImages, imageDir, annotationDir, 1)):
    _, h, w, c = d[0].shape
    axs[n][0].imshow(d[0].reshape(h, w, c))
    axs[n][1].imshow(np.argmax(d[1], axis=3).reshape(h, w))
    if(n == 2):
        break

## Train Model

In [None]:
%%time
# batch_size = 24
batch_size = 2
# iterations = 32
iterations = 5
epoch = 1
patch_i = 32
patch_j = 32
loss_1, loss_2, loss_3 = [], [], []
trainmetrics, valmetrics = [], []
# patch_ones = np.ones((batch_size, patch_i, patch_j, 1), dtype='int8')
# patch_zeros = np.zeros((batch_size, patch_i, patch_j, 1), dtype='int8')

patch_ones = np.random.uniform(0.85, 1.0, ((batch_size, patch_i, patch_j, 1)))
patch_zeros = np.zeros((batch_size, patch_i, patch_j, 1), dtype='int8')

for e in range(epoch):
    avgL1, avgL2, avgL3 = 0, 0, 0
    for i in range(iterations):
    
        realImg, maskImg = next(preProc.data_gen(trainImages, imageDir, annotationDir, batch_size))
        segImg = np.expand_dims(np.argmax(maskImg, 3), 3)
        gImg = np.expand_dims(np.argmax(deepLab.predict(realImg), 3), 3)
        dis_loss_1 = conditionalFCD.train_on_batch([realImg, segImg], patch_ones)
        dis_loss_2 = conditionalFCD.train_on_batch([realImg, gImg], patch_zeros)
        dis_loss = 0.5 * np.add(dis_loss_1, dis_loss_2)

        loss_1.append(dis_loss)
        avgL1 += dis_loss[0]

        valid = patch_ones        
        if(np.random.rand(1) > 0.95):
            valid = patch_zeros

        #Train cGAN   
        cgan_loss = combined_model.train_on_batch([realImg], [valid, maskImg])
        loss_2.append(cgan_loss)

        avgL2 += cgan_loss[1]
        avgL3 += cgan_loss[2]

        if((i+1)%5 == 0):
            print("Epoch %d/%d   iteration %d/%d  D-Acc %3d%%  D-Loss: %f  Total-Loss: %f  cGAN_Dis-Loss: %f cGAN_Gen-Loss: %f" % (e+1, epoch, i+1, iterations, 100*dis_loss[1], dis_loss[0], cgan_loss[0], cgan_loss[1], cgan_loss[2]))
        
        if(i == iterations-1):
            break
    loss_3.append([avgL1/iterations, avgL2/iterations, avgL3/iterations])

    #check train accuracy
    # trainmetrics.append(get_all_metrics())
    # valmetrics.append(get_all_metrics())

    # predict_on_image(gen, e+1)
#     if((e+1)%10==0):
#         print("Save Models")
#         gen.save("gdrive/My Drive/IS_Gen_patch70.h5")
#         print("Generator Model Saved ", e+1)
#         dis.save("gdrive/My Drive/IS_Dis_patch70.h5")
#         print("Discriminator Model Saved ", e+1)

In [None]:
# from keras.models import load_model
# gen = load_model('gdrive/My Drive/.h5')

## Model Metrics

In [None]:
from metrics import Metrics
labels = np.array(['background', 'class_1', 'class_2', 'class_3'])

### Validation Metrics

In [None]:
valMet = Metrics(imageDir, annotationDir, valImages, height, width, deepLab, labels)

In [None]:
valMet.printAllMetrics()

In [None]:
_, axs = plt.subplots(3, 3, figsize=(10, 10))
for n, d in enumerate(preProc.data_gen(valImages, imageDir, annotationDir, 1)):
    _, h, w, c = d[0].shape
    seg = np.argmax(deepLab.predict(d[0]), axis=3).reshape(h, w)
    print("Actual Classes : ", np.unique(np.argmax(d[1],axis=3).reshape(h, w), return_counts=True))
    print("Predicted Classes : ", np.unique(seg, return_counts=True))
    axs[n][0].imshow(d[0].reshape(h, w, c))
    axs[n][1].imshow(seg)
    axs[n][2].imshow(np.argmax(d[1],axis=3).reshape(h, w))
    if(n == 2):
        break

### Test Metrics

In [None]:
testMet = Metrics(imageDir, annotationDir, testImages, height, width, deepLab, labels)

In [None]:
testMet.printAllMetrics()

In [None]:
_, axs = plt.subplots(1, 3, figsize=(10, 10))
for n, d in enumerate(preProc.data_gen(testImages, imageDir, annotationDir, 1)):
    _, h, w, c = d[0].shape
    seg = np.argmax(deepLab.predict(d[0]), axis=3).reshape(h, w)
    print("Actual Classes : ", np.unique(np.argmax(d[1],axis=3).reshape(h, w), return_counts=True))
    print("Predicted Classes : ", np.unique(seg, return_counts=True))
    axs[0].imshow(d[0].reshape(h, w, c))
    axs[0].set_xlabel('Original')
    axs[1].imshow(seg)
    axs[1].set_xlabel('Predicted')
    axs[2].imshow(np.argmax(d[1],axis=3).reshape(h, w))
    axs[2].set_xlabel('Actual')
    if(n == 0):
        break