In [4]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
try:
    import google.colab
    IN_COLAB = True
    %pip install wandb
    %pip install --upgrade "kaleido==0.1.*"
    import kaleido
except:
    IN_COLAB = False
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.utils.data as data_utils
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from torchvision.datasets import MNIST
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
import plotly.io as pio
import tempfile
from PIL import Image
import io
import wandb
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time
from enum import Enum
import plotly.express as px

import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass

NUM_CLASSES = 10

class TempFileContext:
    def __enter__(self):
        self.tmp_file = tempfile.NamedTemporaryFile(suffix=".jpeg", delete=False)
        self.tmp_filename = self.tmp_file.name
        return self.tmp_filename
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.tmp_file.close()
        os.remove(self.tmp_filename)

# As per the DCGAN paper: All the weights are initialized from a zero centered normal distribution with standard deviation 0.02
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

class GeneratorUpSample(nn.Module):
    def __init__(self, channels_noise, channels_img, num_classes, features_g):
        """
        channels_noise: The size of the input noise vector. This noise vector is a random input from which the generator begins the generation of a new sample.
        channels_img: The number of output channels of the generator. This will typically be 1 for grayscale images or 3 for color (RGB) images.
        num_classes: The number of distinct classes or labels that the generator should generate images for. This is used to form the one-hot vector of class labels, which is concatenated to the noise vector to provide the generator with information about the class of image to generate.
        """
        # Conv2d formula => output_size = (input_size - 1) * stride - 2 * padding + kernel_size
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self.gen_block(channels_noise + num_classes, features_g * 4, kernel_size=7, stride=1, padding=0),  # output: (features_g*4) x 7 x 7 # Append class labels to input noise.
            nn.Dropout(p=0.05),
            self._block(features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1),  # output: (features_g*2) x 6 x 6
            nn.Upsample(scale_factor=4, mode='bilinear'),  # output: (features_g*2) x 14 x 14
            nn.Dropout(p=0.05),
            self._block(features_g * 2, features_g, kernel_size=4, stride=2, padding=1),  # output: features_g x 14 x 14
            nn.Upsample(scale_factor=5, mode='bilinear'),  # output: features_g x 28 x 28
            nn.Dropout(p=0.05),
            self._block(features_g, channels_img, kernel_size=7, stride=1, padding=2),  # output: channels_img x 28 x 28
            nn.Tanh(),  # normalize [-1, 1]
        )
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def gen_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    @autocast() # automatically applies precisions to different operations to speed up calculations
    def forward(self, noise):
        return self.gen(noise)

class GeneratorFractional(nn.Module):
    def __init__(self, channels_noise, channels_img, num_classes, features_g):
        """
        channels_noise: The size of the input noise vector. This noise vector is a random input from which the generator begins the generation of a new sample.
        channels_img: The number of output channels of the generator. This will typically be 1 for grayscale images or 3 for color (RGB) images.
        num_classes: The number of distinct classes or labels that the generator should generate images for. This is used to form the one-hot vector of class labels, which is concatenated to the noise vector to provide the generator with information about the class of image to generate.
        """
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._block(channels_noise + num_classes, features_g * 8, kernel_size=4, stride=2, padding=1),  # output: (features_g*8) x 2 x 2
            self._block(features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1),  # output: (features_g*4) x 4 x 4 
            self._block(features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1),  # output: (features_g*2) x 8 x 8
            self._block(features_g * 2, features_g, kernel_size=4, stride=2, padding=1),  # output: features_g x 16 x 16 
            nn.ConvTranspose2d(features_g, channels_img, kernel_size=4, stride=2, padding=3),  # output: channels_img x 28 x 28
            nn.Tanh(),  # normalize inputs to [-1, 1]
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    @autocast() # automatically applies precisions to different operations to speed up calculations
    def forward(self, noise):
        return self.gen(noise)

