# Experimental ML with Holoviews/Geoviews + PyTorch

- type: PyData LA 2019 Proposal
- date: 2019-09-21
- author: Hayley Song (haejinso@usc.edu)
- Prereq: 
    - Basic understanding of visaulization in python (eg. previously have used matplotlib.pyplot library)
    - Basic understanding of neural network training process 
    I'll give a brief overview of the workflow, assuming audiences' previous experience with the following concepts
        - mini-batch training
        - forward-pass, backword-pass 
        - gradient, gradient descent algorithm
        - classification, semantic segmentation
        - image data stored as numpy ndarray
        

## Load Libraries


In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, time
import numpy as np
import pandas as pd
    
from pathlib import Path
from pprint import pprint as pp

import joblib
import pdb

import matplotlib.pyplot as plt
%matplotlib inline

# ignore warnings
import warnings
if not sys.warnoptions:
    warnings.simplefilter('ignore')
    
# Don't generate bytecode
sys.dont_write_bytecode = True

In [None]:
import holoviews as hv
import xarray as xr

from holoviews import opts
from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
from holoviews.streams import Stream, param
from holoviews import streams
import geoviews as gv
import geoviews.feature as gf
from geoviews import tile_sources as gvts


import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cf

hv.notebook_extension('bokeh')
hv.Dimension.type_formatters[np.datetime64] = '%Y-%m-%d'

# Dashboards
import param as pm, panel as pn

In [None]:
# Geoviews visualization default options
H,W, = 500,500
opts.defaults(
    opts.RGB(height=H, width=W, tools=['hover'], active_tools=['wheel_zoom']),
    opts.Image(height=H, width=W, tools=['hover'], active_tools=['wheel_zoom'], framewise=True),#axiswise=True ),
    opts.Image('mask', alpha=0.3),

    opts.Points( tools=['hover'], active_tools=['wheel_zoom']),
    opts.Path(height=H, width=W, tools=['hover'], active_tools=['wheel_zoom']),
    opts.Tiles(height=H, width=W, tools=['hover'], active_tools=['wheel_zoom']),


)

## Set up additional library path

In [None]:
# Add the utils directory to the search path
SP_ROOT = Path.home()/'Playground/ContextNet'
SP_LIBS = SP_ROOT/'scripts' # to be changed to 'src'
LIBS_DIR = Path('../src').absolute()
DIRS_TO_ADD = [SP_LIBS, LIBS_DIR]
for p in DIRS_TO_ADD:
    assert p.exists()
    
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"Added to sys.path: {p}")

# pp(sys.path)
    

In [None]:
from output_helpers import print_mro as mro, nprint
from naming_helpers import get_sp_mask1300_fn
import SpacenetPath as spp
import spacenet_globals as spg


## Step 1: Explore your dataset
    

In [None]:
city = 'vegas'
rgb8_dir = spp.sample_rgb8_dirs[city]
mask_dir = spp.sample_mask_dirs[city]
sp_vec_dir = spp.sample_road_vec_dirs[city]
osm_mask_dir = spp.sample_mask_dirs[city]

In [None]:
rgb_fns = sorted([rgb8_dir/fn for fn in rgb8_dir.ls() if Path(fn).suffix in ['.tif', '.tiff']])
mask_fns = sorted([mask_dir/fn for fn in mask_dir.ls() if Path(fn).suffix in ['.tif', '.tiff']])

In [None]:
for rgb_fn, mask_fn in zip(rgb_fns, mask_fns):
    assert rgb_fn.exists() and mask_fn.exists()
    

In [None]:
# def read_img_and_mask(idx):
#     for img_type, fns in dict(rgb=rgb_fns, mask=mask_fns):
#         da = xr.open_rasterio(fns[idx])
#         nc = len(da.band) #number of bands (ie. channels)


In [None]:
idx = 1
rgb_da = xr.open_rasterio(rgb_fns[idx])/255.
mask_da = xr.open_rasterio(mask_fns[idx])/255.
r,g,b = map(np.asarray,[rgb_da.sel(band=1), rgb_da.sel(band=2), rgb_da.sel(band=3)])
xs, ys = np.array(rgb_da.coords['x']), np.array(rgb_da.coords['y'])

