### Mutual information and correlation over spatial domain

In [1]:
from logger.logger import setup_logging
from utils.configs import BaseConf
from utils.utils import get_data_sub_paths
from utils.setup import setup
from utils.preprocessing import Shaper
from pprint import pprint
import numpy as np
from utils.utils import describe_array
from scipy.stats import pearsonr
import plotly.graph_objects as go

In [2]:
data_sub_paths = get_data_sub_paths()
pprint(np.sort(data_sub_paths))

data_sub_path = 'T24H-X425M-Y440M_2012-01-01_2019-01-01'

array(['T12H-X850M-Y880M_2013-01-01_2017-01-01',
       'T1H-X1700M-Y1760M_2013-01-01_2017-01-01',
       'T24H-X1275M-Y1320M_2012-01-01_2019-01-01',
       'T24H-X1700M-Y1760M_2012-01-01_2019-01-01',
       'T24H-X255M-Y220M_2013-01-01_2017-01-01',
       'T24H-X425M-Y440M_2012-01-01_2019-01-01',
       'T24H-X425M-Y440M_2013-01-01_2017-01-01',
       'T24H-X850M-Y880M_2012-01-01_2019-01-01',
       'T24H-X850M-Y880M_2013-01-01_2017-01-01',
       'T24H-X85M-Y110M_2013-01-01_2017-01-01',
       'T3H-X850M-Y880M_2013-01-01_2017-01-01',
       'T6H-X850M-Y880M_2013-01-01_2017-01-01'], dtype='<U40')


In [3]:
conf, shaper, sparse_crimes, crime_feature_indices = setup(data_sub_path=data_sub_path)

In [4]:
for i,k in enumerate(crime_feature_indices):
    print(f"'{k}':{i},")

'TOTAL':0,
'THEFT':1,
'BATTERY':2,
'CRIMINAL DAMAGE':3,
'NARCOTICS':4,
'ASSAULT':5,
'BURGLARY':6,
'MOTOR VEHICLE THEFT':7,
'ROBBERY':8,
'Arrest':9,


In [5]:
i = 0
crimes = sparse_crimes[:,i:i+1]
print(conf.shaper_threshold) # sum over all time should be above this threshold
print(conf.shaper_top_k) # if larger than 0 we filter out only the top k most active cells of the data grid
new_shaper = Shaper(crimes, conf)

0
-1


In [6]:
dense_crimes = shaper.squeeze(sparse_crimes)

In [7]:
n,c,h,w = sparse_crimes.shape
n,c,h,w

(2557, 10, 94, 66)

In [8]:
coords = []
for y in range(h):
    for x in range(w):
        coords.append((y,x))
coords = np.array(coords)[shaper.index_mask]
yxmap = {tuple(yx): i for i, yx in enumerate(coords)}

In [9]:
# unit test for the mapping
failed = False
for yx, i in yxmap.items():
    y,x = yx
    if (dense_crimes[:,0,i]-sparse_crimes[:,0,y,x]).sum() != 0:
        print(f"failed: y,x,i => {y,x,i}")
        failed = True
        break
print("test passed" if not failed else "test failed")

test passed


In [10]:
def geo_correlation(dense_grid, i, t=0):
    """
    dense_grid: crime counts nd. (n,l)
     being index in axis 1 or l
    """
    n,l = dense_grid.shape
    centre = dense_grid[:,i]
    
    rs = []
    pvs = []
    
    if t == 0:    
        for j in range(l):
            if j == -1:
                r, pv = 0, 0
            else:
                r, pv = pearsonr(centre, dense_grid[:,j])
            rs.append(r)
            pvs.append(pv)
    else:
        for j in range(l):
            if j == -1:
                r, pv = 0, 0
            else:
                r, pv = pearsonr(centre[t:], dense_grid[:-t,j])
            rs.append(r)
            pvs.append(pv)    
                        
    rs = np.array(rs)
    pvs = np.array(pvs)
    return rs,pvs

In [188]:
from sparse_discrete_table import build_discrete_table, quick_cond_mutual_info, quick_mutual_info

def mutual_info(x, y, norm=False):
    """
    Determine the mutual information between x and y conditioned on z
    
    x: np.ndarray (N,d_x) with N obsevations of d_x dimensional vector
    y: np.ndarray (N,d_y) with N obsevations of d_y dimensional vector
    z: np.ndarray (N,d_z) with N obsevations of d_z dimensional vector
    norm: bool, if symmetric normalisation should be done using 0.5*(h(x)+h(y)) as normalising constant
    """    
    xy = np.stack([x, y], axis=1)
    dt = build_discrete_table(xy, ['x','y'])
    mi = dt.mutual_information(['x'], ['y'],norm)
    return mi

def conditional_mutual_info(x,y,z, norm=False):
    """
    Determine the mutual information between x and y conditioned on z
    
    x: np.ndarray (N,d_x) with N obsevations of d_x dimensional vector
    y: np.ndarray (N,d_y) with N obsevations of d_y dimensional vector
    z: np.ndarray (N,d_z) with N obsevations of d_z dimensional vector
    norm: bool, if asymmetric normalisation should be done using cmi(x,x,z) as normalising constant
    """
    xyz = np.stack([x, y, z], axis=1)
    dt = build_discrete_table(xyz, ['x','y', 'z'])
    cmi = dt.conditional_mutual_information(['x'], ['y'],['z'])
    if norm:
        return cmi/dt.conditional_mutual_information(['x'], ['x'],['z'])
    return cmi

