# Visualize samples of data used for experiments
- Dec 31, 2020


## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [None]:
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable, TypeVar

from pprint import pprint
from ipdb import set_trace as brpt

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.tuner.tuning import Tuner


# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

## Set Path 
1. Add project root and src folders to `sys.path`
2. Set DATA_ROOT to `maptile_v2` folder

In [None]:
this_nb_path = Path(os.getcwd())
ROOT = this_nb_path.parent
SRC = ROOT/'src'
DATA_ROOT = Path("/data/hayley-old/maptiles_v2/")
paths2add = [this_nb_path, ROOT]

print("Project root: ", str(ROOT))
print('Src folder: ', str(SRC))
print("This nb path: ", str(this_nb_path))


for p in paths2add:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"\n{str(p)} added to the path.")
# print(sys.path)

In [None]:
from src.data.datasets.maptiles import MaptilesDataset, MapStyles
from src.data.datamodules.maptiles_datamodule import MaptilesDataModule

from src.data.transforms.transforms import Identity, Unnormalizer, LinearRescaler
from src.data.transforms.functional import unnormalize

from src.visualize.utils import show_timgs, show_batch
from src.utils.misc import info
from collections import OrderedDict


---
## 1. Digit datasets
MNIST, MINSTM and USPS

- MNISTM
    - original size of an image: (1, 16,16)
    - labels: {0, ..., 9}
- USPS
    - original size of an image: (3, 28, 28)
    - labels" {0, ..., 9}
    

In [None]:
from src.data.datasets.mnistm import MNISTM
from torchvision.datasets import MNIST, USPS


In [None]:
# MNISTM Dataset
bs = 16
num_workers = 16
pin_memory = True
xforms = transforms.Compose([
    transforms.ToTensor(),
    ])
# target_xforms = 
mnistm_ds = MNISTM(ROOT/'data', 
          transform=xforms,
          download=True)

mnistm_dl = DataLoader(ds, batch_size=bs, shuffle=True, 
               num_workers=num_workers, pin_memory=pin_memory)


x,y = next(iter(dl))
info(x)
info(y)

In [None]:
show_timgs(x)

In [None]:
# USPS Dataset
bs = 16
num_workers = 16
pin_memory = True
xforms = transforms.Compose([
    transforms.ToTensor(),
    ])
# target_xforms = 
usps_ds = USPS(ROOT/'data', 
          transform=xforms,
          download=True)

usps_dl = DataLoader(ds, batch_size=bs, shuffle=True, 
               num_workers=num_workers, pin_memory=pin_memory)


x,y = next(iter(dl))
info(x)
info(y)
show_timgs(x, cmap='gray')

In [None]:
dsets = {"mnistm": mnistm_ds,
         "usps": usps_ds}

dls = {"mnistm": mnistm_dl,
         "usps": usps_dl}

---
## Compute channelwise mean and std of the images in the training/test splits
- First for the MNISTM dataset whose images are RGB (ie. have 3 channels):

In [None]:
def get_channelwise_mean_std(
            dset: Dataset,
            n_channels: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Assume dset[i] returns a tuple of torch.Tensors (x,y)
    and x is in range [0,1], of shape (n_channels, h, w).    
    """
    channel_sum = torch.zeros(n_channels)
    channel_squared_sum = torch.zeros(n_channels)
    n_pixels = 0.
    for i in range(len(dset)):
        timg, _ = dset[i]
        n_pixels += timg.shape[1] * timg.shape[2]
        channel_sum += torch.sum(timg, dim=(1,2))
        channel_squared_sum += torch.sum(timg ** 2, dim=(1,2))
#         breakpoint()
    channel_mean = channel_sum / n_pixels
    channel_std = torch.sqrt(channel_squared_sum / n_pixels - channel_mean ** 2)
    return channel_mean, channel_std

In [None]:
print("MNISTM")
for is_train in [True, False]:
    print("Train: ", is_train)
    ds  = MNISTM(ROOT/'data', 
                 train=is_train,
                transform=transforms.ToTensor(),
                  download=True)

    print("\tMean, std")
    print("\t", get_channelwise_mean_std(ds, 3))
    
# Train mean, std: (tensor([0.4639, 0.4676, 0.4199]), tensor([0.2534, 0.2380, 0.2618]))
# Test mean, std: [0.4627, 0.4671, 0.4209]), tensor([0.2553, 0.2395, 0.2639]

- Now, for the USPS dataset whose images are in grayscale (ie. have 1 channel):

In [None]:
print("USPS")
for is_train in [True, False]:
    print("Train: ", is_train)
    ds  = USPS(ROOT/'data', 
             train=is_train,
            transform=transforms.ToTensor(),
          download=True)

    print("\tMean, std")
    print("\t", get_channelwise_mean_std(ds, 1))
    
#Train mean,std: tensor([0.2469]), tensor([0.2989] 
#Test mean,std: (tensor([0.2599]), tensor([0.3083]))


---
## DataModule objects
Test custom datamodules on each of the datasets above.
1. MNIST-M


In [None]:
from src.data.datamodules import BaseDataModule, USPSDataModule, MNISTMDataModule, MNISTDataModule


In [None]:
mnist_dm = MNISTDataModule(data_root=ROOT/'data', 
                       in_shape=(1,32,32),
                      batch_size=32)
mnistm_dm = MNISTMDataModule(data_root=ROOT/'data', 
                       in_shape=(3, 32,32),
                      batch_size=32)
usps_dm = USPSDataModule(data_root=ROOT/'data', 
                       in_shape=(1,32,32),
                      batch_size=32)

for dm in [mnist_dm, mnistm_dm, usps_dm]:
    print(dm.name)
    dm.setup('fit')
    cmap = 'gray' if dm.in_shape[0] < 3 else None
    show_batch(dm, cmap=cmap, title=dm.name)