## Mutual Information Feature Selection

The aim of this notebook:
- Look at column features and how they influence a cell's self-exciting property (mutual info)

# Why conditional mutual information is larger than mutual information
The conditional variable has more possible states and therfore the join has a higher entropy - the maximum value it can take on can therefore be higher - to make plots easier to read we can normalize the mutual information with the mutual information of the target signal with itself.

$I(X;Y) = H(X) + H(Y) - H(X,Y)$


$I(X;Y|Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z)$

In [1]:
from sparse_discrete_table import SparseDiscreteTable, build_discrete_table
from ipywidgets import widgets
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import plotly.graph_objects as go

In [2]:
# rename and move to utils
def new_tbins(t_min, t_max):
    return np.arange(t_min, t_max+2)
#     return np.arange(t_min-0.5,t_max+1.5,1)

# Interactive Plotly Plots

In [3]:
import pandas as pd
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
from utils.data_processing import time_series_to_time_index
from pprint import pprint

In [4]:
pd.set_option('mode.chained_assignment', None)

In [5]:
mapbox_access_token = open(".mapbox_token").read()
# mapbox_access_token = "open-street-map"
mapbox_access_token

on_state = ('#f050ae', 8)
off_state = ('#ffab00', 8)

mapbox_styles = {
    "sat": "mapbox://styles/bernsblack/ckecz0wr52pfc1at7tvu43fmj", 
    "mono": "mapbox://styles/bernsblack/ckecyyizy065w19psrikmeo5d",
    "dark": "mapbox://styles/bernsblack/ckeikbchd254619s57np6iyum",
}

color_palette = ["#33a8c7","#52e3e1","#a0e426","#fdf148","#ffab00","#f77976","#f050ae","#d883ff","#9336fd"]

`Just for reference where Crimes_Chicago_2001_to_2019_datetime_fmt came from:`
```python
# was used to convert original csv to dataframe and string dates to datetime format
df = pd.read_csv("data/original/Crimes_Chicago_2001_to_2019.csv")
df['Date'] = pd.DatetimeIndex(pd.to_datetime(df['Date']))
df = df[[
    'ID',
    'Date',
    'Primary Type',
    'Arrest',
    'Latitude',
    'Longitude',
]].dropna()
df.to_pickle("data/raw/Crimes_Chicago_2001_to_2019_datetime_fmt.pkl") # Dates are type datetime not string
```

In [6]:
df = pd.read_pickle("data/raw/Crimes_Chicago_2001_to_2019_datetime_fmt.pkl")

In [7]:
# used to remove the anomalies
LAT_MIN, LAT_MAX = 41.6445, 42.0228

LON_MIN, LON_MAX = -87.9344, -87.5244

DATE_MIN, DATE_MAX = '2007-01-01', '2017-01-01'

crime_categories = [
        "THEFT",
        "BATTERY",
        "CRIMINAL DAMAGE",
        "NARCOTICS",
        "ASSAULT",
        "BURGLARY",
        "MOTOR VEHICLE THEFT",
        "ROBBERY",
]

lon_mask = (df.Longitude > LON_MIN) & (df.Longitude < LON_MAX)
lat_mask = (df.Latitude > LAT_MIN) & (df.Latitude < LAT_MAX)
date_mask = (df.Date >= DATE_MIN) & (df.Date < DATE_MAX)
category_mask = df['Primary Type'].isin(crime_categories)

In [8]:
df = df[lon_mask & lat_mask & date_mask & category_mask]

In [9]:
from pandas.api.types import CategoricalDtype
# Label encode the crime catgegories - makes histograms faster
CrimeType = CategoricalDtype(categories=crime_categories, ordered=True)
df['Primary Type'] = df['Primary Type'].astype(CrimeType)
df['c'] = df['Primary Type'].cat.codes

In [10]:
df['t'] = time_series_to_time_index(t_series=df.Date, t_step='1D', floor=True)
# df['t'] = time_series_to_time_index(t_series=df.Date, t_step='1H', floor=True)

In [11]:
# date selectors
date_range = pd.date_range(df.Date.min().ceil('1D'), df.Date.max().floor('1D'), freq='1D')

In [12]:
from utils.utils import ffloor, fceil

In [13]:
from geopy import distance

