### Mutual information and correlation over spatial domain

In [1]:
from logger.logger import setup_logging
from utils.configs import BaseConf
from utils.utils import get_data_sub_paths
from utils.setup import setup
from utils.preprocessing import Shaper
from pprint import pprint
import numpy as np
from utils.utils import describe_array
from scipy.stats import pearsonr
import plotly.graph_objects as go
from sparse_discrete_table import new_discrete_table,quick_cond_mutual_info, quick_mutual_info, SparseDiscreteTable

In [2]:
data_sub_paths = get_data_sub_paths()
pprint(np.sort(data_sub_paths))

data_sub_path = 'T24H-X425M-Y440M_2012-01-01_2019-01-01'

array(['T12H-X850M-Y880M_2013-01-01_2017-01-01',
       'T1H-X1700M-Y1760M_2013-01-01_2017-01-01',
       'T24H-X1275M-Y1320M_2012-01-01_2019-01-01',
       'T24H-X1700M-Y1760M_2012-01-01_2019-01-01',
       'T24H-X255M-Y220M_2013-01-01_2017-01-01',
       'T24H-X425M-Y440M_2012-01-01_2019-01-01',
       'T24H-X425M-Y440M_2013-01-01_2017-01-01',
       'T24H-X850M-Y880M_2012-01-01_2019-01-01',
       'T24H-X850M-Y880M_2013-01-01_2017-01-01',
       'T24H-X85M-Y110M_2013-01-01_2017-01-01',
       'T3H-X850M-Y880M_2013-01-01_2017-01-01',
       'T6H-X850M-Y880M_2013-01-01_2017-01-01'], dtype='<U40')


In [3]:
conf, shaper, sparse_crimes, crime_feature_indices = setup(data_sub_path=data_sub_path)

In [165]:
for i,k in enumerate(crime_feature_indices):
    print(f"'{k}':{i},")

'TOTAL':0,
'THEFT':1,
'BATTERY':2,
'CRIMINAL DAMAGE':3,
'NARCOTICS':4,
'ASSAULT':5,
'BURGLARY':6,
'MOTOR VEHICLE THEFT':7,
'ROBBERY':8,
'Arrest':9,


In [166]:
i = 0
crimes = sparse_crimes[:,i:i+1]
print(conf.shaper_threshold) # sum over all time should be above this threshold
print(conf.shaper_top_k) # if larger than 0 we filter out only the top k most active cells of the data grid
new_shaper = Shaper(crimes, conf)

0
-1


In [167]:
dense_crimes = shaper.squeeze(sparse_crimes)

In [168]:
n,c,h,w = sparse_crimes.shape
n,c,h,w

(2557, 10, 94, 66)

In [169]:
def geo_correlation(dense_grid, i, t=0, filter_self=False):
    """
    dense_grid: crime counts nd. (n,l) being index in axis 1 or l
    filter_self: set the selected cell value lower to make the colorbar 
             more readable - select cell wil obviously have highest correlation
    """
    n,l = dense_grid.shape
    centre = dense_grid[:,i]
    
    rs = []
    pvs = []
    
    if t == 0:    
        for j in range(l):
            r, pv = pearsonr(centre, dense_grid[:,j])
            rs.append(r)
            pvs.append(pv)
    else:
        for j in range(l):
            r, pv = pearsonr(centre[t:], dense_grid[:-t,j])
            rs.append(r)
            pvs.append(pv)
            
    if filter_self:
        rs[i] = 0
                        
    rs = np.array(rs)
    pvs = np.array(pvs)
    return rs,pvs

In [170]:
def heatmap(img,pv):
    h,w = img.shape
    hm = go.Heatmap(
        z=img,
        text=pv,
    )
    
    fw = go.FigureWidget(
        data=[hm],
        layout=dict(
            clickmode='event+select',
            width=600,
            height=800,
        )
    )
    return fw

In [171]:
dc = dense_crimes[:,0]
dc = np.round(np.log2(1 + dc)) # scale
topk = dc.mean(0).argsort()[::-1]

