# Learning latent representations of geometry 

A deep generative model (a variational autoencoder) is used to to infer scene depth structure, to identify the brain regions that carry neural signals of geometry representations.


## environment setup

### conda  
- conda create -n deeplearning python=3.9
- conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
- conda install numpy scipy matplotlib pandas
- conda install -c conda-forge visdom
- conda install -c anaconda scikit-image
- conda install -c conda-forge vit-pytorch

- https://github.com/lucidrains/vit-pytorch


### start visdom server
- python -m visdom.server -port 4000

## Visual Transformer Autoencoders

### import and load dataset

In [None]:
import torch
import torchvision
import torch
from torch import nn
from torchvision import transforms, models
from torch.autograd import Variable
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
import PIL.Image
import glob
import os
import numpy as np
from datetime import datetime
from matplotlib import pyplot as plt
import model_vitae as ViTAE
from visdom import Visdom
import utils
from scipy.io import savemat

# set random seeds
torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)

# defining the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss_type = "monotonic_annealing" # monotonic_annealing, cyclical_annealing, betavae_basic, betavae_capacity_control

# hyperparameters
EPOCH = 6
HIDDEN_DIM = 8 #8, 16, 32, 96
BETA = 1.0
BATCH_SIZE = 96
CAPACITY_MAX = 2000.0 #2000.0
ORTH_FACTOR = 500.0
LATENT_UNIT_RANGE = 2.5

# Plots
viz = Visdom(port=4000)
loss_plt = utils.VisdomPlotter(viz=viz, env_name='main')

# dataset
root_dir = "/media/statespace/Spatial/sptialworkspace/spatialfMRI/simulation/carla_ws/recording/output_random_pose"

run_sample_size = 8000
run_num = 7
town_num = 8
sample_step = 1
data_sampler = utils.dataloader_rgb_depth(train_or_test="train", root_dir=root_dir, \
    start_index=10, run_num=run_num, town_num=town_num, run_sample_size=run_sample_size, sample_step=sample_step, \
    batch_size=BATCH_SIZE, device=device)


#### dataloader test

In [None]:
num_images = int(run_num*run_sample_size/sample_step)
batch_size = 32

batch_num = num_images//batch_size

for batch_index in range(batch_num):

    batch_rgb_images, batch_depth_images = data_sampler.images_sample(batch_num=batch_index, batch_size=batch_size)
    batch_str = '{batch_index:03}/{batch_num:03}'
    if batch_index%20==0:
        print(batch_str.format(batch_index=batch_index, batch_num=batch_num))


In [None]:
print("image string:", glob.glob(data_sampler.current_wildcard_rgb_name)[0])

In [None]:
print("batch_rgb_images shape:", batch_rgb_images.shape)
data = batch_rgb_images[0,:,:,:]
plt.imshow(data[0])
plt.colorbar()
plt.show()
print("batch_depth_images shape:", batch_depth_images.shape)
data = batch_depth_images[0,:,:]
plt.imshow(data)
plt.colorbar()
plt.show()


### define the model


In [None]:
model = ViTAE.ViTAE_RGB2Depth(
    image_size = (320, 640),
    beta=BETA,
    orth_factor = ORTH_FACTOR,
    latent_unit_range = LATENT_UNIT_RANGE,
    patch_size = 32,
    dim_latent = HIDDEN_DIM,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
).to(device)

load_model_params = True
if load_model_params == True:
    model.load_state_dict(torch.load('results/vitea_params_losstype-monotonic_annealing_hdim-8_beta-1.0_EPOCH-6_batch_size-96_capacity_max-2000.0_epoch-6_Sep_1.pkl'))

# Setting the optimiser
learning_rate = 1e-5
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)

### train the model

In [None]:
# train model
current_time = datetime.now()
print('Starting model evaluation...')

model.eval()
loss = []
recon_loss = []
kld_loss = []
total_kld = []

