# Create a Multi-source datamodule for Maptiles
- Jan 6, 2021

Each "style" of monochrome-mnist dataset puts a different color for the digit pixels.
This notebooks shows 

- how we can create each of the datasets so that it outputs a consistent data sample
at each call for the `__getitem__` method (eg. via indexing `myDataset[item_idx]`)

- how to create a single dataset that outputs a datapoint from multiple datasets
in a balanced way, ie. sampling as uniformly as possible to sample from any one of the 
datasets: 

Let's say we have 3 datasets, ds0, ds1, ds2, each of which contains n0, n1, n2 datapoints/observations
respectively. Currently the implementation of `ConcatDataset` in `pytorch` samples a datapoint x from 
a single datasets d = [ds0, ds1, ds2] under a uniform distribution: p(x) = 1/(n0+n1+n2). Consequently, 
this "uniform" distribution puts a uniform probability mass on each datasample in the concatenated dataset, 
but the probability distribution of a sample coming from each dataset, say $\pi = [\pi_0, \pi_1, \pi_2]$ is not uniform, but rather a ratio of the number of samples, ie. $[n_0/n, n_1/n, n_2/n]$ where $n = n_0+n_1+n_2$.  
If we want $\pi$ to be a uniform distribution of selected source dataset, we 
could first compute the ratio of the dataset sizes, and input weighted number of datasets when creating 
the final, single dataset (of multiple sources).

We will demonstrate how to use the ratio of dataset sizes to create a single, multi-source dataset from multiple datasources, so that the final, multi-course dataset outputs a datapoint, uniformly from any consitutent data source.





## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [None]:
import pandas as pd

import numpy as np
import joblib
import matplotlib.pyplot as plt

from pathlib import Path
from typing import Any, List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable, TypeVar

from pprint import pprint
from ipdb import set_trace as brpt

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.tuner.tuning import Tuner


# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

## Set Path 
1. Add project root and src folders to `sys.path`
2. Set DATA_ROOT to `maptile_v2` folder

In [None]:
this_nb_path = Path(os.getcwd())
ROOT = this_nb_path.parent
SRC = ROOT/'src'
DATA_ROOT = Path("/data/hayley-old/maptiles_v2/")
paths2add = [this_nb_path, ROOT]

print("Project root: ", str(ROOT))
print('Src folder: ', str(SRC))
print("This nb path: ", str(this_nb_path))


for p in paths2add:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"\n{str(p)} added to the path.")
# print(sys.path)

In [None]:
from src.data.datamodules.maptiles_datamodule import MaptilesDataModule

from src.data.transforms.transforms import Identity, Unnormalizer, LinearRescaler, Monochromizer
from src.data.transforms.functional import unnormalize, to_monochrome

from src.visualize.utils import show_timg, show_timgs, show_batch, make_grid_from_tensors
from src.utils.misc import info
from collections import OrderedDict, defaultdict
from PIL import Image

### Concatenate MNIST-M and USPS datasets

- MNISTM
    - original size of an image: (1, 16,16)
    - labels: {0, ..., 9}
- USPS
    - original size of an image: (3, 28, 28)
    - labels" {0, ..., 9}
    

In [None]:
from src.data.datasets.mnistm import MNISTM
from torchvision.datasets import USPS


In [None]:
# MNISTM Dataset
bs = 16
num_workers = 16
pin_memory = True
in_shape = (3, 32,32)
xforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(in_shape[-2:]),
    ])
# target_xforms = transforms.Lambda(lambda y: torch.tensor(y)) # already-so
mnistm_ds = MNISTM(ROOT/'data', 
          transform=xforms,
          download=True)

mnistm_dl = DataLoader(mnistm_ds, batch_size=bs, shuffle=True, 
               num_workers=num_workers, pin_memory=pin_memory)


x,y = next(iter(mnistm_dl))
info(x)
info(y)
# show_timgs(x)

In [None]:
# USPS Dataset
bs = 16
num_workers = 16
pin_memory = True
n_channels = 3
in_shape = (n_channels, 32,32)
xforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(in_shape[-2:]),
    transforms.Lambda(lambda x: x.repeat((n_channels, 1, 1)))
    ])
target_xforms = transforms.Lambda(lambda y: torch.tensor(y))
usps_ds = USPS(ROOT/'data', 
          transform=xforms,
          target_transform=target_xforms,
          download=True)

usps_dl = DataLoader(usps_ds, batch_size=bs, shuffle=True, 
               num_workers=num_workers, pin_memory=pin_memory)


x,y = next(iter(usps_dl))
info(x)
info(y)
# show_timgs(x, cmap='gray')

In [None]:
# Concatenated dataset
ds = ConcatDataset([mnistm_ds, usps_ds])

# DataLoader w/o shuffling will iterate over the datasets in order 
# -- So, iterate over mnistm_ds and then iterate over usps_ds
ordered_dl = DataLoader(ds, batch_size=16, shuffle=False)
x, y = next(iter(ordered_dl))
show_timgs(x)
info(x)
info(y)

In [None]:
# DataLoader w/ shuffling will iterate over the concatenated dataset
# in random order
# -- So, iterate over a mixed images from mnistm_ds and usps_ds
shuffled_dl = DataLoader(ds, batch_size=32, shuffle=True)
x, y = next(iter(shuffled_dl))
show_timgs(x)
info(x)
info(y)

In [None]:
len(mnistm_ds), len(usps_ds)

Notice however, MNIST-M dataset has a lot more samples (60,000 vs. 7291). 
A quick fix to create a single dataloader (from multi dataset sources) so that each mini-batch of sample to have equal/balanced number of samples from each dataset, is to... pass in that many copies of the smaller-sized dataset:

- see: https://discuss.pytorch.org/t/train-simultaneously-on-two-datasets/649/36

In [None]:
# Balanced Concatenated dataset
n_copies = len(mnistm_ds)//len(usps_ds)
dsets = [mnistm_ds]
dsets.extend([usps_ds for i in range(n_copies)])

balanced_ds = ConcatDataset(dsets)

# DataLoader w/o shuffling will iterate over the datasets in order 
# -- So, iterate over mnistm_ds and then iterate over usps_ds
ordered_dl = DataLoader(balanced_ds, batch_size=16, shuffle=False)
x, y = next(iter(ordered_dl))
show_timgs(x)
info(x)
info(y)

In [None]:
# DataLoader w/ shuffling will iterate over the concatenated dataset
# in random order
# -- So, iterate over a mixed images from mnistm_ds and usps_ds
shuffled_dl = DataLoader(balanced_ds, batch_size=32, shuffle=True)
x, y = next(iter(shuffled_dl))
show_timgs(x)
info(x)
info(y)