In [None]:
import csv
import numpy as np
import matplotlib.pyplot as plt
from preprocessing import *
from model import *
from keras.optimizers import Adam

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive', force_remount=True)

In [None]:
# !unzip gdrive/My\ Drive/ComicDataset.zip

In [None]:
%%time
trainColorImages = "ComicDataset/Train/ColorImages"
trainSketchImages = "ComicDataset/Train/Black"

testColorImages = "ComicDataset/Test/ColorImages"
testSketchImages = "ComicDataset/Test/Black"

imageSet = get_images(trainColorImages, trainSketchImages)

In [None]:
%%time
_, axs = plt.subplots(2, 2, figsize=(15, 15))
for i, d in enumerate(data_gen(imageSet, 1)):
    _, h0, w0, c0 = d[0].shape
    _, h1, w1, c1 = d[1].shape
    gImg = d[0].reshape(h0, w0) * 0.5 + 0.5
    cImg = d[1].reshape(h1, w1, c1) * 0.5 + 0.5
    axs[i][0].imshow(gImg, cmap='Greys_r')
    axs[i][1].imshow(cImg)
    if(i == 1):
        break

In [None]:
gen = generator()
gen.summary()

In [None]:
dis = discriminator()
dis.summary()
dis.compile(loss='binary_crossentropy', optimizer=Adam(1e-4), metrics=['accuracy'])

In [None]:
combined_model = cGAN(gen, dis)
combined_model.summary()
combined_model.compile(loss={'model_2':'binary_crossentropy', 'model_1':'mse'}, optimizer=Adam(5e-4), metrics=['accuracy'], loss_weights={'model_2':0.1, 'model_1':1})

In [None]:
os.makedirs('Manga Colorization/results/Train')
os.makedirs('Manga Colorization/results/Test')
os.makedirs('Manga Colorization/AccuracyAndLosses')
os.makedirs('Manga Colorization/models')

In [None]:
%%time
batch_size = 8
iteration = 64
epoch = 1000
patch_i = 18
patch_j = 12
discriminatorLoss, generatorLoss = [], []
discriminatorAccuracy, generatorAccuracy = [], []

trainSet = get_image_set(trainColorImages, trainSketchImages,['45.png', '416.png'])
testSet = get_image_set(testColorImages, testSketchImages,['237.png', '441.png'])

