# Linear Transformation 


This notebook creates a linear transformation between two latent spaces. It takes as input two models, and uses then cvxpy to optimize a matrix to transform model 1 into model 2. The Anchor points are selected purely at random in the latent space.

Expected results are:
- Representations in PCA of the two latent spaces
- Representation of the transformed latent space of model 1
- Output of a image inserted into model 1 and decoded by model 2 through latent stitching

In [1]:
# Import relevant libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from sklearn.decomposition import PCA
import os
from gekko import GEKKO
import cvxpy as cp
import numpy as np
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration
# Add here all the models that we want to use
# Naming convention is 'path' + number of model + 'modelname' + number of model

config = {
    'path1': "Polished/models/checkpoints/Pretrained_AE/MNIST/LightningAutoencoder_3.ckpt",
    'modelname1': 'Model 1',
    'path2' : 'Polished/models/checkpoints/Mario_VAE/MNIST/MNIST_VAE_3_10.pth',
    'modelname2': 'Model 2',
}

os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication/')

from Polished.models.definitions.Pretrained_Autoencoder import LightningAutoencoder
from Polished.models.definitions.VariationalAE import VAE

model1 = LightningAutoencoder.load_from_checkpoint(checkpoint_path=config['path1'])

model2 = VAE(in_dim=784, dims=[256, 128, 64, 32], distribution_dim=16)
model2.load_state_dict(state_dict = torch.load(config['path2']))

<All keys matched successfully>

# Optimization Problem in the Linear Case 
Let $x^i,y^i \in \mathbb{R^n}$ for $i = 1,...,m$ and $A \in \mathbb{R}^{n \times n}$ we are looking for the optimal A, which solves the following optimization problem 
$$ min_A \sum_{i = 1}^n ||Ax^i - y^i||^2 $$
where we are using the euclidian norm when not otherwise stated.



## Load dataset

Here the dataset is loaded. For stability, the anchors are also defined here, so any change in that can be done here.

In [3]:

# Import Data 
from Polished.utils.dataloaders.DataLoaderMNIST_single import DataLoader_MNIST

# Transdormations
transformations = [transforms.ToTensor(), 
                                # Normalize between -1 and 1
                                transforms.Normalize((0.5,), (0.5,))
                                ]
# Load the data
data_loader = DataLoader_MNIST(128, transformations)

n_anchors = 500
images, _ = next(iter(data_loader.train_loader))
all_images = []
all_labels = []
for images, labels in data_loader.train_loader:
    all_images.append(images)
    all_labels.append(labels)
# Concatenate all the batches to form a single tensor for images and labels
all_images = torch.cat(all_images, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Distinct labels
labels = torch.unique(all_labels)
# Sample size per label
m_per_label = n_anchors // len(labels)
# Sample from each label
indices = []
for label in labels:
    indices_label = np.where(all_labels == label)[0]
    indices_label = np.random.choice(indices_label, m_per_label, replace=False)
    indices.extend(indices_label)

all_images_sample = all_images[indices]
all_labels_sample = all_labels[indices]

print(all_images_sample.shape)

# Get latent space 
z1 = model1.get_latent_space(all_images_sample)
all_images_sample_view_for_vae = all_images_sample.view(-1, 784)
z2 = model2.get_latent_space(all_images_sample_view_for_vae)

print(z1.T.shape)
print(z2.T.shape)
# Detach from GPU
z1_values = z1.T.detach().numpy()
z2_values = z2.T.detach().numpy()

torch.Size([500, 1, 28, 28])
torch.Size([500, 500])
torch.Size([32, 500])


#### Notes

For the optimization, it is assumed that the first model has a bigger latent space than the second one.

In [8]:
lamda = 0.01

latent_size_1, n_anchors = z1_values.shape
latent_size_2, _ = z2_values.shape

A = cp.Variable((latent_size_2, latent_size_1))

loss = cp.sum((cp.vstack([cp.sum(A @ z1_values[:,i] - z2_values[:,i])**2 for i in range(n_anchors)]))) + lamda * cp.norm(A, 'fro')
objective = cp.Minimize(loss)
problem = cp.Problem(objective)

# Solve the problem
problem.solve(verbose=True, solver=cp.SCS, enforce_dpp = True)

# Print results
print("Optimal value: ", problem.value)
print(A.value)


                                     CVXPY                                     
                                     v1.5.1                                    
(CVXPY) May 30 12:41:56 PM: Your problem has 16000 variables, 0 constraints, and 0 parameters.
(CVXPY) May 30 12:41:56 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) May 30 12:41:56 PM: (If you need to solve this problem multiple times, but with different data, consider using parameters.)
(CVXPY) May 30 12:41:56 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
(CVXPY) May 30 12:41:56 PM: Your problem is compiled with the CPP canonicalization backend.
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
(CVXPY) May 30 12:41:56 PM: Compiling problem (target solver=SCS).
(C

In [11]:
from Polished.utils.plotting_fun.plot_pca_latent import Plot_pca_latent

plotting_fun1 = Plot_pca_latent(data_loader)
plotting_fun1.pca_def(vector = z1_values)
plotting_fun1.plot_latent_transformed(model1, np.eye(latent_size_1), 'Model 1')

plotting_fun2 = Plot_pca_latent(data_loader)
plotting_fun2.pca_def(z2_values)
plotting_fun2.plot_latent_transformed(model2, np.eye(latent_size_2), 'Model 2')

plotting_fun2.plot_latent_transformed(model1, A.values, 'Model 1 Transformed')

TypeError: plot_pca_latent.pca_def() got an unexpected keyword argument 'vector'

In [None]:
from Polished.utils.metrics.distances import Distances

distances = Distances(model1, model2)
print(distances.distance_latents(A.value @ z1_values, z2_values))