# Visualisation from latent

In [1]:
import numpy as np
import pandas as pd
import torch

In [2]:
import sys  
sys.path.insert(0, '/Users/andrzej/Personal/Projects/disentanglement-multi-task')
from models.ae import AEModel
from common.data_loader import get_dataloader

In [None]:
okragOli = torch.from_numpy(np.load('../inverse_circle_umap_n_neighbors_50_min_dist_0.1.npy'))
means = okragOli.mean(dim=1, keepdim=True)
stds = okragOli.std(dim=1, keepdim=True)
normalized_data = (okragOli - means) / stds

In [None]:
from architectures import encoders, decoders

encoder_name = "SimpleConv64"
decoder_name = "SimpleConv64"

encoder = getattr(encoders, encoder_name)
decoder = getattr(decoders, decoder_name)

# model and optimizer
model = AEModel(encoder(8, 3, 64), decoder(8, 3, 64)).to(torch.device('cpu'))

In [None]:
checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))

In [None]:
model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

In [None]:
results = model.decoder(normalized_data)

In [None]:
import torchvision.utils
def visualize_recon(recon_image):
        recon_image = torchvision.utils.make_grid(recon_image)

        samples = recon_image

        torchvision.utils.save_image(samples, "test.png")

In [None]:
visualize_recon(results)

# Visualisation traverse

In [None]:
checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

In [None]:
train_loader = get_dataloader('shapes3d_multitask', '/Users/andrzej/Personal/Projects/data/test_dsets', 1,
                              123, 64, split="train", test_prec=0.125, val_prec=0.125,
                              num_workers=1, pin_memory=True,n_task_headers=1)

In [None]:
from common.utils import grid2gif, get_data_for_visualization, prepare_data_for_visualization
import os

z_dim = 8
l_dim = 0
traverse_z = True
traverse_c = False
num_labels = 0
image_size = 64
num_channels = train_loader.dataset.num_channels()

def set_z(z, latent_id, val):
    z[:, latent_id] += val

def encode_deterministic(**kwargs):
    images = kwargs['images']
    if len(images.size()) == 3:
        images = images.unsqueeze(0)
    z = model.encode(images)
    means = z.mean(dim=1, keepdim=True)
    stds = z.std(dim=1, keepdim=True)
    normalized_data = (z - means) / stds
    return normalized_data

def decode_deterministic(**kwargs):
    latent = kwargs['latent']
    if len(latent.size()) == 1:
        latent = latent.unsqueeze(0)
    return model.decode(latent)

