## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [None]:
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
from skimage.color import rgb2gray
from skimage.transform import resize

from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable

from pprint import pprint
from ipdb import set_trace as brpt

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import loggers as pl_loggers

import ray
# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="2"

## Set Path 
1. Add project root and src folders to `sys.path`
2. Set DATA_ROOT to `maptile_v2` folder

In [None]:
this_nb_path = Path(os.getcwd())
ROOT = this_nb_path.parent
SRC = ROOT/'src'
DATA_ROOT = Path("/data/hayley-old/maptiles_v2/")
paths2add = [this_nb_path, ROOT]

print("Project root: ", str(ROOT))
print('Src folder: ', str(SRC))
print("This nb path: ", str(this_nb_path))


for p in paths2add:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"\n{str(p)} added to the path.")
        
# print(sys.path)



In [None]:
# from src.data.datasets.maptiles import Maptiles, MapStyles
# from src.data.datamodules.mnist_datamodule import MNISTDataModule
# from src.data.datamodules.maptiles_datamodule import MaptilesDataModule
from src.data.datamodules.multisource_maptiles_datamodule import MultiMaptilesDataModule


# from src.models.plmodules.three_fcs import ThreeFCs
# from src.models.plmodules.vanilla_vae import VanillaVAE
# from src.models.plmodules.beta_vae import BetaVAE
from src.models.plmodules.bilatent_vae import BiVAE

from src.visualize.utils import show_timgs, show_npimgs
from src.utils.misc import info, get_next_version_path, n_iter_per_epoch


## Start experiment 
Given a maptile, predict its style as one of OSM, CartoVoyager

In [None]:
# # Instantiate MNIST Datamodule
# in_shape = (1,32,32)
# batch_size = 32
# dm = MNISTDataModule(data_root=ROOT/'data', 
#                        in_shape=in_shape,
#                       batch_size=batch_size)
# dm.setup('fit')
# print("DM: ", dm.name)

In [None]:
# Instantiate Multisource Maptiles DataModule
all_cities = ['la', 'charlotte', 'vegas', 'boston', 'paris', \
              'amsterdam', 'shanghai', 'seoul', 'chicago', 'manhattan', \
             'berlin', 'montreal', 'rome']

data_root = Path("/data/hayley-old/maptiles_v2/")
cities = all_cities # ['berlin', 'rome', 'la', 'amsterdam', 'seoul'] #['paris']
styles =['StamenWatercolor']#['StamenTonerBackground','OSMDefault', 'CartoVoyagerNoLabels']#'StamenWatercolor']#, 'StamenTonerLines']
zooms = ['14']
in_shape = (3, 64, 64)
batch_size = 32
print('cities: ', cities)
print('styes: ', styles)
dm = MultiMaptilesDataModule(
    data_root=data_root,
    cities=cities,
    styles=styles,
    zooms=zooms,
    in_shape=in_shape,
    batch_size=batch_size,
)
dm.setup('fit')

## Get Country name and tile extent (in Kilometers) from tile numbers and zoom (x,y,z)
- Updated: Mar 8, 2021

In [None]:
from src.utils.geo import getTileExtent, getCountryFromTile, getGeoFromTile

In [None]:
fn = "/data/hayley-old/maptiles_v2/paris/StamenWatercolor/14/8301_5639_14.png"
# img = plt.imread(fn)
img = plt.imread(fn, format='jpg')

info(img)
plt.imshow(img)


In [None]:

fn = "/data/hayley-old/maptiles_v2/paris/StamenTonerBackground/14/8301_5638_14.png"
img = plt.imread(fn, format='jpg')
info(img)

In [None]:
paris_dir = Path("/data/hayley-old/maptiles_v2/paris")
zoom='14'
paris_samples = {}

idx = np.random.randint(100)
for style_dir in paris_dir.iterdir():
    style = style_dir.stem
    zoom_dir = style_dir/zoom
    for i, p in enumerate(zoom_dir.iterdir()):
        if p.is_file() and i == idx:
            try:
                img = plt.imread(p)
            except SyntaxError:
                img = plt.imread(p, format='jpg')
            paris_samples[style] = img
            print(style)
            info(img)
            
            # xyz: get tile extent and country
            x,y,z = map(int, p.stem.split('_'))
            country = getCountryFromTile(x,y,z)
            size_y, size_x = getTileExtent(x,y,z)
            
            print('x, y, z: ', x,y,z)
            print('lat_deg, lng_deg: ', getGeoFromTile(x,y,z))
            print(country)
            print('tile size in meters (y,x dir): ',size_y, size_x)
            print()

            break
    