## Library Introduction: `xarray`

- [xarray fundamentals by Ryan Abernat](https://tinyurl.com/y3d2l86g)

In [None]:
rgb_data = dict(lon=xs,
                lat=ys,
                R=r,
                G=g,
                B=b)
# rgb_data

In [None]:
my_rgb_da = xr.DataArray(data=np.dstack([r,g,b]),
                         dims=['y','x','band'],
                         coords={'y': ys,
                                 'x': xs,
                                'band': 'R G B'.split()}
                        )
                                        
my_rgb_da

In [None]:
# Have R,G,B as separate dimensions, instead a single dimenion of `band`
# To do so, we actually need to make it a Dataset
red_da = xr.DataArray(data=r,
                      dims=['y','x'],
                      coords={'y': ys,'x': xs})
green_da = xr.DataArray(data=g,
                      dims=['y','x'],
                      coords={'y': ys,'x': xs})
                                        
blue_da = xr.DataArray(data=b,
                      dims=['y','x'],
                      coords={'y': ys,'x': xs})

In [None]:
# red_da.plot()
# green_da.plot()
# blue_da.plot()

In [None]:
rgb_ds = xr.Dataset(data_vars={'R': red_da, 
                      'G': green_da,
                      'B': blue_da},
           coords={'y': ys, 'x': xs})
rgb_ds

In [None]:
# gv.RGB(rgb_ds, kdims=['x','y'], vdims=['R','G','B'])

End of `xarray` introduction

In [None]:
# # Read as gv elements
# # gv.RGB(rgb_da, kdims=['x','y'], vdims='R G B'.split()) #fails
# gv_rgb = gv.RGB((xs,ys,r,g,b), 
#                 kdims=['Longitude', 'Latitude'], 
#                 vdims='R G B'.split(), 
#                 crs=ccrs.PlateCarree(),
#                group='rgb')

# # gv_mask = gv.Image((xs,ys,mask_da.sel(band=1))) #ok?
# gv_mask = gv.Image(mask_da, kdims=['x','y'],  crs=ccrs.PlateCarree(), group='mask')
# gv_rgb + gv_mask

## Library Introduction: `rasterio`
Alternatively, we can use `rasterio` to read the `geotiff` files

In [None]:
import rasterio as rio
from rasterio.plot import reshape_as_image

In [None]:
rgb_ds = rio.open(rgb_fns[idx])
rgb_bounds = rgb_ds.bounds
rgb_img = reshape_as_image(rgb_ds.read())/255.

mask_ds = rio.open(mask_fns[idx])
mask_bounds = mask_ds.bounds
mask_img = reshape_as_image(mask_ds.read())/255.

In [None]:
hv_rgb = hv.RGB(rgb_img, bounds=rgb_bounds, group='rgb').redim(x='Longitude', y='Latitude')
hv_mask = hv.Image(mask_img, bounds=mask_bounds, group='mask').redim(x='Longitude', y='Latitude')

In [None]:
hv_rgb + hv_mask

### Summary: `xarray.open_rasterio()` vs. `rasterio.open()`

|||
|---|---|
|`xarray.open_rasterio()`| `rasterio.open()`|
|it computes the coordinate points based on the `transform` matrix in the GeoTiff file| computes the `bounds` information upon metadata reading|

In [None]:
# Let's make a dynamic map (~ callback)
def get_hv_rgb(fn):
    ds = rio.open(fn)
    bounds = ds.bounds
    img = reshape_as_image(ds.read())/255.
    return hv.RGB(img, bounds=bounds).redim(x='Longitude', y='Latitude')
# get_hv_rgb(rgb_fns[0])
dmap = hv.DynamicMap(lambda fn: get_hv_rgb(fn), 
                     kdims=['fn'])


# Better to consistently use geoviews for geospatial datasets
def get_gv_rgb(fn):
    rgb_da = xr.open_rasterio(fn)/255.
    r,g,b = map(np.asarray,[rgb_da.sel(band=1), rgb_da.sel(band=2), rgb_da.sel(band=3)])
    xs, ys = np.array(rgb_da.coords['x']), np.array(rgb_da.coords['y'])
    # gv.RGB(rgb_da, kdims=['x','y'], vdims='R G B'.split()) #fails
    return gv.RGB((xs,ys,r,g,b), 
                    kdims=['Longitude', 'Latitude'], 
                    vdims='R G B'.split(), 
                    crs=ccrs.PlateCarree(),
                   group='rgb')

def get_gv_mask(fn):
    da = xr.open_rasterio(fn)/255.
    return gv.Image(da, kdims=['x','y'],  crs=ccrs.PlateCarree(), group='mask')


In [None]:
# Let's make it into a small gui

class DataExplorer(pm.Parameterized):
    rgb_fn = pm.Selector(rgb_fns)
    mask_fn = pm.Selector(mask_fns)
    mask_alpha = pm.Magnitude(0.5)
    
    @pm.depends('rgb_fn', watch=True)
    def get_gv_rgb(self):
        rgb_da = xr.open_rasterio(self.rgb_fn)/255.
        r,g,b = map(np.asarray,[rgb_da.sel(band=1), rgb_da.sel(band=2), rgb_da.sel(band=3)])
        xs, ys = np.array(rgb_da.coords['x']), np.array(rgb_da.coords['y'])
        # gv.RGB(rgb_da, kdims=['x','y'], vdims='R G B'.split()) #fails
        return gv.RGB((xs,ys,r,g,b), 
                        kdims=['Longitude', 'Latitude'], 
                        vdims='R G B'.split(), 
                        crs=ccrs.PlateCarree(),
                       group='rgb')

    @pm.depends('mask_fn', watch=True)
    def get_gv_mask(self):
        da = xr.open_rasterio(self.mask_fn)/255.
        return gv.Image(da,
                        kdims=['x','y'],  
                        crs=ccrs.PlateCarree(), 
                        group='mask').redim(z='RT')
    
    
    @pm.depends('rgb_fn', watch=True)
    def get_bounds(self):
        ds = rio.open(self.rgb_fn)
        xmin, ymin, xmax, ymax = ds.bounds
        return hv.Div(
            f""" 
            <h2>Bounds</h2>
            <p>lon:{xmin, xmax},</p>
            <p>lat: {ymin, ymax}</p>""")
        
        
    def viewable(self):
        dmap_rgb = hv.DynamicMap(self.get_gv_rgb).opts(show_legend=False)
        dmap_mask = hv.DynamicMap(self.get_gv_mask)
        opted = dmap_mask.apply.opts(alpha=self.param.mask_alpha)
        
        dmap_bounds = hv.DynamicMap(self.get_bounds)
        return dmap_rgb * opted + dmap_bounds

ex = DataExplorer()
pn.Row(ex.param, ex.viewable)

In [None]:
# Let's make it into a small gui

class ImagePairExplorer(pm.Parameterized):
    rgb_fn = pm.Selector(rgb_fns)
    mask_alpha = pm.Magnitude(0.5)
    show_legend = pm.Boolean(False)
    show_rgb_hover = pm.Boolean(False)
    show_mask_hover = pm.Boolean(True)
    
    @pm.depends('rgb_fn', watch=True)
    def get_gv_rgb(self):
        rgb_da = xr.open_rasterio(self.rgb_fn)/255.
        r,g,b = map(np.asarray,[rgb_da.sel(band=1), rgb_da.sel(band=2), rgb_da.sel(band=3)])
        xs, ys = np.array(rgb_da.coords['x']), np.array(rgb_da.coords['y'])
        # gv.RGB(rgb_da, kdims=['x','y'], vdims='R G B'.split()) #fails
        return gv.RGB((xs,ys,r,g,b), 
                        kdims=['Longitude', 'Latitude'], 
                        vdims='R G B'.split(), 
                        crs=ccrs.PlateCarree(),
                       group='rgb')
    
    @pm.depends('rgb_fn', watch=True)
    def get_gv_mask(self):
        fn = get_sp_mask1300_fn(self.rgb_fn)
        da = xr.open_rasterio(fn)/255.
        return gv.Image(da,
                        kdims=['x','y'],  
                        crs=ccrs.PlateCarree(), 
                        group='mask').redim(z='RT')
    
    @pm.depends('rgb_fn', watch=True)
    def get_bounds(self):
        ds = rio.open(self.rgb_fn)
        xmin, ymin, xmax, ymax = ds.bounds
        return hv.Div(
            f""" 
            <h2>Bounds</h2>
            <p>lon:{xmin, xmax},</p>
            <p>lat: {ymin, ymax}</p>""")
        
    def viewable(self):
        dmap_rgb = hv.DynamicMap(self.get_gv_rgb)#.opts(show_legend=False)
        dmap_rgb_opted = dmap_rgb.apply.opts(
            opts.RGB(
                tools=['hover'] if self.param.show_rgb_hover else [],
#                 show_legend=self.param.show_legend
            )
        )
        
        dmap_mask = hv.DynamicMap(self.get_gv_mask)
        dmap_mask_opted = dmap_mask.apply.opts(
            opts.Image(
#                 alpha=self.param.mask_alpha, 
                       tools=['hover'] if self.param.show_mask_hover else ['tap'])
        )
        
        
        dmap_bounds = hv.DynamicMap(self.get_bounds)
        return dmap_rgb_opted * dmap_mask_opted + dmap_bounds

In [None]:
ex = ImagePairExplorer()
pn.Row(ex.param, ex.viewable)

## Library introduction: `osmnx`

- Refer to this [overview](https://github.com/gboeing/osmnx-examples/blob/master/notebooks/01-overview-osmnx.ipynb)
- [Automating GIS](https://tinyurl.com/y6ncxg93)

#### Main functions
- Download road network data from OSM 
    - OSM stores data in EPSG:4326 CRS (ie. LAT/LON) 
    - `ox.graph_from_place`, `ox.graph_from_polygon`, `ox.graph_from_bbox`, `ox.graph_from_point`, and a couple more
- Project the data to a proper UTM zone 
    - `ox.project_graph`
    - useful when performing spatial computation in standard metrics (eg. distance between two points, area of a polygon)
    - supports network analysis and basic geostatistics on the network 
- Plot the network graph for visualization
    - `ox.plot_graph`
- Save the figure 
    
#### Useful function
- Simply street network
    - `ox.is_endpoint(G, node)`: checks if `node` is a valid intersection node in graph G (`valid` in the graph theoretical sense)
    - `ox.simplify_graph(G)`: removes nodes that are not network nodes 
        - For example, in OSM, it's common to see multiple intermediate points along a curve line. These intermediate points are not real `nodes` in graph theoretic sense. `simplfy_graph` removes these nodes. See Part 3 of this [notebook](https://tinyurl.com/y53jcpw5)

- Calculate basic network metrics
    - `ox.basic_stats(G)`
    - eg: `circuity_avg`
        
We are going to use `osmnx` to fetch the road network data from OSM of the regions we are interested in. Since `osmnx` utilizes caching on the downloaded area, it is better to download the data for the entire ROI, rather than doing so for each image tile.

In [None]:
# Let add an action button to fetch osm roads
import osmnx as ox

Since we have `bounds` information, we are going to download the OSM data using `ox.graph_from_bbox` function.

In [None]:
bounds = rgb_bounds
north, south, east, west = bounds.top, bounds.bottom, bounds.right, bounds.left
G1 = ox.graph_from_bbox(north, south, east, west, network_type='all') #'all_private'
ox.plot_graph(G1);

`G1`is a `networkx`'s `MultiDiGraph` object. We would like to look at the attributes of this network. For instance, what are the road type of each edge on the graph? The easiest way to inspect the attributes of the network is to convert the network into a `Geopandas.DataFrame` object.  Remember a graph is defined as a set of nodes and edges. So, when we convert the network graph into a DataFrame object, we extract nodes and edges information and store them into two distinct DataFrame objects.

In [None]:
gdf_nodes, gdf_edges = ox.graph_to_gdfs(G1, nodes=True, edges=True)

In [None]:
gdf_nodes.head()

In [None]:
gdf_edges.head()

We will focus on the edges (ie. the roads) for our tutorial. Notice that the Geopandas DataFrame (GPD) has a column called `geometry`. This column stores the geometry information of the road segment in `EPSG:4326` which is the spatial coordinate system used by OSM for storing data. Let's remove some columns that store metadata we are not going to use.

In [None]:
gdf_edges.drop(columns=['ref', 'service', 'u','v'], inplace=True)
print(np.unique(gdf_edges.geom_type))
gdf_edges.head()

Now that we have the `Geopandas.DataFrame` object, we can easily use `Geoviews` to visualize this road network. All the geometry objects are of type `LINESTRING`, so we are going to use `gv.Path` constructor.

In [None]:
gv_osm = gv.Path(gdf_edges)

In [None]:
gv_osm

Great! Let's overlay this on top of our Spacenet RGB and mask rasters.

In [None]:
# display(hv_rgb.opts(axiswise=True) + hv_mask.opts(axiswise=True) + gv_osm.opts(axiswise=True))
display(hv_rgb.opts(axiswise=True).redim(x='Longitude', y='Latitude') * hv_mask.opts(axiswise=True)) #* gv_osm.opts(axiswise=True))

In [None]:

display(gv_rgb.opts(tools=[],height=500, width=500) * gv_mask * gv_osm.opts(axiswise=True))

Now let's put all the pieces together and make an explorer for 
- Spacenet RGB raster image
- Spacenet Road vector linestrings
- Downloaded OSM Road vector Linestrings


In [None]:
basemap = hv.element.tiles.tile_sources['EsriImagery']()

class DSExplorer(pm.Parameterized):
    rgb_fn = pm.Selector(rgb_fns)
    mask_alpha = pm.Magnitude(0.3)
    show_legend = pm.Boolean(False)
#     show_rgb_hover = pm.Boolean(False)
#     show_mask_hover = pm.Boolean(True)
    
    # OSM download parameters
    action_dl = pm.Action(lambda x: x.param.trigger('action_dl'), label='click to download osm')
    osm_log = pm.String(default="", 
                        label="OSM log",
                        doc="Log for OSM download status")
    osm_dl_count = pm.Number(0)#, precedence=0) #inivisible widget
    
    
    def __init__(self, **params):
        super().__init__(**params)
        self.osm_g = None
        self.osm_edges = gpd.GeoDataFrame()
        print('initialized')
    
    ################################################################################
    # Methods
    ################################################################################
    @pm.depends('rgb_fn', watch=True)
    def get_gv_rgb(self):
        rgb_da = xr.open_rasterio(self.rgb_fn)/255.
        r,g,b = map(np.asarray,[rgb_da.sel(band=1), rgb_da.sel(band=2), rgb_da.sel(band=3)])
        xs, ys = np.array(rgb_da.coords['x']), np.array(rgb_da.coords['y'])
        # gv.RGB(rgb_da, kdims=['x','y'], vdims='R G B'.split()) #fails
        return gv.RGB((xs,ys,r,g,b), 
                        kdims=['Longitude', 'Latitude'], 
                        vdims='R G B'.split(), 
                        crs=ccrs.PlateCarree(),
                       group='rgb')
    
    @pm.depends('rgb_fn', watch=True)
    def get_gv_mask(self):
        fn = get_sp_mask1300_fn(self.rgb_fn)
        da = xr.open_rasterio(fn)/255.
        return gv.Image(da,
                        kdims=['x','y'],  
                        crs=ccrs.PlateCarree(), 
                        group='mask').redim(z='RT')
    
    @pm.depends('rgb_fn', watch=True)
    def get_bounds(self):
        ds = rio.open(self.rgb_fn)
        xmin, ymin, xmax, ymax = ds.bounds
        return hv.Div(
            f""" 
            <h2>Bounds</h2>
            <p>lon:{xmin, xmax},</p>
            <p>lat: {ymin, ymax}</p>""")
    
    @param.depends('action_dl', watch=True)
    def _download_osm(self):
        print('Started downloading osm data')
        with rio.open(self.rgb_fn) as ds:
            bounds = ds.bounds
        north, south, east, west = bounds.top, bounds.bottom, bounds.right, bounds.left
        self.osm_g = ox.graph_from_bbox(north, south, east, west)
        self.osm_edges = ox.graph_to_gdfs(self.osm_g, edges=True, nodes=False)
                                                      
        print("OSM data downloaded")
        self.osm_dl_count += 1
        self.osm_log = f'OSM downloaded: {bounds}'
        return gv.Path(self.osm_edges, group='osm')#.relabel(x='Longigude', y='Latitude')

    ################################################################################
    # Viewable
    ################################################################################
    def viewable(self):
        dmap_rgb = hv.DynamicMap(self.get_gv_rgb).opts(show_legend=False,
                                                      tools=[])
        dmap_mask = hv.DynamicMap(self.get_gv_mask)
        dmap_mask_opted = dmap_mask.apply.opts(
            opts.Image(alpha=self.param.mask_alpha)
        )
#         dmap_bounds = hv.DynamicMap(self.get_bounds)
        dmap_osm = hv.DynamicMap(self._download_osm)
        layout = dmap_rgb * dmap_mask + basemap * dmap_osm
        return layout.cols(2)

In [None]:
ex = DSExplorer()
col = pn.Column()
for p in ex.param:
    col.append(pn.panel(ex.param[p]))
pn.Row(col, ex.viewable).servable()

## Step 2: Monitor the training process 

### Build a simple semantic segmentation neural network

In [None]:
# data loader
from torchvision import models 
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()

In [None]:
#todo: link to https://tinyurl.com/yxgck3pp


In [None]:
# Apply the transformations needed
import torchvision.transforms as T
trf = T.Compose([T.Resize(256),
                 T.CenterCrop(224),
                 T.ToTensor(), 
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])
inp = trf(img).unsqueeze(0)

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import torch

from skimage import io, transform


In [None]:
class SpDataset(Dataset):
    def __init__(self, rgb_fns, mask_fns=None, transform=None):
        self.rgb_fns = rgb_fns
        self.mask_fns = mask_fns
        self.transform = transform
    def __len__(self):
        return len(rgb_fns)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        rgb_fn = self.rgb_fns[idx]
        try:
            mask_fn = self.mask_fns[idx]
        except:
            mask_fn = get_sp_mask1300_fn(rgb_fn)
        with rio.open(rgb_fn) as ds:
            rgb_np = reshape_as_image(ds.read())
        with rio.open(mask_fn) as ds:
            mask_np = reshape_as_image(ds.read())
        print('rgb, mask: ', rgb_np.shape, mask_np.shape)
        
        sample = {'rgb': rgb_np, 'mask': mask_np}
        
        if self.transform is not None:
            sample = self.transform(sample)
        return sample
            
            
        

In [None]:
spd = SpDataset(rgb_fns)

In [None]:
spd[0];

In [None]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        rgb_np, mask_np = sample['rgb'], sample['mask']

        h, w = rgb_np.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        rgb_np = transform.resize(rgb_np, (new_h, new_w))
        mask_np = transform.resize(mask_np, (new_h, new_w))
        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively

        return {'rgb': rgb_np, 'mask': mask_np}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        rgb_np, mask_np = sample['rgb'], sample['mask']

        h, w = rgb_np.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        rgb_np = rgb_np[top: top + new_h,
                      left: left + new_w]
        mask_np = mask_np[top: top + new_h,
                      left: left + new_w]


        return {'rgb': rgb_np, 'mask': mask_np}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        rgb_np, mask_np = sample['rgb'], sample['mask']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        rgb = rgb_np.transpose((2, 0, 1))
        rgb = np.asarray(rgb, dtype=np.float32)
        mask = np.asarray(mask_np.squeeze(), dtype=np.float32)
        return {'rgb': torch.from_numpy(rgb),
                'mask': torch.from_numpy(mask)}
    
# todo: add pair transforms.Normalized
class PairNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, sample):
        # sample contains tensors
        normalizer = transforms.Normalize(mean=self.mean, std=self.std)
        return {"rgb": normalizer(sample['rgb']),
                "mask": sample['mask']}
        

In [None]:
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])
def show_sample_mpl(sample):
    """
    todo: add description
    """
    f,ax = plt.subplots(1,2)
    for i, (img_type, img) in enumerate(sample.items()):
        ax[i].set_title(img_type)
        ax[i].imshow(img.squeeze())
    return ax