# get meter per degree estimat
coord_series = df[['Latitude', 'Longitude']]

lat_min, lon_min = coord_series.min()
lat_max, lon_max = coord_series.max()

lat_mean, lon_mean = coord_series.mean()

dy = distance.distance((lat_min, lon_min), (lat_max, lon_min)).m
dx = distance.distance((lat_min, lon_min), (lat_min, lon_max)).m

lat_per_metre = (lat_max - lat_min)/dy
lon_per_metre = (lon_max - lon_min)/dx

ratio_xy = dx/dy
print(f"ratio_xy: {ratio_xy}")
print(f"lat_per_metre: {lat_per_metre}")
print(f"lon_per_metre: {lon_per_metre}")

ratio_xy: 0.8127895963778615
lat_per_metre: 9.003326781922841e-06
lon_per_metre: 1.2003344797182708e-05


### Very important  - all the meta data is mapped from intervals of 0.001 in the lat and lon space with ratios of 8 and 11 to ensure that the grids cels are square

#### Confirm the make_grid technique and hist product the same output

In [14]:
# [20,16] - [10,8] - [5,4] - [3,2] - [1,1]
x_scale, y_scale = 5,4#3,2
xy_scale = np.array([x_scale, y_scale])  # must be integer so that we can easily sample demographic data
dlon, dlat = xy_scale * np.array([0.001, 0.001])
meta_info = {}
meta_info["x_scale"] = x_scale
meta_info["y_scale"] = y_scale
meta_info["dlon"] = float(dlon)
meta_info["dlat"] = float(dlat)
meta_info["x in metres"] = 85000 * float(dlon)
meta_info["y in metres"] = 110000 * float(dlat)

pprint(meta_info)

# use perfect squares perfect squares
"""
all the meta data is mapped from intervals of 0.001 in 
the lat and lon space with ratios of 8 and 11 to ensure that the grids cels are square
"""
# cell_size_m = 430
# dlat = cell_size_m*lat_per_metre
# dlon = cell_size_m*lon_per_metre
 
# use increments of 0.001

xbins = np.arange(ffloor(lon_min, dlon),fceil(lon_max,dlon), dlon)
nx = len(xbins)

ybins = np.arange(ffloor(lat_min, dlat),fceil(lat_max,dlat), dlat)
ny = len(ybins)

nt = int(np.ceil(df.t.max()))
tbins = new_tbins(df.t.min(),df.t.max())

nc = len(crime_categories)
cbins = np.arange(0,nc+1,1)

{'dlat': 0.004,
 'dlon': 0.005,
 'x in metres': 425.0,
 'x_scale': 5,
 'y in metres': 440.0,
 'y_scale': 4}


In [15]:
from pprint import pformat

class State:    
    """
    State object to make handlinge interactive plots easier.
    Callbacks can be added to be triggered when state fields are updated.
    """
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            self.__dict__[k] = v
        self.callbacks = [] # inner repr of the update callback
        self.__recursion_guard = 0
                    
    def __repr__(self):
        display = {**self.__dict__}
        display.pop("_State__recursion_guard",None)
        return pformat(display)
    
    def __getitem__(self, key):
        return self.__dict__[key] 
    
    def __setitem__(self, key, value):
        self.__dict__[key] = value
        self.__update()
        
    def on_update(self, callback, append=False):
        """
        set the callback to be called when a values is set using state['a'] = 1 format
        callback: function that takes State class as input
        """
        if append:
            self.callbacks.append(callback)
        else:
            self.callbacks = [callback]    
            
    def update(self, **kwargs):
        """
        set/update multiple values before calling the on_update callback
        """
        for k, v in kwargs.items():
            self.__dict__[k] = v
        self.__update()            
    
    def __update(self):
        self.__recursion_guard += 1
        if self.__recursion_guard < 2:
            for fn in self.callbacks:                    
                fn(self)
        self.__recursion_guard = 0
        
        
def default_state():
    return  State(
        dlon=dlon,
        dlat=dlat,
        lon_max = LON_MAX,
        lon_min = LON_MIN,
        lat_max = LAT_MAX,
        lat_min = LAT_MIN,         
        date_min = DATE_MIN,
        date_max = DATE_MAX,                
        time_index = 0,
        date_indices = (0,len(date_range)-1),
        crime_types = crime_categories,    
        mi_max_offset = 35,
        block_size_index=2,
    )
    