In [None]:
# Show random maptiles
# Mar 12, 2021
paris_dir = Path("/data/hayley-old/maptiles_v2/paris")
zoom='14'
paris_samples = {}

styles2show = ['StamenTonerBackground', 'StamenWatercolor', 'OSMDefault']
n_samples = 10
locations = [] # list of (lat_deg, lng_deg)
imgs = []
extents = [] # list of the physical area extent covered by the maptile; in meters (size_y, size_x)
for style_dir in paris_dir.iterdir():
    style = style_dir.stem
    zoom_dir = style_dir/zoom
    
    print(f'style: {style}')
    if not style in styles2show:
        print(f'skipping {style}...')
        continue
        
    img_fns = [p for p in zoom_dir.iterdir() if p.is_file()]
    n_imgs = len(img_fns)
    inds = np.random.randint(n_imgs, size=n_samples)
    for ind in inds:
        p = img_fns[ind]
        try:
            img = plt.imread(p)
        except SyntaxError:
            img = plt.imread(p, format='jpg')
        imgs.append(img)

        # xyz: get tile extent and country
        x,y,z = map(int, p.stem.split('_'))
        country = getCountryFromTile(x,y,z)
        size_y, size_x = getTileExtent(x,y,z)
        lat_deg, lng_deg =  getGeoFromTile(x,y,z)
        locations.append((lat_deg, lng_deg))
        
        # Plot and Printout
        info(img)
        print('x, y, z: ', x,y,z)
        print('lat_deg, lng_deg: ', lat_deg, lng_deg)
        print(country)
        print('tile size in meters (y,x dir): ',size_y, size_x)
#         plt.imshow(img)
#         plt.show()
        print()



In [None]:
def round_lat_lng(loc, decimals=2) -> Tuple[float]:
    "loc (Tuple[float]): (lat_deg, lng_deg)"
    return np.round(loc[0], decimals), np.round(loc[1], decimals)


titles = list(map(round_lat_lng, locations))
show_npimgs(imgs, titles=titles);

## Road network queries from OSM using OSMnx
- Verify the road networks look similar to the ones in the maptiles

In [None]:
np_locs = np.array(locations)

In [None]:
import pandas as pd

In [None]:
df_locs = pd.DataFrame(np_locs, columns=['lat', 'lng'])
df_locs

In [None]:
city = 'paris'
df_locs.to_pickle(ROOT/f'cache/sample_locations_{city}.pkl')
df_locs.to_csv(ROOT/f'cache/sample_locations_{city}.csv')



In [None]:
!conda list | grep joblib


In [None]:
# Instantiate Multisource Maptiles DataModule
all_cities = ['la', 'charlotte', 'vegas', 'boston', 'paris', \
              'amsterdam', 'shanghai', 'seoul', 'chicago', 'manhattan', \
             'berlin', 'montreal', 'rome']

data_root = Path("/data/hayley-old/maptiles_v2/")
cities = all_cities # ['berlin', 'rome', 'la', 'amsterdam', 'seoul'] #['paris']
styles =['StamenTonerBackground','OSMDefault', 'CartoVoyagerNoLabels']#'StamenWatercolor']#, 'StamenTonerLines']
zooms = ['14']
in_shape = (3, 64, 64)
batch_size = 32
print('cities: ', cities)
print('styes: ', styles)

start = time.time()
dm = MultiMaptilesDataModule(
    data_root=data_root,
    cities=cities,
    styles=styles,
    zooms=zooms,
    in_shape=in_shape,
    batch_size=batch_size,
)
dm.setup('fit')

print('Took: ', time.time() - start)

In [None]:
# Save the processed cities-styles's filenames for another init of the same dataset in later experiments
from collections import defaultdict
cache = defaultdict(dict)

city_str = '-'.join(sorted(cities))
style_str = '-'.join(styles)
cache[city_str][style_str] = dm.df_fns

In [None]:
cache

## Save dataframe of filenames for this Datamodule (entire DM, ie. including both train_ds and val_ds (and test_ds)

In [None]:
out_dir = ROOT/f'cache/{city_str}'
if not out_dir.exists:
    out_dir.mkdir(parents=True)
    print('Created: ', out_dir)
fn = f'df_fns_{style_str}.pkl'