for epoch in np.arange(0, 1):
    for batch_index in range(4200,4201):
        
        batch_rgb_images, batch_depth_images = data_sampler.images_sample(batch_index=batch_index)
        # batch_rgb_images = batch_rgb_images.to(device)
        # batch_depth_images = batch_depth_images.to(device)

        # forward
        if loss_type == "monotonic_annealing":
            beta_vae_loss = model.calc_monotonic_annealing_loss(
                    rgb_images=batch_rgb_images, 
                    depth_images=batch_depth_images,
                    capacity_num_iter=EPOCH*data_sampler.batch_num,
                    capacity_stop_iter=epoch*data_sampler.batch_num + batch_index)
        elif loss_type == "cyclical_annealing":
            beta_vae_loss = model.calc_cyclical_annealing_loss(
                    rgb_images=batch_rgb_images, 
                    depth_images=batch_depth_images,
                    num_cycles=5,
                    capacity_num_iter=EPOCH*data_sampler.batch_num,
                    capacity_stop_iter=epoch*data_sampler.batch_num + batch_index)
        elif loss_type == "betavae_basic":
            beta_vae_loss = model.calc_beta_vae_basic_loss(rgb_images=batch_rgb_images, depth_images=batch_depth_images)
        elif loss_type == "betavae_capacity_control":
            beta_vae_loss = model.calc_beta_vae_capacity_control_loss(
                    rgb_images=batch_rgb_images, 
                    depth_images=batch_depth_images,
                    capacity_max=CAPACITY_MAX,
                    capacity_num_iter=EPOCH*data_sampler.batch_num,
                    capacity_stop_iter=epoch*data_sampler.batch_num + batch_index
                )

        # backward
        optimizer.zero_grad()
        beta_vae_loss.backward()
        optimizer.step()

        # show loss
        loss.append(beta_vae_loss.detach().cpu().numpy()[0])
        recon_loss.append(model.recon_loss.detach().cpu().numpy())
        kld_loss.append(model.kld_loss.detach().cpu().numpy()[0])
        total_kld.append(model.total_kld.detach().cpu().numpy()[0])

        # if batch_index%5==0:
        #     batch_str = '{batch_index:03}/{batch_num:03}'
        #     print(batch_str.format(batch_index=batch_index, batch_num=batch_num))

        # loss_plt.plot(x=np.arange(0,len(loss)), y=loss, var_name="loss", split_name="loss", title_name="loss along time")

    ##================================epoch================================
    # print("epoch:", epoch + 1)
    # if (epoch + 1) % (1) == 0:
        # loss_plt.plot(x=np.arange(0,len(loss)), y=loss, var_name='loss', split_name=['loss', 'recon_loss', 'kld_loss'], title_name="loss along time")
        loss_stack = np.column_stack((np.array(loss),np.array(recon_loss),np.array(kld_loss),np.array(total_kld)))
        # loss_plt.multiplot(x=np.arange(0,len(loss)), y=loss_stack, var_name='multiloss')
        
        rgb_images_viz = make_grid(batch_rgb_images[0:9].detach(), normalize=True, nrow=3)
        depth_images_viz = make_grid(torch.unsqueeze(batch_depth_images[0:9], 1).detach(), normalize=True, nrow=3)
        recon_images_viz = make_grid(torch.unsqueeze(model.depth_recon[0:9], 1).detach(), normalize=True, nrow=3)
        images = torch.stack([rgb_images_viz, depth_images_viz, recon_images_viz], dim=0).cpu()
        # loss_plt.rgb_depth_images(images=images, var_name="depth", split_name="depth", title_name="rbg recon depth")


### visulize rgb, depth and recon depth

In [None]:
# batch_index = 10
# batch_rgb_images, batch_depth_images = data_sampler.images_sample(batch_num=batch_index, batch_size=batch_size)
# batch_rgb_images = batch_rgb_images.to(device)
# batch_depth_images = batch_depth_images.to(device)

rgb_images_viz = make_grid(batch_rgb_images[0:9].detach(), normalize=True, nrow=3)
depth_images_viz = make_grid(torch.unsqueeze(batch_depth_images[0:9], 1).detach(), normalize=True, nrow=3)
recon_images_viz = make_grid(torch.unsqueeze(model.depth_recon[0:9], 1).detach(), normalize=True, nrow=3)
images = torch.stack([rgb_images_viz, depth_images_viz, recon_images_viz], dim=0).cpu()

