<a href="https://colab.research.google.com/github/dwarfy35/deep_learning2/blob/main/Final_notebook_hopefully.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [197]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.colab import drive
from torch.autograd import grad
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset
from torch.autograd import grad
from collections import defaultdict
import os
import pandas as pd
import torch.nn.functional as F
from itertools import product
from collections import Counter

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [198]:
os.getcwd()

'/content'

number of fonts is 14990

In [199]:
tot_fonts = 14990
train_fonts = (0,int(14990*0.8))
test_fonts = (int(14990*0.8),int(14990*0.99))
val_fonts = (int(14990*0.99),14990)
print(train_fonts)
print(test_fonts)
print(val_fonts)

(0, 11992)
(11992, 14840)
(14840, 14990)


In [200]:
def count_labels(dataset):
    """
    Count the number of occurrences of each label in the dataset.

    Args:
        dataset (Dataset): An instance of NPZDataset.

    Returns:
        dict: A dictionary with labels as keys and their counts as values.
    """
    # Extract all labels from the dataset
    all_labels = [dataset[i][2].item() for i in range(len(dataset))]  # Label is the third item in each dataset entry

    # Count occurrences of each label
    label_counts = Counter(all_labels)

    # Sort and convert to a dictionary for a cleaner output
    sorted_counts = dict(sorted(label_counts.items()))

    # Display results
    print("Label counts in the dataset:")
    for label, count in sorted_counts.items():
        print(f"Label {chr(label + 65)} ({label}): {count}")

    return sorted_counts


In [201]:
class Generator3(nn.Module):
    def __init__(self, latent_dim, style_dim=1, img_size=32):
        super(Generator3, self).__init__()

        self.init_size = img_size // 16  # Initial spatial size after FC layer
        self.fc = nn.Linear(latent_dim + style_dim, 512 * self.init_size * self.init_size)

        self.deconv_blocks = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, z, condition):
        condition_features = condition.mean(dim=(2, 3))  # Global average pooling across height and width
        condition_features = condition_features.view(-1, 1)  # Ensure shape matches style_dim, likely not necessary but a good precaution
        z = torch.cat((z, condition_features), dim=1) #Concatenate to get final z

        out = self.fc(z).view(-1, 512, self.init_size, self.init_size) #Pass through initial linear layer and ensure correct shape

        img = self.deconv_blocks(out) # Standard generator
        return img

class Discriminator3(nn.Module):
    def __init__(self):
        super(Discriminator3, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(1 * 2, 32, 4, 2, 1, bias=False),  # Changed to now require 2 input channels since we concatenate
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 32 * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32 * 2, 32 * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32 * 4, 1, 4, 1, 0, bias=False),
        )

    def forward(self, img, condition_img):
        x = torch.cat((img, condition_img), dim=1) # Concatenate the input image and condition along the channel dimension
        validity = self.main(x) # Then through the discrim as usual
        return validity