class GeneratorConv(nn.Module):
    """
    GeneratorPixelShuffle
    """
    def __init__(self, channels_noise, channels_img, num_classes, features_g):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Linear(channels_noise + num_classes, features_g * 4 * 7 * 7, bias=False),
            nn.BatchNorm1d(features_g * 4 * 7 * 7),
            nn.ReLU()
        )
        """
        In GANs, we start from a random noise vector in a latent space, but we want to generate 2D images. This transformation from a 1D noise vector to a 3D tensor is typically done using a dense layer, which learns to map the latent space effectively to the space of images during the training process.
        Also, this fully connected layer allows the model to create complex mappings from the input noise vector to the output, which is essential when generating realistic images. The capacity of this layer can be adjusted via the number of neurons to control the complexity of the generated images.
        """
        self.gen = nn.Sequential(
            self._block(features_g * 4, features_g * 2, upscale_factor=2),  # output: (features_g*2) x 14 x 14 
            self._block(features_g * 2, features_g, upscale_factor=2),  # output: features_g x 28 x 28 
            nn.Conv2d(features_g, channels_img, kernel_size=4, stride=1, padding=3),  # output: channels_img x 28 x 28
            # nn.Tanh()  # normalize inputs to [-1, 1]
        )

    def _block(self, in_channels, out_channels, upscale_factor):
        return nn.Sequential(
            nn.Upsample(scale_factor=upscale_factor, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, noise):
        noise = noise.view(noise.shape[0], -1)
        x = self.initial(noise)
        x = x.view(x.shape[0], -1, 7, 7)  # reshape into (batch_size, features_g * 4, 7, 7)
        return self.gen(x)

class DiscriminatorConv(nn.Module):
    def __init__(self, channels_img, num_classes, num_kernels, kernel_dim, filters):
        """
        channels_img: The number of input channels to the discriminator, corresponding to the number of channels in the images to be classified.
        features_d: This is the base size of the feature maps in the discriminator. The number of neurons or nodes in each layer of the discriminator is a multiple of this base size.
        num_classes: The number of distinct classes that the discriminator should be able to distinguish between. This is used to form the softmax output layer of the discriminator, which outputs a class probability distribution.
        num_kernels and kernel_dim: These are parameters for the minibatch discrimination layer. The minibatch discrimination layer is designed to make the discriminator sensitive to the variety of samples within a minibatch, to encourage the generator to generate a variety of different samples. num_kernels is the number of unique patterns the layer can learn to identify, and kernel_dim is the size of these learned patterns.
        """
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels_img, filters, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=0.05),
            self._block(filters, filters*2, 4, 2, 1),
            nn.Dropout(p=0.05),
        )
        self.mbd = MinibatchDiscrimination(filters*2*7*7, num_kernels, kernel_dim)
        self.fc = nn.Sequential(
            nn.Linear(filters*2*7*7 + num_kernels, filters*8),
            nn.LeakyReLU(0.2),
            nn.Linear(filters*8, 1 + num_classes),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
    @autocast()
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.mbd(x)
        out = self.fc(x)
        return out[:, 0], nn.functional.softmax(out[:, 1:], dim=1)

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, num_classes, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(channels_noise + num_classes, features_g * 4 * 7 * 7, bias=False),
            nn.BatchNorm1d(features_g * 4 * 7 * 7),
            nn.ReLU(),
            nn.Linear(features_g * 4 * 7 * 7, features_g * 2 * 14 * 14, bias=False),
            nn.BatchNorm1d(features_g * 2 * 14 * 14),
            nn.ReLU(),
            nn.Linear(features_g * 2 * 14 * 14, features_g * 28 * 28, bias=False),
            nn.BatchNorm1d(features_g * 28 * 28),
            nn.ReLU(),
            nn.Linear(features_g * 28 * 28, channels_img * 28 * 28),
            # nn.Tanh()  # normalize inputs to [-1, 1]
        )   
    @autocast()
    def forward(self, noise):
        noise = noise.view(noise.shape[0], -1)
        return self.net(noise).view(noise.shape[0], -1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self, channels_img, num_classes, num_kernels, kernel_dim, filters):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels_img * 28 * 28, filters * 2 * 14 * 14),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=0.05),
            nn.Linear(filters * 2 * 14 * 14, filters * 4 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=0.05),
            MinibatchDiscrimination(filters * 4 * 7 * 7, num_kernels, kernel_dim),
            nn.Linear(filters * 4 * 7 * 7 + num_kernels, filters * 8),
            nn.LeakyReLU(0.2),
            nn.Linear(filters * 8, 1 + num_classes),
        )
    @autocast()
    def forward(self, x):
        x = self.net(x.view(x.size(0), -1))
        return x[:, 0], nn.functional.softmax(x[:, 1:], dim=1)

class MinibatchDiscrimination(nn.Module):
    def __init__(self, input_features, num_kernels, kernel_dim):
        super(MinibatchDiscrimination, self).__init__()
        self.input_features = input_features
        self.num_kernels = num_kernels
        self.kernel_dim = kernel_dim
        self.T = nn.Parameter(torch.randn(input_features, num_kernels * kernel_dim))
    def forward(self, x):
        M = torch.matmul(x, self.T).view(-1, self.num_kernels, self.kernel_dim)
        diffs = M.unsqueeze(0) - M.transpose(0, 1).unsqueeze(2)
        abs_diffs = torch.sum(torch.abs(diffs), dim=2)
        minibatch_features = torch.sum(torch.exp(-abs_diffs), dim=2).T
        return torch.cat((x, minibatch_features), dim=1)

class LR_Metric(Enum):
    VALIDITY = 1
    AGE = 2
    