state = default_state()
state

{'block_size_index': 2,
 'callbacks': [],
 'crime_types': ['THEFT',
                 'BATTERY',
                 'CRIMINAL DAMAGE',
                 'NARCOTICS',
                 'ASSAULT',
                 'BURGLARY',
                 'MOTOR VEHICLE THEFT',
                 'ROBBERY'],
 'date_indices': (0, 3652),
 'date_max': '2017-01-01',
 'date_min': '2007-01-01',
 'dlat': 0.004,
 'dlon': 0.005,
 'lat_max': 42.0228,
 'lat_min': 41.6445,
 'lon_max': -87.5244,
 'lon_min': -87.9344,
 'mi_max_offset': 35,
 'time_index': 0}

`Reference for code to compare grids with previous generation techniques:`
```python
file_location = "data/processed/T24H-X425M-Y440M_2013-01-01_2017-01-01/"

generated_data = np.load(file_location + "generated_data.npz")
x_range = generated_data['x_range']
y_range = generated_data['y_range']
t_range = pd.read_pickle(file_location + "t_range.pkl")

crime_types_grids = generated_data['crime_types_grids'] 

xbins  = x_range - dlon/2
xbins = np.array([*xbins,xbins[-1]+dlon])

ybins  = y_range - dlat/2
ybins = np.array([*ybins,ybins[-1]+dlat])

import json

with open(file_location + "info.json", "r") as f:
    meta_info = json.load(f)
         
meta_info
```

### Widget Plot

In [16]:
# reset state
state.crime_types = [*crime_categories]

state.date_indices = (0,len(date_range)-1)
state.date_min, state.date_max = date_range[[*state.date_indices]]

In [17]:
from sparse_discrete_table import quick_mutual_info, quick_cond_mutual_info

def mutual_info_over_time(a, max_offset=35,norm=True, lognorm=True, include_self=False):
    """
    a: np.ndarray (N,1) with the counts over time
    max_offset: furthest we compare to the signal in time
    norm: if the mutual information should be normalised between 0 and 1
    lognorm: if the array 'a' should be normalised: round(log2(1 + a))
    """
    if lognorm:
        a = np.round(np.log2(1 + a))
    mis = []
    if include_self:
        mi = quick_mutual_info(a,a,norm)
        mis.append(mi)        
    for t in range(1,max_offset+1):
        mi = quick_mutual_info(a[t:],a[:-t],norm)
        mis.append(mi)
    
    offsets = np.arange(1,len(mis)+1)
    return mis, offsets

def conditional_mutual_info_over_time(a, max_offset=35,norm=False,
                                      lognorm=True, include_self=False, cycle=7):
    """
    a: np.ndarray (N,1) with the counts over time
    max_offset: furthest we compare to the signal in time
    norm: if the mutual information should be normalised between 0 and 1
    lognorm: if the array 'a' should be normalised: round(log2(1 + a))
    cycle: cycle of the data we condition on
            if we believe there is a strong weekly trend -> 7 or 24 for daily trends
    """    
    if lognorm:
        a = np.round(np.log2(1 + a)).reshape(-1,1)
        
    dow = np.arange(len(a)) % cycle
    cmis = []
    if include_self:
        cond = np.stack([dow,dow],axis=1)
        cmi = quick_cond_mutual_info(a,a,cond,norm)
        cmis.append(cmi)        
    for t in range(1,max_offset+1):
        cond = np.stack([dow[t:],dow[:-t]],axis=1)
        cmi = quick_cond_mutual_info(a[t:],a[:-t],cond,norm)
        cmis.append(cmi)
        
    offsets = np.arange(1,len(cmis)+1)
    return cmis, offsets

In [18]:
def filter_df(df, state):  
    """
    filter df given the values in state to get a subset of data
    """
    # spatial filters
    lon_mask = (df.Longitude >= state.lon_min) & (df.Longitude <= state.lon_max)    
    lat_mask = (df.Latitude >= state.lat_min) & (df.Latitude <= state.lat_max)

    # time filters
    date_mask = (df.Date >= state.date_min) & (df.Date <= state.date_max) 

    # crime type filters
    crime_types_mask = df['Primary Type'].isin(state.crime_types)

    filter_mask = date_mask & crime_types_mask & lon_mask & lat_mask

    return df[filter_mask]