In [202]:
class NPZDataset(Dataset):
    def __init__(self, npz_file, font_range=None, transform=None, filter_label=None, num_samples=None, missing_p=False):
        """
        Initialize the dataset with optional filtering by font range.

        Args:
            npz_file (str): Path to the .npz file containing images and labels.
            font_range (tuple, optional): Range of fonts to include (start, end). If None, include all fonts.
            transform (callable, optional): Transformations to apply to the images.
            filter_label (int, optional): Filter dataset by a specific label.
            num_samples (int, optional): Limit the number of samples in the dataset.
        """
        # Load the data from the .npz file
        data = np.load(npz_file)
        self.images = data['images']
        self.labels = data['labels']
        self.transform = transform

        font_size = 26  # Number of letters per font (A-Z)

        # Limit the dataset to the specified range of fonts
        if font_range is not None and missing_p == True:
            start_font, end_font = font_range
            start_idx = start_font * font_size
            end_idx = end_font * font_size  # Exclusive of the last font
            self.images = self.images[start_idx:end_idx]
            self.labels = self.labels[start_idx:end_idx]

        if font_range is not None and missing_p == False:
            start_font, end_font = font_range
            start_idx = start_font * font_size
            end_idx = end_font * font_size  # Exclusive of the last font
            self.images = self.images[start_idx:end_idx]
            self.labels = self.labels[start_idx:end_idx]

        # Filter by label if specified
        if filter_label is not None:
            # Find indices of the desired label
            label_indices = np.where(self.labels == filter_label)[0]

            # If num_samples is specified, limit the number of samples
            if num_samples is not None:
                label_indices = label_indices[:num_samples]

            # Filter images and labels
            self.images = self.images[label_indices]
            self.labels = self.labels[label_indices]

        # Validate that labels follow the sequence 0-25 for each font group
        self.validate_labels(font_size)

    def validate_labels(self, font_size):
      """
      Validate that each font group has labels 0-25 and transitions between groups correctly.

      Args:
          font_size (int): Number of letters per font (default is 26 for A-Z).
      """
      num_fonts = len(self.labels) // font_size

      for font_id in range(num_fonts):
          start_idx = font_id * font_size
          end_idx = start_idx + font_size

          # Extract labels for the current font group
          font_labels = self.labels[start_idx:end_idx]

          # Check if the labels are sequentially 0-25
          expected_labels = list(range(font_size))
          if font_labels.tolist() != expected_labels:
              print(f"Error: Font {font_id} does not have labels 0-25.")
              print(f"Actual labels: {font_labels.tolist()}")
              raise ValueError(f"Font {font_id} has incorrect labels.")

      # Check transitions between font groups
      for font_id in range(num_fonts - 1):
          current_end_label = self.labels[(font_id + 1) * font_size - 1]  # Last label of the current font
          next_start_label = self.labels[(font_id + 1) * font_size]      # First label of the next font
          if current_end_label != 25 or next_start_label != 0:
              print(f"Error in transition between fonts {font_id} and {font_id + 1}.")
              print(f"Last label of font {font_id}: {current_end_label}")
              print(f"First label of font {font_id + 1}: {next_start_label}")
              raise ValueError(f"Incorrect transition between fonts {font_id} and {font_id + 1}.")


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        """
        Retrieve an item from the dataset.

        Args:
            idx (int): Index of the item.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Image, condition image, and label.
        """
        # Get the image and label for the given index
        image = self.images[idx]
        label = self.labels[idx]

        # Determine the font ID based on dataset ordering
        font_id = idx // 26  # Calculate the font ID based on position in dataset

        # Get the "A" from the same font as the condition
        condition_index = font_id * 26  # First letter (A) of the current font
        condition_image = self.images[condition_index]
        if self.labels[condition_index] != 0:
          raise ValueError("Condition image label should be 0.")

        # Reshape the images
        image = image[np.newaxis, ...]  # Add channel dimension
        condition_image = condition_image[np.newaxis, ...]

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
            condition_image = self.transform(condition_image)

        # Convert to PyTorch tensors
        image = torch.tensor(image, dtype=torch.float32)
        condition_image = torch.tensor(condition_image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return image, condition_image, label


In [203]:
def number_to_alphabet(num):
    if 0 <= num <= 25:
        return chr(num + 65)  # 65 is the ASCII code for 'A'
    else:
        raise ValueError("Number must be between 0 and 25 inclusive.")


In [204]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples, condition_samples, device="cuda"):
    """
    Compute the gradient penalty for WGAN-GP in a conditional GAN setup where
    the discriminator takes both the input and the condition image.

    Args:
        discriminator: The discriminator model.
        real_samples: Batch of real images.
        fake_samples: Batch of fake images generated by the generator.
        condition_samples: Batch of conditional images (e.g., letter A from the same font).
        device: The device to run the computations on.

    Returns:
        Gradient penalty (scalar).
    """
    # Interpolate between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)  # Mixing factor
    interpolated_samples = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)

    # Compute the discriminator output for the interpolated samples and conditions
    d_interpolates = discriminator(interpolated_samples, condition_samples)

    # Flatten the discriminator output
    d_interpolates = d_interpolates.view(-1)  # Shape: [batch_size]

    # Create grad_outputs for the gradient computation
    fake = torch.ones(d_interpolates.size(), device=device)

    # Compute gradients of the discriminator's output w.r.t. interpolated inputs
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolated_samples,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    # Compute the gradient penalty
    gradients = gradients.view(gradients.size(0), -1)  # Flatten gradients
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty


