<a href="https://colab.research.google.com/github/dwarfy35/deep_learning2/blob/main/CGAN_mode_collapse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.colab import drive
from torch.autograd import grad
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset
from torch.autograd import grad
from collections import defaultdict

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [26]:
ngf = 32
nc = 1
ndf = 32

number of fonts is 14990

In [27]:
class FontStyleEncoder(nn.Module):
    def __init__(self, img_size, style_dim):
        super(FontStyleEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * (img_size // 4) ** 2, style_dim)  # Global style features
        )

    def forward(self, x):
        return self.encoder(x)  # Returns a feature vector representing the font style


In [28]:
class Generator3(nn.Module):
    def __init__(self, latent_dim, style_dim=128, img_size=32):
        """
        Args:
            latent_dim: Dimension of the latent vector (z).
            style_dim: Dimension of the font style vector extracted by FontStyleEncoder.
            img_size: Spatial size of the input image (assumed square).
        """
        super(Generator3, self).__init__()

        self.init_size = img_size // 16  # Initial spatial size after FC layer
        self.fc = nn.Linear(latent_dim + style_dim, 512 * self.init_size * self.init_size)

        # Deconvolutional blocks
        self.deconv_blocks = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

        # FontStyleEncoder to extract font style features
        self.font_style_encoder = FontStyleEncoder(img_size, style_dim)

    def forward(self, z, condition):
        """
        Forward pass for the generator.

        Args:
            z: Latent vector (batch_size, latent_dim).
            condition: Conditional image (batch_size, 1, img_size, img_size).

        Returns:
            Generated image (batch_size, 1, img_size, img_size).
        """
        # Extract font style features from the condition image
        font_style_features = self.font_style_encoder(condition)

        # Concatenate latent vector with font style features
        print(f"Latent vector shape: {z.shape}")
        print(f"Condition features shape: {font_style_features.shape}")
        z = torch.cat((z, font_style_features), dim=1)
        print(f"Combined latent vector shape: {z.shape}")
        # Fully connected layer
        out = self.fc(z)
        out = out.view(out.size(0), 512, self.init_size, self.init_size)

        # Deconvolutional blocks
        img = self.deconv_blocks(out)
        return img


class Discriminator3(nn.Module):
    def __init__(self, nc=1, ndf=32):

        super(Discriminator3, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),  # Input channels = 1 (grayscale)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            #nn.Sigmoid(),
        )

    def forward(self, img):
        validity = self.main(img)
        return validity



In [29]:
class NPZDataset(Dataset):
    def __init__(self, npz_file, font_range=None, transform=None, filter_label=None, num_samples=None):
        """
        Initialize the dataset with optional filtering by font range.

        Args:
            npz_file (str): Path to the .npz file containing images and labels.
            font_range (tuple, optional): Range of fonts to include (start, end). If None, include all fonts.
            transform (callable, optional): Transformations to apply to the images.
            filter_label (int, optional): Filter dataset by a specific label.
            num_samples (int, optional): Limit the number of samples in the dataset.
        """
        # Load the data from the .npz file
        data = np.load(npz_file)
        self.images = data['images']
        self.labels = data['labels']
        self.transform = transform

        font_size = 26  # Number of letters per font (A-Z)

        # Limit the dataset to the specified range of fonts
        if font_range is not None:
            start_font, end_font = font_range
            start_idx = start_font * font_size
            end_idx = end_font * font_size  # Exclusive of the last font
            self.images = self.images[start_idx:end_idx]
            self.labels = self.labels[start_idx:end_idx]

        # Filter by label if specified
        if filter_label is not None:
            # Find indices of the desired label
            label_indices = np.where(self.labels == filter_label)[0]

            # If num_samples is specified, limit the number of samples
            if num_samples is not None:
                label_indices = label_indices[:num_samples]

            # Filter images and labels
            self.images = self.images[label_indices]
            self.labels = self.labels[label_indices]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        """
        Retrieve an item from the dataset.

        Args:
            idx (int): Index of the item.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Image, condition image, and label.
        """
        # Get the image and label for the given index
        image = self.images[idx]
        label = self.labels[idx]

        # Determine the font ID based on dataset ordering
        font_id = idx // 26  # Calculate the font ID based on position in dataset

        # Get the "A" from the same font as the condition
        condition_index = font_id * 26  # First letter (A) of the current font
        condition_image = self.images[condition_index]

        # Reshape the images
        image = image[np.newaxis, ...]  # Add channel dimension
        condition_image = condition_image[np.newaxis, ...]

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
            condition_image = self.transform(condition_image)

        # Convert to PyTorch tensors
        image = torch.tensor(image, dtype=torch.float32)
        condition_image = torch.tensor(condition_image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return image, condition_image, label


In [30]:
def number_to_alphabet(num):
    if 0 <= num <= 25:
        return chr(num + 65)  # 65 is the ASCII code for 'A'
    else:
        raise ValueError("Number must be between 0 and 25 inclusive.")


In [31]:
def displayGeneratedImage(class_index, generator, z_s, num_classes, device='cuda'):
    """
    Generates and displays an image for a given class using the generator.

    Args:
        class_index (int): Index of the character class to generate (0 to num_classes - 1).
        generator (nn.Module): Pre-trained generator model.
        z_dim (int): Dimension of the style vector.
        num_classes (int): Number of character classes.
        device (str): Device for computation ('cuda' or 'cpu').
    """
    # Ensure the class index is valid
    if not (0 <= class_index < num_classes):
        raise ValueError(f"Invalid class_index: {class_index}. Must be in range [0, {num_classes - 1}].")

    # Create the one-hot vector for the class
    z_c = torch.zeros(1, num_classes, device=device)
    z_c[0, class_index] = 1  # Set the desired class

    # Create the random style vector
    #z_s = torch.randn(1, z_dim, device=device)

    # Concatenate the style and class vectors
    z = torch.cat((z_s, z_c), dim=1)

    # Generate the image
    with torch.no_grad():
        generated_img = generator(z).cpu().numpy()[0, 0]  # Extract the first batch and first channel

    # Rescale the image from [-1, 1] to [0, 255]
    #generated_img = np.rot90(generated_img, k=-1)
    generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))

    # Display the image
    plt.figure(figsize=(5, 5))
    plt.imshow(generated_img, cmap='gray')
    plt.axis("off")
    plt.title(f"Generated Image for Class {number_to_alphabet(class_index)}")
    plt.show()


