## Get Latent Codings from TR images using ViTAE

### ViTAE model evaluate

In [1]:
import torch
import torchvision
import torch
from torch import nn
from torchvision import transforms, models
from torch.autograd import Variable
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
import PIL.Image
import glob
import os
import numpy as np
from datetime import datetime
from matplotlib import pyplot as plt
import model_vitae as ViTAE
# from visdom import Visdom
import utils
from scipy.io import savemat

# set random seeds
torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)

# defining the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# loss_type = "capacity_control" # basic, capacity_control

# hyperparameters
EPOCH = 6
HIDDEN_DIM = 8 #8, 16, 32, 96
BETA = 1.0
BATCH_SIZE = 96
CAPACITY_MAX = 2000.0 #2000.0
ORTH_FACTOR = 500.0
LATENT_UNIT_RANGE = 2.5

# Plots
# viz = Visdom(port=4000)
# loss_plt = utils.VisdomPlotter(viz=viz, env_name='main')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = ViTAE.ViTAE_RGB2Depth(
    image_size = (320, 640),
    beta=BETA,
    orth_factor = ORTH_FACTOR,
    latent_unit_range = LATENT_UNIT_RANGE,
    patch_size = 32,
    dim_latent = HIDDEN_DIM,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
).to(device)

load_model_params = True
if load_model_params == True:
    # model.load_state_dict(torch.load('results/vitea_params_losstype-monotonic_annealing_hdim-8_beta-1.0_EPOCH-6_batch_size-96_capacity_max-2000.0_epoch-6.pkl'))
    model.load_state_dict(torch.load('results/vitea_params_losstype-monotonic_annealing_hdim-8_beta-1.0_EPOCH-6_batch_size-96_capacity_max-2000.0_epoch-6_Sep_1.pkl'))

# Setting the optimiser
learning_rate = 1e-5
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)
model.eval()