In [205]:
import matplotlib.pyplot as plt

def display_fonts(dataset, start_font, end_font, font_size=26):
    """
    Display a range of fonts from the dataset.

    Args:
        dataset (NPZDataset): An instance of the NPZDataset class.
        start_font (int): The starting font index (inclusive).
        end_font (int): The ending font index (exclusive).
        font_size (int): The number of letters per font (default is 26 for A-Z).
    """
    num_fonts = end_font - start_font
    if num_fonts <= 0:
        print("Invalid font range. Ensure end_font > start_font.")
        return

    # Calculate the number of rows and columns for the grid
    num_rows = num_fonts
    num_cols = font_size

    # Set up the figure
    plt.figure(figsize=(num_cols * 1.5, num_rows * 1.5))

    for font_idx in range(start_font, end_font):
        for letter_idx in range(font_size):
            global_idx = font_idx * font_size + letter_idx
            if global_idx >= len(dataset):
                print(f"Reached end of dataset at index {global_idx}.")
                break

            image, _, label = dataset[global_idx]
            image = image.squeeze().numpy()  # Remove channel dimension for visualization

            # Plot the image
            plt.subplot(num_rows, num_cols, (font_idx - start_font) * font_size + letter_idx + 1)
            plt.imshow(image, cmap="gray")
            plt.axis("off")
            plt.title(chr(65 + label.item()), fontsize=8)  # Use label to display the corresponding letter

    plt.tight_layout()
    plt.show()


In [206]:
def validate_labels(self, font_size):
    """
    Validate that each font group has labels 0-25 and transitions between groups correctly.

    Args:
        font_size (int): Number of letters per font (default is 26 for A-Z).
    """
    num_fonts = len(self.labels) // font_size

    for font_id in range(num_fonts):
        start_idx = font_id * font_size
        end_idx = start_idx + font_size

        # Extract labels for the current font group
        font_labels = self.labels[start_idx:end_idx]

        # Check if the labels are sequentially 0-25
        expected_labels = list(range(font_size))
        if font_labels.tolist() != expected_labels:
            print(f"Error: Font {font_id} does not have labels 0-25.")
            print(f"Actual labels: {font_labels.tolist()}")
            raise ValueError(f"Font {font_id} has incorrect labels.")

    # Check transitions between font groups
    for font_id in range(num_fonts - 1):
        current_end_label = self.labels[(font_id + 1) * font_size - 1]  # Last label of the current font
        next_start_label = self.labels[(font_id + 1) * font_size]      # First label of the next font
        if current_end_label != 25 or next_start_label != 0:
            print(f"Error in transition between fonts {font_id} and {font_id + 1}.")
            print(f"Last label of font {font_id}: {current_end_label}")
            print(f"First label of font {font_id + 1}: {next_start_label}")
            raise ValueError(f"Incorrect transition between fonts {font_id} and {font_id + 1}.")


In [207]:
# Dataset creation
#npz_file = "/content/gdrive/My Drive/character_font.npz"
npz_file = "/content/gdrive/My Drive/corrected_dataset2.npz"

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

#fullData = NPZDataset(npz_file, font_range =(0,14990), transform=transform)

train_data = NPZDataset(npz_file, font_range=train_fonts,transform=transform)
test_data = NPZDataset(npz_file, font_range=test_fonts,transform=transform)
val_data = NPZDataset(npz_file, font_range=val_fonts,transform=transform)
#testing_data = NPZDataset(npz_file, font_range=(0,100),transform=transform,remove_missing_p=True)


In [208]:
count_labels(train_data)

  image = torch.tensor(image, dtype=torch.float32)
  condition_image = torch.tensor(condition_image, dtype=torch.float32)


KeyboardInterrupt: 

In [None]:
count_labels(test_data)

In [None]:
count_labels(val_data)

