In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%pip install -q git+https://github.com/DeepTrackAI/deeplay.git
# %pip install deeplay
%pip install -q deeptrack --pre
%pip install -q multiprocess

import torch
from rich import print

# Define GPU device
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))

print(f"Torch version: {torch.__version__}\nCUDA version: {torch.version.cuda}\nDevice: {device}")

In [None]:
!sudo apt install cm-super dvipng texlive-latex-extra texlive-latex-recommended texlive-fonts-extra texlive-science 2>/dev/null >/dev/null

In [None]:
import deeplay as dl
import deeptrack as dt
import torch.nn as nn
import torchvision
import multiprocess as mp
from pathlib import Path
import os
from kornia.contrib import extract_tensor_patches, combine_tensor_patches
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from kornia.utils import tensor_to_image
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from contextlib import contextmanager
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
PLOT_CONTEXT = {
    ##########
    # Figure #
    ##########
    "figure.autolayout": True,
    "figure.subplot.left": 0.5,
    "figure.dpi": 600,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.04,
    # Linewidths
    "axes.linewidth": 0.25,
    "grid.linewidth": 0.25,
    "xtick.major.width": 0.25,
    "xtick.minor.width": 0.25,
    "ytick.major.width": 0.25,
    "ytick.minor.width": 0.25,
    # Plots
    "lines.linewidth": 0.5,
    "lines.markersize": 2,
    "axes.grid": False,
    # Ticks
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.minor.visible": True,
    "ytick.minor.visible": True,
    "xtick.top": True,
    "ytick.right": True,
    # Legend
    "patch.linewidth": 0.25,
    "legend.frameon": True,
    "legend.fancybox": False,
    "legend.loc": "upper left",
    # Colours
    "axes.prop_cycle": "(cycler('color', ['k', 'r', 'b', 'g']) + cycler('ls', ['-', '--', ':', '-.']))",
    ###############
    # Typesetting #
    ###############
    # Title
    "figure.titlesize": 10,
    "figure.titleweight": "bold",
    "axes.titlesize": 9,
    "axes.titleweight": "normal",
    # Axes
    "axes.labelsize": 9,
    "axes.labelweight": "normal",
    # Ticks
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "xtick.major.size" : 3,
    "ytick.major.size" : 3,
    "xtick.minor.size" : 1.5,
    "ytick.minor.size" : 1.5,
    # Legend
    "legend.fontsize": 6,
    "legend.edgecolor": "grey",
    #########
    # LaTeX #
    #########
    "text.usetex": True,  # Use LaTeX
    # LaTeX standard math and physics preamble with sans-serif font in math mode
    "text.latex.preamble": r"""
\usepackage{amsmath}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{bbm}
\usepackage{gensymb}
\usepackage[italicdiff]{physics}
\usepackage{icomma}
\usepackage{siunitx}
\sisetup{
  detect-all,
  locale=UK,% Set locale to UK
  output-decimal-marker={,}, % Set comma as decimal separator
  range-phrase=--, % Set range to use "--" instead of " to "
  range-units=single, % Use only a single unit in a range
  per-mode=reciprocal, % Alternatively set to "symbol"
  sticky-per, % Only one \per-command
  bracket-unit-denominator=false, % No parenthesis
  separate-uncertainty=true, % Separate uncertainty with "+/-"
}
\usepackage{sansmath}
\sansmath
\centering
"""
}


mpl.rcParams.update(PLOT_CONTEXT)

In [None]:
DATA_PATH: Path = Path.cwd() / "data"
DATA_PATH.mkdir(exist_ok=True)
os.environ["DATA_PATH"] = str(DATA_PATH)

mnist_dataset_path: Path = DATA_PATH / "MNIST_dataset" / "mnist"

if not mnist_dataset_path.exists():
    !cd $DATA_PATH && git clone https://github.com/DeepTrackAI/MNIST_dataset

train_files = dt.sources.ImageFolder(
    root=str(mnist_dataset_path / "train"),
)
test_files = dt.sources.ImageFolder(
    root=str(mnist_dataset_path / "test"),
)
files = dt.sources.Join(train_files, test_files)

In [None]:
image_pipeline = (
    dt.LoadImage(files.path)
    >> dt.NormalizeMinMax()
    >> dt.MoveAxis(2, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)

label_pipeline = (
    dt.Value(files.label_name[0])
    >> int
)

In [None]:
train_dataset = dt.pytorch.Dataset(image_pipeline & label_pipeline,
                                  inputs=train_files)
test_dataset = dt.pytorch.Dataset(image_pipeline & label_pipeline,
                                  inputs=test_files)

n_dataset_workers = mp.cpu_count()
train_loader = dl.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=n_dataset_workers)
test_loader = dl.DataLoader(test_dataset, batch_size=64, shuffle=True,num_workers=n_dataset_workers)

In [None]:
class Generator(dl.UNet2d):
    def __init__(self, *args, **kwargs):
        self.patch_size = kwargs["channels"][0]
        super().__init__(*args, **kwargs)
        # self[..., "activation"].configure(nn.GELU)

    def forward(self, input):
        _, _, height, width = input.shape
        input.requires_grad_()

        if any(l < self.patch_size for l in (height, width)):
            raise ValueError(
                f"Image ({height}x{width}) is smaller than kernel size ({self.patch_size}x{self.patch_size})"
            )

        stride = [np.gcd(height, self.patch_size), np.gcd(width, self.patch_size)]
        input_patches = extract_tensor_patches(input=input, window_size=self.patch_size, stride=stride).movedim(1, 0)

        output_patches = torch.vmap(super().forward)(input_patches).movedim(0, 1)

        output = combine_tensor_patches(
            patches=output_patches, original_size=(height, width), window_size=self.patch_size, stride=stride
        )

        return output