In [32]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

#dataset = NPZDataset(npz_file, transform=transform)

In [33]:
import torch
from torch.autograd import grad

def compute_gradient_penalty(discriminator, real_samples, fake_samples, device="cuda"):
    """
    Compute the gradient penalty for WGAN-GP in a conditional GAN setup where
    the discriminator takes the conditional input separately.

    Args:
        discriminator: The discriminator model.
        real_samples: Batch of real images.
        fake_samples: Batch of fake images generated by the generator.
        condition_samples: Batch of conditional images (e.g., letter A from the same font).
        device: The device to run the computations on.

    Returns:
        Gradient penalty (scalar).
    """
    # Interpolate between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolated_samples = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)

    # Pass interpolated samples and condition samples separately to the discriminator
    d_interpolates = discriminator(interpolated_samples)  # Output shape: [batch_size, 1]

    # Flatten the discriminator output
    d_interpolates = d_interpolates.view(-1)  # Shape: [batch_size]

    # Create grad_outputs for the gradient computation
    fake = torch.ones(d_interpolates.size(), device=device)

    # Compute gradients of the discriminator's output w.r.t. interpolated inputs
    gradients = grad(
        outputs=d_interpolates,
        inputs=interpolated_samples,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    # Compute gradient penalty
    gradients = gradients.view(gradients.size(0), -1)  # Flatten gradients
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty


In [34]:
def split_dataset_no_shuffle(full_data, test_size=0.1, val_size=0.1):
    """
    Split the dataset into train, validation, and test sets without shuffling.

    Args:
        full_data (Dataset): Full dataset to be split.
        test_size (float): Proportion of the dataset to include in the test split.
        val_size (float): Proportion of the dataset to include in the validation split.

    Returns:
        train_data (Dataset): Training split.
        val_data (Dataset): Validation split.
        test_data (Dataset): Test split.
    """
    total_size = len(full_data)
    test_size_count = int(total_size * test_size)
    val_size_count = int(total_size * val_size)
    train_size_count = total_size - test_size_count - val_size_count

    # Perform the split without shuffling (sequential split)
    train_data = torch.utils.data.Subset(full_data, range(0, train_size_count))
    val_data = torch.utils.data.Subset(full_data, range(train_size_count, train_size_count + val_size_count))
    test_data = torch.utils.data.Subset(full_data, range(train_size_count + val_size_count, total_size))

    return train_data, val_data, test_data


# Example usage with your dataset (assuming `fullData` is already loaded as a PyTorch Dataset)


In [35]:
# Hyperparameters
z_dim = 100
num_classes = 26  # For uppercase alphabets
img_size = 32  # Assuming 32x32 images
batch_size = 1024
print(batch_size)
lr = 0.0002
lambda_gp = 10  # Gradient penalty weight
n_critic = 5  # Number of discriminator updates per generator update
epochs = 100
npz_file = "/content/gdrive/My Drive/character_font.npz"

fullData = NPZDataset(npz_file, transform=transform)
#train_data, val_data, test_data = split_dataset_no_shuffle(full_data=fullData, test_size=0.1, val_size=0.1)
train_data = NPZDataset(npz_file, font_range=(0,100),transform=transform)

# Print sizes to verify
#print(f"Train Data: {len(train_data)} samples")
#print(f"Validation Data: {len(val_data)} samples")
#print(f"Test Data: {len(test_data)} samples")
#print(len(fullData))

#testing_data = NPZDataset(npz_file, transform=transform, num_fonts=1)

1024


In [36]:
#testloader = DataLoader(test_data, batch_size=1, shuffle=False)

In [37]:
from collections import defaultdict


In [38]:
from collections import defaultdict
from torch.utils.data import DataLoader

# Step 1: Group the data by label
grouped_data = defaultdict(list)
for img, condition_img, label in train_data:  # Unpack condition_img as well
    grouped_data[label.item()].append((img, condition_img, label))

# Step 2: Prepare datasets for each label
grouped_datasets = {
    label: [(img, cond_img, lbl) for img, cond_img, lbl in imgs]
    for label, imgs in grouped_data.items()
}

# Step 3: Create dataloaders for each group
grouped_dataloaders = {
    letter: DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for letter, dataset in grouped_datasets.items()
}


  image = torch.tensor(image, dtype=torch.float32)
  condition_image = torch.tensor(condition_image, dtype=torch.float32)


In [39]:
def displayGeneratorOutput(generator, z_s, condition_image, class_index, num_classes, device='cuda'):
    """
    Displays the output of the generator for a single class.

    Args:
        generator (nn.Module): Pre-trained generator model.
        z_s (torch.Tensor): Style vector.
        condition_image (torch.Tensor): Conditional image (e.g., "A").
        class_index (int): Index of the class to generate (0-25 for A-Z).
        num_classes (int): Total number of character classes.
        device (str): Device for computation ('cuda' or 'cpu').
    """
    # Ensure the inputs are on the correct device
    z_s = z_s.to(device)
    condition_image = condition_image.to(device)

    # Create the one-hot vector for the class
    z_c = torch.zeros(batch_size, num_classes, device=device)
    z_c[0, class_index] = 1  # Set the desired class

    # Concatenate the style and class vectors
    z = torch.cat((z_s, z_c), dim=1)

    # Generate the image
    with torch.no_grad():
        generated_img = generator(z, condition_image.unsqueeze(0)).cpu().numpy()[0, 0]

    # Rescale the image from [-1, 1] to [0, 255]
    generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))
    condition_img = np.uint8(np.interp(condition_img, (-1, 1), (0, 255)))
    # Display the generated image
    plt.imshow(generated_img, cmap='gray')
    plt.axis("off")
    plt.title(f"Generated Image for Class {chr(65 + class_index)}")  # Convert index to letter
    plt.show()
    plt.imshow(condition_img, cmap='gray')
    plt.axis("off")
    plt.title(f"Condition")  # Convert index to letter
    plt.show()