In [None]:
# Step 1: Group the data by label
grouped_data = defaultdict(list)
for img, condition_img, label in train_data:  # Unpack condition_img as well
    grouped_data[label.item()].append((img, condition_img, label))

# Step 2: Prepare datasets for each label
grouped_datasets = {
    label: [(img, cond_img, lbl) for img, cond_img, lbl in imgs]
    for label, imgs in grouped_data.items()
}

# Step 3: Create dataloaders for each group
grouped_dataloaders = {
    letter: DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for letter, dataset in grouped_datasets.items()
}


In [None]:
def displayGeneratedSample(generator, z_s, condition_imgs, labels, num_classes, class_index=0, device='cuda'):
    """
    Displays one generated image from the generator.

    Args:
        generator (nn.Module): Pre-trained generator model.
        z_s (torch.Tensor): Style vector.
        condition_imgs (torch.Tensor): Conditional images (batch).
        labels (torch.Tensor): Labels for the batch.
        num_classes (int): Number of character classes.
        class_index (int): Index of the sample in the batch to display.
        device (str): Device for computation ('cuda' or 'cpu').
    """
    # Ensure the inputs are on the correct device
    z_s = z_s.to(device)
    condition_imgs = condition_imgs.to(device)
    labels = labels.to(device)

    # Create one-hot encoding for the batch
    batch_size = labels.size(0)
    z_c = torch.zeros(batch_size, num_classes, device=device)
    z_c[torch.arange(batch_size), labels] = 1

    # Concatenate style and class vectors
    z = torch.cat((z_s, z_c), dim=1)

    # Generate images
    with torch.no_grad():
        fake_imgs = generator(z, condition_imgs)

    # Select the specified class index
    generated_img = fake_imgs[class_index].cpu().numpy()[0]  # Extract the first channel

    # Rescale the image from [-1, 1] to [0, 255]
    generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))

    # Display the generated image
    plt.imshow(generated_img, cmap='gray')
    plt.axis("off")
    plt.title(f"Generated Image for Class {chr(65 + labels[class_index].item())}")  # Convert label to letter
    plt.show()


In [None]:
def add_noise(images, noise_level=0.1):
    """
    Add Gaussian noise to images.

    Args:
        images: Tensor of images (B, C, H, W).
        noise_level: Standard deviation of the Gaussian noise.

    Returns:
        Noisy images.
    """
    noise = torch.randn_like(images) * noise_level
    return images + noise


In [None]:
d_losses = []
g_losses = []

In [None]:
save_directory = "/content/gdrive/My Drive"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The multiple values are a remnant for when we wanted to do grid search
latent_dim_values = [50,100,150]  # Dimension of the latent vector z
style_dim = 1  # Remnant of when we tried changing the dimension of the flattened image
img_size = 32
num_classes = 26
epochs = 1500
n_critic_values = [1,3,5]
zc_weight =  1
lr_values = [0.00001,0.0002, 0.002]
noise_levels = [0.005, 0.05, 0.2]
smooth_values = [0.05,0.1,0.15]
lambda_gp_values = [5,10,20]
batch_size = 1024

lr = 0.0002
latent_dim = 100
nl = 0.05
n_critic = 1
lambda_gp = 10
sv = 0.1