In [172]:
# only use a single day of the week
# dow_i = np.array([i for i in range(len(dc)-7) if i % 7 == 0]) + 0
# dc = dc[dow_i]

### Geo-correlation

In [173]:
i = topk[0]
y,x = shaper.i_to_yx(i)
t=0
filter_self=False

state = dict(
    x=x,
    y=y,
    i=i,
    t=t,
    filter_self=filter_self,
)

gcr,gcpv = geo_correlation(dc,i,t,filter_self)
geo_corr_grid = shaper.unsqueeze(gcr.reshape(1,1,-1))[0,0]
geo_pv_grid = shaper.unsqueeze(gcpv.reshape(1,1,-1))[0,0]

In [174]:
from ipywidgets import widgets

In [175]:
def get_widget_value(change):
    if isinstance(change, dict) and change.get('name') == 'value':
        return change.get('new')
    return None

In [176]:
fw = heatmap(geo_corr_grid, geo_pv_grid)

def draw():
    global fw, state, dc
        
    x = state['x']
    y = state['y']
    t = state['t']        
    filter_self = state['filter_self']
    
    i = shaper.yx_to_i(y,x)
    if i is None:
        fw.update_layout(title={"text":f"current cell: invalid coordinate"})
        return
    
    gcr,gcpv = geo_correlation(dc,i,t, filter_self)
    geo_corr_grid = shaper.unsqueeze(gcr.reshape(1,1,-1))[0,0]
    geo_pv_grid = shaper.unsqueeze(gcpv.reshape(1,1,-1))[0,0]
    
    with fw.batch_update():
        fw.update_layout(title={"text":f"current cell: y,x = {y,x}"})
        fw.data[0].z = geo_corr_grid
        fw.data[0].text = geo_pv_grid

def set_state(trace,points,selector):
    global state    
    
    y = points.ys[0]
    x = points.xs[0]

    state["x"] = x
    state["y"] = y
    draw()

    
def on_change_offset_slider(change):
    global state
    value = get_widget_value(change)
    if value is not None:
        state["t"] = value
        draw()
    
offset_slider = widgets.IntSlider(
    value=0,
    min=0, 
    max=30,
    step=1,
    description='Offset:',
    continuous_update=False,
)
offset_slider.observe(on_change_offset_slider)

filter_self_checkbox = widgets.Checkbox(
    value=state['filter_self'],
    description='Filter self',
)


def set_filter_self(change):
    value = get_widget_value(change)
    if value is not None:
        state['filter_self'] = value 
        draw()

filter_self_checkbox.observe(set_filter_self)

    
fw.data[0].on_click(set_state)
# fw.data[0].on_selection(set_state|)
widgets.VBox([    
    filter_self_checkbox,
    offset_slider,
    fw,
])

