In [1]:
import argparse
import itertools

from src.pipeline import pipeline
from src.training_utils import training_utils

In [2]:
EXP_HPARAMS = {
    "params": (
        {},
    ),
    "seeds": (420,),
}
config = training_utils.get_config("MNIST")
for hparams_overwrite_list, seed in itertools.product(EXP_HPARAMS["params"], EXP_HPARAMS["seeds"]):
    hparams_str = ""
    for k, v in hparams_overwrite_list.items():
        config[k] = v
        hparams_str += str(k) + "-" + str(v) + "_"
    config["model_architecture"] = "bigbigan"
    config["hparams_str"] = hparams_str.strip("_")
    config["seed"] = seed
    #set batch size
    config["bs"] = 1


In [3]:
pip = pipeline.BigBiGANInference.from_checkpoint(checkpoint_path="./data/MNIST/bigbigan/checkpoints/checkpoint_40.pth", data_path="./data", config=config)


In [4]:
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

In [5]:
import torch
from torchvision import datasets, transforms
# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the training data
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

In [48]:
def plot_images(images, labels, title):
    fig, axs = plt.subplots(1, len(images), figsize=(15, 3))
    fig.suptitle(title, fontsize=16)
    for i, (img, label) in enumerate(zip(images, labels)):
        axs[i].imshow(np.transpose(img.numpy(), (1, 2, 0)).squeeze(), cmap='gray')
        axs[i].set_title(f'Label: {label.item()}')
        axs[i].axis('off')
    plt.show()

def get_dataloader(dataset, batch_size=1, class_label=0):
    filtered_indices = [i for i, (_, label) in enumerate(dataset) if label == class_label]
    filtered_dataset = torch.utils.data.Subset(train_dataset, filtered_indices)
    dataloader = torch.utils.data.DataLoader(filtered_dataset, batch_size=batch_size, shuffle=True)
    #Print the number of images in the dataloader
    print(f"Number of images in dataloader: {len(dataloader)}")
    #PLOT IMAGES
    return dataloader
    
def plot_batch(dataloader, title):
    for batch_images, batch_labels in dataloader:
        # Plot the images from the batch
        plot_images(batch_images, batch_labels, 'Sample Images from class: ')
        break  # Only plot the first batch for demonstration purposes

def encode_batch(pip, dataloader):
    print("Dataloader size: ", len(dataloader))
    encoded_images = []
    for batch_images, batch_labels in dataloader:
        #batch_images = batch_images.unsqueeze(0)
        batch_images = batch_images.to(config.device)
        z_img = pip.encode(batch_images)
        encoded_images.append(z_img.detach().cpu().numpy())
    #encoded_images = np.array(encoded_images)
    encoded_images = np.concatenate(encoded_images, axis=0)

    return encoded_images

In [54]:
dt_zero = get_dataloader(train_dataset, batch_size=2, class_label=0)
dt_nine = get_dataloader(train_dataset, batch_size=2, class_label=9)
dt_seven = get_dataloader(train_dataset, batch_size=2, class_label=7)

Number of images in dataloader: 2962
Number of images in dataloader: 2975
Number of images in dataloader: 3133


In [55]:
z_batchzero = encode_batch(pip, dt_zero)
print(z_batchzero.shape)
z_batchnine = encode_batch(pip, dt_nine)
print(z_batchnine.shape)
z_batchseven = encode_batch(pip, dt_seven)
print(z_batchseven.shape)



Dataloader size:  2962
(5923, 100)
Dataloader size:  2975
(5949, 100)
Dataloader size:  3133
(6265, 100)


In [58]:
# Combine the batches into a single array for each class
z_combined_zero = np.concatenate(z_batchzero, axis=0).reshape(-1, 1)
z_combined_nine = np.concatenate(z_batchnine, axis=0).reshape(-1, 1)
z_combined_seven = np.concatenate(z_batchseven, axis=0).reshape(-1, 1)

# Apply PCA separately for each class
pca_zero = PCA(n_components=1)
pca_nine = PCA(n_components=1)
pca_seven = PCA(n_components=1)

latent_zero_pca = pca_zero.fit_transform(z_combined_zero)
latent_nine_pca = pca_nine.fit_transform(z_combined_nine)
latent_seven_pca = pca_seven.fit_transform(z_combined_seven)

# Plot the results with labels
plt.scatter(latent_zero_pca[:,0], latent_zero_pca[:,1], label='Class 0', alpha=0.55)
plt.scatter(latent_nine_pca[:,0], latent_nine_pca[:,1], label='Class 9', alpha=0.55)
plt.scatter(latent_seven_pca[:,0], latent_seven_pca[:,1], label='Class 7', alpha=0.55)
plt.title("PCA on Latent Space")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.legend()
plt.show()


ValueError: n_components=2 must be between 0 and min(n_samples, n_features)=1 with svd_solver='full'