generator_model = Generator(
    in_channels=1,
    channels=[16, 32, 64],
    out_channels=1,
    skip=dl.Cat()
)
# print(generator_model)

In [None]:
class Discriminator(dl.DeeplayModule):
    def __init__(
            self,
            in_channels,
            hidden_channels_cnn,
            hidden_channels_mlp,
            out_channels_cnn,
            out_features
        ):
        super().__init__()

        self.cnn = dl.ConvolutionalNeuralNetwork(
            in_channels=in_channels,
            hidden_channels=hidden_channels_cnn,
            out_channels=out_channels_cnn,
            pool=nn.MaxPool2d(kernel_size=2),
            out_activation=nn.ReLU,
        )

        self.pool = dl.Layer(nn.MaxPool2d, kernel_size=2)
        self.flatten = dl.Layer(nn.Flatten)

        self.dense = dl.MultiLayerPerceptron(
            in_features=out_channels_cnn,
            hidden_features=hidden_channels_mlp,
            out_features=out_features,
            out_activation=nn.Identity,
        )
        self.dense[..., "layer#0"].configure(nn.LazyLinear)
        # self.dense[..., "activation#-1"] = dl.Layer[nn.Softmax]


    def forward(self, x):
        x = self.cnn(x)
        x = self.flatten(self.pool(x))
        x = self.dense(x)

        return torch.sigmoid(x)

discriminator_1 = Discriminator(
    in_channels=1,
    hidden_channels_cnn=[16, 32, 64],
    hidden_channels_mlp=[64, 32],
    out_channels_cnn=64,
    out_features=1
)
discriminator_2 = Discriminator(
    in_channels=1,
    hidden_channels_cnn=[16, 32, 64],
    hidden_channels_mlp=[64, 32],
    out_channels_cnn=64,
    out_features=10
)
# discriminator_2.dense[..., "activation#-1"] = dl.Layer[nn.Softmax]
# print(discriminator_1)
# print(discriminator_2)

In [None]:
from torchmetrics.functional import accuracy
from torchmetrics.image import StructuralSimilarityIndexMeasure