def visualize_traverse(limit: tuple, spacing, data=None, test=False, data_to_visualisation=None):
    interp_values = torch.arange(limit[0], limit[1]+spacing, spacing)
    num_cols = interp_values.size(0)

    sample_images_dict, sample_labels_dict = prepare_data_for_visualization(data_to_visualisation)
    encodings = dict()
        
    for key in sample_images_dict.keys():
        encodings[key] = encode_deterministic(images=sample_images_dict[key], labels=sample_labels_dict[key])

    gifs = []
    for key in encodings:
        latent_orig = encodings[key]
        label_orig = sample_labels_dict[key]
        orig_image = sample_images_dict[key]
        print('latent_orig: {}, label_orig: {}'.format(latent_orig, label_orig))
        samples = []

        # encode original on the first row
        sample = decode_deterministic(latent=latent_orig.detach(), labels=label_orig)
        
        for _ in interp_values:
            samples.append(orig_image.unsqueeze(0))
            
        for _ in interp_values:
            samples.append(sample)
            
        for zid in range(z_dim):
            for val in interp_values:
                latent = latent_orig.clone()
                latent[:, zid] += val
                set_z(latent, zid, val)
                sample = decode_deterministic(latent=latent, labels=label_orig)

                samples.append(sample)
                gifs.append(sample)
                    
        samples = torch.cat(samples, dim=0).cpu()
        samples = torchvision.utils.make_grid(samples, nrow=num_cols)
        
        file_name = os.path.join(".", '{}_{}.{}'.format("traverse", key, "png"))
        torchvision.utils.save_image(samples, file_name)
        
    total_rows = num_labels * l_dim + \
                 z_dim * int(traverse_z) + \
                 num_labels * int(traverse_c)
    gifs = torch.cat(gifs)
    gifs = gifs.view(len(encodings), total_rows, num_cols,
                     num_channels, image_size, image_size).transpose(1, 2)
    for i, key in enumerate(encodings.keys()):
        for j, val in enumerate(interp_values):
            file_name = \
                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png'))
            torchvision.utils.save_image(tensor=gifs[i][j].cpu(),
                                         fp=file_name,
                                         nrow=total_rows, pad_value=1)
            
        file_name = os.path.join('.', '{}_{}.{}'.format('traverse', key, 'gif'))

        grid2gif(str(os.path.join('.', '{}_{}*.{}').format('tmp', key, 'png')),
                 file_name, delay=10)

        # Delete temp image files
        for j, val in enumerate(interp_values):
            os.remove(
                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png')))
    return samples

In [None]:
data = next(iter(train_loader))

In [None]:
checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

min_ = -1
max_ = 1
spacing_ = 0.1
samples = visualize_traverse(limit=(min_,max_), spacing=spacing_, data_to_visualisation=data)

import matplotlib.pyplot as plt

npimg = samples.detach().numpy()
print(npimg.shape)

plt.figure(figsize=(50,40))
plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)
plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)

plt.savefig("travers-multi-10.png", bbox_inches='tight')

In [None]:
checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-single-5/last', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

min_ = -1
max_ = 1
spacing_ = 0.1
samples = visualize_traverse(limit=(min_,max_), spacing=spacing_, data_to_visualisation=data)

import matplotlib.pyplot as plt
plt.figure(figsize=(50,40))

npimg = samples.detach().numpy()
print(npimg.shape)
plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)
plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)

plt.savefig("travers-single-5.png", bbox_inches='tight')

In [None]:
checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-random/last', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

min_ = -1
max_ = 1
spacing_ = 0.1
samples = visualize_traverse(limit=(min_,max_), spacing=spacing_, data_to_visualisation=data)

import matplotlib.pyplot as plt
plt.figure(figsize=(50,40))

npimg = samples.detach().numpy()
print(npimg.shape)
plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)
plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)

plt.savefig("travers-random.png", bbox_inches='tight')

# Try to plot it with labels...

In [None]:
import matplotlib.pyplot as plt

npimg = samples.detach().numpy()
print(npimg.shape)
plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)
plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)

plt.savefig("test.png", bbox_inches='tight')

In [None]:
import torchvision.utils
def visualize_recon(input_image, recon_image, test=False):
    input_image = torchvision.utils.make_grid(input_image)
    recon_image = torchvision.utils.make_grid(recon_image)

    white_line = torch.ones((3, input_image.size(1), 10)).to('cpu')

    samples = torch.cat([input_image, white_line, recon_image], dim=2)

    torchvision.utils.save_image(samples, 'reconstruction_random.png')

In [4]:
import torchvision.utils
def visualize_recon(recon_image):
        recon_image = torchvision.utils.make_grid(recon_image)

        samples = recon_image

        torchvision.utils.save_image(samples, "input.png")

In [3]:
train_loader = get_dataloader('shapes3d_multitask', '/Users/andrzej/Personal/Projects/data/test_dsets', 64,
                              123, 64, split="train", test_prec=0.125, val_prec=0.125,
                              num_workers=1, pin_memory=True,n_task_headers=1)

k = iter(train_loader)
batch = k.next()[0]


In [5]:
visualize_recon(batch)

In [None]:
checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

In [None]:
permute=[0,2,1]

In [None]:
batch.permute(0,2,3,1)

In [None]:
batch[:, permute]