def get_total_counts(sub):
    """
    sub: dataframe of crime incidents
    returns total_counts_y, total_counts_x
    """
    if len(sub) > 0:
        nt = int(np.ceil(sub.t.max()))
        tbins = new_tbins(sub.t.min(),sub.t.max())

        #  total counts line/curve TODO abstract into function
        total_counts, edges = np.histogram(sub.t,bins=tbins)
        total_counts_y = total_counts
        total_counts_x = date_range[edges[:-2]]
    else:
        total_counts_y = np.zeros(len(date_range))
        i, j = state.date_indices
        total_counts_x = date_range[i,j]
    return total_counts_y, total_counts_x


def new_bins(data_frame, state):
    xbins = np.arange(
        start=ffloor(state.lon_min, state.dlon),
        stop=fceil(state.lon_max,state.dlon), 
        step=state.dlon,
    )
    nx = len(xbins)

    ybins = np.arange(
        start=ffloor(state.lat_min, state.dlat),
        stop=fceil(state.lat_max,state.dlat),
        step=state.dlat,
    )
    ny = len(ybins)

    nt = int(np.ceil(data_frame.t.max()))
    tbins = new_tbins(data_frame.t.min(),data_frame.t.max())

    nc = len(state.crime_types)
    cbins = np.arange(0,nc+1,1)

    return tbins, cbins, ybins, xbins

def bin_data_frame(data_frame, state):
    """
    will bin the data from into a N,C,H,W grid depending on the state
    """
    bins = new_bins(data_frame, state)
        
    binned_data, bins = np.histogramdd(
        sample=data_frame[['t', 'c', 'Latitude', 'Longitude']].values,
        bins=bins,
    )
    
    return binned_data, bins

def get_mean_map(data_frame, state):
    """
    will be a heat map of the means with the lat and lon min and max remaining constant
    """
    xbins = np.arange(
        start=ffloor(LON_MIN, state.dlon),
        stop=fceil(LON_MAX,state.dlon), 
        step=state.dlon,
    )
    nx = len(xbins)

    ybins = np.arange(
        start=ffloor(LAT_MIN, state.dlat),
        stop=fceil(LAT_MAX,state.dlat),
        step=state.dlat,
    )
    ny = len(ybins)

    nt = int(np.ceil(data_frame.t.max()))
    tbins = new_tbins(data_frame.t.min(),data_frame.t.max())

    nc = len(state.crime_types)
    cbins = np.arange(0,nc+1,1)

    bins = tbins, cbins, ybins, xbins
    
    binned_data, bins = np.histogramdd(
        sample=data_frame[['t', 'c', 'Latitude', 'Longitude']].values,
        bins=bins,
    )
    
    mean_map = binned_data.sum(1).mean(0)
    return mean_map, xbins, ybins

In [19]:
font_dict = dict(
    family="Time New Roman",     
    size=14,
#     color="RebeccaPurple",
)

# quick plotly plot builders
def new_scatter(title, ylabel, xlabel, font_dict=None):
    return go.FigureWidget(
        data=go.Scatter(),
        layout=dict(
            title_text=title,
            title_x=0.5,        
            xaxis_title=xlabel,
            yaxis_title=ylabel,
            legend_title="Legend Title",
            font=font_dict,
        ),
    )


def new_heatmap(z):
    return go.Figure(
        data=go.Heatmap(z=z),
        layout=dict(
            height=600,
#             width=ratio_xy*fig_height,
            yaxis=dict(scaleanchor="x", scaleratio=1/ratio_xy),
        ),    
    )

# possible plotly fonts
# fonts = [
#     "Arial", "Balto", "Courier New",
#     "Droid Sans", "Droid Serif", "Droid Sans Mono",
#     "Gravitas One", "Old Standard TT", "Open Sans",
#     "Overpass", "PT Sans Narrow",
#     "Raleway", "Times New Roman",
# ]

In [20]:
fig_total_counts = new_scatter(title="Total Crimes",ylabel="Total Counts",xlabel="Date Time")