class GAN(dl.Application):
    def __init__(self, generator, discriminator_1=None, discriminator_2=None, plot_outputs=False, disc_1_loss_w=1, disc_2_loss_w=1, norm_loss_w=1, **kwargs):
        super().__init__(**kwargs)
        # Neural networks
        self.generator = generator
        self.discriminator_1 = discriminator_1
        self.discriminator_2 = discriminator_2

        if not discriminator_1 and not discriminator_2:
            raise ValueError("The GAN must have at least one discriminator")

        # Plot outputs for first batch in every epoch
        self.plot_outputs = plot_outputs

        # Generator loss weights
        self.disc_1_loss_w = disc_1_loss_w
        self.disc_2_loss_w = disc_2_loss_w
        self.norm_loss_w = norm_loss_w

        self.ssim = StructuralSimilarityIndexMeasure()

        self.automatic_optimization = False

    def configure_optimizers(self):
        optimizers, schedulers = [], []

        generator_optimizer = self.create_optimizer_with_params(
            dl.Adam(lr=1e-4,betas=(0.5, 0.999)), self.generator.parameters()
        )
        optimizers.append(generator_optimizer)

        if self.discriminator_1:
            discriminator_1_optimizer = self.create_optimizer_with_params(
                dl.Adam(lr=1e-5), self.discriminator_1.parameters()
            )
            optimizers.append(discriminator_1_optimizer)
            discriminator_1_scheduler = StepLR(discriminator_1_optimizer, step_size=1,gamma=0.8)
            schedulers.append(discriminator_1_scheduler)

        if self.discriminator_2:
            discriminator_2_optimizer = self.create_optimizer_with_params(
                dl.Adam(lr=1e-4), self.discriminator_2.parameters()
            )
            optimizers.append(discriminator_2_optimizer)

        return optimizers, schedulers

    def forward(self, batch):
        return self.generator(batch)

    def plot_training_images(self, batch_tensor, n_plots, name=None):
        fig, axs = plt.subplots(1, n_plots, figsize=((10, n_plots * 10)))

        for i_ax, ax, img in zip(range(n_plots), axs.ravel(), [tensor_to_image(o) for o in batch_tensor[:n_plots].clone().detach()]):
            ax.imshow(img.squeeze(), cmap="gray")

            if name and i_ax == 0:
                ax.set_ylabel(name)
                ax.set_xticks([])
                ax.set_yticks([])
            else:
                ax.set_axis_off()


        plt.show()

    def generator_discriminator_1_loss(self, input, target):
        return self.disc_1_loss_w * F.binary_cross_entropy_with_logits(input, target)

    def generator_norm_loss(self, input, target):
        return self.norm_loss_w * F.binary_cross_entropy_with_logits(input, target)

    def generator_discriminator_2_loss(self, input, target):
        # Prepare target
        target_oh = F.one_hot(target, num_classes=10)
        fake_target = torch.zeros_like(input)

        formatted_input = F.softmax(input, dim=1) * target_oh.squeeze(1)

        return self.disc_2_loss_w * torch.sum((formatted_input - fake_target) ** 2.0) / fake_target.shape[0]

    def discriminator_1_loss(self, output_real, output_fake, target_real, target_fake):
        loss_real = F.binary_cross_entropy_with_logits(output_real, target_real)
        loss_fake = F.binary_cross_entropy_with_logits(output_fake, target_fake)

        return (loss_real + loss_fake) / 2

    def discriminator_2_loss(self, input, target):
        return F.binary_cross_entropy(input, target)

    def train_generator(self, optimizer, input, target, batch_idx):
        self.toggle_optimizer(optimizer)

        fake_label = torch.zeros(input.size(0), 1).type_as(input)

        # Feed input into generator
        gen_output = self(input)

        if self.discriminator_1:
            # Feed generated image to discriminator 1
            disc_1_output_fake = self.discriminator_1(gen_output)

            # Prepare target
            real_label = torch.ones(input.size(0), 1).type_as(input)

            # Compute loss
            gen_loss_disc_1 = self.generator_discriminator_1_loss(disc_1_output_fake, real_label)
            self.log("train_gen_loss_disc_1_step", gen_loss_disc_1, prog_bar=True)
            self.log("train_gen_loss_disc_1_epoch", gen_loss_disc_1, prog_bar=True, on_step=False, on_epoch=True)
        else:
            gen_loss_disc_1 = 0

        if self.discriminator_2:
            # Feed generated image to discriminator 2
            disc_2_output_fake = self.discriminator_2(gen_output)

            # Compute loss
            gen_loss_disc_2 = self.generator_discriminator_2_loss(disc_2_output_fake, target)
            self.log("train_gen_loss_disc_2_step", gen_loss_disc_2, prog_bar=True)
            self.log("train_gen_loss_disc_2_epoch", gen_loss_disc_2, prog_bar=True, on_step=False, on_epoch=True)
        else:
            gen_loss_disc_2 = 0

        # Compute norm of pixel difference between input and generator output
        diff = input - gen_output
        norm = torch.norm(diff, dim=(2,3))

        # Compute loss for norm
        gen_loss_norm = self.generator_norm_loss(norm, fake_label)
        self.log("train_gen_loss_norm_step", gen_loss_norm, prog_bar=True)
        self.log("train_gen_loss_norm_epoch", gen_loss_norm, prog_bar=True, on_step=False, on_epoch=True)

        # Compute total generator loss
        gen_loss = gen_loss_disc_2 + gen_loss_disc_1 + gen_loss_norm
        self.log("train_gen_loss_step", gen_loss, prog_bar=True)
        self.log("train_gen_loss_epoch", gen_loss, prog_bar=True, on_step=False, on_epoch=True)

        # Run optimization
        self.manual_backward(gen_loss)
        optimizer.step()
        optimizer.zero_grad()
        self.untoggle_optimizer(optimizer)

        ###############################
        # Print output once per epoch #
        ###############################

        if batch_idx == 0 and self.plot_outputs:
            print(f"[bold]Epoch {self.current_epoch}")
            n_plots = 5

            self.plot_training_images(input, n_plots, "Input image")
            self.plot_training_images(gen_output, n_plots, "Generator output")
            self.plot_training_images(torch.abs(diff), n_plots, "Absolute difference")

            if self.discriminator_1:
                print("Discriminator 1 guess\n", disc_1_output_fake[:5].transpose(0,1))

            if self.discriminator_2:
                print("Discriminator 2 guess (fake)\n", torch.argmax(disc_2_output_fake,dim=1)[:5].unsqueeze(-1).transpose(0,1))

    def train_discriminator_1(self, optimizer, input, output):
        self.toggle_optimizer(optimizer)

        # Feed input into generator
        gen_output = self(input)

        # Feed real input and generator output into discriminator 1
        disc_1_output_real = self.discriminator_1(input)
        disc_1_output_fake = self.discriminator_1(gen_output)

        real_label = torch.ones(input.size(0), 1).type_as(input)
        fake_label = torch.zeros(input.size(0), 1).type_as(input)

        # Compute loss
        disc_1_loss = self.discriminator_1_loss(disc_1_output_real, disc_1_output_fake, real_label, fake_label)
        self.log("train_disc_1_loss_step", disc_1_loss, prog_bar=True)
        self.log("train_disc_1_loss_epoch", disc_1_loss, prog_bar=True, on_step=False, on_epoch=True)

        # Run optimization
        self.manual_backward(disc_1_loss)
        optimizer.step()
        optimizer.zero_grad()
        self.untoggle_optimizer(optimizer)

    def train_discriminator_2(self, optimizer, input, target, batch_idx):
        self.toggle_optimizer(optimizer)

        # Feed input into discriminator directly (only train on real images)
        disc_2_output_real = self.discriminator_2(input)

        # Compute loss
        disc_2_loss = F.cross_entropy(disc_2_output_real, target.squeeze())
        self.log("train_disc_2_loss_step", disc_2_loss, prog_bar=True)
        self.log("train_disc_2_loss_epoch", disc_2_loss, prog_bar=True, on_step=False, on_epoch=True)

        # Run optimization
        self.manual_backward(disc_2_loss)
        optimizer.step()
        optimizer.zero_grad()
        self.untoggle_optimizer(optimizer)

        # Print discriminator 2 guess at start of epoch
        if batch_idx == 0 and self.plot_outputs:
            print("Discriminator 2 guess (real)\n", torch.argmax(disc_2_output_real[:5],dim=1))

    def training_step(self, batch, batch_idx):
        ####################
        # Prepare training #
        ####################

        self.train()

        # Prepare inputs and targets
        input, target = batch

        # Get optimizers
        optimizers = self.optimizers()
        gen_opt = optimizers[0]

        if self.discriminator_1 and self.discriminator_2:
            disc_1_opt = optimizers[1]
            disc_1_sched = self.lr_schedulers()
            disc_2_opt = optimizers[2]
        elif self.discriminator_1 and not self.discriminator_2:
            disc_1_opt = optimizers[1]
            disc_1_sched = self.lr_schedulers()
        else:
            disc_2_opt = optimizers[1]

        # Train generator
        self.train_generator(gen_opt, input, target, batch_idx)

        # Train discriminator 1
        if self.discriminator_1:
            self.train_discriminator_1(disc_1_opt, input, target)

        # Train discriminator 2
        if self.discriminator_2:
            self.train_discriminator_2(disc_2_opt, input, target, batch_idx)

    def test_step(self, batch, batch_idx):
        metrics = {}

        # Prepare inputs and targets
        input, target = batch

        # Feed input into generator
        gen_output = self(input)

        # Compute structural similarity index
        metrics["test_ssim"] = self.ssim(gen_output, input)

        if self.discriminator_1:
            # Feed real input and generator output into discriminator 1
            disc_1_output_real = self.discriminator_1(input)
            disc_1_output_fake = self.discriminator_1(gen_output)

            real_label = torch.ones(input.size(0), 1).type_as(input)
            fake_label = torch.zeros(input.size(0), 1).type_as(input)

            # Compute accuracies
            metrics["test_accuracy_discriminator_1_real"] = accuracy(torch.round(disc_1_output_real), real_label, task="binary")
            metrics["test_accuracy_discriminator_1_fake"] = accuracy(torch.round(disc_1_output_fake), fake_label, task="binary")
            metrics["test_accuracy_discriminator_1_total"] = accuracy(torch.round(torch.cat((disc_1_output_real, disc_1_output_fake), dim=0)), torch.cat((real_label, fake_label), dim=0), task="binary")

        if self.discriminator_2:
            disc_2_output_real = self.discriminator_2(input)
            disc_2_output_fake = self.discriminator_2(gen_output)

            # Compute accuracy
            metrics["test_accuracy_discriminator_2_real"] = accuracy(torch.argmax(disc_2_output_real,dim=1).unsqueeze(-1), target, task="multiclass", num_classes=10)
            metrics["test_accuracy_discriminator_2_fake"] = accuracy(torch.argmax(disc_2_output_fake,dim=1).unsqueeze(-1), target, task="multiclass", num_classes=10)

        self.log_dict(metrics)

        return metrics


