In [None]:
import torch
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt

# Create an optimizer
initial_lr = 0.001
min_lr = 0.000001
optimizer = torch.optim.SGD([torch.randn(1, requires_grad=True)], lr= initial_lr)

# Define the number of epochs
num_epochs = 100
cycles = 4
# Learning rate schedulers
cosineAnnealingWarmRestarts = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(num_epochs/cycles), T_mult=1, eta_min=min_lr)
schedulers = {
    # "LambdaLR": lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch),
    # "MultiplicativeLR": lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95),
    # "StepLR": lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1),
    # "MultiStepLR": lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1),
    # "ConstantLR": lr_scheduler.ConstantLR(optimizer),
    # "LinearLR" : lr_scheduler.LinearLR(optimizer),
    # "ExponentialLR": lr_scheduler.ExponentialLR(optimizer, gamma=0.1),
    # "PolynomialLR": lr_scheduler.PolynomialLR(optimizer,total_iters=4, power=1.0),
    # "CosineAnnealingLR": lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0),
    "ChainedScheduler" : lr_scheduler.ChainedScheduler([lr_scheduler.ConstantLR(optimizer, total_iters=10), cosineAnnealingWarmRestarts]),
    # "SequentialLR": lr_scheduler.SequentialLR(optimizer, schedulers=[lr_scheduler.ConstantLR(optimizer, factor=0.1, total_iters=2), lr_scheduler.ExponentialLR(optimizer, gamma=0.9)], milestones=[2]),
    # "ReduceLROnPlateau": lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10),
    # "CyclicLR": lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=1, step_size_up=5, mode='triangular2'),
    # "OneCycleLR": lr_scheduler.OneCycleLR(optimizer, max_lr=1, total_steps=num_epochs),
    # "CosineAnnealingWarmRestarts": lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(num_epochs/cycles), T_mult=1, eta_min=min_lr)
}

# Create a plot for each scheduler
for name, scheduler in schedulers.items():
    lrs = []
    for epoch in range(num_epochs):
        optimizer.step()
        lrs.append(optimizer.param_groups[0]["lr"])
        if name != "ReduceLROnPlateau":
            scheduler.step()
        else:
            scheduler.step(epoch)  # Assume loss is decreasing with epoch for this example
        optimizer.zero_grad()

    plt.figure()
    plt.plot(lrs)
    plt.title(name)

plt.show()


In [83]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
try:
    import google.colab
    IN_COLAB = True
    %pip install wandb
    %pip install --upgrade "kaleido==0.1.*"
    import kaleido
except:
    IN_COLAB = False
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.utils.data as data_utils
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from torchvision.datasets import MNIST
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
import plotly.io as pio
import tempfile
from PIL import Image
import io
import wandb
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time
from enum import Enum

import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass

NUM_CLASSES = 10

class TempFileContext:
    def __enter__(self):
        self.tmp_file = tempfile.NamedTemporaryFile(suffix=".jpeg", delete=False)
        self.tmp_filename = self.tmp_file.name
        return self.tmp_filename
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.tmp_file.close()
        os.remove(self.tmp_filename)