In [295]:
X = np.random.randint([0,0,0],[4,2,2],(1000,3))
x,y,z = X[:,0],X[:,1],X[:,2]

xy = np.stack([x, y], axis=1)
dt = build_discrete_table(xy, ['x','x'])
nmi = dt.mutual_information(['x'], ['x'],True)
mi = dt.mutual_information(['x'], ['x'],False)



(1.9982664712050986, 1.0, 1.0, 0.0)

In [12]:
# define geo mutual information

In [13]:
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression

In [87]:
def heatmap(img,pv):
    h,w = img.shape
    hm = go.Heatmap(
        z=img,
        text=pv,
    )
    
    fw = go.FigureWidget(
        data=[hm],
        layout=dict(
            clickmode='event+select',
            width=600,
            height=800,
        )
    )
    return fw

In [88]:
dc = dense_crimes[:,0]
topk = dc.mean(0).argsort()[::-1]

In [89]:
state = dict(
    i=topk[0],
    t=0,
)

gcr,gcpv = geo_correlation(dc,state['i'],state['t'])
geo_corr_grid = shaper.unsqueeze(gcr.reshape(1,1,-1))[0,0]
geo_pv_grid = shaper.unsqueeze(gcpv.reshape(1,1,-1))[0,0]

In [90]:
from ipywidgets import widgets

In [91]:
def get_widget_value(change):
    if isinstance(change, dict) and change.get('name') == 'value':
        return change.get('new')
    return None

In [92]:
fw = heatmap(geo_corr_grid, geo_pv_grid)


def draw():
    global fw, state, dc
    
    gcr,gcpv = geo_correlation(dc,state['i'],state['t'])
    geo_corr_grid = shaper.unsqueeze(gcr.reshape(1,1,-1))[0,0]
    geo_pv_grid = shaper.unsqueeze(gcpv.reshape(1,1,-1))[0,0]
    
    with fw.batch_update():
        fw.data[0].z = geo_corr_grid
        fw.data[0].text = geo_pv_grid

def set_state(trace,points,selector):
    global state    
    
    y = points.ys[0]
    x = points.xs[0]
#     print(f"click {x},{y} => {mm[y,x]}")
#     print(f"click {i} => {dm[i]}")
    
    state["i"] = yxmap[(y,x)]
    draw()

    
def on_change(change):
    global state
    value = get_widget_value(change)
    if value is not None:
        state["t"] = value
        draw()
    
offset_slider = widgets.IntSlider(
    value=0,
    min=0, 
    max=30,
    step=1,
    description='Offset:',
    continuous_update=False,
)

offset_slider.observe(on_change)
    
fw.data[0].on_click(set_state)
# fw.data[0].on_selection(set_state|)
widgets.VBox([    
    offset_slider,
    fw,
])

VBox(children=(IntSlider(value=0, continuous_update=False, description='Offset:', max=30), FigureWidget({
    …

### Mutual Information Grids

In [93]:
def geo_mutual_info(dense_grid, i, t=0):
    """
    dense_grid: crime counts nd. (n,l)
     being index in axis 1 or l
    """
    n,l = dense_grid.shape
    centre = dense_grid[:,i]
        
    mis = [] 
    if t == 0:    
        for j in range(l):
            if j == i:
                mi = 0
            else:
                mi = mutual_info(centre, dense_grid[:,j],True)
            mis.append(mi)
    else:
        for j in range(l):
            if j == i:
                mi = 0
            else:
                mi = mutual_info(centre[t:], dense_grid[:-t,j], True)
            mis.append(mi)    
                        
    mis = np.array(mis)
    return mis

In [94]:
state = dict(
    i=topk[0],
    t=0,
)

geo_mi = geo_mutual_info(dc,state['i'],state['t'])
geo_mi_grid = shaper.unsqueeze(geo_mi.reshape(1,1,-1))[0,0]

In [113]:
fw = go.FigureWidget(
    data=[
        go.Heatmap(            
            z=geo_mi_grid,
            text=geo_mi_grid,
            zmin=0,
            zmax=.03,
        ),
#         go.Surface(            
#             z=geo_mi_grid,
#             text=geo_mi_grid,
#         ),                
    ],
    layout=dict(
        clickmode='event+select',
        width=600,
        height=800,
    )
)

def draw():
    global fw, state, dc
    
    x = state['x']
    y = state['y']
    t = state['t']
        
    i = yxmap.get((y,x))
    if i is None:
        return
    
    geo_mi = geo_mutual_info(dc,i,t)
    geo_mi_grid = shaper.unsqueeze(geo_mi.reshape(1,1,-1))[0,0]
    
    with fw.batch_update():
        fw.data[0].z = geo_mi_grid
        fw.update_layout(title={"text":f"y,x = {y,x}"})

def on_click_heatmap(trace,points,selector):
    global state    
    
    y = points.ys[0]
    x = points.xs[0]
#     print(f"click {x},{y} => {mm[y,x]}")
#     print(f"click {i} => {dm[i]}")    
    
    state["x"] = x
    state["y"] = y    
    draw()

    
def on_change_offset(change):
    global state
    value = get_widget_value(change)
    if value is not None:
        state["t"] = value
        draw()
    
offset_slider = widgets.IntSlider(
    value=0,
    min=0, 
    max=30,
    step=1,
    description='Offset:',
    continuous_update=False,
)

# callback registers
offset_slider.observe(on_change_offset)
fw.data[0].on_click(on_click_heatmap)



widgets.VBox([    
    offset_slider,
    fw,
])

VBox(children=(IntSlider(value=0, continuous_update=False, description='Offset:', max=30), FigureWidget({
    …