joblib.dump(dm.df_fns, out_dir/fn)

In [None]:
# Test if we can load the cached df_fns and create the same  train_ds and val_ds
df_fns = joblib.load(out_dir/fn)
print(len(df_fns))
assert (df_fns.equals(dm.df_fns))


In [None]:
start = time.time()
dm2 =  MultiMaptilesDataModule(
    df_fns = df_fns,
    data_root=data_root,
    cities=cities,
    styles=styles,
    zooms=zooms,
    in_shape=in_shape,
    batch_size=batch_size,
)
dm2.setup('fit')
print('=== Using cached df_fns ===')
print('DM init took: ', time.time() - start)

In [None]:
# # Pickle this datamodule
# import joblib
# nb_name = '16-a'
# joblib.dump(dm, ROOT/'cache'/f'dm_{nb_name}.pkl')

In [None]:
print('train size: ', len(dm.train_ds))
# show a batch
dl = dm.train_dataloader()
batch = next(iter(dl))
x, label_c, label_s = dm.unpack(batch)
info(x)
show_timgs(x, titles=label_s.tolist(), cmap='gray' if in_shape[0]==1 else None)
print(label_c)
print(label_s)

In [None]:
# Instantiate the pl Module
from src.models.plmodules.bilatent_vae import BiVAE

# betas = [0.1 * 3**i for i in range(10)]
# for kld_weight in [1.0]
n_styles = len(styles)
latent_dim = 10
hidden_dims = [32, 64, 128, 256] #,512]
adversary_dims = [32,32,32]
act_fn = nn.LeakyReLU()
learning_rate = 3e-4

is_contrasive = True
kld_weight = 1.0 # vae_loss = recon_loss + kld_weight * kld_weight; betas[0];
adv_loss_weight = 15. # loss = vae_loss + adv_loss_weight * adv_loss

# enc_type = 'resnet'
enc_type = 'conv'

# dec_type = 'conv'
dec_type = 'resnet'

if enc_type == 'resnet':
    hidden_dims = [32, 32, 64, 128, 256]

model = BiVAE(
    in_shape=in_shape, 
    n_styles=n_styles,
    latent_dim=latent_dim,
    hidden_dims=hidden_dims,
    adversary_dims=adversary_dims,
    learning_rate=learning_rate,
    act_fn=act_fn,
    is_contrasive=is_contrasive,
    kld_weight=kld_weight,
    adv_loss_weight=adv_loss_weight,
    enc_type=enc_type,
    dec_type=dec_type,
)


In [None]:
model.name

In [None]:
# Instantiate a PL `Trainer` object
# Start the experiment
max_epochs = 200
exp_name = f'{model.name}_{dm.name}'
tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'{ROOT}/temp-logs', 
                                         name=exp_name,
                                         log_graph=False,
                                        default_hp_metric=False)
print("Log dir: ", tb_logger.log_dir)

log_dir = Path(tb_logger.log_dir)
if not log_dir.exists():
    log_dir.mkdir(parents=True)
    print("\nCreated: ", log_dir)
    

# Log computational graph
# model_wrapper = ModelWrapper(model)
# tb_logger.experiment.add_graph(model_wrapper, model.example_input_array.to(model.device))
# tb_logger.log_graph(model)

trainer_config = {
    'gpus':1,
    'max_epochs': max_epochs,
    'progress_bar_refresh_rate':0,
    'terminate_on_nan':True,
    'check_val_every_n_epoch':10,
    'logger':tb_logger,
#     'callbacks':callbacks,
}


## Ray - Tune: efficient hyperparameter turning of DL experiments with distributed execution engine (in Python)

In [None]:
ray.init()

### Training is logged to (on March 1st, 2021)

Log dir:  /data/hayley-old/Tenanbaum2000/temp-logs/BiVAE-C-conv-resnet-1.0-15.0_Maptiles_la-charlotte-vegas-boston-paris-amsterdam-shanghai-seoul-chicago-manhattan-berlin-montreal-rome_CartoVoyagerNoLabels-OSMDefault-StamenTonerBackground_14/version_0
Created:  /data/hayley-old/Tenanbaum2000/temp-logs/BiVAE-C-conv-resnet-1.0-15.0_Maptiles_la-charlotte-vegas-boston-paris-amsterdam-shanghai-seoul-chicago-manhattan-berlin-montreal-rome_CartoVoyagerNoLabels-OSMDefault-StamenTonerBackground_14/version_0