def show_sample_gv(sample):
    """
    todo: make it geoviews
    """
    f,ax = plt.subplots(1,2)
    for i, (img_type, img) in enumerate(sample.items()):
        ax[i].set_title(img_type)
        ax[i].imshow(img.squeeze())
    return ax

In [None]:
# Apply each of the above transforms on sample.
idx = 2
sample = spd[idx]
trsfmed = composed(sample)
show_sample(trsfmed);

In [None]:
# transformed dataset
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
tsfmed_ds = SpDataset(rgb_fns, transform=transforms.Compose([
    Rescale(256), 
    RandomCrop(224),
    ToTensor(),
    PairNormalize(mean=IMAGENET_MEAN,
                 std=IMAGENET_STD)
]))
    
    
    

In [None]:
for i in range(len(tsfmed_ds)):
    sample = tsfmed_ds[i]

    print(i, sample['rgb'].size(), sample['mask'].size())

    if i == 3:
        break

In [None]:
# dataloader
# loader settings
batch_size = 1
shuffle = True

dataloader = DataLoader(tsfmed_ds, batch_size=batch_size,
                        shuffle=shuffle, num_workers=4)


In [None]:
it = iter(dataloader)

In [None]:
sample = next(it)
rgb_np = reshape_as_image(sample['rgb'].numpy().squeeze())
mask_np = sample['mask'].numpy().squeeze()
show_sample_mpl({'rgb': rgb_np, 
                 'mask': mask_np})

