[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlexchange/als_ml_tutorial/blob/main/3_4_ML_tutorial_CNN.ipynb)

# 1. Pretrain an encoder to convert images to latent vectors
## Set environment and load model

In [None]:
import torch

device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
#from torchsummary import summary
from model_utils import cnnAutoencoder

auto_cnn = cnnAutoencoder(input_shape=(3, 64, 64), latent_dim=1000)
auto_cnn.to(device)

print(auto_cnn)

## Data Preparation



In [None]:
import model_utils
from torch.utils.data import DataLoader, random_split
from torchvision import transforms


data_directory = '/lovelace/zhuowen/diffusers/als/discriminator_data/data/00000000'
input_size = 64
data_transform=transforms.Compose([transforms.Resize(input_size),
                              transforms.CenterCrop(input_size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ])
dataset = model_utils.myDataset(data_directory, transform=data_transform)

# Split the dataset into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
print(len(train_loader), len(val_loader))

## Training the encoder

In [None]:
from torch.nn import MSELoss
from torch.optim import AdamW
from tqdm import tqdm

# Training parameters
criterion = MSELoss()
optimizer = AdamW(auto_cnn.parameters(), lr=0.001)   #0.005
num_epochs = 5     #20

epoch_loss = []
epoch_val_loss = []
for epoch in range(num_epochs):
    auto_cnn.train()                        # Set the model to training mode
    for input_batch, labels in tqdm(train_loader):
        optimizer.zero_grad()               # Zero the gradients
        out_batch = auto_cnn(input_batch.to(device))          # Forward pass
        loss = criterion(out_batch, input_batch.to(device))   # Compute loss
        loss.backward()                     # Backpropagation
        optimizer.step()                    # Update weights
    # auto_cnn.eval()
    val_loss = 0.0
    with torch.no_grad():                   # Validation
        for val_batch, labels in tqdm(val_loader):
            out_val_batch = auto_cnn(val_batch.to(device))
            val_batch_loss = criterion(out_val_batch, val_batch.to(device))
            val_loss += val_batch_loss.item()
    avg_val_loss = val_loss / len(val_loader)
    epoch_loss.append(loss.detach().cpu().numpy())
    epoch_val_loss.append(avg_val_loss)
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.6f}, Validation Loss: {avg_val_loss:.6f}')

# Save model
torch.save(auto_cnn.state_dict(), 'cnn_autoencoder.pth')

## Loss plot

In [None]:
import matplotlib.pyplot as plt

plt.plot(epoch_loss)
plt.plot(epoch_val_loss)
plt.legend(["Loss", "Validation Loss"])
plt.grid()

## Evaluate the encoder and visualization

In [None]:
import numpy as np
import random

out_val = []
with torch.no_grad(): 
    for batch, labels in tqdm(val_loader):
        reconstructed_batch = auto_cnn(batch.to(device))
        out_val.append(reconstructed_batch.detach().cpu().numpy())
reconstructed_set = np.vstack(out_val)

In [None]:
# drop the last batch and labels
val_image_set = [v for v, l in list(val_loader)[:-1]]
val_set = np.vstack(val_image_set)

# # Randomly selectes 5 images from the validation set
indxs = random.sample(range(len(reconstructed_set)), 5)

fig, axs = plt.subplots(2, 5, figsize=(10,4))
for i in range(5):
    original_img = val_set[indxs[i]][0,:]
    reconstructed_img = reconstructed_set[indxs[i]][0,:]
    axs[0, i].imshow(np.squeeze(original_img))
    axs[0, i].axis('off')
    axs[1, i].imshow(np.squeeze(reconstructed_img))
    axs[1, i].axis('off')
plt.subplots_adjust(wspace=0.3, hspace=0.2)

# 2.Explore and visualize the latent space with UMAP

In [None]:
# Load the trained model if needed
import torch
auto_cnn = cnnAutoencoder(input_shape=(3, 64, 64), latent_dim=1000)
auto_cnn.load_state_dict(torch.load('cnn_autoencoder.pth'))
auto_cnn.eval()
auto_cnn.to(device)

f_vec = []
with torch.no_grad(): 
    for batch, labels in val_loader:
        batch_f_vec = auto_cnn.encoder(batch.to(device))
        f_vec.append(batch_f_vec.detach().cpu().numpy())
f_vec = np.vstack(f_vec)

f_vec.shape

## 2.1 Visualize the experimental dataset

In [None]:
import umap
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

def perform_umap(f_vec, n_neighbors=None, min_dist=None, random_state=None):
    # Perform U-MAP
    if n_neighbors is None and min_dist is None:
        umap_model = umap.UMAP(n_components=2, 
                               random_state=random_state)
    else:
        umap_model = umap.UMAP(n_components=2, 
                               n_neighbors=n_neighbors,
                               min_dist=min_dist,
                               random_state=random_state)
    umap_result = umap_model.fit_transform(f_vec)
    return umap_result