# As per the DCGAN paper: All the weights are initialized from a zero centered normal distribution with standard deviation 0.02
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, num_classes):
        """
        channels_noise: The size of the input noise vector. This noise vector is a random input from which the generator begins the generation of a new sample.
        channels_img: The number of output channels of the generator. This will typically be 1 for grayscale images or 3 for color (RGB) images.
        num_classes: The number of distinct classes or labels that the generator should generate images for. This is used to form the one-hot vector of class labels, which is concatenated to the noise vector to provide the generator with information about the class of image to generate.
        """
        super(Generator, self).__init__()
        self.channels_noise = channels_noise
        self.num_classes = num_classes

        self.gen = nn.Sequential(
            self.gen_block(channels_noise + num_classes, 256, kernel_size=7, stride=1, padding=0), # Append class labels to input noise.
            self.gen_block(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose2d(128, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def gen_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    @autocast() # automatically applies precisions to different operations to speed up calculations
    def forward(self, z):
        return self.gen(z)

class Discriminator(nn.Module):
    def __init__(self, channels_img, num_classes, num_kernels, kernel_dim):
        """
        channels_img: The number of input channels to the discriminator, corresponding to the number of channels in the images to be classified.
        features_d: This is the base size of the feature maps in the discriminator. The number of neurons or nodes in each layer of the discriminator is a multiple of this base size.
        num_classes: The number of distinct classes that the discriminator should be able to distinguish between. This is used to form the softmax output layer of the discriminator, which outputs a class probability distribution.
        num_kernels and kernel_dim: These are parameters for the minibatch discrimination layer. The minibatch discrimination layer is designed to make the discriminator sensitive to the variety of samples within a minibatch, to encourage the generator to generate a variety of different samples. num_kernels is the number of unique patterns the layer can learn to identify, and kernel_dim is the size of these learned patterns.
        """
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels_img, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            self._block(32, 64, 4, 2, 1),
        )
        self.mbd = MinibatchDiscrimination(64*7*7, num_kernels, kernel_dim)
        self.fc = nn.Sequential(
            nn.Linear(64*7*7 + num_kernels, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1 + num_classes),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
    @autocast()
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.mbd(x)
        out = self.fc(x)
        return out[:, 0], nn.functional.softmax(out[:, 1:], dim=1)

class MinibatchDiscrimination(nn.Module):
    def __init__(self, input_features, num_kernels, kernel_dim):
        super(MinibatchDiscrimination, self).__init__()
        self.input_features = input_features
        self.num_kernels = num_kernels
        self.kernel_dim = kernel_dim
        self.T = nn.Parameter(torch.randn(input_features, num_kernels * kernel_dim))
    def forward(self, x):
        M = torch.matmul(x, self.T).view(-1, self.num_kernels, self.kernel_dim)
        diffs = M.unsqueeze(0) - M.transpose(0, 1).unsqueeze(2)
        abs_diffs = torch.sum(torch.abs(diffs), dim=2)
        minibatch_features = torch.sum(torch.exp(-abs_diffs), dim=2).T
        return torch.cat((x, minibatch_features), dim=1)

class LR_Metric(Enum):
    VALIDITY = 1
    AGE = 2
    
class CustomDataLoader:
    def __init__(self, dataset, batch_size, device):
        self.dataset = dataset
        self.batch_size = batch_size
        self.device = device
        self.data = self.dataset.data.float().to(self.device)
        self.targets = self.dataset.targets.to(self.device)
        self.num_samples = len(self.data)

    def __iter__(self):
        self.indices = torch.randperm(self.num_samples, device=self.device)
        self.idx = 0
        return self

    def __next__(self):
        if self.idx >= self.num_samples:
            raise StopIteration

        indices = self.indices[self.idx:self.idx+self.batch_size]
        batch_data = self.data[indices]
        batch_targets = self.targets[indices]

        self.idx += self.batch_size

        return batch_data, batch_targets

    def __len__(self):
        return (self.num_samples + self.batch_size - 1) // self.batch_size

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data = MNIST(root='./dataset/minst/',train=True,download=True,transform=transform)
test_data = MNIST(root='./dataset/minst/', train=False, download=True, transform=transform)

# Hyperparameters
image_size = 28 * 28

@dataclass
class Config:
    latent_dim: int = 100
    batch_size: int = 256
    num_epochs: int = 100
    num_kernels: int = 10
    kernel_dim: int = 3
    learning_rate: float = 0.0002
    lr_restarts: int = 5
    min_lr: float = 1e-10
    lambda_class: int = 1
    replay_buffer_size: int = 1000

c = Config()
preppedConfig = {}
for k, v in dataclasses.asdict(c).items():
    if dataclasses.is_dataclass(v):
        preppedConfig[k] = dataclasses.asdict(v)
    else:
        preppedConfig[k] = v
wandb.init(project="mnist-gan", config=c)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    # Because the performance of cuDNN algorithms to compute the convolution of different kernel sizes varies, 
    # the auto-tuner can run a benchmark to find the best algorithm (current algorithms are these, these, and these). 
    # It’s recommended to use turn on the setting when your input size doesn’t change often. If the input size changes often, 
    # the auto-tuner needs to benchmark too frequently, which might hurt the performance.

train_loader = CustomDataLoader(train_data, batch_size=c.batch_size, device=device)
test_loader = CustomDataLoader(test_data, batch_size=c.batch_size, device=device)

# logging every epoch
t_age = torch.zeros(num_epochs).to(device)
t_curGap = torch.zeros(num_epochs).to(device)
t_oldGap = torch.zeros(num_epochs).to(device)
t_oldScore = torch.zeros(num_epochs).to(device)
t_replayScore = torch.zeros(num_epochs).to(device)
t_replayValidity = torch.zeros(num_epochs).to(device)
t_oldValidity = torch.zeros(num_epochs).to(device)
t_accuracy = torch.zeros(num_epochs).to(device)
t_d_lr = torch.zeros(num_epochs).to(device)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
dataset = MNIST('data/MNIST', train=True, download=True, transform=transform)
# Move the tensors in the dataset to the GPU
# if torch.backends.mps.is_available():
#     device = torch.device("mps")
generator = Generator(c.latent_dim, 1, NUM_CLASSES).to(device)
initialize_weights(generator)
discriminator = Discriminator(1, NUM_CLASSES, c.num_kernels, c.kernel_dim).to(device)
initialize_weights(discriminator)

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=c.learning_rate, betas=(0.5, 0.9))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=c.learning_rate, betas=(0.5, 0.9))\

train_batches = len(train_loader)

# logging every batch
t_real_validity = torch.zeros(num_epochs * train_batches).to(device)
t_fake_validity = torch.zeros(num_epochs * train_batches).to(device)
t_d_fakeClassLoss = torch.zeros(num_epochs * train_batches).to(device)
t_d_realClassLoss = torch.zeros(num_epochs * train_batches).to(device)
t_d_fakeAccuracy = torch.zeros(num_epochs * train_batches).to(device)
t_d_realAccuracy = torch.zeros(num_epochs * train_batches).to(device)
t_d_loss_base = torch.zeros(num_epochs * train_batches).to(device)
t_g_loss_base = torch.zeros(num_epochs * train_batches).to(device)
# logging is varied
t_images = []

class LearningRateScheduler:
    def __init__(self, initial_lr, replay_buffer_size, total_batches, batch_size, METRIC=LR_Metric.VALIDITY):
        self.initial_lr = initial_lr
        self.replay_buffer_size = replay_buffer_size
        self.total_batches = total_batches
        self.batch_size = batch_size
        self.METRIC = METRIC
        self.samplesPerBatch = int(np.ceil(replay_buffer_size / total_batches))
        self.filledIndex = 0
        self.oldFake_validity = torch.zeros(self.replay_buffer_size, device=device)
        self.oldReal_validity = torch.zeros(self.replay_buffer_size, device=device)
        self.oldFake_validities = torch.zeros(self.replay_buffer_size, device=device)
        self.z_replay = torch.zeros(self.replay_buffer_size, c.latent_dim + NUM_CLASSES, device=device)
        self.age = torch.zeros(self.replay_buffer_size, device=device)
        self.kickTopPercent = 0.25
        self.openIndexes = torch.ones(self.replay_buffer_size, device=device)
        self.real_validity_total = torch.zeros(1, device=device)
        self.fake_validity_total = torch.zeros(1, device=device)
        self.numSamples = 0

    def fillReplayBuffer(self, real_validity, real_validities, fake_validity, fake_validities, z):
        """
        samples (amouting to replay_buffer_size) will be evenly provided by all batches to fill the replay buffer in 1 epoch
        """
        with torch.no_grad():
            self.real_validity_total += real_validities.sum()
            self.fake_validity_total += fake_validities.sum()
            self.numSamples += len(z)
            openings = (self.openIndexes > 0).sum().item()
            numSamples = len(z)  
            if self.filledIndex < self.replay_buffer_size: 
                # start filling the buffer front to back, only fill self.samplesPerBatch to prevent, the early batches from dominanting the replay buffer
                remaining = self.replay_buffer_size - self.filledIndex
                numSelected = np.min([remaining, numSamples, self.samplesPerBatch])
                selected = np.random.choice(numSamples, numSelected, replace=False)
                indexes = torch.arange(self.filledIndex, self.filledIndex + len(selected))
                self.filledIndex += len(selected)
            elif openings:
                # randomly select samples to fill the openIndexes in the replay buffer
                indexes = torch.nonzero(self.openIndexes).squeeze()
                numSelected = np.min([openings, numSamples, self.samplesPerBatch])
                selected = np.random.choice(numSamples, numSelected, replace=False)
                indexes = np.random.choice(indexes.numel(), numSelected, replace=False)
            else:
                return
            self.oldFake_validity[indexes] = fake_validity.repeat(len(indexes))
            self.oldReal_validity[indexes] = real_validity.repeat(len(indexes))
            self.oldFake_validities[indexes] = fake_validities[selected].squeeze()
            self.z_replay[indexes,:] = z[selected]
            self.age[indexes] = 0
            self.openIndexes[indexes] = 0

    def plotReplayValidities(self):
        i_replays = (self.openIndexes == 0).nonzero().squeeze()
        fig = go.Figure()
        fig.add_trace(go.Histogram(x=self.oldReal_validity[i_replays].cpu().numpy(), name="real"))
        fig.add_trace(go.Histogram(x=self.oldFake_validity[i_replays].cpu().numpy(), name="fake"))
        fig.update_layout(barmode='overlay', title="saved validity scores histogram")
        fig.show()

    def update_learning_rate(self, epoch, d, g):
        with torch.no_grad():
            i_replays = (self.openIndexes == 0).nonzero().squeeze()
            z_replay = self.z_replay[i_replays]
            z_replay = z_replay.view(len(z_replay), c.latent_dim + NUM_CLASSES, 1, 1)
            fake_replay = g(z_replay)
            replayFake_validities, _ = d(fake_replay)
            replayFake_validities = replayFake_validities.squeeze()
    
            # if gaps are negatives then discriminator then fake images are getting higher validity scores than real ones
            oldGap = (self.oldReal_validity[i_replays] - self.oldFake_validity[i_replays]).mean()
            curGap = self.real_validity_total / self.numSamples - self.fake_validity_total / self.numSamples
            # positive - smaller positive

            replayScores = replayFake_validities - curGap.repeat(len(i_replays))
            oldScores = self.oldFake_validities[i_replays] - oldGap.repeat(len(i_replays))

            # logging
            t_age[epoch] = self.age[i_replays].mean()
            t_curGap[epoch] = curGap
            t_oldGap[epoch] = oldGap
            t_replayScore[epoch] = replayScores.mean()
            t_oldScore[epoch] = oldScores.mean()
            t_replayValidity[epoch] = replayFake_validities.mean()
            t_oldValidity[epoch] = self.oldFake_validities[i_replays].mean()
            # try:
            #     wandb.log({"curGap": curGap, "oldGap": oldGap, "replayScore": replayScores.mean().item(), "oldScore": oldScores.mean().item(), 'avgAge': self.age[i_replays].mean().item()})
            # except:
            #     ...
            if self.METRIC.value == LR_Metric.VALIDITY.value:
                metric = replayFake_validities.squeeze()
            elif self.METRIC.value == LR_Metric.AGE.value:
                metric = self.age[i_replays].squeeze()
                raise NotImplementedError("needs to be adjusted")
            else:
                raise Exception("Invalid metric")
            # Kick out top 10% of the replay buffer based on replayScores scores
            # lowest to highest, drop the highest
            i_highestMetric = torch.argsort(metric)[-int(np.ceil(self.kickTopPercent * self.replay_buffer_size)):]
            self.openIndexes[i_highestMetric] = 1

            # kick out first half for testing
            # self.openIndexes[:int(self.replay_buffer_size/2)] = torch.ones(int(self.replay_buffer_size/2)).to(device)

            self.age += 1

lr_scheduler_trial = LearningRateScheduler(initial_lr=0.001, replay_buffer_size=c.replay_buffer_size, total_batches=train_batches, batch_size=c.batch_size)
d_lr_scheduler = lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0=int(num_epochs/c.lr_restarts), T_mult=1, eta_min=min_lr)