In [40]:
import matplotlib.pyplot as plt
import numpy as np


def displayGeneratedSample(generator, z_s, condition_imgs, labels, num_classes, class_index=0, device='cuda'):
    """
    Displays one generated image from the generator.

    Args:
        generator (nn.Module): Pre-trained generator model.
        z_s (torch.Tensor): Style vector.
        condition_imgs (torch.Tensor): Conditional images (batch).
        labels (torch.Tensor): Labels for the batch.
        num_classes (int): Number of character classes.
        class_index (int): Index of the sample in the batch to display.
        device (str): Device for computation ('cuda' or 'cpu').
    """
    # Ensure the inputs are on the correct device
    z_s = z_s.to(device)
    condition_imgs = condition_imgs.to(device)
    labels = labels.to(device)

    # Create one-hot encoding for the batch
    batch_size = labels.size(0)
    z_c = torch.zeros(batch_size, num_classes, device=device)
    z_c[torch.arange(batch_size), labels] = 1

    # Concatenate style and class vectors
    z = torch.cat((z_s, z_c), dim=1)

    # Generate images
    with torch.no_grad():
        fake_imgs = generator(z, condition_imgs)

    # Select the specified class index
    generated_img = fake_imgs[class_index].cpu().numpy()[0]  # Extract the first channel

    # Rescale the image from [-1, 1] to [0, 255]
    generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))

    # Display the generated image
    plt.imshow(generated_img, cmap='gray')
    plt.axis("off")
    plt.title(f"Generated Image for Class {chr(65 + labels[class_index].item())}")  # Convert label to letter
    plt.show()