def plot_reduction(reduced_data, original_data, groups=None, category=None, zoom=0.35, savefig=False, outname=None):
    """ This function plots the dimensionally reduced data."""
    
    if 'pca' in category.lower():
        title = 'PCA Results'; xtitle = 'Principal Component 1'; ytitle = 'Principal Component 2'
    elif ('umap' in category.lower()) or ('u-map' in category.lower()):
        title = 'U-Map Results'; xtitle = 'U-Map Dimension 1'; ytitle = 'U-Map Dimension 2'
    else:
        title = 'Dimensionality Reduction Results'; xtitle = 'Reduced Dimension 1'; ytitle = 'Reduced Dimension 2'
    
    plt.figure(figsize=(10, 12))
    plt.rcParams.update({'font.size': 14})
    
    # Plot 2d reduction
    plt.subplot(2,1,1)
    if not groups:
        plt.scatter(reduced_data[:, 0], reduced_data[:, 1], alpha=0.5)
    else:
        # group colors
        cs = ['y','r', 'g', 'b']
        colors = [cs[i] for i in groups]
        plt.scatter(reduced_data[:, 0], reduced_data[:, 1], alpha=0.5, c=colors)
        plt.legend(['exp', 'generated_unreal', 'generated_real', 'train_set'])
        
    plt.title(title); plt.xlabel(xtitle); plt.ylabel(ytitle)
    plt.grid(True)

    # Plot 2d reduction with original images
    plt.subplot(2,1,2)
    scatter = plt.scatter(reduced_data[:, 0], reduced_data[:, 1], marker='o', s=30, c='b')
    
    for i in range(len(reduced_data)):
        if original_data[i].shape[0] > 1:
            orig_im = original_data[i][0,:]
        else:
            orig_im = original_data[i]
            
        imagebox = OffsetImage(orig_im, zoom=zoom)  # Adjust the zoom factor as needed
        ab = AnnotationBbox(imagebox, (reduced_data[i, 0], reduced_data[i, 1]), frameon=False)
        plt.gca().add_artist(ab)

    plt.title('Original Images Embedded')
    plt.xlabel(xtitle); plt.ylabel(ytitle); plt.title(title + ' with Original Images')
    plt.grid(True)

    plt.tight_layout()
    if savefig:
        plt.savefig(outname)
    plt.show()
    
    
    pass

In [None]:
umap_features = perform_umap(f_vec, n_neighbors=5, min_dist=0.5)

plot_reduction(umap_features, np.squeeze(val_set), category='umap', zoom=0.4)

## 2.2 Visualize the latent vectors classified as realistic

In [None]:
dataset_real = model_utils.myDataset('/lovelace/zhuowen/diffusers/als/40k_generated/als_2400_labeled/real', transform=data_transform)

# Split the dataset into training and validation
train_size = int(0.8 * len(dataset_real))
val_size = len(dataset_real) - train_size
train_dataset_real, val_dataset_real = random_split(dataset_real, [train_size, val_size])
val_loader_real = DataLoader(val_dataset_real, batch_size=16, shuffle=False)


f_vec = []
with torch.no_grad(): 
    for batch, labels in val_loader_real:
        batch_f_vec = auto_cnn.encoder(batch.to(device))
        f_vec.append(batch_f_vec.detach().cpu().numpy())
f_vec = np.vstack(f_vec)

f_vec.shape

umap_features = perform_umap(f_vec, n_neighbors=5, min_dist=0.5)
val_set_real = np.vstack(list(iter(val_loader_real)))
plot_reduction(umap_features, np.squeeze(val_set_real), category='umap', zoom=0.4)


## 2.3 Visualize the latent vectors classified as fake

In [None]:
dataset_fake = model_utils.myDataset('/lovelace/zhuowen/diffusers/als/40k_generated/als_2400_labeled/fake', transform=data_transform)

# Split the dataset into training and validation
train_size = int(0.8 * len(dataset_fake))
val_size = len(dataset_fake) - train_size
train_dataset_fake, val_dataset_fake = random_split(dataset_fake, [train_size, val_size])
val_loader_fake = DataLoader(val_dataset_fake, batch_size=16, shuffle=False)

f_vec = []
with torch.no_grad(): 
    for batch in val_loader_fake:
        batch_f_vec = auto_cnn.encoder(batch.to(device))
        f_vec.append(batch_f_vec.detach().cpu().numpy())
f_vec = np.vstack(f_vec)

f_vec.shape

umap_features = perform_umap(f_vec, n_neighbors=5, min_dist=0.5)
val_set_fake = np.vstack(list(iter(val_loader_fake)))
plot_reduction(umap_features, np.squeeze(val_set_fake), category='umap', zoom=0.4)

## 2.4 Visualize latent vectors of exp and generated images

In [None]:
import torchvision.datasets as dset

# sorted paths, alphabetic order
mix_datasets = dset.ImageFolder(root="/lovelace/zhuowen/diffusers/als/latent_vis", transform=data_transform)
val_loader_mix = DataLoader(mix_datasets, batch_size=16, shuffle=False)

f_vec = []
groups = []
val_set_mix = []
with torch.no_grad(): 
    for batch, gs in val_loader_mix:
        batch_f_vec = auto_cnn.encoder(batch.to(device))
        f_vec.append(batch_f_vec.detach().cpu().numpy())
        groups.extend(gs)
        val_set_mix.append(batch)
        
f_vec = np.vstack(f_vec)
val_set_mix = np.vstack(val_set_mix)

print(val_set_mix.shape)
print(f_vec.shape)
print(len(groups))


In [None]:
umap_features = perform_umap(f_vec, n_neighbors=5, min_dist=0.5)
plot_reduction(umap_features, np.squeeze(val_set_mix), groups=groups, category='umap', zoom=0.4, savefig=True, outname='latent_space_mix.png')