# d_lr_scheduler = lr_scheduler.
classCriterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()

def createGridFakeImages(epoch=0, cubeSide=4, show=False, step=None, log=True):
    fig = make_subplots(rows=1, cols=2,
                        horizontal_spacing=0.01, 
                        shared_yaxes=True)
    numImages = torch.tensor([cubeSide**2], device=device)
    # Generate and plot fake images with labels
    labels = torch.randint(0, 10, (numImages,), device=device)
    labels_one_hot = torch.zeros(numImages, 10, device=device).scatter_(1, labels.view(numImages, 1), 1)
    with torch.no_grad():
        z = torch.randn(numImages, c.latent_dim, device=device)
        g_input = torch.cat((z, labels_one_hot), dim=1)
        g_input = g_input.view(numImages, c.latent_dim + NUM_CLASSES, 1, 1)
        fake_images = generator(g_input)
        # fake_validities, d_fakeClass = discriminator(fake_images)
        # g_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)
    fig = make_subplots(rows=cubeSide, cols=cubeSide, 
                        horizontal_spacing = 0.025,
                        vertical_spacing = 0.04,
                        subplot_titles=[str(label.item()) for label in labels])
    fake_images = fake_images.squeeze().cpu().numpy()
    for i in range(numImages):
        r = int(i/cubeSide) + 1
        c = int(i%cubeSide) + 1
        imageFlipped = np.flip(fake_images[i], 0)
        fig.add_trace(go.Heatmap(z=imageFlipped, 
                                colorscale='Greys',), row=r, col=c)
    fig.update_layout(title_text="Generated Images epoch: " + str(epoch), 
                    margin=dict(l=0, r=0, t=60, b=0),
                    height=800, width=800, showlegend=False)
    fig.update_traces(showscale=False)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    if show:
        fig.show()
    if log:
        if step is None:
            raise Exception("step must be provided when logging an image")
        # Convert the figure to a JPEG image and log using wandb
        image_bytes = pio.to_image(fig, format='jpeg')
        t_images.append((image_bytes, step))

