# File Preprocessing  

In [None]:
'''
Run from google colab
'''

from google.colab import drive

drive.mount('/content/drive')

In [None]:
import os

# List files in the root of Google Drive
drive_root = '/content/drive/My Drive/'
for file_name in os.listdir(drive_root):
    print(file_name)

In [None]:
# Data Colab Extraction (ONLY USE ONCE)
'''
import zipfile

# Path to your zip file in Google Drive
zip_file_path = '/content/drive/My Drive/ML_FINAL_PROJECT.zip'
extract_to_path = '/content/drive/My Drive/ML_FINAL_PROJECT/'

# Create the extraction directory if it doesn't exist
os.makedirs(extract_to_path, exist_ok=True)

# Extract the zip file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to_path)

print('Extraction complete.')
'''

In [None]:
import pandas as pd
import numpy as np
import cv2
from pathlib import Path
import os
from tqdm.auto import tqdm
import shutil
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from torchvision.models import vgg16, VGG16_Weights

In [None]:
cwd_path = Path.cwd()

In [None]:
og_train_path = os.path.join(cwd_path, "drive/MyDrive/ML_FINAL_PROJECT/ML_FINAL_PROJECT/fruits-360/Training")
og_test_path  = os.path.join(cwd_path, "drive/MyDrive/ML_FINAL_PROJECT/ML_FINAL_PROJECT/fruits-360/Test")

In [None]:
def getRandomFruitDirs(classCount):
    return np.random.choice(os.listdir(og_train_path), classCount, replace=False)

In [None]:
def MakeDirs(srcDir, newDirs):

    if type(newDirs) == str:
        dirpath = os.path.join(srcDir, newDirs)
        if os.path.isdir(dirpath):
            shutil.rmtree(dirpath)

        os.mkdir(dirpath)


    if type(newDirs) == list:
        for dirname in newDirs:
            dirpath = os.path.join(srcDir, dirname)
            if os.path.isdir(dirpath):
                shutil.rmtree(dirpath)

            os.mkdir(dirpath)

In [None]:
def transferFile(srcPaths, destDir):
    if not os.path.isdir(destDir):
        os.mkdir(destDir)

    for idx, imagePth in enumerate(srcPaths):
        dPath = os.path.join(destDir, f"img{idx}.jpg")
        os.symlink(imagePth, dPath)


In [None]:
def getImageMatrix(links, title, image_size):

    sub_matrix = np.zeros(image_size[0] * image_size[1] * 3)
    labels = []
    for link in links:
        vect_item = cv2.imread(link)
        vect_item = cv2.resize(vect_item, image_size)
        vect_item = cv2.cvtColor(vect_item, cv2.COLOR_BGR2RGB)
        sub_matrix = np.vstack((sub_matrix, vect_item.flatten()))
        labels.append(title)

    return sub_matrix[1:]/255, np.array(labels)

In [None]:
def getSamplesDir(sampleSize, valSize, testSize, selectedFruits, classDist = None):
    trainSize = 1 - (valSize + testSize)
    train_fn = []
    test_fn = []
    val_fn = []

    if classDist is None:
        cDist = np.ones(selectedFruits.shape[0]) * 1/selectedFruits.shape[0]
    else:
        cDist = classDist

    if np.abs(np.sum(cDist) - 1) > 1e-9:
        raise Exception("Class Distribution is greater than one!")


    if testSize <= 0:
        raise Exception("Training sample size too small!")

    train_sample_sz = int(trainSize * sampleSize)
    val_sample_sz = int(valSize * sampleSize)
    test_sample_sz = int(testSize * sampleSize)

    subset_dir = "ML_FINAL_PROJECT/fruits-360"
    MakeDirs(subset_dir, "Subsets")

    subset_dir = "ML_FINAL_PROJECT/fruits-360/Subsets"

    subs = ["Train_sub", "Validation_sub", "Test_sub"]
    MakeDirs("ML_FINAL_PROJECT/fruits-360/Subsets", subs)

    for idx, fruit in enumerate(selectedFruits):
        globbed_train = glob(os.path.join(og_train_path, fruit, "*.jp*g"))
        globbed_test = glob(os.path.join(og_test_path, fruit, "*.jp*g"))

        train_fn = np.random.choice(globbed_train, int(train_sample_sz * cDist[idx]), replace = False)
        testval_fn = np.random.choice(globbed_test, int((test_sample_sz  + val_sample_sz)*cDist[idx]), replace = False)

        test_fn, val_fn = testval_fn[ : int(test_sample_sz * cDist[idx])],  testval_fn[int(test_sample_sz * cDist[idx]) + 1 :]

        transferFile(train_fn, os.path.join(subset_dir, subs[0], fruit))
        transferFile(val_fn, os.path.join(subset_dir, subs[1], fruit))
        transferFile(test_fn, os.path.join(subset_dir, subs[2], fruit))

