## Load libraries

In [56]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [None]:
import pandas as pd
import joblib
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable, TypeVar

from pprint import pprint
from ipdb import set_trace as brpt

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.tuner.tuning import Tuner


# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

## Set Path 
1. Add project root and src folders to `sys.path`
2. Set DATA_ROOT to `maptile_v2` folder

In [None]:
this_nb_path = Path(os.getcwd())
ROOT = this_nb_path.parent
SRC = ROOT/'src'
DATA_ROOT = Path("/data/hayley-old/maptiles_v2/")
paths2add = [this_nb_path, ROOT]

print("Project root: ", str(ROOT))
print('Src folder: ', str(SRC))
print("This nb path: ", str(this_nb_path))


for p in paths2add:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"\n{str(p)} added to the path.")
        
# print(sys.path)

In [None]:
# Data transforms
from src.data.transforms.transforms import Identity, Unnormalizer, LinearRescaler
from src.data.transforms.functional import unnormalize

# Utils
from src.visualize.utils import show_timg, show_timgs, show_batch, make_grid_from_tensors
from src.utils.misc import info, get_next_version_path
from collections import OrderedDict

In [None]:
# DataModules
from src.data.datamodules import MNISTDataModule, MNISTMDataModule, MonoMNISTDataModule
from src.data.datamodules import MultiMonoMNISTDataModule

# plModules
from src.models.plmodules.vanilla_vae import VanillaVAE
from src.models.plmodules.iwae import IWAE
from src.models.plmodules.bilatent_vae import BiVAE
from src.models.plmodules.three_fcs import ThreeFCs


# Evaluations
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.utilities.cloud_io import load as pl_load
from src.evaluator.qualitative import save_content_transfers, save_style_transfers, run_both_transfers

is centred on the object, the scene background is removed and additional generative factors (shape and lighting) are held constant. Each generative factor is independently sampled from its respec- tive uniform distribution: azimuth(z0) ∼ U[0, 2π], elevation(z1) ∼ U[0, π/2], red(z2) ∼ U[0, 1], green(z3) ∼ U[0,1], blue(z4) ∼ U[0,1]. We divide the images into training (160,000), validation (20,000) and test (20,000) sets before removing images which contain particular generative fac- tor combinations to faciliate the evaluation of zeroshot performance (see Appendix B.2). This left 142,927, 17,854 and 17,854 images in the training, validation and test sets respectively.


---
z = [z0, z1, z2, z3, z4]

- azimuth(z0) ∼ U[0, 2π], 
- elevation(z1) ∼ U[0, π/2], 
- red(z2) ∼ U[0, 1], 
- green(z3) ∼ U[0,1], 
- blue(z4) ∼ U[0,1].

In [None]:
data_dir = Path('/data/hayley-old/Tenanbaum2000/data/Teapots')
data = np.load(data_dir/'teapots.npz')
gap_ids = np.load(data_dir/'gap_ids.npy')

In [None]:
list(data.keys())

In [None]:
imgs = np.array(
    [img for i,img in enumerate(data["images"]) if i not in gap_ids]
)
print(len(imgs))

In [None]:
gts = [gt for i,gt in enumerate(data["gts"]) if i not in gap_ids]
gts = np.array(gts)

In [None]:
gts.shape

In [None]:
imgs.shape

In [None]:
gts[0]


In [None]:
for i in range(10):
    idx = np.random.choice(len(imgs))
    print(gts[idx])
    plt.imshow(imgs[idx])
    plt.axis('off')    
    plt.show()
