# Recreating figure 3 (rectangle game heatmap) from Shafto 2014

- $6 \times 6$ grid
- Hypothesis space is all $2 \times 2$ to $5 \times 5$ rectangles (there should be 196 rectangles)
- Data points include all pairs of pixels, with pixels at each location being either a positive or negative example, like the line game. There should be 2520 unique labeled pairs.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

## Generate hypothesis space

In [None]:
empty_grid = np.zeros((6, 6))

In [None]:
# Find all possible rectangle sizes and create rectangle for each size
rectangles = {}
for i in range(2, 6):
    for j in range(2, 6): 
        rectangles[(i, j)] = np.ones((i, j))

In [None]:
rectangles.keys()

In [None]:
# Generate hypothesis space
h = []

for k, v in rectangles.items():
    for i in range(empty_grid.shape[0]-k[0]+1): 
        for j in range(empty_grid.shape[1]-k[1]+1): 
            grid = np.zeros((6, 6))
            grid[i:v.shape[0]+i, j:v.shape[1]+j] = v
            h.append(grid)

In [None]:
len(h)

In [None]:
# Check a few items to see if h looks the way we want it to 
h[:3]

## Generate data $(k=2)$

In [None]:
indices = []
for i in range(36): 
    for j in range(i+1, 36):
        indices.append((i, j))

In [None]:
len(indices)

In [None]:
empty_d = np.zeros(36)

In [None]:
def generate_data():
    d = []
    
    combs = [
        (1, 0), 
        (0, 1), 
        (0, 0), 
        (1, 1)
    ]

    for i in indices: 
        for j in combs: 
            d_i = np.zeros(36)
            d_i[:] = np.nan
            d_i[i[0]], d_i[i[1]] = j[0], j[1]
            d.append(d_i)
    
    return d

In [None]:
d = generate_data()
d[:3]

In [None]:
len(d)

### Find iteration 0 probabilities

In [None]:
d_mask = np.ma.masked_invalid(d)

In [None]:
for i in range(len(h)): 
    h[i] = h[i].flatten()

In [None]:
df_0 = pd.DataFrame(np.zeros((len(d), len(h))))

for row_h in range(len(h)): 
    for row_d in range(len(d)): 
        if np.array_equal(d[row_d][~d_mask.mask[row_d]], h[row_h][~d_mask.mask[row_d]]):
            df_0.iloc[row_d, row_h] = 1  
            
# Turn values into probabilities; each column sums up to 1
df_0 = df_0.div(df_0.sum(axis=0), axis=1)

In [None]:
df_0

### Make heatmaps

In [None]:
def select_pos_ex_rows(d): 
    """Keep values of rows of d with two positive examples"""
    d_copy = [i.copy() for i in d]  # probably shouldn't have used a list here, oops
    d_copy = np.stack(d_copy)
    
    pos_ex_rows = np.nansum(d_copy, axis=1) == 2
    d_copy[~pos_ex_rows] = np.nan
    
    return d_copy

In [None]:
def make_heatmap_pos_ex(df, d, h_idx):
    d = select_pos_ex_rows(d)
    heatmap = np.matmul(np.nan_to_num(d.T), df[h_idx])
    heatmap = heatmap / np.sum(heatmap)
    heatmap = heatmap.reshape(6, 6)
    
    return heatmap

In [None]:
def select_neg_ex_rows(d):
    """Keep values of rows of d with two negative examples and turn zeros into 1s"""
    d_copy = [i.copy() for i in d]
    d_copy = np.stack(d_copy)
    
    neg_ex_rows = np.nansum(d_copy, axis=1) == 0
    d_copy[~neg_ex_rows] = np.nan
    d_copy[d_copy == 0] = 1
        
    return d_copy

In [None]:
def make_heatmap_neg_ex(df, d, h_idx): 
    d = select_neg_ex_rows(d)
    
    heatmap = np.matmul(np.nan_to_num(d.T), df[h_idx])
    heatmap = heatmap / np.sum(heatmap)
    heatmap = heatmap.reshape(6, 6)
    return heatmap

In [None]:
make_heatmap_pos_ex(df_0, d, 35)

In [None]:
make_heatmap_neg_ex(df_0, d, 35)

### Iterate over model

In [None]:
def iterate_over_model(n, df_0):
    '''
    given number of iterations n and P(d|h) matrix for iteration 0, find P(d|h) matrix after iteration n 
    '''
    n_iter = n
    df = df_0

    for n in range(n_iter): 
        df = df.div(df.sum(axis=1), axis=0)  # P(h|d)
        df = df.div(df.sum(axis=0), axis=1)  # P(d|h)
    
    return df.fillna(0)

In [None]:
df_100 = iterate_over_model(100, df_0)

In [None]:
df_100

### A few positive example heatmaps

In [None]:
sns.heatmap(make_heatmap_pos_ex(df_100, d, 95))

In [None]:
sns.heatmap(make_heatmap_pos_ex(df_100, d, 192))

## Make part a of the figure

Indices for the center 3x3 matrices are 95, 96, 99, 100

In [None]:
# Take mean of the four nonzero parts of the matrices

indices = [95, 96, 99, 100]
center_matrices = []

for i in indices: 
    heatmap = make_heatmap_pos_ex(df_100, d, i)
    center_matrices.append(heatmap[np.nonzero(heatmap)].reshape(3, 3))

center_matrices = np.stack(center_matrices)
center_matrices_heatmap = np.mean(center_matrices, axis=0)

In [None]:
center_matrices[0]

In [None]:
center_matrices_heatmap

In [None]:
# Find corners and non corners 
corner_idx = np.array([
    [1, 0, 1],
    [0, 0, 0], 
    [1, 0, 1]
])

corners = center_matrices_heatmap[corner_idx == 1]
non_corners = center_matrices_heatmap[corner_idx != 1]

In [None]:
# Find probabilities per unit area 
corner_breakdown = [np.sum(corners) / len(corners), np.sum(non_corners) / len(non_corners)]
corner_breakdown = corner_breakdown / np.sum(corner_breakdown)

corner_breakdown = {
    'corners': corner_breakdown[0],
    'non_corners': corner_breakdown[1]
}

In [None]:
corner_breakdown

In [None]:
corner_noncorner_heatmap = np.empty((3, 3))
corner_noncorner_heatmap[corner_idx == 1] = corner_breakdown['corners']
corner_noncorner_heatmap[corner_idx == 0] = corner_breakdown['non_corners']

In [None]:
plt.figure(figsize=(4, 3))
_ = sns.heatmap(corner_noncorner_heatmap, vmin=0, vmax=1, cmap='BuPu', xticklabels=False, yticklabels=False)

In [None]:
plt.bar(corner_breakdown.keys(), corner_breakdown.values())
plt.ylim((0, 1))
plt.ylabel('relative probability per unit area')
plt.show()

## Negative examples heatmap

In [None]:
# indices = [95, 96, 99, 100]
# center_matrices = []

# for i in indices: 
#     heatmap = make_heatmap_neg_ex(df_100, d, i)
#     center_matrices.append(heatmap[np.nonzero(heatmap)].reshape(3, 3))

# center_matrices = np.stack(center_matrices)
# center_matrices_heatmap = np.mean(center_matrices, axis=0)

In [None]:
make_heatmap_neg_ex(df_100, d, 95)

In [None]:
sns.heatmap(make_heatmap_neg_ex(df_100, d, 95))

In [None]:
sns.heatmap(make_heatmap_neg_ex(df_100, d, 13))

## Part b of the figure