class CustomDataLoader:
    def __init__(self, X, Y, batch_size, device):
        self.batch_size = batch_size
        self.device = device
        self.data = X.float().to(self.device)
        self.targets = Y.to(self.device)
        self.num_samples = len(self.data)
    def __iter__(self):
        self.indices = torch.randperm(self.num_samples, device=self.device)
        self.idx = 0
        return self
    def __next__(self):
        if self.idx >= self.num_samples:
            raise StopIteration
        indices = self.indices[self.idx:self.idx+self.batch_size]
        batch_data = self.data[indices]
        batch_targets = self.targets[indices]
        self.idx += self.batch_size
        return batch_data, batch_targets
    def __len__(self):
        return (self.num_samples + self.batch_size - 1) // self.batch_size

image_size = 28 * 28
@dataclass
class Config:
    latent_dim: int = 100
    batch_size: int = 256 * 2
    num_epochs: int = 100
    num_kernels: int = 10
    kernel_dim: int = 3
    d_learning_rate: float = 0.0002
    g_learning_rate: float = 0.0002
    lr_restarts: int = 5
    min_lr: float = 1e-10
    lambda_class: int = 1
    replay_buffer_size: int = 1000
    features_g: int = 4
    features_d: int = 4
    logEnd: bool = True
    standardization: bool = True

c = Config(logEnd=False)

# normalize to [-1, 1]
train_data = MNIST(root='data/MNIST',train=True,download=True)
test_data = MNIST(root='data/MNIST', train=False, download=True)

if c.standardization:
    mean = train_data.data.float().mean()
    std = train_data.data.float().std()
    normalized_train_data = (train_data.data.float() - mean) / std
    normalized_test_data = (test_data.data.float() - mean) / std
else:
    normalized_train_data = 2 * train_data.data.float() / 255 - 1
    normalized_test_data = 2 * train_data.data.float() / 255 - 1

preppedConfig = {}
for k, v in dataclasses.asdict(c).items():
    if dataclasses.is_dataclass(v):
        preppedConfig[k] = dataclasses.asdict(v)
    else:
        preppedConfig[k] = v
wandb.init(project="mnist-gan", config=c)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    # Because the performance of cuDNN algorithms to compute the convolution of different kernel sizes varies, 
    # the auto-tuner can run a benchmark to find the best algorithm (current algorithms are these, these, and these). 
    # It’s recommended to use turn on the setting when your input size doesn’t change often. If the input size changes often, 
    # the auto-tuner needs to benchmark too frequently, which might hurt the performance.

train_loader = CustomDataLoader(normalized_train_data, train_data.targets, batch_size=c.batch_size, device=device)
test_loader = CustomDataLoader(normalized_test_data, test_data.targets, batch_size=c.batch_size, device=device)

# logging every epoch
logsPerEpochNames = ['age', 'curGap', 'oldGap', 'oldScore', 'replayScore', 'replayValidity', 'oldValidity', 'accuracy', 'd_lr','g_lr']
logsPerEpoch = {k: torch.zeros(c.num_epochs).to(device) for k in logsPerEpochNames}

generator = Generator(c.latent_dim, 1, NUM_CLASSES, c.features_g).to(device)
initialize_weights(generator)
discriminator = Discriminator(1, NUM_CLASSES, c.num_kernels, c.kernel_dim, c.features_d).to(device)
initialize_weights(discriminator)

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=c.g_learning_rate, betas=(0.5, 0.9))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=c.d_learning_rate, betas=(0.5, 0.9))

train_batches = len(train_loader)

# logging every batch
logsPerBatchNames = ['real_validity', 'fake_validity', 'd_fakeClassLoss', 'd_realClassLoss', 'd_fakeAccuracy', 'd_realAccuracy', 'd_loss_base', 'g_loss_base', 'g_fakeClassLoss']
logsPerBatch = {k: torch.zeros(c.num_epochs * train_batches, device=device) for k in logsPerBatchNames}
# logging is varied
t_images = []