In [43]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
latent_dim = 100  # Dimension of the latent vector z
style_dim = 64  # Dimension of the font style vector
img_size = 32  # Image size
num_classes = 26  # Number of character classes (A-Z)
z_dim = 126
epochs = 1000
n_crtic = 5
zc_weight =  1
print(batch_size)

generator = Generator3(z_dim, style_dim, img_size).to(device)
discriminator = Discriminator3(nc=1, ndf=32).to(device)
encoder = FontStyleEncoder(img_size, style_dim).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

# Training Loop
for epoch in range(epochs):
    for letter in range(num_classes):
        dataloader = grouped_dataloaders[letter]
        for i, (real_imgs, condition_imgs, labels) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            condition_imgs = condition_imgs.to(device)
            print(condition_imgs.shape)

            labels = labels.to(device)
            batch_size = labels.size(0)

            real_imgs = real_imgs.permute(0, 2, 3, 1).cuda()

            condition_imgs = condition_imgs.permute(0, 2, 3, 1).cuda()



            # Extract font style features from conditional images
            #font_style_features = encoder(condition_imgs) this is now done in the generator

            # Train Discriminator
            z_s = torch.randn(batch_size, latent_dim).to(device)  # Latent vector
            z_c = torch.zeros(batch_size, num_classes).to(device)  # One-hot class vector
            z_c[torch.arange(batch_size), labels] = 1
            z = torch.cat((z_s, z_c * zc_weight), dim=1)  # Combine latent vector and one hot
            fake_imgs = generator(z, condition_imgs)  # Generate fake images


            real_validity = discriminator(real_imgs)
            fake_validity = discriminator(fake_imgs.detach())
            gradient_penalty = compute_gradient_penalty(
                discriminator,
                real_samples=real_imgs,
                fake_samples=fake_imgs,
                device=device
            )
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            if i % n_critic == 0:
                z_s = torch.randn(batch_size, latent_dim).to(device)
                z = torch.cat((z_s, z_c), dim=1)
                fake_imgs = generator(z, condition_imgs)
                fake_validity = discriminator(fake_imgs)
                g_loss = -torch.mean(fake_validity)

                optimizer_G.zero_grad()
                g_loss.backward()
                optimizer_G.step()

            #Display generated sample periodically
            if epoch % 10 == 0 and i == 0:
                displayGeneratedSample(generator, z_s, condition_imgs, labels, num_classes, class_index=0, device=device)

    # Print training progress
    print(f"Epoch [{epoch+1}/{epochs}], D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")

    # Save models periodically
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), "/content/gdrive/My Drive/generator.pth")
        torch.save(discriminator.state_dict(), "/content/gdrive/My Drive/discriminator.pth")


1024
torch.Size([100, 32, 1, 32])
torch.Size([100, 1, 32, 32])
condition shape: torch.Size([100, 1, 32, 32])
z shape: torch.Size([100, 126])
condition shape: torch.Size([102400, 1])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 100 but got size 102400 for tensor number 1 in the list.