In [21]:
# MI Plots on a single curve
fig_mi_cmi = go.FigureWidget(
        data=[
            go.Scatter(
                name="$I(C_t;C_{t-k})$",#"Mutual Information",
            ),
            go.Scatter(
                name="$I(C_t;C_{t-k}|DoW_t;DoW_{t-k})$",#"Condtional Mutual Information",
            ),
        ],
        layout=dict(
            title_text="Mutual and Condtional Mutual Information",
            title_x=0.5,        
            xaxis_title="Offest in days (k)",
            yaxis_title="Normalised Score [0,1]",
            legend_title="Curves",
#             font=font_dict,
        ),
    )

line_mi, line_cmi = fig_mi_cmi.data

In [22]:
binned_data, bins = bin_data_frame(df,state)
counts_mean = binned_data.sum(1).mean(0)

In [23]:
scattergl = go.Scattergl(
#     x=sub.Longitude,
#     y=sub.Latitude,
#     opacity=.6,
    hoverinfo='skip',
    mode='markers',
    name='scattergl',
#     marker_symbol="x-thin", # "cross",
    marker_color='red',
#     hovertext="",
)

heatmapgl = go.Heatmap(
    z=counts_mean,
    x=xbins,
    y=ybins, 
    hoverinfo='skip',
    colorscale='viridis',
    opacity=1,
    name='heatmapgl',
)


fig_height = 600
fig_app = go.FigureWidget(
    data=[heatmapgl, scattergl],
    layout=dict(
        margin={"r":50,"t":30,"l":0,"b":0},
        height=fig_height,
#         width=fig_height*ratio_xy,
        yaxis=dict(scaleanchor="x", scaleratio=1/ratio_xy),
        clickmode='event+select',
    ),
)

scattergl = fig_app.data[1]
heatmapgl = fig_app.data[0]


#### Update figures based on state

In [24]:
block_size_changed = True

In [25]:
# global updater
def draw(state):  
    global block_size_changed
    sub = filter_df(df, state)
    total_counts_y, total_counts_x = get_total_counts(sub)
    
    mi_y, mi_x = mutual_info_over_time(
        a=total_counts_y, max_offset=state.mi_max_offset, norm=True, include_self=False)
    cmi_y, cmi_x = conditional_mutual_info_over_time(
        a=total_counts_y, max_offset=state.mi_max_offset, norm=True, include_self=False)
    
    slide = sub[sub.t == state.time_index]
    
    with fig_app.batch_update():
        if block_size_changed:
            mean_map, xbins, ybins = get_mean_map(df, state)
            block_size_changed = False
            
            heatmapgl.x = xbins
            heatmapgl.y = ybins
            heatmapgl.z = mean_map
        
        
        scattergl.x=slide.Longitude
        scattergl.y=slide.Latitude
        
        fig_total_counts.data[0].x = total_counts_x
        fig_total_counts.data[0].y = total_counts_y
        
        line_mi.x = mi_x
        line_mi.y = mi_y
        
        line_cmi.x = cmi_x
        line_cmi.y = cmi_y

In [26]:
def scattergl_on_select(trace,points,selector):
    global lat_filter, lon_filter
    if trace.name == 'scattergl':                
        state.lon_min, state.lon_max = selector.xrange
        state.lat_min, state.lat_max = selector.yrange
        
        draw(state)
            
scattergl.on_selection(scattergl_on_select)

In [27]:
# widget setup
from ipywidgets import Layout, widgets

# helper functions
def get_widget_index(change):
    if isinstance(change, dict) and change.get('name') == 'index':
        return change.get('new')
    return None

def get_widget_value(change):
    if isinstance(change, dict) and change.get('name') == 'value':
        return change.get('new')
    return None

#### Widgets and how they should update state

In [28]:
# time index date display label
current_date_label = widgets.Label(f'Date: {date_range[state.time_index].strftime("%c")}')

# time index selector
time_index_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(date_range)-1,
    step=1,
    description='Time Index:',
    continuous_update=False,
    layout=Layout(width='80%'),
)

def on_change_time_index(change):
    global current_date_label, state
    time_index = get_widget_value(change)
    if time_index:
        current_date_label.value = f'Date: {date_range[time_index].strftime("%c")}'
        state.time_index = time_index
        draw(state)
        
time_index_slider.observe(on_change_time_index)