ViTAE_RGB2Depth(
  (encoder): ViTAE_Encoder(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1=32, p2=32)
      (1): Linear(in_features=3072, out_features=1024, bias=True)
    )
    (transformer): Transformer(
      (layers): ModuleList(
        (0): ModuleList(
          (0): Attention(
            (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
            (to_out): Linear(in_features=1024, out_features=1024, bias=False)
          )
          (1): FeedForward(
            (net): Sequential(
              (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (1): Linear(in_features=1024, out_features=2048, bias=True)
              (2): GELU()
              (3): Linear(in_features=2048, out_features=1024, bias=True)
            )
          )
        )
        (1): ModuleList(
          

### Transform rgb images onto latent spaces

In [3]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
from torch import nn
import PIL.Image
from scipy.io import savemat, loadmat
from datetime import datetime
# from visdom import Tensor, Visdom
import glob
import pickle
import random
import sys
from sklearn.decomposition import PCA

# Define data loading step
class image_transform_dataloader(Dataset):

    def __init__(self, root_dir, start_index, weather_num, town_num, run_sample_size, sample_step, device=None):
        self.root_dir = root_dir

        self.weather_num = weather_num
        self.town_num = town_num
        self.run_sample_size = run_sample_size
        self.sample_step = sample_step
        self.start_index = start_index
        self.device = device
        
        self.pil_to_tensor = torch.nn.Sequential(
            torchvision.transforms.Resize([320, 640]),                   
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], # These are RGB mean+std values
                std=[0.229, 0.224, 0.225])  # across a large photo dataset.
        ).to(self.device)

        self.depth_to_tensor = torch.nn.Sequential(
            torchvision.transforms.Resize([320, 640]),                   
        ).to(self.device)

        self.towns_dict = {
            0: "Town01",
            1: "Town02",
            2: "Town03",
            3: "Town04",
            4: "Town05",
            5: "Town06",
            6: "Town07",
            7: "Town08",
            }

        self.weather_dict = {
            0: "ClearNight",
            1: "WetSunset",
            2: "WetCloudyNoon",
            3: "ClearNoon",
            4: "ClearSunset",
            5: "HardRainSunset",
            6: "FogMorning",
            }

        print("---Summary---")
        print("num of towns:", self.town_num)
        print("num of weather in each town:", self.weather_num)
        print("sample size of each town:", self.run_sample_size)
        print("sample step:", self.sample_step)

        self.num_images = int(self.run_sample_size/self.sample_step)
        print("num of total images per weather:", self.num_images)
        print("-----------")

        return

    def rgb_images_sample(self, i_town, i_weather):

        self.rgb_images = torch.zeros((self.num_images, self.sample_step, 3, 320, 640), device=self.device)

        for index in range(self.num_images):
            
            # calc index
            self.sample_step_images = torch.zeros((self.sample_step, 3, 320, 640), device=self.device)

            for i_step in range(self.sample_step):

                index_in_town_sample = index*self.sample_step + i_step

                # folder
                self.rgb_image_folder_str = self.towns_dict[i_town] + '_' + self.weather_dict[i_weather]

                # index image
                current_rgb_str = 'rgb_{idx:06}.jpg'
                self.current_wildcard_rgb_name = os.path.join(self.root_dir, self.rgb_image_folder_str, \
                        current_rgb_str.format(idx = index_in_town_sample + self.start_index))

                current_rbg_name = glob.glob(self.current_wildcard_rgb_name)[0]

                rgb_image = PIL.Image.open(current_rbg_name).convert('RGB')
                rgb_img_data = self.pil_to_tensor(torchvision.transforms.functional.to_tensor(rgb_image))

                self.sample_step_images[i_step] = rgb_img_data

            self.rgb_images[index] = self.sample_step_images

        return self.rgb_images


In [4]:
## get weight of each component from TR images
# each run 5*60=300s, TR=1.5s, 50 fps,
# num of TR 300/1.5=200TR,
# start index of image = 60, 
# sample step = TR*50 = 75,
# sample size for each town 300*50 = 15000

# root_dir="/media/statespace/S/recording/depth"
# root_dir = "/home/statespace/Workspace/simulation/carla_ws/recording/output_synchronized"
root_dir="/media/statespace/S/recording/output_synchronized"

num_TR = 200
run_sample_size = 15000
town_num = 8
weather_num = 7
sample_step = 75
num_images = int(run_sample_size/sample_step)

data_sampler = image_transform_dataloader(root_dir=root_dir, start_index=60, town_num=town_num, weather_num=weather_num,run_sample_size=run_sample_size, sample_step=sample_step)


---Summary---
num of towns: 8
num of weather in each town: 7
sample size of each town: 15000
sample step: 75
num of total images per weather: 200
-----------


In [14]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
    
for i_encoder_transformer_layer in np.arange(len(model.encoder.transformer.layers)):
    hook_name = 'encoder_transformer_layer_'+str(i_encoder_transformer_layer)
    model.encoder.transformer.layers[i_encoder_transformer_layer][1].net[3].register_forward_hook(get_activation(hook_name))

hook_name = 'latent_state_units'
model.encoder.head_to_latent[4].register_forward_hook(get_activation(hook_name))

for i_decoder_transformer_layer in np.arange(len(model.decoder.transformer.layers)):
    hook_name = 'decoder_transformer_layer_'+str(i_decoder_transformer_layer)
    model.decoder.transformer.layers[i_decoder_transformer_layer][1].net[3].register_forward_hook(get_activation(hook_name))

In [16]:
units_latent_space = np.zeros((town_num, weather_num, num_images, model.dim_latent))
model_dim_patch = 200
num_layers = 12
encode_decode_space = np.zeros((town_num, weather_num, num_images, num_layers, model_dim_patch))

for i_town in range(town_num):
    for i_weather in range(weather_num):
        
        print("town:",data_sampler.towns_dict[i_town],"weather:",data_sampler.weather_dict[i_weather])
        
        weather_images = data_sampler.rgb_images_sample(i_town=i_town, i_weather=i_weather)

        # print("weather_images.shape:", weather_images.shape)

        for i_image in range(num_images):
            temp_weather_image = weather_images[i_image]
            temp_weather_image = temp_weather_image.to(device=device)

            # encoder
            output = model(temp_weather_image)
            i_layers = 0
            for i_encoder_transformer_layer in np.arange(len(model.encoder.transformer.layers)):
                hook_name = 'encoder_transformer_layer_'+str(i_encoder_transformer_layer)
                temp_hook_output = np.mean(np.mean(activation[hook_name].detach().cpu().numpy(), axis=2), axis=0)
                encode_decode_space[i_town][i_weather][i_image][i_layers] = temp_hook_output
                i_layers = i_layers + 1
                # print(hook_name, ':', activation[hook_name].shape)

            # decoder
            for i_decoder_transformer_layer in np.arange(len(model.decoder.transformer.layers)):
                hook_name = 'decoder_transformer_layer_'+str(i_decoder_transformer_layer)
                temp_hook_output = np.mean(np.mean(activation[hook_name].detach().cpu().numpy(), axis=2), axis=0)
                encode_decode_space[i_town][i_weather][i_image][i_layers] = temp_hook_output
                i_layers = i_layers + 1
                # print(hook_name, ':', activation[hook_name].shape)

            # latent
            hook_name = 'latent_state_units'
            image_latent_units = (torch.sigmoid(activation[hook_name][:, :model.dim_latent])*2.0 - 1.0)*model.latent_unit_range
            # print(hook_name, ':', activation[hook_name].shape)
            units_latent_space[i_town][i_weather][i_image] = np.mean(image_latent_units.detach().cpu().numpy(), axis=0)

            # sys.exit("Error message")

            # output 
            # image_latent_units = (torch.sigmoid(model.encoder(temp_weather_image)[:, :model.dim_latent])*2.0 - 1.0)*model.latent_unit_range
            # units_latent_space[i_town][i_weather][i_image] = np.mean(image_latent_units.detach().cpu().numpy(), axis=0)


town: Town01 weather: ClearNight
town: Town01 weather: WetSunset
town: Town01 weather: WetCloudyNoon
town: Town01 weather: ClearNoon
town: Town01 weather: ClearSunset
town: Town01 weather: HardRainSunset
town: Town01 weather: FogMorning
town: Town02 weather: ClearNight
town: Town02 weather: WetSunset
town: Town02 weather: WetCloudyNoon
town: Town02 weather: ClearNoon
town: Town02 weather: ClearSunset
town: Town02 weather: HardRainSunset
town: Town02 weather: FogMorning
town: Town03 weather: ClearNight
town: Town03 weather: WetSunset
town: Town03 weather: WetCloudyNoon
town: Town03 weather: ClearNoon
town: Town03 weather: ClearSunset
town: Town03 weather: HardRainSunset
town: Town03 weather: FogMorning
town: Town04 weather: ClearNight
town: Town04 weather: WetSunset
town: Town04 weather: WetCloudyNoon
town: Town04 weather: ClearNoon
town: Town04 weather: ClearSunset
town: Town04 weather: HardRainSunset
town: Town04 weather: FogMorning
town: Town05 weather: ClearNight
town: Town05 weathe

In [None]:
units_latent_space = np.zeros((town_num, weather_num, num_images, model.dim_latent))

for i_town in range(town_num):
    for i_weather in range(weather_num):
        
        print("town:",data_sampler.towns_dict[i_town],"weather:",data_sampler.weather_dict[i_weather])
        
        weather_images = data_sampler.rgb_images_sample(i_town=i_town, i_weather=i_weather)

        for i_image in range(num_images):
            temp_weather_image = weather_images[i_image]
            temp_weather_image = temp_weather_image.to(device=device)
            image_latent_units = (torch.sigmoid(model.encoder(temp_weather_image)[:, :model.dim_latent])*2.0 - 1.0)*model.latent_unit_range

            units_latent_space[i_town][i_weather][i_image] = np.mean(image_latent_units.detach().cpu().numpy(), axis=0)


In [None]:
from scipy.io import savemat

matdic = {"units_latent_space":units_latent_space,}
save_dir = "/media/statespace/Spatial/sptialworkspace/spatialfMRI/simulation/carla_ws/vit/results/"
savemat(save_dir+"units_"+ str(HIDDEN_DIM) +"_latent_space.mat", matdic)


In [17]:
from scipy.io import savemat

matdic = {"units_latent_space":units_latent_space,"encode_decode_space":encode_decode_space}
save_dir = "/media/statespace/Spatial/sptialworkspace/spatialfMRI/simulation/carla_ws/vit/results/"
savemat(save_dir+"units_"+ str(HIDDEN_DIM) +"_latent_encode_decode_space.mat", matdic)


In [None]:
from scipy.io import loadmat

save_dir = "/media/statespace/Spatial/sptialworkspace/spatialfMRI/simulation/carla_ws/vit/results/"

matdic = loadmat(save_dir+"units_"+ str(HIDDEN_DIM) +"_latent_encode_decode_space.mat")

units_latent_space = matdic["units_latent_space"]
encode_decode_space = matdic["encode_decode_space"]
print("units_latent_space shape (town, weather, images, latent_dims) =", units_latent_space.shape)
print("encode_decode_space shape (town, weather, images, latent_dims) =", encode_decode_space.shape)

In [None]:
from scipy.io import loadmat

save_dir = "/media/statespace/Spatial/sptialworkspace/spatialfMRI/simulation/carla_ws/vit/results/"

matdic = loadmat(save_dir+"units_"+ str(HIDDEN_DIM) +"_latent_space.mat")

units_latent_space = matdic["units_latent_space"]
print("units_latent_space shape (town, weather, images, latent_dims) =", units_latent_space.shape)

##### variance for different weather

In [None]:
import numpy as np
weather_var = np.var(units_latent_space, axis=1)
print(weather_var.shape)

# weather_var = np.einsum('ijk->ikj', weather_var)

import matplotlib.pyplot as plt 
weather_var = np.mean(weather_var, axis=1)
print(weather_var.shape)

weather_var = np.einsum('ij->ji', weather_var)
print(weather_var.shape)

plt.plot(weather_var[:,0:8])
plt.legend(np.arange(weather_var.shape[1]))

In [None]:
import numpy as np
temp_var = np.mean(units_latent_space, axis=1)
print(temp_var.shape)

weather_var = np.var(temp_var, axis=1)
print(weather_var.shape)

# weather_var = np.einsum('ijk->ikj', weather_var)

# import matplotlib.pyplot as plt 
# weather_var = np.mean(weather_var, axis=1)
# print(weather_var.shape)

# weather_var = np.einsum('ij->ji', weather_var)
# print(weather_var.shape)

plt.plot(weather_var.T)
plt.legend(np.arange(weather_var.shape[1]))

##### variance for different weather with 8 units

In [None]:
from scipy.io import loadmat

HIDDEN_DIM = 8

save_dir = "/media/statespace/Spatial/sptialworkspace/spatialfMRI/simulation/carla_ws/vit/results/"

matdic = loadmat(save_dir+"units_"+ str(HIDDEN_DIM) +"_latent_space.mat")

units_latent_space = matdic["units_latent_space"]
print("units_latent_space shape (town, weather, images, latent_dims) =", units_latent_space.shape)

##### variance along weather

In [None]:
# Import required packages
from tkinter import font
import matplotlib
# matplotlib.use('TkAgg')
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pylab import cm
import matplotlib.font_manager as fm

font_names = [f.name for f in fm.fontManager.ttflist]
# print(font_names)

mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 16
plt.rcParams['axes.linewidth'] = 2
fig = plt.figure(figsize=(16, 5))

import numpy as np
print("units_latent_space shape (town, weather, images, latent_dims) =", units_latent_space.shape)

weather_var = np.var(units_latent_space, axis=1)
print(weather_var.shape)

# weather_var = np.einsum('ijk->ikj', weather_var)
weather_var = np.mean(weather_var, axis=1)
print(weather_var.shape)

weather_var = np.einsum('ij->ji', weather_var)

legent_list = []
for i_legend in range(units_latent_space.shape[0]-1):
  legent_list.append("town0"+str(i_legend+1))
legent_list.append("town10")

colors = plt.cm.jet(np.linspace(0, 1, units_latent_space.shape[0]))
for i in range(units_latent_space.shape[0]):
  plt.plot(weather_var[:,i], 'o-', color=colors[i], linewidth=3)
plt.legend(legent_list, loc='upper right', ncol=4)
plt.xticks(np.arange(units_latent_space.shape[0]))
# plt.yticks(np.arange(0, 0.024, 0.004))
plt.xlabel("latent units", fontsize=22)
plt.ylabel("variance along weather", fontsize=22)
plt.title("variance of latent units along weather", fontsize=24)
plt.show()
plt.savefig("results/latent_units_var_weather.eps", dpi=500)
plt.savefig("results/latent_units_var_weather.png", dpi=500)

##### variance along time sequences

In [None]:
# Import required packages
from tkinter import font
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pylab import cm
import matplotlib.font_manager as fm

font_names = [f.name for f in fm.fontManager.ttflist]
# print(font_names)

mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 16
plt.rcParams['axes.linewidth'] = 2
fig = plt.figure(figsize=(16, 5))

import numpy as np
print("units_latent_space shape (town, weather, images, latent_dims) =", units_latent_space.shape)

weather_var = np.var(units_latent_space, axis=2)
print(weather_var.shape)

# weather_var = np.einsum('ijk->ikj', weather_var)
weather_var = np.mean(weather_var, axis=1)
print(weather_var.shape)

weather_var = np.einsum('ij->ji', weather_var)

legent_list = []
for i_legend in range(units_latent_space.shape[0]-1):
  legent_list.append("town0"+str(i_legend+1))
legent_list.append("town10")

colors = plt.cm.jet(np.linspace(0, 1, units_latent_space.shape[0]))
for i in range(units_latent_space.shape[0]):
  plt.plot(weather_var[:,i], 'o-', color=colors[i], linewidth=3)
plt.legend(legent_list, loc='upper right', ncol=4)
plt.xticks(np.arange(units_latent_space.shape[0]))
plt.yticks(np.arange(0, 1.6, 0.2))
plt.xlabel("latent units", fontsize=22)
plt.ylabel("variance along depth", fontsize=22)
plt.title("variance of latent units along depth", fontsize=24)
plt.show()
plt.savefig("results/latent_units_var_depth.eps", dpi=500)
plt.savefig("results/latent_units_var_depth.png", dpi=500)