In [None]:
# trainer = pl.Trainer(fast_dev_run=3)
trainer = pl.Trainer(**trainer_config)
# trainer.tune(model=model, datamodule=dm)
print("\nMetrics: ", trainer.callback_metrics.keys())# todo: delete

# Fit model
trainer.fit(model, dm)
print(f"Finished at ep {trainer.current_epoch, trainer.batch_idx}")

In [None]:
model.current_epoch, model.logger.log_dir

## Log  hparmeters and `best_score` to tensorboard

In [None]:
hparams = model.hparams.copy()
hparams.update(dm.hparams)
best_score = trainer.checkpoint_callback.best_model_score.item()
metrics = {'hparam/best_score': best_score} #todo: define a metric and use it here
pprint(hparams)
pprint(metrics)

In [None]:
# Use pl.Logger's method "log_hyperparameters" which handles the 
# hparams' element's formats to be suitable for Tensorboard logging
# See: 
# https://sourcegraph.com/github.com/PyTorchLightning/pytorch-lightning@be3e8701cebfc59bec97d0c7717bb5e52afc665e/-/blob/pytorch_lightning/loggers/tensorboard.py#explorer:~:text=def%20log_hyperparams
best_score = trainer.checkpoint_callback.best_model_score.item()
metrics = {'hparam/best_score': best_score} #todo: define a metric and use it here
trainer.logger.log_hyperparams(hparams, metrics)

# Evaluations

In [None]:
from src.models.plmodules.utils import get_best_ckpt, load_model, load_best_model
from pytorch_lightning.utilities.cloud_io import load as pl_load


Load best model recorded during the training


In [None]:
ckpt_path = get_best_ckpt(model, verbose=True)
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)  # dict object
print(ckpt['epoch'])

In [None]:
# Load bestmodel
model.load_state_dict(ckpt['state_dict'])


## Reconstruction
    
    

In [None]:
from torch.utils.tensorboard import SummaryWriter

def show_recon(model: BiVAE, 
               tb_writer: SummaryWriter=None,
               global_step:int=0,
               unnorm:bool=True, 
               to_show:bool=True, 
               verbose:bool=False):
    model.eval()
    dm = model.trainer.datamodule
    cmap = 'gray' if dm.size()[0] ==1 else None
    train_mean, train_std = dm.train_mean, dm.train_std
    with torch.no_grad():
        for mode in ['train', 'val']:
            dl = getattr(model, f"{mode}_dataloader")()
            batch = next(iter(dl))
            
            x = batch['img']
#             label_c = batch['digit']  # digit/content label (int) -- currently not used
#             label_s = batch['color']
            x = x.to(model.device)
            x_recon = model.generate(x)
            
            # Move to cpu for visualization
            x = x.cpu()
            x_recon = x_recon.cpu()
            
            if verbose: 
                info(x, f"{mode}_x")
                info(x_recon, f"{mode}_x_recon")
                
            if unnorm:
                x_unnormed = unnormalize(x, train_mean, train_std)
                x_recon_unnormed = unnormalize(x_recon, train_mean, train_std)
                if verbose:
                    print("===After unnormalize===")
                    info(x_unnormed, f"{mode}_x_unnormed")
                    info(x_recon_unnormed, f"{mode}_x_recon_unnormed")
                    
            if to_show:
                _x = x_unnormed if unnorm else x
                _x_recon = x_recon_unnormed if unnorm else x_recon
                show_timgs(_x, title=f"Input: {mode}", cmap=cmap)
#                 show_timgs(_x_recon, title=f"Recon: {mode}", cmap=cmap)
                show_timgs(LinearRescaler()(_x_recon), title=f"Recon(linearized): {mode}", cmap=cmap)

            # Log input-recon grid to TB
            if tb_writer is not None:
                input_grid = torchvision.utils.make_grid(x_unnormed) # (C, gridh, gridw)
                recon_grid = torchvision.utils.make_grid(x_recon_unnormed) # (C, gridh, gridw)
                normed_recon_grid = torchvision.utils.make_grid(LinearRescaler()(x_recon_unnormed))
                
                grid = torch.cat([input_grid, normed_recon_grid], dim=-1) #inputs | recons
                tb_writer.add_image(f"{mode}/recons", grid, global_step=global_step)


In [None]:
show_recon(model, tb_logger.experiment, global_step=1, verbose=True)