In [None]:
out = fcn(sample['rgb'])['out']
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
print(om.shape, np.unique(om))
plt.imshow(om)

In [None]:
out.shape

## Step 3: Interactively test your trained model on the new data

In [None]:
from torchvision import models

## Step 4: What have the model learned?

## Step 5: Examples

## Step 6: Summary
- Main Takeway

- Resources
    - General: 
        - Github repo for this talk:
        - Holoviews:
        - PyTorch:

    - Data:
        - Remote sensing data: google-earth-engine
    - More from PyViz team:
        - Link to scipy tutorials:
        - Panel
        - 

---

- Prepare your dataset: train, validation, test
    - classification: 
        - eg: airplane/not-airplane, cat/dog/giraffe, land cover classifiation (forest, road, ...)
        - eg: semantic segmentation: classify each pixel into a label in the label categories
    This talk focuses on the semantic segmentation. So our dataset consists of the input image (RGB) and the target image which will be a "mask" image whose pixel has one of the labels in {'highway', 'track', 'dirt', 'others'}
    
    - clustering: 

===================
PyCon Talk Proposal
===================

:Title: Experimental ML with PyViz + PyTorch
:Duration: 30 min
:Level: Intermediate
:Categories: ?

Summary
=======

Both newcomers and experienced developers alike love Python's built-in data types — especially dictionaries!  But how do dictionaries work? What do they do better than other container types, and where, on the other hand, are their weaknesses?  Using simple, vivid diagrams that show the secrets of how the dictionary is implemented, and a series of progressively interesting examples of its behavior, we will train the Python developer's mind to picture what the dictionary is doing in just enough detail to make good decisions, as your data sets get larger, about when to use dictionaries and when other data structures might be more appropriate.