class LearningRateScheduler:
    def __init__(self, initial_lr, replay_buffer_size, total_batches, batch_size, METRIC=LR_Metric.VALIDITY):
        self.initial_lr = initial_lr
        self.replay_buffer_size = replay_buffer_size
        self.total_batches = total_batches
        self.batch_size = batch_size
        self.METRIC = METRIC
        self.samplesPerBatch = int(np.ceil(replay_buffer_size / total_batches))
        self.filledIndex = 0
        self.oldFake_validity = torch.zeros(self.replay_buffer_size, device=device, dtype=torch.float16)
        self.oldReal_validity = torch.zeros(self.replay_buffer_size, device=device, dtype=torch.float16)
        self.oldFake_validities = torch.zeros(self.replay_buffer_size, device=device, dtype=torch.float16)
        # self.oldFake_validity = torch.zeros(self.replay_buffer_size, device=device) #, dtype=torch.float16)
        # self.oldReal_validity = torch.zeros(self.replay_buffer_size, device=device) #, dtype=torch.float16)
        # self.oldFake_validities = torch.zeros(self.replay_buffer_size, device=device) #, dtype=torch.float16)
        self.z_replay = torch.zeros(self.replay_buffer_size, c.latent_dim + NUM_CLASSES, device=device)
        self.age = torch.zeros(self.replay_buffer_size, device=device)
        self.kickTopPercent = 0.25
        self.openIndexes = torch.ones(self.replay_buffer_size, device=device)
        self.real_validity_total = torch.zeros(1, device=device)
        self.fake_validity_total = torch.zeros(1, device=device)
        self.numSamples = 0

    def fillReplayBuffer(self, real_validity, real_validities, fake_validity, fake_validities, z):
        """
        samples (amouting to replay_buffer_size) will be evenly provided by all batches to fill the replay buffer in 1 epoch
        """
        with torch.no_grad() and torch.cuda.amp.autocast():
            self.real_validity_total += real_validities.sum()
            self.fake_validity_total += fake_validities.sum()
            self.numSamples += len(z)
            openings = (self.openIndexes > 0).sum().item()
            numSamples = len(z)  
            if self.filledIndex < self.replay_buffer_size: 
                # start filling the buffer front to back, only fill self.samplesPerBatch to prevent, the early batches from dominanting the replay buffer
                remaining = self.replay_buffer_size - self.filledIndex
                numSelected = np.min([remaining, numSamples, self.samplesPerBatch])
                selected = np.random.choice(numSamples, numSelected, replace=False)
                indexes = torch.arange(self.filledIndex, self.filledIndex + len(selected))
                self.filledIndex += len(selected)
            elif openings:
                # randomly select samples to fill the openIndexes in the replay buffer
                indexes = torch.nonzero(self.openIndexes).squeeze()
                numSelected = np.min([openings, numSamples, self.samplesPerBatch])
                selected = np.random.choice(numSamples, numSelected, replace=False)
                indexes = np.random.choice(indexes.numel(), numSelected, replace=False)
            else:
                return
            self.oldFake_validity[indexes] = fake_validity.repeat(len(indexes))
            self.oldReal_validity[indexes] = real_validity.repeat(len(indexes))
            self.oldFake_validities[indexes] = fake_validities[selected].squeeze()
            self.z_replay[indexes,:] = z[selected]
            self.age[indexes] = 0
            self.openIndexes[indexes] = 0

    def plotReplayValidities(self):
        i_replays = (self.openIndexes == 0).nonzero().squeeze()
        fig = go.Figure()
        fig.add_trace(go.Histogram(x=self.oldReal_validity[i_replays].cpu().numpy(), name="real"))
        fig.add_trace(go.Histogram(x=self.oldFake_validity[i_replays].cpu().numpy(), name="fake"))
        fig.update_layout(barmode='overlay', title="saved validity scores histogram")
        fig.show()

    def update_learning_rate(self, epoch, d, g):
        with torch.no_grad():
            i_replays = (self.openIndexes == 0).nonzero().squeeze()
            z_replay = self.z_replay[i_replays]
            z_replay = z_replay.view(len(z_replay), c.latent_dim + NUM_CLASSES, 1, 1)
            fake_replay = g(z_replay)
            replayFake_validities, _ = d(fake_replay)
            replayFake_validities = replayFake_validities.squeeze()
    
            # if gaps are negatives then discriminator then fake images are getting higher validity scores than real ones
            oldGap = (self.oldReal_validity[i_replays] - self.oldFake_validity[i_replays]).mean()
            curGap = self.real_validity_total / self.numSamples - self.fake_validity_total / self.numSamples
            # positive - smaller positive

            replayScores = replayFake_validities - curGap.repeat(len(i_replays))
            oldScores = self.oldFake_validities[i_replays] - oldGap.repeat(len(i_replays))
            values = [ self.age[i_replays].mean(), curGap, oldGap, replayScores.mean(), oldScores.mean(), replayFake_validities.mean(), self.oldFake_validities[i_replays].mean()]
            for name, value in zip(['age', 'curGap', 'oldGap', 'replayScore', 'oldScore', 'replayValidity', 'oldValidity'], values):
                logsPerEpoch[name][epoch] = value
               
            if self.METRIC.value == LR_Metric.VALIDITY.value:
                metric = replayFake_validities.squeeze()
            elif self.METRIC.value == LR_Metric.AGE.value:
                metric = self.age[i_replays].squeeze()
                raise NotImplementedError("needs to be adjusted")
            else:
                raise Exception("Invalid metric")
            # Kick out top 10% of the replay buffer based on replayScores scores, lowest to highest, drop the highest
            i_highestMetric = torch.argsort(metric)[-int(np.ceil(self.kickTopPercent * self.replay_buffer_size)):]
            self.openIndexes[i_highestMetric] = 1

            # kick out first half for testing
            # self.openIndexes[:int(self.replay_buffer_size/2)] = torch.ones(int(self.replay_buffer_size/2)).to(device)

            self.age += 1