play_button = widgets.Play(
    value=0,
    min=time_index_slider.min,
    max=time_index_slider.max,
    step=1,
    interval=800,
    description="Press play",
    disabled=False
)
widgets.jslink((play_button, 'value'), (time_index_slider, 'value'))


# Date Slider
date_range_slider = widgets.SelectionRangeSlider(
    options=[d.strftime('%y/%m/%d') for d in date_range],
    index=(0, len(date_range)-1),
    description='Date Range:',
    disabled=False,
    orientation = 'horizontal',
    layout=Layout(width='95%'),
    continuous_update=False,
)

def on_change_date_range_slider(change):
    global state, time_index_slider
    
    index = get_widget_index(change)
    if index:        
        i, j = index
        
        time_index_slider.value = i 
        time_index_slider.min = i
        time_index_slider.max = j
                
        state.update(
            date_range_indices=(i,j),  
            date_min=date_range[i],
            date_max=date_range[j],        
        )
        draw(state)

date_range_slider.observe(on_change_date_range_slider)


# crime type selector
def on_change_crime_types(button):
    global state
    if button.button_style == "":
        button.button_style = "success"
        state.crime_types.append(button.description)
    else:
        button.button_style = ""
        state.crime_types.remove(button.description)
    draw(state)
    
def new_buttons(names,all_selected=False):
    global state
    buttons = []
    
    if all_selected:
        state.crime_types = [*names]
        button_style = "success"
    else:
        state.crime_types = []
        button_style = ""
        
    for name in names:
        button = widgets.Button(
            description=name,
            tooltip=name,
            disabled=False,
            button_style=button_style, # 'success', 'info', 'warning', 'danger' or ''
        )
        
        button.on_click(on_change_crime_types)
        
        buttons.append(button)
    return buttons
    
    
time_selectors = widgets.VBox([
    widgets.HBox([date_range_slider]),
    widgets.HBox([time_index_slider])    
])    

buttons = new_buttons(names=crime_categories, all_selected=True)
crime_selectors = widgets.VBox([
    widgets.HBox(buttons[:4]),
    widgets.HBox(buttons[4:]),
])

# Spatial block size selection widget
dlon_dlat_opts = list(map(tuple,(np.array([[20,16],[10,8], [5,4], [3,2], [1,1]]) * np.array([0.001, 0.001]))))
dx_dly_opts = list(map(tuple,(np.array([[20,16],[10,8], [5,4], [3,2], [1,1]]) * np.array([85, 110]))))

dropdown_block_size = widgets.Dropdown(
    options=[(v, i) for i,v in enumerate(dx_dly_opts)],
    value=state.block_size_index,
    description='dx, dy:',
    disabled=False,
)

def on_change_dropdown_block_size(change):
    global state, block_size_changed
    value = get_widget_value(change)
    if value is not None:  
        state.block_size_index = value
        state.dlon, state.dlat = dlon_dlat_opts[value]
        block_size_changed = True
        draw(state)
        
dropdown_block_size.observe(on_change_dropdown_block_size)


# widgets combinations and setup
row0 = widgets.HBox([play_button, crime_selectors, dropdown_block_size])
row1 = time_selectors

controller = widgets.VBox([row0, row1, current_date_label])

In [32]:
state.mi_max_offset = 80
draw(state)
widgets.VBox([fig_mi_cmi, fig_total_counts, controller, fig_app])