images_grids = make_grid(images, normalize=True, nrow=3)
print(images_grids.shape)
fig = plt.figure(figsize=(20,8))
plt.imshow(np.transpose(images_grids,(1, 2, 0)))


### metrics on training data

In [None]:
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import explained_variance_score, mean_absolute_percentage_error 
y_true = batch_depth_images.cpu().numpy()
y_pred = torch.squeeze(model.depth_recon).detach().cpu().numpy()

print("y_true.shape:",y_true.shape)
print("y_pred.shape:",y_pred.shape)

print("mean_squared_error:",mean_squared_error(y_true.reshape(BATCH_SIZE,-1), y_pred.reshape(BATCH_SIZE,-1)))
print("explained_variance_score:",explained_variance_score(y_true.reshape(BATCH_SIZE,-1), y_pred.reshape(BATCH_SIZE,-1)))
plt.plot(model.mu.detach().cpu().numpy().flatten())
plt.show()
# plt.plot(model.logvar.detach().cpu().numpy().flatten())
# plt.show()
plt.plot(y_true.flatten())
plt.show()
plt.plot(y_pred.flatten())
plt.show()
# print((y_true.flatten(), y_pred.flatten()))

temp_coef = np.corrcoef(y_true.reshape(BATCH_SIZE,-1) - y_true.reshape(BATCH_SIZE,-1).mean(0), \
  y_pred.reshape(BATCH_SIZE,-1) - y_pred.reshape(BATCH_SIZE,-1).mean(0))
plt.imshow(temp_coef)
plt.colorbar()
# print(temp_coef)


### traverse hiddent units

In [None]:
model.eval()

limit = 3.0
# inter = limit*2.0/10.0
decoder = model.decoder
encoder = model.encoder
num_image_grid = 5
interpolation = np.linspace(-limit, limit, num=num_image_grid)

# for batch_index in range(batch_num):
batch_index = 3500 # 3500, 90 for first night index
# batch_rgb_images, batch_depth_images = data_sampler.images_sample(batch_index=batch_index)
# batch_rgb_images = batch_rgb_images.to(device)
# batch_depth_images = batch_depth_images.to(device)

fixed_idx = 30
fixed_img = batch_rgb_images[fixed_idx]
fixed_img = torch.unsqueeze(fixed_img, axis=0).to(device=device)
fixed_img_z = (torch.sigmoid(model.encoder(fixed_img)[:, :model.dim_latent])*2.0 - 1.0)*model.latent_unit_range
fixed_depth_recon = decoder(fixed_img_z)

z_ori = fixed_img_z
z_ori = torch.zeros_like(fixed_img_z)
# print("z_ori:", z_ori.shape , z_ori.detach().cpu().numpy())

samples = []
for row in range(model.dim_latent):
    z = z_ori.clone()
    for val in interpolation:
        z[:, row] = val
        sample = decoder(z)
        sample = torch.unsqueeze(sample,axis=0)
        samples.append(sample)
        val_str = "{0:.3f}".format(val)
        # print("row:",row,",val:",val_str,",",z.detach().cpu().numpy()[0,0:5])

samples = torch.cat(samples, dim=0).cpu()

samples_grids = make_grid(samples, normalize=False, nrow=num_image_grid)
print(samples_grids.shape)
# print('max:',np.max(samples_grids.numpy().flatten()))
# print('min:',np.min(samples_grids.numpy().flatten()))

org_rgb_image = batch_rgb_images[fixed_idx].cpu().numpy()
org_depth_image = batch_depth_images[fixed_idx].cpu().numpy()

temp_orb_rgb_image = (org_rgb_image - np.min(org_rgb_image.flatten())) \
    /(np.max(org_rgb_image.flatten()) - np.min(org_rgb_image.flatten()))