gan = GAN(
    generator=generator_model.create(),
    discriminator_1=discriminator_1.create(),
    discriminator_2=discriminator_2.create(),
    disc_1_loss_w=2,
    disc_2_loss_w=100,
    norm_loss_w=0.5,
    plot_outputs=True
)
print(gan)

In [None]:
from torchsummary import summary

print("[bold]Generator")
summary(gan.generator, (1, 28, 28), device="cpu")
print("[bold]Discriminator 1")
summary(gan.discriminator_1, (1, 28, 28), device="cpu")
print("[bold]Discriminator 2")
summary(gan.discriminator_2, (1, 28, 28), device="cpu")

In [None]:
from deeplay.callbacks import LogHistory
from lightning.pytorch.loggers import CSVLogger
RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True, parents=True)
mpl.rcParams.update({"figure.dpi":100})

logger = CSVLogger(save_dir=RESULTS_DIR, name=f"gan_full_logs")
training_history = LogHistory()
trainer = dl.Trainer(max_epochs=100, callbacks=[training_history], logger=logger)
trainer.fit(gan, train_loader)

In [None]:
training_history.plot()

In [None]:
trainer.test(gan, test_loader)

In [None]:
mpl.rcParams.update({"figure.dpi":600})

In [None]:
import pandas as pd
from lightning.pytorch.loggers import CSVLogger

EXPERIMENT_1_DIR = Path("gan_experiment_1")
EXPERIMENT_1_DIR.mkdir(exist_ok=True, parents=True)

def run_gan_experiment_discriminator_1():
    n_epochs = 50
    n_runs = 10
    norm_loss_w = 1
    disc_1_norm_loss_w_ratio_arr = np.geomspace(1e-2, 1e2, n_runs)
    disc_1_loss_w_arr = disc_1_norm_loss_w_ratio_arr
    sum_arr = disc_1_loss_w_arr + norm_loss_w
    disc_1_loss_w_arr /= sum
    norm_loss_w /= sum

    test_metrics_dcts = []

    for i_run, disc_1_loss_w, disc_1_norm_loss_w_ratio in zip(range(len(disc_1_loss_w_arr)), disc_1_loss_w_arr, disc_1_norm_loss_w_ratio_arr):
        gan = GAN(
            generator=generator_model.create(),
            discriminator_1=discriminator_1.create(),
            disc_1_loss_w=disc_1_loss_w,
            norm_loss_w=norm_loss_w
        )

        logger = CSVLogger(save_dir=EXPERIMENT_1_DIR, name=f"gan_disc_1_{i_run}_{disc_1_loss_w}_logs")
        trainer = dl.Trainer(max_epochs=n_epochs, logger=logger)
        trainer.fit(gan, train_loader)
        torch.save(gan.state_dict(),  EXPERIMENT_1_DIR / f"gan_disc_1_{i_run}_{disc_1_loss_w}.pth")

        test_metrics = trainer.test(gan, test_loader)[0]
        test_metrics["disc_1_loss_w"] = disc_1_loss_w
        test_metrics["norm_loss_w"] = norm_loss_w
        test_metrics["disc_1_norm_loss_w_ratio"] = disc_1_norm_loss_w_ratio
        test_metrics["epochs"] = n_epochs

        test_metrics_dcts.append(test_metrics)

    return pd.DataFrame(test_metrics_dcts)