In [None]:
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
#for lambda_gp in lambda_gp_values:
for h in range(1):
  print(f"Training with lr={lr}, lambda_gp={lambda_gp}, nl={nl}, n_critic={n_critic}, latent_dim={latent_dim}, sv={sv}")
  z_dim = latent_dim + num_classes #Remnant of when we wanted to train with different latent_dims

  generator = Generator3(z_dim, style_dim, img_size).to(device)
  discriminator = Discriminator3().to(device)

    # Optimizers
  optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
  optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
  dlosses = []
  glosses = []
  d_test_losses = []
  g_test_losses = []

    # Training Loop
  for epoch in range(epochs):
        for letter in range(num_classes):
            dataloader = grouped_dataloaders[letter]
            for i, (real_imgs, condition_imgs, labels) in enumerate(dataloader):
                real_imgs = real_imgs.to(device)
                condition_imgs = condition_imgs.to(device)

                labels = labels.to(device)
                current_size = labels.size(0) # Makes sure that we dynamically adjust to anything requiring number of batches

                real_imgs = real_imgs.permute(0, 2, 3, 1).cuda() # Fixes it so we are in the correct order for pytorch

                condition_imgs = condition_imgs.permute(0, 2, 3, 1).cuda()


                # Train Discriminator
                z_s = torch.randn( current_size, latent_dim).to(device)  # Latent vector
                z_c = torch.zeros( current_size, num_classes).to(device)  # One-hot class vector
                z_c[torch.arange( current_size), labels] = 1 # Set the desired target to 1
                z = torch.cat((z_s, z_c * zc_weight), dim=1)  # Combine latent vector and one hot
                fake_imgs = generator(z, condition_imgs)  # Generate fake images
                noisy_condition_imgs = add_noise(condition_imgs, noise_level=nl) # Add a small amount of noise to condition imgs for discrimininator for leniency
                real_validity = discriminator(real_imgs,  noisy_condition_imgs)  # Include condition image
                real_validity = real_validity - sv # Label smoothing which is supposed to help the generator by artifically decreasing the confidence of the discriminator on the real imgs
                fake_validity = discriminator( fake_imgs,  noisy_condition_imgs)  # Include condition image
                gradient_penalty = compute_gradient_penalty(
                    discriminator,
                    real_samples=real_imgs,
                    fake_samples=fake_imgs,
                    condition_samples= noisy_condition_imgs,
                    device=device
                )

                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty # Wasserstein loss for discrim

                optimizer_D.zero_grad()
                d_loss.backward()
                optimizer_D.step()

                # Train Generator
                if i % n_critic == 0: # Makes the generator train less frequently
                    z_s = torch.randn (current_size, latent_dim).to(device) # Same as for discrim
                    z = torch.cat((z_s, z_c), dim=1)
                    fake_imgs = generator(z, condition_imgs)
                    fake_validity = discriminator(fake_imgs,  noisy_condition_imgs)
                    g_loss = -torch.mean(fake_validity)
                    optimizer_G.zero_grad()
                    g_loss.backward()
                    optimizer_G.step()
                # Display generated samples for first image in first batch every 200 epochs
                if epoch % 200 == 0 and i == 0:
                    displayGeneratedSample(generator, z_s, condition_imgs, labels, num_classes, class_index=0, device=device)


        # Print training progress
        print(f"Epoch [{epoch+1}/{epochs}], D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}, grad_penalty: {gradient_penalty}")
        # Test and save
        if epoch % 10 == 0:
          torch.save(generator.state_dict(), "/content/gdrive/My Drive/generator.pth")
          torch.save(discriminator.state_dict(), "/content/gdrive/My Drive/discriminator.pth")
          dlosses.append(d_loss.item())
          glosses.append(g_loss.item())
          # Test. Basically same as above just without updating the weights for either.
          generator.eval()
          discriminator.eval()
          with torch.no_grad():
            for i, (real_imgs, condition_imgs, labels) in enumerate(test_dataloader):
                real_imgs = real_imgs.to(device)
                condition_imgs = condition_imgs.to(device)

                labels = labels.to(device)
                current_size = labels.size(0)

                real_imgs = real_imgs.permute(0, 2, 3, 1).cuda()

                condition_imgs = condition_imgs.permute(0, 2, 3, 1).cuda()





                z_s = torch.randn( current_size, latent_dim).to(device)  # Latent vector
                z_c = torch.zeros( current_size, num_classes).to(device)  # One-hot class vector
                z_c[torch.arange( current_size), labels] = 1 # Set the desired target to 1
                z = torch.cat((z_s, z_c * zc_weight), dim=1)  # Combine latent vector and one hot
                fake_imgs = generator(z, condition_imgs)  # Generate fake images
                noisy_condition_imgs = add_noise(condition_imgs, noise_level=nl) # Add a small amount of noise to condition imgs for discrimininator for leniency
                real_validity = discriminator(real_imgs,  noisy_condition_imgs)  # Include condition image
                real_validity = real_validity - sv # Label smoothing which is supposed to help the generator by artifically decreasing the confidence of the discriminator on the real imgs
                fake_validity = discriminator(fake_imgs.detach(),  noisy_condition_imgs)  # Include condition image
                real_imgs.requires_grad = True # Required for gradient penalty
                with torch.enable_grad():
                  gradient_penalty = compute_gradient_penalty(
                    discriminator,
                    real_samples=real_imgs,
                    fake_samples=fake_imgs,
                    condition_samples= noisy_condition_imgs,  # Pass condition for gradient penalty
                    device=device
                )

                d_test_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
                g_test_loss = -torch.mean(fake_validity)
            d_test_losses.append(d_loss.item())
            g_test_losses.append(g_loss.item())
            generator.train() # Prep for training again
            discriminator.train()
            print(f"TESTING, D_loss_test: {d_test_loss.item():.4f}, G_loss_test: {g_test_loss.item():.4f}")
          csv_file = f"{save_directory}/gan_metrics_lr_{lr}_gp_{lambda_gp}_nl_{nl}_critic_{n_critic}_dim_{latent_dim}_sv_{sv}_train.csv"
  metrics = pd.DataFrame({
            "epoch": list(range(len(dlosses))),
            "d_loss": dlosses,
            "g_loss": glosses
        })
  metrics.to_csv(csv_file, index=False)
  print(f"Saved metrics to {csv_file}")

  csv_file = f"{save_directory}/gan_metrics_lr_{lr}_gp_{lambda_gp}_nl_{nl}_critic_{n_critic}_dim_{latent_dim}_sv_{sv}_test.csv"
  metrics = pd.DataFrame({
            "epoch": list(range(len(d_test_losses))),
            "d_loss": d_test_losses,
            "g_loss": g_test_losses
        })
  metrics.to_csv(csv_file, index=False)
  print(f"Saved metrics to {csv_file}")


  # Saving metrics
  csv_file = f"{save_directory}/gan_metrics_lr_{lr}_gp_{lambda_gp}_nl_{nl}_critic_{n_critic}_dim_{latent_dim}_sv_{sv}_train.csv"
  metrics = pd.DataFrame({
            "epoch": list(range(len(dlosses))),
            "d_loss": dlosses,
            "g_loss": glosses
        })
  metrics.to_csv(csv_file, index=False)
  print(f"Saved metrics to {csv_file}")

  csv_file = f"{save_directory}/gan_metrics_lr_{lr}_gp_{lambda_gp}_nl_{nl}_critic_{n_critic}_dim_{latent_dim}_sv_{sv}_test.csv"
  metrics = pd.DataFrame({
            "epoch": list(range(len(d_test_losses))),
            "d_loss": d_test_losses,
            "g_loss": g_test_losses
        })
  metrics.to_csv(csv_file, index=False)
  print(f"Saved metrics to {csv_file}")

  torch.save(generator.state_dict(), "/content/gdrive/My Drive/generator.pth")
  torch.save(discriminator.state_dict(), "/content/gdrive/My Drive/discriminator.pth")