scaler = GradScaler()
# GradScaler with PyTorch's autocast prevents gradient underflow in mixed precision training.
# It achieves this by scaling up the loss before backward pass to keep float16 gradients from vanishing.
# After gradients are computed, they are scaled back before the optimizer updates the model weights.

for epoch in range(num_epochs):
    correct, total = 0, 0
    for i, (real_images, labels) in enumerate(train_loader):
        # s_time = time.time()
        # print(f"Epoch {epoch}/{num_epochs} Batch {i}/{total_steps}")
        _batch_size = real_images.size(0)
        real_images = real_images.unsqueeze(1)
        labels_one_hot = torch.zeros(_batch_size, 10, device=device).scatter_(1, labels.view(_batch_size, 1), 1).to(device)

        # train generator
        # Setting gradients to zeroes by model.zero_grad() or optimizer.zero_grad() would execute memset for all parameters and update gradients with reading and writing operations. 
        # However, setting the gradients as None would not execute memset and would update gradients with only writing operations.
        generator_optimizer.zero_grad(set_to_none=True)
        z = torch.randn(_batch_size, c.latent_dim).to(device)
        g_input = torch.cat((z, labels_one_hot), dim=1)
        g_input = g_input.view(_batch_size, c.latent_dim + NUM_CLASSES, 1, 1)
        fake_images = generator(g_input)
        fake_validities, d_fakeClass = discriminator(fake_images)
        # g_loss should minimize the difference in predicting classes among the same classes
        g_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)
        # WGAN-GP
        # g_loss = -torch.mean(fake_validities) + g_fakeClassLoss * lambda_class
        d_logits_gen = fake_validities.view(-1)
        # LSGAN
        g_loss_base = criterion(d_logits_gen, torch.ones_like(d_logits_gen))
        g_loss = g_loss_base + g_fakeClassLoss * c.lambda_class
        scaler.scale(g_loss).backward()
        scaler.step(generator_optimizer)
        
        # train discriminator
        discriminator_optimizer.zero_grad(set_to_none=True)
        real_validities, d_realClass = discriminator(real_images)
        fake_validities, d_fakeClass = discriminator(fake_images.clone().detach())
        loss_disc_real = criterion(real_validities, torch.ones_like(real_validities))
        loss_disc_fake = criterion(fake_validities, -torch.ones_like(fake_validities)) # modified to -1 from normal LSGAN 0 target
        # LSGAN
        d_loss_base = (loss_disc_real + loss_disc_fake) / 2
        
        # gradient_penalty = compute_gradient_penalty(discriminator, real_images.data, fake_images.data)
        # d_loss = -torch.mean(real_validities) + torch.mean(fake_validities) + lambda_gp * gradient_penalty
        d_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)
        d_fakeAccuracy = (d_fakeClass.argmax(dim=1) == labels_one_hot.argmax(dim=1)).float().mean()
        d_realClassLoss = classCriterion(d_realClass, labels_one_hot)
        d_realAccuracy = (d_realClass.argmax(dim=1) == labels_one_hot.argmax(dim=1)).float().mean()
        d_loss = d_loss_base + (d_fakeClassLoss + d_realClassLoss) / 2
        scaler.scale(d_loss).backward()
        scaler.step(discriminator_optimizer)

        correct += (real_validities > 0).sum().item() + (fake_validities < 0).sum().item()
        total += len(real_validities) + len(fake_validities)
        if (i+1) % 200 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{train_batches}], d_loss: {d_loss_base.item():.4f}, g_loss: {g_loss.item():.4f}")
        g_input = g_input.view(_batch_size, c.latent_dim + NUM_CLASSES)
        real_validity = real_validities.mean()
        fake_validity = fake_validities.mean()
        lr_scheduler_trial.fillReplayBuffer(real_validity, real_validities, fake_validity, fake_validities, g_input)
        # print("lr_scheduler_trial: ", time.time() - s_time)
        i_step = epoch * train_batches + i
        t_real_validity[i_step] = real_validity
        t_fake_validity[i_step] = fake_validity
        t_d_fakeClassLoss[i_step] = d_fakeClassLoss
        t_d_realClassLoss[i_step] = d_realClassLoss
        t_d_fakeAccuracy[i_step] = d_fakeAccuracy
        t_d_realAccuracy[i_step] = d_realAccuracy
        t_d_loss_base[i_step] = d_loss_base
        t_g_loss_base[i_step] = g_loss_base

        scaler.update()
            
    d_lr_scheduler.step()
    accuracy = correct / total

    lr_scheduler_trial.update_learning_rate(epoch, discriminator, generator)

    t_accuracy[epoch] = accuracy
    t_d_lr[epoch] = discriminator_optimizer.param_groups[0]['lr']

    if epoch % 20 == 0:
        createGridFakeImages(epoch=epoch,cubeSide=5, show=True, log=True, step=i_step)