lr_scheduler_trial = LearningRateScheduler(initial_lr=0.001, replay_buffer_size=c.replay_buffer_size, total_batches=train_batches, batch_size=c.batch_size)
d_lr_scheduler = lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0=int(c.num_epochs/c.lr_restarts), T_mult=1, eta_min=c.min_lr)

# d_lr_scheduler = lr_scheduler.
classCriterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()

def createGridFakeImages(epoch=0, cubeSide=3, show=False, step=None, log=True):
    fig = make_subplots(rows=1, cols=2,
                        horizontal_spacing=0.01, 
                        shared_yaxes=True)
    numImages = torch.tensor([cubeSide**2], device=device)
    # Generate and plot fake images with labels
    labels = torch.randint(0, 10, (numImages,), device=device)
    labels_one_hot = torch.zeros(numImages, 10, device=device).scatter_(1, labels.view(numImages, 1), 1)
    with torch.no_grad():
        z = torch.randn(numImages, c.latent_dim, device=device)
        g_input = torch.cat((z, labels_one_hot), dim=1)
        g_input = g_input.view(numImages, c.latent_dim + NUM_CLASSES, 1, 1)
        fake_images = generator(g_input)
        # fake_validities, d_fakeClass = discriminator(fake_images)
        # g_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)
    fig = make_subplots(rows=cubeSide, cols=cubeSide, 
                        horizontal_spacing = 0.025,
                        vertical_spacing = 0.04,
                        subplot_titles=[str(label.item()) for label in labels])
    fake_images = fake_images.squeeze().cpu().numpy()
    for i in range(numImages):
        row = int(i/cubeSide) + 1
        col = int(i%cubeSide) + 1
        imageFlipped = np.flip(fake_images[i], 0)
        fig.add_trace(go.Heatmap(z=imageFlipped, 
                                colorscale='Greys',), row=row, col=col)
    fig.update_layout(title_text="Generated Images epoch: " + str(epoch), 
                    margin=dict(l=0, r=0, t=60, b=0),
                    height=400, width=400, showlegend=False)
    fig.update_traces(showscale=False)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    if show:
        fig.show()
    if log:
        if step is None:
            raise Exception("step must be provided when logging an image")
        # Convert the figure to a JPEG image and log using wandb
        image_bytes = pio.to_image(fig, format='jpeg')
        if not c.logEnd:
            with TempFileContext() as tmp_filename:
                with open(tmp_filename, 'wb') as tmp_file:
                    tmp_file.write(image_bytes)
                wandb.log(wandb.Image(tmp_filename), step=step)
        else:
            t_images.append((image_bytes, step))

def check_frozen_parameters(model):
    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(f"Parameter '{name}' is frozen!")

scaler = GradScaler()
# GradScaler with PyTorch's autocast prevents gradient underflow in mixed precision training.
# It achieves this by scaling up the loss before backward pass to keep float16 gradients from vanishing.
# After gradients are computed, they are scaled back before the optimizer updates the model weights.

if not c.logEnd:
  wandb.watch([generator, discriminator], log="all")