experiment_1_metrics_df = run_gan_experiment_discriminator_1()
experiment_1_metrics_df.to_csv(EXPERIMENT_1_DIR / "test_metrics.csv", sep=",")

In [None]:
os.environ["EXPERIMENT_1_DIR"] = str(EXPERIMENT_1_DIR)
os.environ["EXPERIMENT_1_ZIP_FILE"] = "gan_experiment_1.zip"
!zip -r ${EXPERIMENT_1_ZIP_FILE} ${EXPERIMENT_1_DIR}
!cp -r ${EXPERIMENT_1_ZIP_FILE} /content/drive/MyDrive/${EXPERIMENT_1_ZIP_FILE}

In [None]:
experiment_1_metrics_df

## Warp images for model architecture graph

In [None]:
import cv2
from google.colab.patches import cv2_imshow
import kornia as K

In [None]:
WARPED_IMAGES_DIR = Path("warped_images")
WARPED_IMAGES_DIR.mkdir(exist_ok=True)

def warp_image(image, height_factor, width_factor, scaling_factor):
    # Rezise image for warping, but keep image pixelated
    image = cv2.resize(image, tuple(np.array(image.shape[:2])*scaling_factor), interpolation = cv2.INTER_NEAREST)

    height,width = image.shape
    height_max = height-1
    width_max = width -1
    height_factor = 0.9
    width_factor = 0.5
    new_height = int(height * height_factor)
    new_width = int(width * width_factor)
    height_diff = height - new_height
    width_diff = width - new_width

    # Change image formate to introduce transparency
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGRA)*255

    # Define warp points (old and new positions)
    points_1 = np.float32([[0,0],[0,height_max],[width_max,0],[width_max,height_max]])
    points_2 = np.float32([[0,height_diff],[0,height_max],[width_max-width_diff,0],[width_max-width_diff,height_max-height_diff]])

    # Compute transformation matrix
    transform_mat = cv2.getPerspectiveTransform(points_1,points_2)

    # Perform warping
    dst = cv2.warpPerspective(image,transform_mat,(width, height))

    # Crop warped image
    dst = dst[:,:new_width]

    return dst


def generate_plot_warped_images(gan, n_images):
    height_factor = 0.9
    width_factor = 0.6
    scaling_factor = 20

    for i_batch, batch in enumerate(test_loader):
        input, target = batch
        gen_output = gan.generator(input[:n_images])

        for i_plot in range(n_images):
            input_image = tensor_to_image(input[i_plot])
            gen_image = tensor_to_image(gen_output[i_plot])

            warped_input_image = warp_image(input_image, height_factor, width_factor, scaling_factor)
            warped_gen_image = warp_image(gen_image, height_factor, width_factor, scaling_factor)

            input_image_resized = cv2.resize(input_image, tuple(np.array(input_image.shape[:2])*scaling_factor), interpolation = cv2.INTER_NEAREST)
            gen_image_resized =cv2.resize(gen_image, tuple(np.array(gen_image.shape[:2])*scaling_factor), interpolation = cv2.INTER_NEAREST)
            cv2.imwrite(str(WARPED_IMAGES_DIR/f"mnist_{i_plot}_digit_{target[i_plot].item()}.png"), input_image_resized*255)
            cv2.imwrite(str(WARPED_IMAGES_DIR/f"gen_mnist_{i_plot}_digit_{target[i_plot].item()}.png"), gen_image_resized*255)
            cv2.imwrite(str(WARPED_IMAGES_DIR/f"warped_mnist_{i_plot}_digit_{target[i_plot].item()}.png"), warped_input_image)
            cv2.imwrite(str(WARPED_IMAGES_DIR/f"warped_gen_mnist_{i_plot}_digit_{target[i_plot].item()}.png"), warped_gen_image)

        break

generate_plot_warped_images(gan, n_images=10)

In [None]:
def generate_plot_generated_images(gan, n_plots):
    for i_batch, batch in enumerate(test_loader):
        input, _ = batch
        input = input[:n_plots]

        with torch.no_grad():
            gen_output = gan.generator(input.to(device)).cpu()

        abs_diff = torch.abs(input - gen_output)
        break

    fig, axs = plt.subplots(n_plots, 3, figsize=((3, n_plots)))

    for i_row, row_axs in enumerate(axs):
        row_axs[0].imshow(tensor_to_image(input[i_row]), cmap="gray")
        row_axs[0].set_xlabel(r"$\downarrow$\\9")
        row_axs[1].imshow(tensor_to_image(gen_output[i_row]), cmap="gray")
        row_axs[2].imshow(tensor_to_image(abs_diff[i_row]), cmap="gray")

        [(ax.set_xticks([]), ax.set_yticks([])) for ax in row_axs]

    fig.show()