_t_age = t_age.cpu().detach().numpy()
_t_curGap = t_curGap.cpu().detach().numpy()
_t_oldGap = t_oldGap.cpu().detach().numpy()
_t_oldScore = t_oldScore.cpu().detach().numpy()
_t_replayScore = t_replayScore.cpu().detach().numpy()
_t_replayValidity = t_replayValidity.cpu().detach().numpy()
_t_oldValidity = t_oldValidity.cpu().detach().numpy()
_t_accuracy = t_accuracy.cpu().detach().numpy()
_t_d_lr = t_d_lr.cpu().detach().numpy()

_t_real_validity = t_real_validity.cpu().detach().numpy()
_t_fake_validity = t_fake_validity.cpu().detach().numpy()
_t_d_fakeClassLoss = t_d_fakeClassLoss.cpu().detach().numpy()
_t_d_realClassLoss = t_d_realClassLoss.cpu().detach().numpy()
_t_d_fakeAccuracy = t_d_fakeAccuracy.cpu().detach().numpy()
_t_d_realAccuracy = t_d_realAccuracy.cpu().detach().numpy()
_t_d_loss_base = t_d_loss_base.cpu().detach().numpy()
_t_g_loss_base = t_g_loss_base.cpu().detach().numpy()

imageIndex = 0
_t_images = t_images.copy()