In [None]:
def getSamplesMat(sampleSize, selectedFruits, classDist = None, resize = (100, 100)):
    matrix = np.zeros(shape = (1, resize[0] * resize[1] * 3))
    labels = np.zeros(0)

    if classDist is None:
        cDist = np.ones(selectedFruits.shape[0]) * 1/selectedFruits.shape[0]
    else:
        cDist = classDist

    if np.abs(np.sum(cDist) - 1) > 1e-9:
        raise Exception("Class Distribution is greater than one!")
        return


    for idx, fruit in enumerate(selectedFruits):
        globbed_train = glob(os.path.join(og_train_path, fruit, "*.jp*g"))
        globbed_test = glob(os.path.join(og_test_path, fruit, "*.jp*g"))

        fn = np.random.choice(globbed_train + globbed_test, int(sampleSize * cDist[idx]) ,replace = False)
        res_matrix, new_labels = getImageMatrix(fn, fruit,resize)

        matrix = np.vstack([matrix, res_matrix])
        labels = np.concatenate([labels, new_labels])

    return matrix[1:], labels, fn[0]

In [None]:
def GrayScaleImageSquare(dataset, flatten=True):

    if len(dataset.shape) == 2:
        try:
            dim = int(np.floor(np.sqrt(dataset.shape[1]/3)))
        except:
            raise Exception("Singular images may not work, use reshape(-1)")


        try:
            dataset = dataset.reshape(-1, dim, dim, 3)
        except:
            raise Exception("Input image is probably not colored")
    elif len(dataset.shape) == 4:
        dim = dataset.shape[1]

    grayConv = np.zeros(shape = [dataset.shape[0], dim, dim]);
    for idx, image in enumerate(dataset):
        gray_image = cv2.cvtColor(image.astype('float32'), cv2.COLOR_BGR2GRAY)
        grayConv[idx] = gray_image

    if flatten == True:
        grayConv = grayConv.reshape(-1, dim**2)

    return grayConv

# Pulling the images using the matrix command (For all but Neural Networks):

In [None]:
# Specify fruit classes here or get n random fruits
count = 40
classes = getRandomFruitDirs(count) #or ["Madarine", "Apple Golden 1" ... etc ]

In [None]:
classes = classes[classes != '.ipynb_checkpoints']

In [None]:
# overall Sample size:
sample_size = 1000

# Selected fruit classes (from above)
selected_fruits = classes

# Would contain the fraction distribution of classes for experimentation (None is all equal)
classDist = None # ex: [0.2, 0.1, 0.4, 0.1, 0.1] MUST ADD TO 1

# resize: Resizes the image for further experimentation default is 100 x 100
resizes = (100, 100)


image_dataset, labels, what = getSamplesMat(sample_size, selected_fruits, classDist, resizes)
image_dataset = GrayScaleImageSquare(image_dataset)

In [None]:
print("Dataset shape \"sample_size\" images of shape 3*resizes[1]*resizes[2] is: ", image_dataset.shape)

In [None]:
image_dataset.shape

In [None]:
labels.shape

In [None]:
# display an image:

# Reshape the image to the resizes coordinates:

# image index (change)
index = 10
image = image_dataset[index].reshape(resizes[0], resizes[1], 3)
plt.imshow(image)
plt.title(labels[index])

# Using a CNN with 20 fruit classes:

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torchsummary as summary


import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, TensorDataset

import copy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


import matplotlib.pyplot as plt