VBox(children=(Checkbox(value=False, description='Filter self'), IntSlider(value=0, continuous_update=False, d…

### Mutual Information Grids

In [158]:
from sparse_discrete_table import quick_mutual_info, quick_cond_mutual_info

def mutual_info_over_time(a, max_offset=35,norm=True, lognorm=True, include_self=False):
    """
    a: np.ndarray (N,1) with the counts over time
    max_offset: furthest we compare to the signal in time
    norm: if the mutual information should be normalised between 0 and 1
    lognorm: if the array 'a' should be normalised: round(log2(1 + a))
    """
    if lognorm:
        a = np.round(np.log2(1 + a))
    mis = []
    if include_self:
        mi = quick_mutual_info(a,a,norm)
        mis.append(mi)        
    for t in range(1,max_offset+1):
        mi = quick_mutual_info(a[t:],a[:-t],norm)
        mis.append(mi)
    
    offsets = np.arange(1,len(mis)+1)
    return mis, offsets

def conditional_mutual_info_over_time(a, max_offset=35,norm=False, lognorm=True, include_self=False):
    """
    a: np.ndarray (N,1) with the counts over time
    max_offset: furthest we compare to the signal in time
    norm: if the mutual information should be normalised between 0 and 1
    lognorm: if the array 'a' should be normalised: round(log2(1 + a))
    """    
    if lognorm:
        a = np.round(np.log2(1 + a)).reshape(-1,1)
        
    dow = np.arange(len(a)) % 7
    cmis = []
    if include_self:
        cond = np.stack([dow,dow],axis=1)
        cmi = quick_cond_mutual_info(a,a,cond,norm)
        cmis.append(cmi)        
    for t in range(1,max_offset+1):
        cond = np.stack([dow[t:],dow[:-t]],axis=1)
        cmi = quick_cond_mutual_info(a[t:],a[:-t],cond,norm)
        cmis.append(cmi)
        
    offsets = np.arange(1,len(cmis)+1)
    return cmis, offsets

In [159]:
def geo_mutual_info(dense_grid, i, t=0, condition=False, filter_self=False, lognorm=True):
    """
    dense_grid: crime counts nd. (n,l)
     being index in axis 1 or l
    """
    if lognorm:
        dense_grid = np.round(np.log2(1 + dense_grid))
    
    n,l = dense_grid.shape
    centre = dense_grid[:,i]
    dow  = np.arange(len(centre))
        
    mis = [] 
    if t == 0:    
        for j in range(l):
            if condition:
                cond = dow
                mi = quick_cond_mutual_info(centre, dense_grid[:,j], cond,True)
            else:
                mi = quick_mutual_info(centre, dense_grid[:,j],True)
            mis.append(mi)
    else:
        for j in range(l):
            if condition:
                cond = np.stack([dow[t:],dow[:-t]], axis=1)
                mi = quick_cond_mutual_info(centre[t:], dense_grid[:-t,j],cond, True)
            else:            
                mi = quick_mutual_info(centre[t:], dense_grid[:-t,j], True)
            mis.append(mi)    
               
    if filter_self:
        mis[i] = 0
    mis = np.array(mis)
    return mis

In [160]:
state = dict(
    i=topk[0],
    t=0,
)

geo_mi = geo_mutual_info(dc,state['i'],state['t'],False)
geo_mi_grid = shaper.unsqueeze(geo_mi.reshape(1,1,-1))[0,0]

In [161]:
fw = go.FigureWidget(
    data=[
        go.Heatmap(            
            z=geo_mi_grid,
            text=geo_mi_grid,
            zmin=0,
            zmax=.03,
        ),
#         go.Surface(            
#             z=geo_mi_grid,
#             text=geo_mi_grid,
#         ),                
    ],
    layout=dict(
        clickmode='event+select',
        width=600,
        height=800,
    )
)

def draw():
    global fw, state, dc
    
    x = state['x']
    y = state['y']
    t = state['t']
        
    i = shaper.yx_to_i(y,x)
    if i is None:
        fw.update_layout(title={"text":f"current cell: invalid coordinate"})
        return
    
    geo_mi = geo_mutual_info(dc,i,t,False)
    geo_mi_grid = shaper.unsqueeze(geo_mi.reshape(1,1,-1))[0,0]
    
    with fw.batch_update():
        fw.data[0].z = geo_mi_grid
        fw.update_layout(title={"text":f"y,x = {y,x}"})

def on_click_heatmap(trace,points,selector):
    global state    
    
    y = points.ys[0]
    x = points.xs[0]
#     print(f"click {x},{y} => {mm[y,x]}")
#     print(f"click {i} => {dm[i]}")    
    
    state["x"] = x
    state["y"] = y    
    draw()

    
def on_change_offset(change):
    global state
    value = get_widget_value(change)
    if value is not None:
        state["t"] = value
        draw()
    
offset_slider = widgets.IntSlider(
    value=0,
    min=0, 
    max=30,
    step=1,
    description='Offset:',
    continuous_update=False,
)

# callback registers
offset_slider.observe(on_change_offset)
fw.data[0].on_click(on_click_heatmap)



widgets.VBox([    
    offset_slider,
    fw,
])

VBox(children=(IntSlider(value=0, continuous_update=False, description='Offset:', max=30), FigureWidget({
    …

KeyError: 'x'