for epoch in range(num_epochs):
    for i in range(train_batches):
        step = epoch * train_batches + i
        if i != train_batches - 1:
            metrics = {'real_validity': _t_real_validity[step],
                        'fake_validity': _t_fake_validity[step], 
                        'd_fakeClassLoss': _t_d_fakeClassLoss[step], 
                        'd_realClassLoss': _t_d_realClassLoss[step], 
                        'd_fakeAccuracy': _t_d_fakeAccuracy[step], 
                        'd_realAccuracy': _t_d_realAccuracy[step], 
                        'd_loss_base': _t_d_loss_base[step], 
                        'g_loss_base': _t_g_loss_base[step]}
            wandb.log(metrics, step=step)
    epochMetrics = {'avgAge': _t_age[epoch], 
                       'curGap': _t_curGap[epoch], 
                       'oldGap': _t_oldGap[epoch], 
                       'oldScore': _t_oldScore[epoch], 
                       'replayScore': _t_replayScore[epoch], 
                       'replayValidity': _t_replayValidity[epoch], 
                       'oldValidity': _t_oldValidity[epoch], 
                       'accuracy': _t_accuracy[epoch], 
                       'd_lr': _t_d_lr[epoch], }
    epochMetrics.update(metrics)
    if len(_t_images) and _t_images[0][1] == step:
        with TempFileContext() as tmp_filename:
            image_bytes = _t_images[0][0]
            with open(tmp_filename, 'wb') as tmp_file:
                tmp_file.write(image_bytes)
            epochMetrics['generator_output'] = wandb.Image(tmp_filename)
            wandb.log(epochMetrics, step=step)
        _t_images = _t_images[1:]
    else:
        wandb.log(epochMetrics, step=step)