Description
===========

With some judicious use of ``ctypes``, one can write a Python routine
that dissects a Python dictionary or set and displays its internals.
Using such a tool — which I will also release on PyPI, to accompany the
presentation slides — my presentation will show how dictionaries behave
as you add items, overwrite them later, remove them, and iterate across
the whole dictionary.

By contemplating these normally hidden mechanics, and by showing some
judicious results from the ``timeit`` module, both newcomers and
experienced developers can gain new insight into the trade-offs that
dictionaries provide between space and computational complexity,
compared to the other alternatives in Python.  They will also understand
why Python provides a ``hash()`` function; why user-defined classes are
given the freedom to define their own hash function as well; and, what
happens if they choose not to.

The talk will actually discuss sets for most of its length, since they
are simpler to diagram and understand, then show, at the end, how a
dictionary is just a set with a second column, that holds a reference to
an object stored at that key value.

The talk will go something like this — each of the following 5 items,
I'm imagining, will take up about five minutes (and probably five to ten
slides) of my presentation, adding together to 25 minutes (leaving
5 minutes left over for questions):

1. Computer memory is like a Python list
----------------------------------------

Computer memory is indexed by integers, like Python lists (though the
indexes tend to be much bigger!).  So a Python list is simple: it's an
array of numbers, each indicating where in memory a Python object is
stored, and Python can jump directly to list item *n*, but has to
iterate across the whole list to find whether a particular item is in
the list.

