In [None]:
import numpy as np
import os
import time

In [None]:
import keras

In [None]:
keras.__version__

In [None]:
# os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
#!pip install -q -U tensorflow-addons==0.11.2

### Utilities

In [None]:
from utils import verifyDir
from utils.networks import normalize, unnormalize, plot_data

### Dataset

In [None]:
from utils.CIFAR10 import load_real_samples

### Discriminator & Generator

In [None]:
from utils.CIFAR10 import define_discriminator
from utils.CIFAR10 import define_generator

### Semi-Supervised GAN

In [None]:
from utils.networks import define_gan

### Selecting sub-set 

In [None]:
from utils.networks import select_supervised_samples, generate_real_samples
from utils.networks import generate_fake_samples, generate_latent_points

### Training

In [None]:
from utils.networks import train_gan

### Loading Dataset

In [None]:
# load image data
dataset_train, dataset_test = load_real_samples()

### Parameters

In [None]:
input_shape = (32, 32, 3)
num_classes = 10

learning_rate = 2e-4
latent_dim = 100

epochs=100
batch_size=128

labeled_rate = 4/50
labeled_samples = int(dataset_train[0].shape[0]*labeled_rate)

In [None]:
LOG_PATH = f"Logs/SSGAN_CIFAR10/Classifier_{labeled_samples}/"
verifyDir(LOG_PATH)

### Creating Models

In [None]:
from utils.networks import f1_score, auc_pr, precision_score, recall_score

In [None]:
metrics_list=["accuracy", f1_score, auc_pr]

In [None]:
# create the discriminator models
unsupervised_model, supervised_model = define_discriminator(in_shape=input_shape, 
                                                            n_classes=num_classes, 
                                                            learning_rate = learning_rate,
                                                            metrics_list=metrics_list)
# create the generator
generator_model = define_generator(latent_dim=latent_dim)

In [None]:
supervised_model.summary()

In [None]:
unsupervised_model.summary()

In [None]:
generator_model.summary()

In [None]:
# create the gan
from keras.optimizers import Adam
opt_gan = Adam(lr=learning_rate, beta_1=0.5)

gan_model = define_gan(generator_model, unsupervised_model, optimizer_grad = opt_gan)

In [None]:
gan_model.summary()

### Training

In [None]:
train_gan(generator_model, unsupervised_model, supervised_model, gan_model, 
      dataset_train, dataset_test, latent_dim=latent_dim, 
      n_epochs=epochs, n_batch=batch_size, n_classes=num_classes, 
      # n_samples=labeled_samples)
      label_rate=labeled_rate, unnormalize_image=True)

### Testing

In [None]:
dataset_train, dataset_test = load_real_samples()

In [None]:
from tensorflow.keras.models import load_model

In [None]:
last_step = int(dataset_train[0].shape[0]/batch_size)*epochs
last_step

In [None]:
supervised_model = load_model(f'{LOG_PATH}supervised_model_{last_step}.h5')

In [None]:
X_train, y_train = dataset_train
_, acc = supervised_model.evaluate(X_train, y_train, verbose=0)
print('Train Classifier Accuracy: %.3f%%\n' % (acc * 100))

In [None]:
X_test, y_test = dataset_test
_, acc = supervised_model.evaluate(X_test, y_test, verbose=0)
print('Test Classifier Accuracy: %.3f%%\n' % (acc * 100))

### Plotting

In [None]:
import pandas as pd

In [None]:
results_file = pd.read_csv(f"{LOG_PATH}SSL_GAN.csv", sep=";")

In [None]:
log_file = results_file.iloc[:,1:]
log_file

In [None]:
fig = log_file[["generator_loss", "unsupervised_real_loss", "unsupervised_fake_loss", "supervised_loss"]].plot(figsize=(16,12)).get_figure()
fig.savefig(f'{LOG_PATH}GAN_loss.png')

In [None]:
fig = log_file[["train_loss", "test_loss"]].plot(figsize=(16,12)).get_figure()
fig.savefig(f'{LOG_PATH}train_test_loss.png')

In [None]:
fig = log_file[["train_acc", "test_acc"]].plot(figsize=(16,12), ylim=(0,100), yticks=range(0,110,10)).get_figure()
fig.savefig(f'{LOG_PATH}train_test_acc.png')

In [None]:
fig = log_file[["unsupervised_real_acc", "unsupervised_fake_acc"]].plot(figsize=(16,12), ylim=(0,100), yticks=range(0,110,10)).get_figure()
fig.savefig(f'{LOG_PATH}unsupervised_real_fake_acc.png')