# gan = GAN(
#     generator=generator_model.create(),
#     discriminator_1=discriminator_1.create()
# ).to(device)
# gan.load_state_dict(torch.load(EXPERIMENT_1_DIR / "gan_experiment_discriminator_1_1_50.0.pth"))
plot = generate_plot_generated_images(gan, 3)
# plot.show()

In [None]:
torch.save(gan.state_dict(),  "nice_gan_full.pth")

In [None]:
gan.load_state_dict(torch.load("/content/drive/MyDrive/data/results/nice_gan_full.pth"))
gan.to(device)

In [None]:
def compute_confusion_matrix(generate_image=False):
    # Initialize lists to store predictions and ground truth labels
    predicted_labels = []
    ground_truth_labels = []

    # Iterate through the test loader
    for batch in test_loader:
        images, labels = batch

        if generate_image:
            # Pass the images through the generator to generate fake images
            generated_images = gan.generator(images.to(device)).cpu()
        else:
            generated_images = images

        # Pass generated images through discriminator 2 to get predictions
        disc_2_output_fake = F.softmax(gan.discriminator_2(generated_images.to(device)).cpu(), dim=1)
        batch_predicted_labels = torch.argmax(disc_2_output_fake, dim=1)

        # Append batch predictions and ground truth labels to the lists
        predicted_labels.extend(batch_predicted_labels.tolist())
        ground_truth_labels.extend(labels.tolist())

    # Convert lists to numpy arrays
    predicted_labels = np.array(predicted_labels)
    ground_truth_labels = np.array(ground_truth_labels).flatten()

    conf_matrix = confusion_matrix(ground_truth_labels, predicted_labels, normalize="true")*100
    class_labels = sorted(set(ground_truth_labels))

    return conf_matrix, class_labels

conf_matrix, class_labels = compute_confusion_matrix(generate_image=False)
conf_matrix_gen, class_labels_gen = compute_confusion_matrix(generate_image=True)

In [None]:
def plot_confusion_matrix(conf_matrix, class_labels, plot_name):
    # Plot confusion matrix using seaborn's heatmap
    cm = 1/2.54  # centimeters in inches
    plt.figure(figsize=(7*cm, 7*cm))
    # Remove cbar ticks (using cbar_kwargs)
    sns.heatmap(conf_matrix, annot=True, annot_kws=dict(fontsize=5), fmt=".2g",cmap="Reds", xticklabels=class_labels, yticklabels=class_labels, vmin=0, vmax=100, cbar_kws=dict(label="Accuracy $[\%]$"))
    plt.ylabel("Actual Class")
    plt.xlabel("Predicted Class")
    plt.tick_params(which="both",width=0)
    plt.savefig(f"plot_{plot_name}.png")

plot_confusion_matrix(conf_matrix, class_labels,"confusion_matrix_mnist")
plot_confusion_matrix(conf_matrix_gen, class_labels_gen,"confusion_matrix_generated_mnist")


In [None]:
def generate_gaussian_noise(input_images, alpha):
    # Generate random Gaussian noise with the same size as the input image
    mean=0
    std=0.1
    noise = torch.randn_like(input_images) * std + mean
    scaled_alpha = alpha / torch.norm(noise, dim = (2,3))
    scaler = torch.ones_like(input_images) * scaled_alpha.unsqueeze(1).unsqueeze(2)
    noise = noise * scaler* 3  #added factor of 3 to increase noise.

    noisy_images = input_images + noise

    # Clip the pixel values to ensure they're within the valid range [0, 1]
    noisy_images = torch.clamp(noisy_images, 0, 1)

    return noisy_images

def generate_images(input_images):
    with torch.no_grad():
        gen_output = gan.generator(input_images.to(device)).cpu()

    return gen_output


def classify_images(images):
    # Test discriminator 2 accuracy on images with gaussian noise
    with torch.no_grad():
        discriminator_2_output = gan.discriminator_2(images.to(device)).cpu()

    return discriminator_2_output


In [None]:
#IMPLEMENT
def other_classify_images(images):
  # Test Other discriminator accuracy on images with gaussian noise
  other_discriminator = gan.other_discriminator
  with torch.no_grad():
    other_discriminator_output = other_discriminator(images)
  return other_discriminator_output


def plot_images(batch_tensor, n_plots, plot_name, classifications=None, name=None):
    fig, axs = plt.subplots(1, n_plots, figsize=((6*2, n_plots * 6)))

    for i_ax, ax, img in zip(range(n_plots), axs.ravel(), [tensor_to_image(o) for o in batch_tensor[:n_plots].clone().detach()]):
        ax.imshow(img.squeeze(), cmap="gray")

        if name and i_ax == 0:
            ax.set_ylabel(name)

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # else:
        #     ax.set_axis_off()

        if classifications is not None:
            ax.set_xlabel(r"\huge$\downarrow$\\\vspace{2mm}\textbf{" + str(classifications[i_ax].item()) + "}")

    plt.savefig(f"plot_{plot_name}.png")


input_images_matches = 0
generated_images_matches = 0
noisy_images_matches = 0