In [None]:
z_dim = latent_dim + num_classes #Remnant of when we wanted to train with different latent_dims

generator = Generator3(z_dim, style_dim, img_size).to(device)
discriminator = Discriminator3().to(device)
generator.load_state_dict(torch.load("/content/gdrive/My Drive/generator.pth"))
discriminator.load_state_dict(torch.load("/content/gdrive/My Drive/discriminator.pth"))

In [None]:
display_data = DataLoader(val_data, batch_size=1, shuffle=False)

In [None]:
def saveGridOfGeneratedImages(generator, z_s, condition_image, num_classes, output_file, condition_folder,j, device='cuda'):
    """
    Generates a grid of images for all classes and saves it as a single image file.
    Also saves the condition image in a separate folder.
    """
    # Ensure condition_image has the correct dimensions
    if condition_image.ndim == 5:
        condition_image = condition_image.squeeze(0)  # Remove unnecessary batch dimension
    print("Condition Image Shape After Squeeze:", condition_image.shape)

    # Save the conditional image
    condition_img_rescaled = np.uint8(np.interp(condition_image.cpu().numpy()[0, 0], (-1, 1), (0, 255)))
    condition_file = os.path.join(condition_folder, f"condition_image{j//26}.png")
    plt.imsave(condition_file, condition_img_rescaled, cmap='gray')
    print(f"Conditional image saved to {condition_file}")

    # Create a grid for displaying images
    fig, axes = plt.subplots(2, 13, figsize=(20, 8))

    for class_index in range(num_classes):
        # Create the one-hot vector for the class
        z_c = torch.zeros(1, num_classes, device=device)
        z_c[0, class_index] = 1  # Set the desired class

        # Concatenate the style and class vectors
        z = torch.cat((z_s, z_c), dim=1)

        # Generate the image
        with torch.no_grad():
            generated_img = generator(z, condition_image).cpu().numpy()[0, 0]  # Extract the first batch and first channel

        # Rescale the image from [-1, 1] to [0, 255]
        generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))

        # Plot the image in the appropriate subplot
        row = class_index // 13
        col = class_index % 13
        ax = axes[row, col]
        ax.imshow(generated_img, cmap='gray')
        ax.axis("off")
        ax.set_title(chr(65 + class_index), fontsize=12)  # Display A-Z above images

    # Adjust layout and save the figure as a single image
    plt.tight_layout()
    plt.savefig(output_file, bbox_inches='tight')
    plt.close(fig)

    print(f"Grid image saved to {output_file}")