f, axarr = plt.subplots(1, 3, figsize=(30,10))
axarr[0].imshow(np.transpose(temp_orb_rgb_image,(1, 2, 0)))
axarr[1].imshow(org_depth_image, cmap='jet', vmax=1, vmin=0)
axarr[2].imshow(np.squeeze(fixed_depth_recon.detach().cpu().numpy()), cmap='jet', vmax=1, vmin=0)
plt.show()

fig = plt.figure(figsize=(30,90))
# plt.imshow(np.transpose(samples_grids,(1, 2, 0)))
# plt.imshow(samples_grids[0], cmap='jet', vmax=1, vmin=0)
plt.imshow(samples_grids[0], cmap='jet')
# plt.imshow(np.squeeze(sample.detach().cpu().numpy()), cmap='jet')
plt.show()


In [None]:
temp_mu = model.mu.detach().cpu()
print('temp_mu shape:', temp_mu.shape)

# corr_mu = torch.corrcoef(temp_mu.T)

triu_corr_mu = torch.triu(torch.corrcoef(temp_mu.T) + 1.0, diagonal=1) 
# triu_loss = torch.sum(torch.flatten(triu_corr_mu))
triu_loss = torch.flatten(triu_corr_mu).sum(0,True)
print(triu_loss)

fig = plt.figure(figsize=(10, 10))
plt.imshow(triu_corr_mu.numpy() - 1.0, interpolation='None', cmap='jet')
plt.title('corr_mu')
plt.colorbar()
plt.show()


#### test images traverse

In [None]:
model.eval()

limit = 3.0
# inter = limit*2.0/10.0
decoder = model.decoder
encoder = model.encoder
interpolation = np.linspace(-limit, limit, num=10)

# town07 batch: 3500 -> 4080
for batch_index in np.arange(3500, 4080, 25):
    # batch_index = 3500
    batch_rgb_images, batch_depth_images = data_sampler.images_sample(batch_index=batch_index)
    batch_rgb_images = batch_rgb_images.to(device)
    batch_depth_images = batch_depth_images.to(device)

    fixed_idx = 0
    fixed_img = batch_rgb_images[fixed_idx]
    fixed_img = torch.unsqueeze(fixed_img, axis=0).to(device=device)
    fixed_img_z = (torch.sigmoid(model.encoder(fixed_img)[:, :model.dim_latent])*2.0 - 1.0)*model.latent_unit_range
    fixed_depth_recon = decoder(fixed_img_z)

    z_ori = fixed_img_z
    z_ori = torch.zeros_like(fixed_img_z)
    # print("z_ori:", z_ori.shape , z_ori.detach().cpu().numpy())

    samples = []
    for row in range(model.dim_latent):
        z = z_ori.clone()
        for val in interpolation:
            z[:, row] = val
            sample = decoder(z)
            sample = torch.unsqueeze(sample,axis=0)
            samples.append(sample)
            val_str = "{0:.3f}".format(val)
            # print("row:",row,",val:",val_str,",",z.detach().cpu().numpy()[0,0:5])

    samples = torch.cat(samples, dim=0).cpu()

    samples_grids = make_grid(samples, normalize=False, nrow=10)
    # print(samples_grids.shape)
    # print('max:',np.max(samples_grids.numpy().flatten()))
    # print('min:',np.min(samples_grids.numpy().flatten()))

    org_rgb_image = batch_rgb_images[fixed_idx].cpu().numpy()
    org_depth_image = batch_depth_images[fixed_idx].cpu().numpy()

    temp_orb_rgb_image = (org_rgb_image - np.min(org_rgb_image.flatten())) \
        /(np.max(org_rgb_image.flatten()) - np.min(org_rgb_image.flatten()))
    f, axarr = plt.subplots(1, 3, figsize=(10,10))
    axarr[0].imshow(np.transpose(temp_orb_rgb_image,(1, 2, 0)))
    axarr[1].imshow(org_depth_image, cmap='jet', vmax=1, vmin=0)
    axarr[2].imshow(np.squeeze(fixed_depth_recon.detach().cpu().numpy()), cmap='jet', vmax=1, vmin=0)
    plt.show()