for i_batch, batch in enumerate(test_loader):
    input_images = batch[0]
    gen_images = generate_images(input_images)
    labels = batch[1].squeeze()

    gen_diff = input_images - gen_images
    gen_norm = torch.norm(gen_diff, dim=(2,3))
    alpha = gen_norm

    noisy_images = generate_gaussian_noise(input_images, alpha = alpha)
    gaussian_diff = input_images - noisy_images
    gaussian_norm = torch.norm(gaussian_diff, dim=(2,3))

    classified_input_images = torch.argmax(classify_images(input_images),dim=1)
    classified_generated_images = torch.argmax(classify_images(gen_images),dim=1)
    classified_noisy_images = torch.argmax(classify_images(noisy_images),dim=1)

    input_images_matches += torch.sum(labels == classified_input_images)
    generated_images_matches += torch.sum(labels == classified_generated_images.unsqueeze(-1).transpose(0,1))
    noisy_images_matches += torch.sum(labels == classified_noisy_images.unsqueeze(-1).transpose(0,1))


    # print('gen norm', gen_norm[:5])
    # print('gaussian', gaussian_norm[:5])
    if i_batch == 0:
      # Visualize the first few images
      n_plots = 5
      plot_images(input_images, n_plots, "mnist_input", classified_input_images)
      plot_images(noisy_images, n_plots, "mnist_noisy", classified_noisy_images)
      plot_images(torch.abs(gaussian_diff), n_plots, "mnist_gaussian_diff")
      plot_images(gen_images, n_plots, "mnist_generated", classified_generated_images)
      plot_images(torch.abs(gen_diff), n_plots, "mnist_generated_diff")

      print("Discriminator 2 guess on input images (real)\n", classified_input_images[:5])
      print("Discriminator 2 guess on generated images (fake)\n", classified_generated_images[:5].unsqueeze(-1).transpose(0,1))
      print("Discriminator 2 guess on noisy images (fake)\n", classified_noisy_images[:5].unsqueeze(-1).transpose(0,1))


    #break  # Remove this line to process the entire dataset


print("Input images accuracy\n", input_images_matches / len(test_loader.dataset)) #/input_images.size(0))
print("Generated images accuracy\n", generated_images_matches / len(test_loader.dataset)) #/input_images.size(0))
print("Noisy images accuracy\n", noisy_images_matches / len(test_loader.dataset)) #/input_images.size(0))



In [None]:
def generate_ssims(gan, strength_factor=1):
    gen_image_ssims = []
    noisy_image_ssims = []
    ssim = StructuralSimilarityIndexMeasure(reduction=None)

    for i_batch, batch in enumerate(test_loader):
        input, _ = batch

        with torch.no_grad():
            gen_output = gan.generator(input.to(device)).cpu()

        gen_diff = input - gen_output
        gen_norm = torch.norm(gen_diff, dim=(2,3))

        noisy_input = generate_gaussian_noise(input, alpha = gen_norm)
        gaussian_diff = input_images - noisy_images

        gen_image_ssims.append(ssim(gen_output, input))
        noisy_image_ssims.append(ssim(noisy_input, input))

    gen_image_ssims = torch.concat(gen_image_ssims)
    noisy_image_ssims = torch.concat(noisy_image_ssims)

    return gen_image_ssims, noisy_image_ssims

In [None]:
def plot_ssim_histograms(gen_image_ssims, noisy_image_ssims, plot_name):
    # plt.figure(figsize=(5,2))
    cm = 1/2.54  # centimeters in inches
    plt.figure(figsize=(7*cm, 7*cm))
    sns.histplot(gen_image_ssims.detach().cpu().numpy(), stat="probability", bins=30, color="darkred", edgecolor=None, alpha=0.8, label="Adversarial noise")
    sns.histplot(noisy_image_ssims.detach().cpu().numpy(), stat="probability", bins=30, color="gray", edgecolor=None, alpha=0.8, label="Gaussian noise")
    plt.xlim(0,1)
    plt.xlabel("SSIM")
    plt.ylabel("Frequency")
    plt.legend()
    plt.savefig(f"plot_{plot_name}.png")

gen_image_ssims, noisy_image_ssims = generate_ssims(gan)
print(torch.mean(gen_image_ssims).item(), torch.mean(noisy_image_ssims).item())
plot_ssim_histograms(gen_image_ssims, noisy_image_ssims, "test_ssim_histogram_adversarial_gaussian")

In [None]:
import pandas as pd