patch_ones = np.ones((batch_size, patch_i, patch_j, 1))
patch_zeros = np.zeros((batch_size, patch_i, patch_j, 1))
for e in range(epoch):
    for i, d in enumerate(data_gen(imageSet, batch_size)):
        gImg, cImg = d[0], d[1]
        
        r = np.random.rand(1)
        if(r > 0.95):
            real = patch_zeros
            fake = patch_ones
        else:
            real = patch_ones
            fake = patch_zeros
        
        genImg = gen.predict(gImg)

        dis_loss_1 = dis.train_on_batch([gImg, cImg], real)
        dis_loss_2 = dis.train_on_batch([gImg, genImg], fake)
        dis_loss = 0.5 * np.add(dis_loss_1, dis_loss_2)

        discriminatorLoss.append(dis_loss[0])
        discriminatorAccuracy.append(100*dis_loss[1])
        
        cgan_loss = combined_model.train_on_batch([gImg], [patch_ones, cImg])
        
        generatorLoss.append(cgan_loss[2])
        generatorAccuracy.append(100*cgan_loss[4])
        
        if(i%10 == 0):
            print("Epoch %d/%d   iteration %d/%d  D Acc %3d%%  D Loss: %f  cGAN_Gen Acc: %3d%%  cGAN_Gen Loss: %f  cGAN_Dis Acc: %3d%%  cGAN_Dis Loss: %f  cGAN Total Loss %f" % (e, epoch, i, iteration, 100*dis_loss[1], dis_loss[0], 100*cgan_loss[4], cgan_loss[2], cgan_loss[3], cgan_loss[1], cgan_loss[0]))
        
        if(i == iteration-1):
            break    
    if ((e+1) % 2 == 0):
        rpath = 'Manga Colorization/results/'
        for i, img in enumerate(zip(trainSet[0],trainSet[1])):
            p = rpath + 'Train/{}_{}_result.png'.format(e,i)
            save_and_plot_image(img[1], img[0], gen, p)
        for i, img in enumerate(zip(testSet[0],testSet[1])):
            p = rpath + 'Test/{}_{}_result.png'.format(e,i)
            save_and_plot_image(img[1], img[0], gen, p)
        
        with open('Manga Colorization/AccuracyAndLosses/DiscriminatorLossFile.csv', mode='a') as discriminatorLossFile:
            discriminatorLossFile_writer = csv.writer(discriminatorLossFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
            for r in discriminatorLoss:
                discriminatorLossFile_writer.writerow([r])
            discriminatorLoss = []
        
        with open('Manga Colorization/AccuracyAndLosses/DiscriminatorAccuracyFile.csv', mode='a') as discriminatorAccuracyFile:
            discriminatorAccuracyFile_writer = csv.writer(discriminatorAccuracyFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
            for r in discriminatorAccuracy:
                discriminatorAccuracyFile_writer.writerow([r])
            discriminatorAccuracy = []
        
        with open('Manga Colorization/AccuracyAndLosses/GeneratorLossFile.csv', mode='a') as DCGanLossFile:
            DCGanLossFile_writer = csv.writer(DCGanLossFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
            for r in generatorLoss:
                DCGanLossFile_writer.writerow([r])
            generatorLoss = []
        
        with open('Manga Colorization/AccuracyAndLosses/GeneratorAccuracyFile.csv', mode='a') as generatorAccuracyFile:
            generatorAccuracyFile_writer = csv.writer(generatorAccuracyFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
            for r in generatorAccuracy:
                generatorAccuracyFile_writer.writerow([r])
            generatorAccuracy = []

        gen.save('Manga Colorization/models/generator.h5')
        dis.save('Manga Colorization/models/discriminator.h5')
        combined_model.save('Manga Colorization/models/combined_model.h5')

In [None]:
_, axs = plt.subplots(2, 3, figsize=(15, 15))
for i, d in enumerate(data_gen(imageSet, 1)):
    fake_A = gen.predict(d[0])
    _, h, w, c = d[1].shape
    
    gImg = d[0].reshape(h, w) * 0.5 + 0.5
    fake_A = fake_A * 0.5 + 0.5
    cImg = d[1].reshape(h, w, c) * 0.5 + 0.5
    axs[i][0].imshow(gImg.reshape(h, w), cmap='Greys_r')
    axs[i][1].imshow(fake_A.reshape(h, w, c))
    axs[i][2].imshow(cImg.reshape(h, w, c))
    
    if(i == 1):
        break

In [None]:
testSet = get_images(testColorImages, testSketchImages)

In [None]:
_, axs = plt.subplots(2, 3, figsize=(15, 15))
for i, d in enumerate(data_gen(testSet, 1)):
    fake_A = gen.predict(d[0])
    _, h, w, c = d[1].shape
    
    gImg = d[0].reshape(h, w) * 0.5 + 0.5
    fake_A = fake_A * 0.5 + 0.5
    cImg = d[1].reshape(h, w, c) * 0.5 + 0.5
    axs[i][0].imshow(gImg.reshape(h, w), cmap='Greys_r')
    axs[i][1].imshow(fake_A.reshape(h, w, c))
    axs[i][2].imshow(cImg.reshape(h, w, c))
    
    if(i == 1):
        break

In [None]:
genLoss = []
disLoss = []
with open('Manga Colorization/AccuracyAndLosses/DiscriminatorLossFile.csv', newline='') as f:
    reader = csv.reader(f)
    disLoss = np.array([float(i[0]) for i in reader])
with open('Manga Colorization/AccuracyAndLosses/GeneratorLossFile.csv', newline='') as f:
    reader = csv.reader(f)
    genLoss = np.array([float(i[0]) for i in reader])

In [None]:
plt.plot(disLoss)
plt.plot(genLoss)
plt.title('Model Loss')
plt.ylabel('loss')
plt.xlabel('iteration')
plt.legend(['binary_crossentropy loss', 'binary_crossentropy+L2'], loc='upper right')
plt.show()
plt.savefig('gdrive/My Drive/Model_loss(b_l2).png')

In [None]:
genAccuracy = []
disAccuracy = []
with open('Manga Colorization/AccuracyAndLosses/DiscriminatorAccuracyFile.csv', newline='') as f:
    reader = csv.reader(f)
    disAccuracy = np.array([float(i[0]) for i in reader])
with open('Manga Colorization/AccuracyAndLosses/GeneratorAccuracyFile.csv', newline='') as f:
    reader = csv.reader(f)
    genAccuracy = np.array([float(i[0]) for i in reader])

In [None]:
plt.plot(genAccuracy)
plt.plot(disAccuracy)
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('iteration')
plt.legend(['Discriminator', 'cGAN'], loc='upper right')
plt.show()