wandb.finish()




You should consider upgrading via the '/Users/mnann/Documents/Code/AuthenticCursor/venvDev/bin/python -m pip install --upgrade pip' command.


TypeError: 'MNIST' object does not support item assignment

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(
    root='./dataset/minst/',
    train=True,
    download=False,
    transform=transform
)
train_loader = DataLoader(
    dataset=train_data,
    shuffle=True,
    batch_size=batch_size
)
train_data.train_data.to(torch.device("cuda:0"))  # put data into GPU entirely
train_data.train_labels.to(torch.device("cuda:0"))



In [25]:
numImages = torch.tensor([2000]).to(device)
# Generate and plot fake images with labels
labels = torch.randint(0, 10, (numImages,)).to(device)
labels_one_hot = torch.zeros(numImages, 10).to(device).scatter_(1, labels.view(numImages, 1), 1)
with torch.no_grad():
    z = torch.randn(numImages, c.latent_dim).to(device)
    g_input = torch.cat((z, labels_one_hot), dim=1)
    g_input = g_input.view(numImages, c.latent_dim + NUM_CLASSES, 1, 1)
    fake_images = generator(g_input)
    fake_validities, d_fakeClass = discriminator(fake_images)
    g_fakeClassLoss = classCriterion(d_fakeClass, labels_one_hot)

fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
label_text = [str(label.item()) for label in labels]
plt.figure(figsize=(5, 5))
for i in range(numImages):
    plt.subplot(int(numImages**0.5), int(numImages**0.5), i+1)
    plt.axis('off')
    plt.title(label_text[i], fontsize=10)
    plt.imshow(fake_images[i].cpu().squeeze(), cmap='gray')
plt.subplots_adjust(wspace=0.25, hspace=0.25)
plt.suptitle("Generated Images epoch: " + str(epoch), fontsize=16)
plt.show()

# # save plotly

# # Save the figure to a file
# image_path = "image.jpg"
# plt.savefig(image_path)
# # Convert the saved image file to wandb.Image and log using wandb
# with open(image_path, "rb") as img_file:
#     img_data = img_file.read()
#     image = Image.open(io.BytesIO(img_data))
#     wandb.log({"generator_output": wandb.Image(image)})

g_fakeClassLoss:  2.302658796310425