for epoch in range(c.num_epochs):
    correct, total = 0, 0
    epoch_metrics = {}
    for i, (real_images, labels) in enumerate(train_loader):
        batch_metrics = {}
        # s_time = time.time()
        # print(f"Epoch {epoch}/{num_epochs} Batch {i}/{total_steps}")
        _batch_size = real_images.size(0)
        real_images = real_images.unsqueeze(1)
        labels_one_hot = torch.zeros(_batch_size, NUM_CLASSES, device=device).scatter_(1, labels.view(_batch_size, 1), 1).to(device)

        # train generator
        # Setting gradients to zeroes by model.zero_grad() or optimizer.zero_grad() would execute memset for all parameters and update gradients with reading and writing operations. 
        # However, setting the gradients as None would not execute memset and would update gradients with only writing operations.
        generator_optimizer.zero_grad(set_to_none=True)
        z = torch.randn(_batch_size, c.latent_dim).to(device)
        g_input = torch.cat((z, labels_one_hot), dim=1)
        g_input = g_input.view(_batch_size, c.latent_dim + NUM_CLASSES, 1, 1)
        fake_images = generator(g_input)
        fake_validities, d_fakeClass = discriminator(fake_images)
        # g_loss should minimize the difference in predicting classes among the same classes
        g_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)
        # WGAN-GP
        # g_loss = -torch.mean(fake_validities) + g_fakeClassLoss * lambda_class
        d_logits_gen = fake_validities.view(-1)
        # LSGAN
        g_loss_base = criterion(d_logits_gen, torch.ones_like(d_logits_gen))
        g_loss = g_loss_base + g_fakeClassLoss * c.lambda_class
        # g_loss.backward()
        # generator_optimizer.step()
        scaler.scale(g_loss).backward()
        scaler.step(generator_optimizer)
        
        # train discriminator
        discriminator_optimizer.zero_grad(set_to_none=True)
        real_validities, d_realClass = discriminator(real_images)
        fake_validities, d_fakeClass = discriminator(fake_images.clone().detach())

        loss_disc_real = criterion(real_validities, torch.ones_like(real_validities))
        loss_disc_fake = criterion(fake_validities, -torch.ones_like(fake_validities)) # modified to -1 from normal LSGAN 0 target
        # LSGAN
        d_loss_base = (loss_disc_real + loss_disc_fake) / 2
        
        # gradient_penalty = compute_gradient_penalty(discriminator, real_images.data, fake_images.data)
        # d_loss = -torch.mean(real_validities) + torch.mean(fake_validities) + lambda_gp * gradient_penalty
        d_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)
        d_fakeAccuracy = (d_fakeClass.argmax(dim=1) == labels_one_hot.argmax(dim=1)).float().mean()
        d_realClassLoss = classCriterion(d_realClass, labels_one_hot)
        d_realAccuracy = (d_realClass.argmax(dim=1) == labels_one_hot.argmax(dim=1)).float().mean()
        d_loss = d_loss_base + (d_fakeClassLoss + d_realClassLoss) / 2
        # d_loss.backward()
        # discriminator_optimizer.step()
        scaler.scale(d_loss).backward()
        scaler.step(discriminator_optimizer)

        # print("g_loss_base: ", g_loss_base.item(), "g_fakeClassLoss: ", g_fakeClassLoss.item(), "d_loss_base: ", d_loss_base.item(), "d_fakeClassLoss: ", d_fakeClassLoss.item(), "d_realClassLoss: ", d_realClassLoss.item())

        # if i == train_batches - 1:
        #     fig = px.imshow(fake_images[0].detach().squeeze().cpu().numpy(), color_continuous_scale='Greys')
        #     fig.show()
        #     fig = px.imshow(real_images[0].detach().squeeze().cpu().numpy(), color_continuous_scale='Greys')
        #     fig.show()

        correct += (real_validities > 0).sum().item() + (fake_validities < 0).sum().item()
        total += len(real_validities) + len(fake_validities)
        # if (i+1) % 200 == 0:
        #     print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{train_batches}], d_loss: {d_loss_base.item():.4f}, g_loss: {g_loss.item():.4f}")
        g_input = g_input.view(_batch_size, c.latent_dim + NUM_CLASSES)
        real_validity = real_validities.mean()
        fake_validity = fake_validities.mean()
        lr_scheduler_trial.fillReplayBuffer(real_validity, real_validities, fake_validity, fake_validities, g_input)
        # print("lr_scheduler_trial: ", time.time() - s_time)
        i_step = epoch * train_batches + i
        values = [real_validity, fake_validity, d_fakeClassLoss, d_realClassLoss, d_fakeAccuracy, d_realAccuracy, d_loss_base, g_loss_base, g_fakeClassLoss]
        for name, value in zip(logsPerBatchNames, values):
            logsPerBatch[name][i_step] = value

        if not c.logEnd and i != train_batches - 1:
            batch_metrics.update({name: value.item() for name, value in zip(logsPerBatchNames, values)})
            wandb.log(batch_metrics, step=i_step)

        scaler.update()
    
    check_frozen_parameters(discriminator)
    check_frozen_parameters(generator)

    d_lr_scheduler.step()
    accuracy = correct / total

    lr_scheduler_trial.update_learning_rate(epoch, discriminator, generator)

    logsPerEpoch['accuracy'][epoch] = accuracy
    logsPerEpoch['g_lr'][epoch] = generator_optimizer.param_groups[0]['lr']
    logsPerEpoch['d_lr'][epoch] = discriminator_optimizer.param_groups[0]['lr']

    if not c.logEnd:
        epoch_metrics.update({name: logsPerEpoch[name][epoch].item() for name in logsPerEpochNames})
        epoch_metrics.update(batch_metrics)
        wandb.log(epoch_metrics, step=i_step)
    
    print(f"Epoch [{epoch+1}/{c.num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, accuracy: {accuracy:.4f}, d_lr: {discriminator_optimizer.param_groups[0]['lr']:.6f}")
    # if epoch % 20 == 0:
    createGridFakeImages(epoch=epoch,cubeSide=5, show=True, log=True, step=i_step)