In [None]:
# Colored Fruitss
red_fruits = [
    "Grape Pink", "Tomato Cherry Red", "Tomato Maroon", "Apple Pink Lady", "Apricot",
    "Tomato 1", "Peach", "Cherry 2", "Apple Braeburn", "Tomato Heart",
    "Pepper Red", "Pomegranate", "Mango Red", "Pear Red", "Kaki", "Onion Red",
    "Strawberry", "Tamarillo", "Cucumber Ripe", "Pear Forelle"
]


yellow_orange_fruits = [
    "Pepper Yellow", "Grapefruit Pink", "Tomato Yellow", "Mandarine", "Tangelo",
    "Grapefruit White", "Cherry Wax Yellow", "Kaki", "Peach 2", "Apple Red Yellow 1",
    "Lemon", "Pomelo Sweetie", "Pear Red", "Pepper Orange", "Peach", "Maracuja",
    "Apricot", "Cucumber Ripe 2", "Carambula", "Tomato 2"
]

green_fruits = ['Pepper Green', 'Grape White 4', 'Grape Blue',
         'Tomato not Ripened', 'Avocado', 'Grape White 3',
         'Cherry Wax Black', 'Blueberry', 'Cherry 1', 'Plum 3', 'Plum',
         'Nut Forest', 'Mango', 'Pomelo Sweetie', 'Limes', 'Pineapple',
         'Apple Granny Smith', 'Guava', 'Watermelon', 'Apple Golden 3']

# Fruit Shapes


circular_fruits = [
    "Cantaloupe 1", "Cantaloupe 2", "Cherry 1", "Cherry 2", "Apricot",
    "Cherry Wax Black", "Cherry Wax Red", "Cherry Wax Yellow", "Clementine",
    "Grape Blue", "Grape White", "Grapefruit Pink", "Nectarine", "Orange",
    "Peach", "Plum", "Plum 2", "Tomato Cherry Red", "Walnut", "Watermelon"
]

no_circular_fruits = [
    "Apple Braeburn", "Apple Crimson Snow", "Apple Golden 1", "Apple Granny Smith",
    "Apple Pink Lady", "Apple Red 1", "Apple Red Delicious", "Apple Red Yellow 1",
    "Beetroot", "Blueberry", "Fig", "Guava", "Lemon", "Limes", "Onion Red",
    "Onion White", "Pepper Red", "Lemon Meyer", "Mango Red", "Kiwi"
]

tubular_fruits = [
    "Eggplant", "Avocado", "Avocado Ripe", "Banana", "Banana Lady Finger", "Banana Red",
    "Cactus fruit", "Carambula", "Corn", "Cucumber Ripe", "Hazelnut", "Kumquats",
    "Pear 2", "Pear Red", "Pepino", "Tomato 2", "Nut Pecan", "Melon Piel de Sapo",
    "Pear Abate", "Corn Husk"
]

random40 = getRandomFruitDirs(40)

ds_name = "Green Fruits (Greyscaled)"

In [None]:
# overall Sample size:
sample_size = 1000


# Selected fruit classes (from above)
selected_fruits = np.array(green_fruits)

# Would contain the fraction distribution of classes for experimentation (None is all equal)
classDist = None # ex: [0.2, 0.1, 0.4, 0.1, 0.1] MUST ADD TO 1

# resize: Resizes the image for further experimentation default is 100 x 100
resizes = (100, 100)


image_dataset, labels, what = getSamplesMat(sample_size, selected_fruits, classDist, resizes)
image_dataset = GrayScaleImageSquare(image_dataset)

In [None]:
ImageDataReshaped = image_dataset.reshape(-1, resizes[0], resizes[1],1)

In [None]:
class CustomDataset:
    def __init__(self, tensors, transforms):
        #assert all(tensors[0].size(0) == tensors.size(0) for tensor in tensors), "size mismatch"
        self.tensors = tensors
        self.trans = transforms

    def __getitem__(self, index):
        if self.trans != None:
            x = self.trans(self.tensors[0][index])
            y = self.tensors[1][index]
        else:
            x = self.tensors[0][index]
            y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)

In [None]:
crop_factor = 1
cropSz = int(resizes[0] * crop_factor)
transforms_Basic = T.Compose([
    T.ToPILImage(),
    T.RandomCrop(cropSz),
    T.RandomRotation(20),
    T.ToTensor()
])