In [None]:
def calc_accuracy_over_strength_factor(gan, n_steps=11):
    strength_factors = np.linspace(0, 1, n_steps).tolist()
    gen_image_accuracy_results = []
    noisy_image_accuracy_results = []
    gen_image_ssim_results = []
    noisy_image_ssim_results = []

    for factor in strength_factors:
        input_matches = 0
        n_gen_output_matches = 0
        n_noisy_input_matches = 0
        gen_image_ssims = []
        noisy_image_ssims = []
        ssim = StructuralSimilarityIndexMeasure(reduction=None)

        for i_batch, batch in enumerate(test_loader):
            input, labels = batch
            factor_tensor = torch.ones_like(input) * factor

            with torch.no_grad():
                gen_output = gan.generator(input.to(device)).cpu()

            gen_diff = gen_output - input
            gen_norm = torch.norm(gen_diff, dim=(2, 3))

            noisy_input = generate_gaussian_noise(input, alpha=gen_norm)
            noisy_diff = noisy_input - input

            # Add noise to input images up to a factor
            gen_output_factorized = input + factor_tensor * gen_diff
            noisy_input_factorized = input + factor_tensor * noisy_diff

            gen_image_ssims.append(ssim(gen_output_factorized, input))
            noisy_image_ssims.append(ssim(noisy_input_factorized, input))

            with torch.no_grad():
                classified_gen_output = torch.argmax(gan.discriminator_2((gen_output_factorized).to(device)).cpu(), dim=1)
                classified_noisy_input = torch.argmax(gan.discriminator_2((noisy_input_factorized).to(device)).cpu(), dim=1)

            n_gen_output_matches += torch.sum(labels == classified_gen_output.unsqueeze(-1))
            n_noisy_input_matches += torch.sum(labels == classified_noisy_input.unsqueeze(-1))

        gen_image_accuracy_results.append((n_gen_output_matches / len(test_loader.dataset)).item())
        noisy_image_accuracy_results.append((n_noisy_input_matches / len(test_loader.dataset)).item())
        gen_image_ssim_results.append(torch.concat(gen_image_ssims).mean().item())
        noisy_image_ssim_results.append(torch.concat(noisy_image_ssims).mean().item())

    return pd.DataFrame(
        {
            "strength_factor": strength_factors,
            "gen_image_accuracy": gen_image_accuracy_results,
            "noisy_image_accuracy": noisy_image_accuracy_results,
            "gen_image_ssim": gen_image_ssim_results,
            "noisy_image_ssim": noisy_image_ssim_results,
        }
    )

accuracy_over_strength_factor_df = calc_accuracy_over_strength_factor(gan, 11)
accuracy_over_strength_factor_df.to_csv("accuracy_over_strength_factor.csv", sep=",")
accuracy_over_strength_factor_df

In [None]:
def plot_accuracy_over_strength_factor(df):
    cm = 1/2.54  # centimeters in inches
    plt.figure(figsize=(7*cm, 7*cm))
    plt.plot(accuracy_over_strength_factor_df["strength_factor"], accuracy_over_strength_factor_df["gen_image_accuracy"], marker="o")
    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.xlabel(r"Strength factor, $\varphi$")
    plt.ylabel("Classification accuracy")
    plt.savefig(f"plot_accuracy_over_strength_factor.png")

plot_accuracy_over_strength_factor(accuracy_over_strength_factor_df)

In [None]:
import pandas as pd

training_metrics_path = Path("/content/drive/MyDrive/data/results/gan_full_logs/version_0/metrics.csv")
training_metrics = pd.read_csv(training_metrics_path)

In [None]:
training_metrics

In [None]:
def plot_metrics(metrics_df):

    # Plot generator losses
    cm = 1/2.54  # centimeters in inches
    plt.figure(figsize=(7*cm, 7*cm))

    plt.plot(metrics_df[metrics_df["train_gen_loss_disc_1_epoch"].notnull()]["epoch"]+1, metrics_df["train_gen_loss_disc_1_epoch"].dropna(), label=r"$\alpha \mathcal{L}_\text{Disc. 1}$", marker="o")
    plt.plot(metrics_df[metrics_df["train_gen_loss_disc_2_epoch"].notnull()]["epoch"]+1, metrics_df["train_gen_loss_disc_2_epoch"].dropna(), label=r"$\beta \mathcal{L}_\text{Disc. 2}^*$", marker="o")
    plt.plot(metrics_df[metrics_df["train_gen_loss_norm_epoch"].notnull()]["epoch"]+1, metrics_df["train_gen_loss_norm_epoch"].dropna(), label=r"$\gamma \mathcal{L}_\text{Norm}$", marker="o")
    plt.xlim(min(metrics_df[metrics_df["train_gen_loss_disc_1_epoch"].notnull()]["epoch"]+1), max(metrics_df[metrics_df["train_gen_loss_disc_1_epoch"].notnull()]["epoch"]+1))
    plt.xticks(np.arange(min(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1), max(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1)+1, 4))
    plt.xlabel("Training epoch")
    plt.ylabel("Loss")
    plt.legend(loc="best")
    plt.savefig(f"plot_train_gen_loss_components_epoch.png")

    # Plot discriminator 2 Loss
    fig, ax1 = plt.subplots(figsize=(7*cm, 7*cm))
    # ax2 = ax1.twiny()

    # ax2.plot(metrics_df[metrics_df["train_disc_2_loss_step"].notnull()]["step"], metrics_df["train_disc_2_loss_step"].dropna(), alpha=0.5, lw=0.2)
    # ax2.set_xlim(min(metrics_df[metrics_df["train_gen_loss_step"].notnull()]["step"]), max(metrics_df[metrics_df["train_gen_loss_step"].notnull()]["step"]))
    # ax2.set_xlabel("Training step")

    ax1.plot(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1, metrics_df["train_disc_2_loss_epoch"].dropna(), marker="o")
    ax1.set_xlim(min(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1), max(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1))
    ax1.set_xticks(np.arange(min(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1), max(metrics_df[metrics_df["train_disc_2_loss_epoch"].notnull()]["epoch"]+1)+1, 4))
    ax1.set_xlabel("Training epoch")

    plt.ylabel(r"$\mathcal{L}_\text{Disc. 2}$ loss")
    plt.savefig(f"plot_train_disc_2_loss_epoch.png")


plot_metrics(training_metrics)