In [None]:
display_data = DataLoader(testing_data, batch_size=1, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch




In [None]:
z_s = torch.randn(1, z_dim, device='cuda')
z_c = torch.zeros(1, num_classes, device='cuda')
z_c[0, 0] = 1  # Class index 0 (A)
condition_image = torch.randn(1, 1, 32, 32, device='cuda')  # Random condition
z = torch.cat((z_s, z_c), dim=1)
with torch.no_grad():
    output = generator(z, condition_image)
    print("Generated image range:", output.min().item(), output.max().item())
    plt.imshow(output.cpu().numpy()[0, 0], cmap='gray')
    plt.show()

In [None]:
def plot_condition_and_real_images(condition_imgs, real_imgs, batch_index):
    """
    Plot all condition images alongside their corresponding real images in the batch.

    Args:
        condition_imgs (torch.Tensor): Batch of condition images (B, C, H, W).
        real_imgs (torch.Tensor): Batch of real images (B, C, H, W).
        batch_index (int): Index of the current batch.
    """
    condition_imgs = condition_imgs.cpu().numpy()  # Convert to NumPy for plotting
    real_imgs = real_imgs.cpu().numpy()  # Convert to NumPy for plotting

    # Determine batch size
    batch_size = condition_imgs.shape[0]

    # Create a figure for all condition-real image pairs
    fig, axes = plt.subplots(2, batch_size, figsize=(15, 5))
    fig.suptitle(f"Condition and Real Images for Batch {batch_index}", fontsize=16)

    # Handle the case where batch_size == 1
    if batch_size == 1:
        axes = np.expand_dims(axes, axis=1)  # Ensure axes is 2D

    for i in range(batch_size):
        # Rescale the condition image from [-1, 1] to [0, 1]
        condition_img = (condition_imgs[i, 0] + 1) / 2  # Assuming grayscale
        real_img = (real_imgs[i, 0] + 1) / 2  # Rescale the real image as well

        # Plot condition image
        axes[0, i].imshow(condition_img, cmap='gray')
        axes[0, i].set_title(f"Condition Image {i}", fontsize=10)
        axes[0, i].axis('off')

        # Plot real image
        axes[1, i].imshow(real_img, cmap='gray')
        axes[1, i].set_title(f"Real Image {i}", fontsize=10)
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch


def saveGridOfGeneratedImages(generator, z_s, condition_image, num_classes, output_file, condition_folder,j, device='cuda'):
    """
    Generates a grid of images for all classes and saves it as a single image file.
    Also saves the condition image in a separate folder.
    """
    # Ensure condition_image has the correct dimensions
    if condition_image.ndim == 5:
        condition_image = condition_image.squeeze(0)  # Remove unnecessary batch dimension
    print("Condition Image Shape After Squeeze:", condition_image.shape)

    # Save the conditional image
    condition_img_rescaled = np.uint8(np.interp(condition_image.cpu().numpy()[0, 0], (-1, 1), (0, 255)))
    condition_file = os.path.join(condition_folder, f"condition_image{j//26}.png")
    plt.imsave(condition_file, condition_img_rescaled, cmap='gray')
    print(f"Conditional image saved to {condition_file}")

    # Create a grid for displaying images
    fig, axes = plt.subplots(2, 13, figsize=(20, 8))

    for class_index in range(num_classes):
        # Create the one-hot vector for the class
        z_c = torch.zeros(1, num_classes, device=device)
        z_c[0, class_index] = 1  # Set the desired class

        # Concatenate the style and class vectors
        z = torch.cat((z_s, z_c), dim=1)

        # Generate the image
        with torch.no_grad():
            generated_img = generator(z, condition_image).cpu().numpy()[0, 0]  # Extract the first batch and first channel

        # Rescale the image from [-1, 1] to [0, 255]
        generated_img = np.uint8(np.interp(generated_img, (-1, 1), (0, 255)))

        # Plot the image in the appropriate subplot
        row = class_index // 13
        col = class_index % 13
        ax = axes[row, col]
        ax.imshow(generated_img, cmap='gray')
        ax.axis("off")
        ax.set_title(chr(65 + class_index), fontsize=12)  # Display A-Z above images

    # Adjust layout and save the figure as a single image
    plt.tight_layout()
    plt.savefig(output_file, bbox_inches='tight')
    plt.close(fig)

    print(f"Grid image saved to {output_file}")



# Mount Google Drive
#drive.mount('/content/drive')

# Define output folders
output_folder = "/content/gdrive/My Drive/cGAN_overfit_working"
condition_folder = "/content/gdrive/My Drive/cGAN_conditions_working"
os.makedirs(output_folder, exist_ok=True)
os.makedirs(condition_folder, exist_ok=True)

# Example conditional image (e.g., the "A" image)
for j, (real_imgs, condition_imgs, labels) in enumerate(display_data):
            real_imgs = real_imgs.cuda()
            condition_imgs = condition_imgs.cuda()
            labels = labels.cuda()
            #batch_size = labels.size(0)
            real_imgs = real_imgs.permute(0, 2, 3, 1).cuda()
            condition_imgs = condition_imgs.permute(0, 2, 3, 1).cuda()
            #condition_imgs = condition_imgs.cuda()
            plot_condition_and_real_images(condition_imgs,real_imgs, batch_index=j)


            # Loop to generate and save 10 grids
            #num_classes = 26  # A-Z
            # Generate a new random style vector for each grid
            #z_dim = 100  # Adjust based on your generator's style vector dimension
            z_s = torch.randn(1, 100, device='cuda')

            # Define the output file paths
            saveGridOfGeneratedImages(generator=generator,z_s=z_s,condition_image=condition_imgs,num_classes=26,output_file=os.path.join(output_folder, f"grid_{j%26,j//26}.png"),condition_folder=condition_folder,j=j)
print("All 10 grids and conditional images saved successfully!")