VBox(children=(FigureWidget({
    'data': [{'name': '$I(C_t;C_{t-k})$',
              'type': 'scatter',
     …

### Statistics on selection

In [None]:
# distribution of selection counts


# distribution of crime types


# distribution of hour of day by type


# distribution of selection counts

In [None]:
binned_data, bins = bin_data_frame(df,state)
mean_map = binned_data.sum(1).mean(0)
new_heatmap(z=mean_map)

### Geo-plots

In [None]:
from utils.plots import hist2d_to_geo

# get block infos for whole df
counts, edges = np.histogramdd(
    sample=df[['t', 'c', 'Latitude', 'Longitude']].values,
    bins=(tbins, cbins, ybins, xbins),
)

counts_sum = counts.sum(0).sum(0)
counts_mean = counts.mean(0).mean(0)

blocks_geo, blocks_info = hist2d_to_geo(counts_mean,xbins,ybins,filter_zero=True)

sub = df.sample(5_000)

In [None]:
heatmap_geo = go.Choroplethmapbox(
    geojson=blocks_geo,
    locations=blocks_info.id,
    z=blocks_info.value,
    colorscale="Viridis",
    zmin=blocks_info.value.min(),
    zmax=blocks_info.value.max(),
    marker_opacity=0.5,
    marker_line_width=0,
    name='heatmap',
    hoverinfo='text',
    hovertext=blocks_info[['y','x','value']],
)

# scatter_geo = go.Scattermapbox(
#     lon=sub.Longitude,
#     lat=sub.Latitude,
#     opacity=.4,
#     hoverinfo='skip',
#     name='scatter',
# )

scatter_geo = []
for cat in crime_categories:
    tmp = sub[sub["Primary Type"] == cat]
    scatter = go.Scattermapbox(
        lon=tmp.Longitude,
        lat=tmp.Latitude,
        opacity=.4,
        hoverinfo='skip',
        name=cat,
    )
    scatter_geo.append(scatter)
    
    

fig_geo = go.FigureWidget(data=[heatmap_geo, *scatter_geo])

fig_geo.update_layout(
    clickmode='event+select',
    mapbox_style=mapbox_styles["sat"],
    mapbox_accesstoken=mapbox_access_token,
    mapbox_zoom=10,
    mapbox_center = {"lat": np.mean(ybins), "lon": np.mean(xbins)},
    margin={"r":0,"t":30,"l":0,"b":10},
    height=500,
    showlegend=True,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
    ),
)

def heatmap_geo_on_select(trace,points,selector):
    global state
    
    inds = points.point_inds
    selection = blocks_info.iloc[inds]
    _,_,f_x_min,f_y_min = selection.min()
    _,_,f_x_max,f_y_max = selection.max()

    state.lon_min, state.lon_max = xbins[[f_x_min,f_x_max]]
    state.lat_min, state.lat_max = ybins[[f_y_min,f_y_max]]
    
    
heatmap_geo = fig_geo.data[0]    
heatmap_geo.on_selection(heatmap_geo_on_select)

fig_geo

---
### Animated Crime Over Time

In [None]:
def new_ffloor(delta):
    def res(value): 
        return ffloor(value, delta)
    return res



sub = df[['t', 'Longitude', 'Latitude','c', 'Primary Type', 'Date']].sample(10_000).copy()
sub['date_str'] = sub.Date.dt.floor('1D').dt.strftime('%c')
sub['count'] = np.ones_like(sub.t)
ff = new_ffloor(0.01)
sub[['Longitude', 'Latitude']] = sub[['Longitude', 'Latitude']].apply(ff)

gb = sub.groupby(by=['Longitude', 'Latitude','t','c'])
points = gb.first()
points['count'] = gb.count()['count']
# points['count'] = points['count'].apply(lambda x: np.log2(1+x))
points['count'] = points['count']/points['count'].min()
points = points.reset_index()
points = points.sort_values(['t','c'])

points = points[points.t < 4*5]
points

In [None]:
import plotly.express as px

points['size'] = 10*points['count']

fig = px.scatter(
    data_frame=points, 
    x="Longitude", 
    y="Latitude", 
    title="Animated Crime Over Time",
#     facet_col="Primary Type",
#     facet_col_wrap=3,
    width=800, 
    height=800,
    opacity=.5,
    size="size",
    size_max=points['size'].max(),
    color='Primary Type',
#     marginal_x='violin',
    animation_frame="t",
#     animation_group="t",
#     color="count",
)

fig.update_layout(
    margin=dict(l=20, r=50, t=50, b=20),
    paper_bgcolor="LightSteelBlue",
#     transition ={'duration': 1},
#     frame={'duration': 100},
)

fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 500
fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0

fig

---
## 2D Histogram

In [None]:
# display distribution of a sample of the data
fig_height=1000
px.density_heatmap(
    data_frame=df.sample(10000),
    x="Longitude",
    y="Latitude", 
    marginal_x="histogram",
    marginal_y="histogram",
    height=fig_height,
    width=ratio_xy*fig_height,
    nbinsx=nx,
    nbinsy=ny,
)

---