An ordered list would let you find items more quickly, by jumping in
halfway through, and then restricting your search to one half of the
remaining list, just like looking for a name in a telephone book.  But,
the cost would still grow as the list grew longer.  And, lists would be
expensive to keep ordered!

So, let's think about another plan.  In a normal list, items wind up at
all sorts of indexes.  What if we created a list, and magically knew
ahead of time exactly where each Python object belonged?  Then we could
jump right to a given item, immediately, every time!

2. The idea of a hash table
---------------------------

We would need a function, called a *hash function*, that when given a
certain value — like the number 42, or the string ``"Ni"`` — always
returns the same index.  Python provides this with a built-in called
``hash()`` for which each built-in type provides an implementation.

As an example, we will examine an empty ``set()`` — “look, it starts
with space for eight items, even when it's empty!” — then we run
``hash()`` on three simple Python values; then we insert them into the
set, and see them land right at the indexes where ``hash()`` told them
to.

We now see the trade-off a hash table makes: in return for holding open
several empty slots, and thus spending *memory*, it can find an item (or
discover that it's absent) after only incurring the *static* cost of
computing a hash.  With a few ``timeit`` tests, we determine how costly
it is to compute a hash compared to two simple list operations: jumping
directly to list item *n*, versus iterating across a small or large list
to find an item.  Very small lists are very fast, but quickly become
more expensive than computing the hash value to look in a set or
dictionary.

3. When indexes collide
-----------------------

The ``hash()`` function has to return a limited range of values for an
unlimited range of inputs, so many objects *collide*.  I will create a
collision in the set shown in the slides, and show how the hash table
shunts aside the second item and puts it in a second spot that it can
find it quickly again when we ask.  Removing the collision can still
leave the other object stranded where it was put, so the cost of a
collision can linger.

When I add a fifth item to the set, it suddenly becomes 32 items long!
This is to prevent collisions from piling up too deep; both the size of
the hash table, *and* some of its behavior, are thus driven by the need
to handle collisions.  I will note that, when a set or dictionary is
re-sized, all of its contents are re-inserted, so that any junk left
over gets periodically erased as long as the dictionary is occasionally
growing or shrinking.

I will show how the cost of re-allocating the whole hash table is
reasonable if spread across many hundreds of set inserts, and also
quickly show an animation of the dictionary growing, then shrinking as
items are removed.  By showing some animations of how a dictionary
"looks" as it grows and gets used, using some real-world data from
observing a dictionary in one of my own applications, I will give a feel
for how they behave in the wild.

4. Providing your own hash function
-----------------------------------

The critical idea of giving one of your own classes its own
``__hash__()`` method is whether each member of your class represents a
*value* that could later — or even simultaneously — be represented by a
different instance of your class.  I will show how we can easily create
two floats with different ``id()`` but the same value (such that they
satisfy Python equality), and show how they both go into the same
dictionary slot because they have the same hash value.

With a few examples and simple illustrations, I will show how simply
calling ``hash()`` on the instance variables that give your class
instance its own unique value, and combining their values together, you
can create a decent hash for your own class.

What if your class has no hash routine?  I will show how objects are
then tested for uniqueness, rather than value, and how deleting and
re-creating the "same" object gives it a different hash-table slot.

5. The dictionary, and its alternatives
---------------------------------------

First, I finally show the dictionary in all of its glory: like a set, it
keeps items in a hash table; but it adds a second column that for each
key provides a "value".

I will then show an animation of iteration across a dictionary: why the
objects come out in random order, and why it's dangerous to modify the
dictionary during iteration.

Finally, I will briefly discuss alternatives.  If you want items back
out in order, rather than having random access, use a ``heapq``.  If you
only add and remove objects from the ends of a series, use a ``deque``.
If you need both key-value referencing *and* ordering, then (for today)
you might just use ``sorted()`` on your keys each time, or (in the
future) use an ``OrderedDict``.

But, for most uses, the List and Dictionary are king, and the audience
will now hopefully understand why they're each perfect for their common
uses.  “Any questions?”