for name in logsPerEpochNames:
    logsPerEpoch[name] = logsPerEpoch[name].cpu().detach().numpy()
for name in logsPerBatchNames:
    logsPerBatch[name] = logsPerBatch[name].cpu().detach().numpy()

def wandbLogAtEnd():
    imageIndex = 0
    _t_images = t_images.copy()
    for epoch in range(c.num_epochs):
        for i in range(train_batches):
            step = epoch * train_batches + i
            if i != train_batches - 1:
                metrics = {name : logsPerBatch[name][step] for name in logsPerBatchNames}
                wandb.log(metrics, step=step)
        epochMetrics = {name : logsPerEpoch[name][epoch] for name in logsPerEpochNames}
        epochMetrics.update(metrics)
        if len(_t_images) and _t_images[0][1] == step:
            with TempFileContext() as tmp_filename:
                image_bytes = _t_images[0][0]
                with open(tmp_filename, 'wb') as tmp_file:
                    tmp_file.write(image_bytes)
                epochMetrics['generator_output'] = wandb.Image(tmp_filename)
                wandb.log(epochMetrics, step=step)
            _t_images = _t_images[1:]
        else:
            wandb.log(epochMetrics, step=step)
    wandb.finish()

if c.logEnd:
    wandbLogAtEnd()

SyntaxError: cannot assign to function call (3621095602.py, line 607)

In [3]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, num_classes, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(channels_noise + num_classes, features_g * 4 * 7 * 7, bias=False),
            nn.BatchNorm1d(features_g * 4 * 7 * 7),
            nn.ReLU(),
            nn.Linear(features_g * 4 * 7 * 7, features_g * 2 * 14 * 14, bias=False),
            nn.BatchNorm1d(features_g * 2 * 14 * 14),
            nn.ReLU(),
            nn.Linear(features_g * 2 * 14 * 14, features_g * 28 * 28, bias=False),
            nn.BatchNorm1d(features_g * 28 * 28),
            nn.ReLU(),
            nn.Linear(features_g * 28 * 28, channels_img * 28 * 28),
            # nn.Tanh()  # normalize inputs to [-1, 1]
        )

    def forward(self, noise):
        noise = noise.view(noise.shape[0], -1)
        return self.net(noise).view(noise.shape[0], -1, 28, 28)


class Discriminator(nn.Module):
    def __init__(self, channels_img, num_classes, num_kernels, kernel_dim, filters):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels_img * 28 * 28, filters * 2 * 14 * 14),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=0.05),
            nn.Linear(filters * 2 * 14 * 14, filters * 4 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=0.05),
            MinibatchDiscrimination(filters * 4 * 7 * 7, num_kernels, kernel_dim),
            nn.Linear(filters * 4 * 7 * 7 + num_kernels, filters * 8),
            nn.LeakyReLU(0.2),
            nn.Linear(filters * 8, 1 + num_classes),
        )

    @autocast()
    def forward(self, x):
        x = self.net(x.view(x.size(0), -1))
        return x[:, 0], nn.functional.softmax(x[:, 1:], dim=1)

generator = Generator(c.latent_dim,1, NUM_CLASSES, 32).to(device)
numImages = 10
labels = torch.randint(0, 10, (numImages,), device=device)
labels_one_hot = torch.zeros(numImages, 10, device=device).scatter_(1, labels.view(numImages, 1), 1)
z = torch.randn(numImages, c.latent_dim, device=device)
g_input = torch.cat((z, labels_one_hot), dim=1)
g_input = g_input.view(numImages, c.latent_dim + NUM_CLASSES, 1, 1)
fake_images = generator(g_input)
fake_images.shape

discriminator = Discriminator(1, NUM_CLASSES, 5, 3, 32).to(device)
discriminator(fake_images)