In [None]:
le = LabelEncoder()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(ImageDataReshaped,labels, test_size=0.2, shuffle=True)

In [None]:
unq_train = np.unique(y_train)
unq_test = np.unique(y_test)

fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (10, 5))

sns.countplot(x=y_train, ax = ax[0])
ax[0].set_title("Train Dist")
ax[0].set_xticks(np.arange(unq_train.shape[0]), unq_train, rotation=45, ha='right')

sns.countplot(x=y_test, ax = ax[1])
ax[1].set_title("Test Dist")
ax[1].set_xticks(np.arange(unq_test.shape[0]), unq_test, rotation=45, ha='right')

fig.suptitle("Class Distribution")


In [None]:
y_train = torch.tensor(le.fit_transform(y_train)).long()
y_test = torch.tensor(le.transform(y_test)).long()
X_train = torch.tensor(X_train).float()
X_test = torch.tensor(X_test).float()

In [None]:
X_train = X_train.permute(0,3,1,2)
X_test = X_test.permute(0,3,1,2)

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
print(device)

# Testing a Basic CNN

In [None]:
# Creating a simple CNN model to train the network:
def CreateCNNModel(n_outputs, n_imgSize):
    class CnnModel(nn.Module):
        def __init__(self, n_outputs, n_imgSize):
            super().__init__()

            # Setting universal parameters
            self.kernel = 3
            self.pad = 0
            self.stride = 1
            self.pool = 2
            self.gradient = None

            # Assumes the image remains square
            self.imgSize = n_imgSize

            self.conv1 = nn.Conv2d(1, 8, self.kernel, self.stride, self.pad)

            outSize = int(np.floor((self.imgSize[0] + 2*self.pad - self.kernel)//self.stride) + 1)
            outSize = outSize//self.pool

            self.conv2 = nn.Conv2d(8, 16, self.kernel, self.stride, self.pad)

            outSize = int(np.floor((outSize + 2*self.pad - self.kernel)//self.stride) + 1)
            outSize = outSize//self.pool

            self.conv3 = nn.Conv2d(16, 32, self.kernel, self.stride, self.pad)

            outSize = int(np.floor((outSize + 2*self.pad - self.kernel)//self.stride) + 1)
            outSize = outSize//self.pool

            self.bnorm1 = nn.BatchNorm2d(32)

            self.flatten = nn.Flatten()

            self.fc1 = nn.Linear(outSize**2 * 32, 50)
            self.fc2 = nn.Linear(50, n_outputs)

        def activations_hook(self, grad):
            self.gradient = grad

        def get_activations_hook(self):
            return self.gradient

        def getFeatures(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, self.pool, self.pool)

            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, self.pool, self.pool)

            x = F.relu(self.conv3(x))
            x = F.max_pool2d(x, self.pool, self.pool)
            return x

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, self.pool, self.pool)

            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, self.pool, self.pool)

            x = F.relu(self.conv3(x))
            # Enable hook if doing grad cam
            #h = x.register_hook(self.activations_hook)
            x = F.max_pool2d(x, self.pool, self.pool)

            x = self.bnorm1(x)
            x = self.flatten(x)

            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x


    model = CnnModel(n_outputs, n_imgSize)
    lossfunc = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

    return model, lossfunc, optimizer

In [None]:
train = CustomDataset((X_train, y_train), transforms_Basic)
test = CustomDataset((X_train, y_train), None)

train_loader = DataLoader(train, batch_size = 32, shuffle=True)
test_loader = DataLoader(test, batch_size = test.tensors[0].shape[0])

## Constructing and testing the model:

In [None]:
class_sz = len(np.unique(y_train))

In [None]:
model, lossfunc, optimizer = CreateCNNModel(class_sz, resizes)
model = model.to(device)

In [None]:
x_base,y_base = next(iter(train_loader))
x_base = x_base.to(device)
test_results = model(x_base)

In [None]:
idx = 4
plt.imshow(x_base[idx].cpu().permute(1,2,0).detach().numpy())
plt.title(y_base[idx].item())

# Model Training Function:

In [None]:
#
def TrainModel(net, lossfunc, optimizer, train_loader, test_loader, epochs = 50, val_set = True):
    train_acc_hist = []
    train_loss_hist = []
    test_acc_hist = []
    test_loss_hist = []
    net.train()

    for epochi in range(epochs):
        batch_acc = []
        batch_loss = []
        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)

            pred = net(x)
            loss = lossfunc(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            batch_loss.append(loss.cpu().item())
            acc_score = 100 * torch.mean((torch.argmax(pred.cpu(), dim=1) == y.cpu()).float())
            batch_acc.append(acc_score.detach().numpy())

        train_acc_hist.append(np.mean(batch_acc))
        train_loss_hist.append(np.mean(batch_loss))

        if val_set:
          X_test, y_test = next(iter(test_loader))
          X_test = X_test.to(device)
          y_test = y_test.to(device)
          net.eval()
          with torch.no_grad():
              y_hat = net(X_test)
              test_loss = lossfunc(y_hat, y_test)
              test_loss_hist.append(test_loss.item())
              test_acc = 100*torch.mean((torch.argmax(y_hat.cpu(), dim=1) == y_test.cpu()).float())
              test_acc_hist.append(test_acc.detach().numpy())

    return train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist


In [None]:
train_loss, test_loss, train_acc, test_acc = TrainModel(model, lossfunc, optimizer, train_loader, test_loader, val_set=True)

In [None]:
# Metrics:
fig, ax = plt.subplots(nrows = 1, ncols=2, figsize=(10, 5))
fig.suptitle(f"CNN Metrics for {ds_name}")
ax[0].plot(train_loss, label="Train Loss")
ax[0].plot(test_loss, label="Test Loss")
ax[0].set_title("Model Loss")
ax[0].set_ylabel("Loss (CCE)")
ax[0].set_xlabel("Epochs")

ax[1].plot(train_acc, label="Train")
ax[1].plot(test_acc, label="Test")
ax[1].set_title("Model Accuracy")
ax[1].set_ylabel("Accuracy (%)")
ax[1].set_xlabel("Epochs")

plt.legend()

fig.tight_layout()

In [None]:
x_test, y_test = next(iter(test_loader))
x_test = x_test.to(device)
y_hat = model(x_test)
preds = torch.argmax(y_hat.cpu(), dim=1)

print(f"Classification Report for {ds_name}:")
print(classification_report(preds, y_test))

In [None]:
conf_matrix = confusion_matrix(preds, y_test)
sns.heatmap(conf_matrix, annot = True)
plt.title(f"Confusion Matrix for CNN with {ds_name}")

plt.xticks(np.arange(len(selected_fruits)) + 0.5, selected_fruits, rotation=90);
plt.yticks(np.arange(len(selected_fruits)) + 0.5, selected_fruits, rotation = 360);



In [None]:
# Source: https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82
def GrabGradCam(image, correct_label, vgg_model):
    vgg_model.eval()
    prediction = vgg_model(image)
    prediction[0][correct_label].backward()
    final_activations = vgg_model.get_activations_hook()
    # global average pool the final activations (512 and 12 of them)
    fa_mean = final_activations.mean(dim=[0,2,3])
    # inititalize the activations:
    weighted_activations = vgg_model.getFeatures(image).detach()[0]
    for idx in range(len(weighted_activations)):
        weighted_activations[idx] *= fa_mean[idx]
    # unsqueeze the weighted acitvations
    weighted_activations = weighted_activations.unsqueeze(0)
    # weighted_heatmap:
    weighted_activations = weighted_activations.mean(dim = 1).squeeze()

    # Normalizing the heatmap
    heatmap = torch.maximum(weighted_activations, torch.tensor(0))
    heatmap /= torch.max(heatmap)
    heatmap = heatmap.cpu().detach().numpy()

    heatmap = cv2.resize(heatmap, (image.shape[2], image.shape[3]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    og_img = image[0].cpu().detach().permute(1,2,0).numpy() * 255
    og_img = og_img.astype(np.uint8)
    superimposed_img = heatmap * 0.4 + og_img
    return superimposed_img/np.max(superimposed_img)

In [None]:
# Testing the GrabGradCam model:
test, label = next(iter(test_loader))
test_img = test[10].unsqueeze(0).to(device)
test_label = label[10].to(device)

In [None]:
model.eval()
plt.imshow(GrabGradCam(test_img, test_label, model))
pred = model(test_img).cpu()
pred_class = torch.argmax(pred, dim = 1)
pred_name = le.inverse_transform(pred_class.detach().numpy().reshape(1, -1))
actual_name = le.inverse_transform(test_label.cpu().detach().numpy().reshape(1, -1))
plt.title(f"Predicted: {pred_name} Actual: {actual_name}")

# Creating the WGAN For Fruits

In [None]:
# Steps:
# build Generator network
# build Discriminator network
# build Gradient Penalty Function
# build training function:

In [None]:
from tqdm.auto import tqdm
from torchvision.utils import make_grid

In [None]:
def show(tensor, num=25, wandbactive=0, name=''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow=5).permute(1,2,0)
  grid = grid/2 + 0.5
  plt.imshow(grid.clip(0,1))
  plt.show()

In [None]:
learning_rate_crit = 0.0002
learning_rate_gen = 0.0001
samples = 10000
img_size = (64,64)
gan_reshaped = (64,64)
bs = 128
critic_cycles = 5
n_epochs = 70
z_dim = 225

In [None]:
# Get Overall Set of fruits;
# change this to add more samples (preferable 10000)

fruits = getRandomFruitDirs(130)
fruits = fruits[fruits != '.ipynb_checkpoints']
dataset = getSamplesMat(samples, fruits, resize = gan_reshaped)
main = dataset[0]
labels = dataset[1]


In [None]:
le = LabelEncoder()

In [None]:
main = main.reshape(-1, 64, 64, 3).astype(np.float32)
main_norm = main * 2 - 1
main_trans = np.transpose(main_norm, (0, 3,1,2)).astype(np.float32, copy=False)
print(main.shape)
real_data = torch.from_numpy(main_trans)
labels = torch.tensor(le.fit_transform(labels)).long()

In [None]:
plt.imshow(main[2].reshape(64,64,3))

In [None]:
tensor_dataset = TensorDataset(real_data, labels)
dataloader = DataLoader(tensor_dataset, batch_size = bs, shuffle=True)

In [None]:
def GenNoise(num, z_dim):
    latent = torch.randn(num, z_dim).view(num, z_dim, 1, 1)
    return latent.to(device)


class GeneratorModel(nn.Module):
  def __init__(self, n_dim=16, z_dim = 100):
      super().__init__()
      self.z_dim = z_dim
      self.gen = nn.Sequential(
        nn.ConvTranspose2d(z_dim, n_dim*8, 4, 1, 0), # 4x4
        nn.BatchNorm2d(n_dim * 8),
        nn.ReLU(True),

        nn.ConvTranspose2d(n_dim*8, n_dim*4, 4, 2, 1), # 8x8
        nn.BatchNorm2d(n_dim * 4),
        nn.ReLU(True),

        nn.ConvTranspose2d(n_dim*4, n_dim*2, 4, 2, 1), # 16x16
        nn.BatchNorm2d(n_dim * 2),
        nn.ReLU(True),

        nn.ConvTranspose2d(n_dim*2, n_dim*1, 4, 2, 1), # 32 x 32
        nn.BatchNorm2d(n_dim * 1),
        nn.ReLU(True),

        nn.ConvTranspose2d(n_dim, 3, 4, 2, 1), #64 x 64
        nn.Tanh()
    )

  def forward(self, noise):
    x = noise.view(len(noise), self.z_dim, 1, 1)  # 128 x 200 x 1 x 1
    return self.gen(x)


In [None]:
class CriticModel(nn.Module):
    def __init__(self, n_dim=16):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(3, n_dim, 4, 2, 1),  # 64x64 -> 32x32
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_dim, n_dim*2, 4, 2, 1),  # 32x32 -> 16x16
            nn.BatchNorm2d(n_dim*2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_dim*2, n_dim*4, 4, 2, 1),  # 16x16 -> 8x8
            nn.BatchNorm2d(n_dim*4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_dim*4, n_dim*8, 4, 2, 1),  # 8x8 -> 4x4
            nn.BatchNorm2d(n_dim*8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_dim*8, 1, 4, 1, 0)  # 4x4 -> 1x1
        )

    def forward(self, image):
        return self.disc(image).view(len(image), -1)


# Test the dataset:

In [None]:
gen_model = GeneratorModel(z_dim = z_dim)
gen_model = gen_model.to(torch.device(device))
crit_model = CriticModel()
crit_model = crit_model.to(torch.device(device))

# Defining the Gradient Penalty function

In [None]:
# Constructing the gradient penalty function
# The gradient penality is used as a regularizer term in the wasserstien loss
def get_gp(real, fake, crit, device, gamma = 10):
    cur_bs , _, _, _ = real.shape
    alpha = torch.rand(cur_bs, 1, 1, 1).to(device)
    inter_img = alpha * real + (1-alpha) * fake
    inter_img.requires_grad_(True)
    predictions = crit(inter_img)
    model_gradient = torch.autograd.grad(
        inputs = inter_img,
        outputs = predictions,
        grad_outputs = torch.ones_like(predictions),
        retain_graph = True,
        create_graph = True,
    )[0]

    model_gradient = model_gradient.view(len(model_gradient), -1)
    gradient_norm = model_gradient.norm(2, dim=1)
    gp = ((gradient_norm - 1)**2).mean()
    return gp

In [None]:
# Construct the optimizer functions:
gen_optimizer = torch.optim.Adam(gen_model.parameters(),lr=learning_rate_gen, betas = (0.5, 0.99))
crit_optimizer = torch.optim.Adam(crit_model.parameters(),lr=learning_rate_crit, betas = (0.5, 0.99))

In [None]:
# Train Function:
# Train for 1 epoch only! (DEBUG):
n_epochs = 150
gen_loss_hist = []
cur_step = 0
gamma = 10
crit_loss_hist = []
for epochi in range(n_epochs):
    gen_loss_batch = []
    crit_loss_batch = []
    for real, _ in dataloader:
        real = real.to(device)
        cur_bs = len(real)
        # Training the critic Model
        for i in range(critic_cycles):
            # Zero Grad

            # Generate Fake Noise
            fake_noise = GenNoise(cur_bs, z_dim)
            fake = gen_model(fake_noise)
            # create the real predictions
            fake_preds = crit_model(fake.detach()).mean()
            real_preds = crit_model(real).mean()
            # calculating the gradient penalty:
            gp = get_gp(real, fake, crit_model, device)


            # Wasserstein Loss
            real_loss = fake_preds - real_preds + gamma * gp
            crit_loss_batch.append(real_loss.item())

            crit_optimizer.zero_grad()
            real_loss.backward(retain_graph = True)
            crit_optimizer.step()
        # Training the generative model:

        noise = GenNoise(cur_bs, z_dim)
        fake = gen_model(noise)
        gen_loss = -crit_model(fake).mean()

        gen_optimizer.zero_grad()
        gen_loss.backward()
        gen_optimizer.step()

        gen_loss_batch.append(gen_loss.item())
        cur_step += 1
        if cur_step%100 == 0:
          gan_noise = GenNoise(cur_bs, z_dim)
          sample_image = gen_model(gan_noise)
          show(real)
          show(sample_image)



    gen_loss_hist.append(np.mean(gen_loss_batch))
    crit_loss_hist.append(np.mean(crit_loss_batch))





In [None]:
plt.plot(gen_loss_hist, label="Generator Loss")
plt.plot(crit_loss_hist, label = "Critic Loss")
plt.legend()
plt.title("GAN Generator and Critic Loss (Wasserstien Loss)")
plt.xlabel("Epoch")
plt.ylabel("Loss")


In [None]:
# Generator Test:
gan_noise = GenNoise(10, z_dim)
sample_image = gen_model(gan_noise)
print(sample_image.shape)
targ_img = sample_image.cpu().permute(0, 2, 3, 1).detach().numpy()
targ_img = targ_img/2 + 0.5
plt.imshow(targ_img[1])

In [None]:
gen_model = GeneratorModel(z_dim=250)
crit_model = CriticModel()

gen_model = gen_model.to(device)
crit_model = crit_model.to(device)
train(gen_model, crit_model, dataloader, 250, 10)