# Mount Google Drive
#drive.mount('/content/drive')

# Define output folders
output_folder = "/content/gdrive/My Drive/cGAN_final_fonts"
condition_folder = "/content/gdrive/My Drive/cGAN_final_conditions"
os.makedirs(output_folder, exist_ok=True)
os.makedirs(condition_folder, exist_ok=True)

# Example conditional image (e.g., the "A" image)
for j, (real_imgs, condition_imgs, labels) in enumerate(display_data):
            real_imgs = real_imgs.cuda()
            condition_imgs = condition_imgs.cuda()
            labels = labels.cuda()
            #batch_size = labels.size(0)
            real_imgs = real_imgs.permute(0, 2, 3, 1).cuda()
            condition_imgs = condition_imgs.permute(0, 2, 3, 1).cuda()
            #condition_imgs = condition_imgs.cuda()
            #plot_condition_and_real_images(condition_imgs,real_imgs, batch_index=j)


            # Loop to generate and save 10 grids
            #num_classes = 26  # A-Z
            # Generate a new random style vector for each grid
            #z_dim = 100  # Adjust based on your generator's style vector dimension
            z_s = torch.randn(1, 100, device='cuda')

            fig, axes = plt.subplots(2, 13, figsize=(20, 8))
            #axes = np.expand_dims(axes, axis=1)  # Ensure axes is 2D

            for class_index in range(num_classes):
                # Create the one-hot vector for the class
                z_c = torch.zeros(1, num_classes, device=device)
                z_c[0, class_index] = 1  # Set the desired class

                # Concatenate the style and class vectors
                z = torch.cat((z_s, z_c), dim=1)

                # Generate the image
                with torch.no_grad():
                    generated_img = generator(z, condition_imgs).cpu().numpy()[0, 0]  # Extract the first batch and first channel

                # Rescale the image from [-1, 1] to [0, 255]
                generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))

                # Plot the image in the appropriate subplot
                row = class_index // 13
                col = class_index % 13
                ax = axes[row, col]
                ax.imshow(generated_img, cmap='gray')
                ax.axis("off")
                ax.set_title(chr(65 + class_index), fontsize=12)  # Display A-Z above images

            # Adjust layout and save the figure as a single image
            plt.tight_layout()
            #plt.savefig(output_file, bbox_inches='tight')
            #plt.close(fig)

            # Define the output file paths
            saveGridOfGeneratedImages(generator=generator,z_s=z_s,condition_image=condition_imgs,num_classes=26,output_file=os.path.join(output_folder, f"grid_{j%26,j//26}.png"),condition_folder=condition_folder,j=j)
print("All 10 grids and conditional images saved successfully!")