(tensor([-0.0310, -0.0364, -0.0615, -0.0573, -0.0472, -0.0359, -0.0599, -0.0405,
         -0.0438, -0.0354], grad_fn=<SelectBackward0>),
 tensor([[0.0967, 0.0990, 0.1055, 0.0961, 0.0938, 0.0966, 0.1057, 0.0964, 0.1057,
          0.1046],
         [0.0985, 0.0981, 0.1069, 0.0968, 0.0941, 0.0944, 0.1047, 0.0951, 0.1063,
          0.1051],
         [0.0980, 0.0979, 0.1031, 0.0984, 0.0943, 0.0959, 0.1049, 0.0966, 0.1076,
          0.1034],
         [0.0979, 0.0967, 0.1049, 0.0973, 0.0925, 0.0963, 0.1044, 0.0983, 0.1086,
          0.1030],
         [0.0982, 0.0966, 0.1048, 0.0977, 0.0932, 0.0963, 0.1048, 0.0962, 0.1095,
          0.1027],
         [0.0996, 0.0965, 0.1031, 0.0987, 0.0963, 0.0947, 0.1062, 0.0967, 0.1072,
          0.1011],
         [0.0985, 0.0955, 0.1052, 0.0963, 0.0945, 0.0967, 0.1029, 0.0972, 0.1095,
          0.1038],
         [0.0987, 0.0956, 0.1053, 0.0965, 0.0932, 0.0969, 0.1046, 0.0974, 0.1069,
          0.1050],
         [0.1003, 0.0939, 0.1051, 0.0967, 0.0951, 0.095

wandb: Network error (ConnectionError), entering retry loop.


In [None]:
numImages = torch.tensor([2000]).to(device)
# Generate and plot fake images with labels
labels = torch.randint(0, 10, (numImages,)).to(device)
labels_one_hot = torch.zeros(numImages, 10).to(device).scatter_(1, labels.view(numImages, 1), 1)
with torch.no_grad():
    z = torch.randn(numImages, c.latent_dim).to(device)
    g_input = torch.cat((z, labels_one_hot), dim=1)
    g_input = g_input.view(numImages, c.latent_dim + NUM_CLASSES, 1, 1)
    fake_images = generator(g_input)
    fake_validities, d_fakeClass = discriminator(fake_images)
    g_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)

fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
label_text = [str(label.item()) for label in labels]
plt.figure(figsize=(5, 5))
for i in range(numImages):
    plt.subplot(int(numImages**0.5), int(numImages**0.5), i+1)
    plt.axis('off')
    plt.title(label_text[i], fontsize=10)
    plt.imshow(fake_images[i].cpu().squeeze(), cmap='gray')
plt.subplots_adjust(wspace=0.25, hspace=0.25)
plt.suptitle("Generated Images epoch: " + str(epoch), fontsize=16)
plt.show()

# # save plotly

# # Save the figure to a file
# image_path = "image.jpg"
# plt.savefig(image_path)
# # Convert the saved image file to wandb.Image and log using wandb
# with open(image_path, "rb") as img_file:
#     img_data = img_file.read()
#     image = Image.open(io.BytesIO(img_data))
#     wandb.log({"generator_output": wandb.Image(image)})

In [None]:
import torch
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt

# Create an optimizer
initial_lr = 0.001
min_lr = 0.000001
optimizer = torch.optim.SGD([torch.randn(1, requires_grad=True)], lr= initial_lr)

# Define the number of epochs
num_epochs = 100
cycles = 4
# Learning rate schedulers
cosineAnnealingWarmRestarts = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(c.num_epochs/cycles), T_mult=1, eta_min=min_lr)
schedulers = {
    # "LambdaLR": lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch),
    # "MultiplicativeLR": lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95),
    # "StepLR": lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1),
    # "MultiStepLR": lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1),
    # "ConstantLR": lr_scheduler.ConstantLR(optimizer),
    # "LinearLR" : lr_scheduler.LinearLR(optimizer),
    # "ExponentialLR": lr_scheduler.ExponentialLR(optimizer, gamma=0.1),
    # "PolynomialLR": lr_scheduler.PolynomialLR(optimizer,total_iters=4, power=1.0),
    # "CosineAnnealingLR": lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0),
    "ChainedScheduler" : lr_scheduler.ChainedScheduler([lr_scheduler.ConstantLR(optimizer, total_iters=10), cosineAnnealingWarmRestarts]),
    # "SequentialLR": lr_scheduler.SequentialLR(optimizer, schedulers=[lr_scheduler.ConstantLR(optimizer, factor=0.1, total_iters=2), lr_scheduler.ExponentialLR(optimizer, gamma=0.9)], milestones=[2]),
    # "ReduceLROnPlateau": lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10),
    # "CyclicLR": lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=1, step_size_up=5, mode='triangular2'),
    # "OneCycleLR": lr_scheduler.OneCycleLR(optimizer, max_lr=1, total_steps=num_epochs),
    # "CosineAnnealingWarmRestarts": lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(num_epochs/cycles), T_mult=1, eta_min=min_lr)
}

# Create a plot for each scheduler
for name, scheduler in schedulers.items():
    lrs = []
    for epoch in range(num_epochs):
        optimizer.step()
        lrs.append(optimizer.param_groups[0]["lr"])
        if name != "ReduceLROnPlateau":
            scheduler.step()
        else:
            scheduler.step(epoch)  # Assume loss is decreasing with epoch for this example
        optimizer.zero_grad()

    plt.figure()
    plt.plot(lrs)
    plt.title(name)

plt.show()
