# Hinton Diagram to Show Inner vs Inter Tranmissions Across the Cities Over Time

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})
import csv
import os

list1 = []
list2 = []
for i in range(20):
    for j in range(20):
        list2.append(i+j)
    list1.append(list2)
    list2 = []
    
arr = np.array(list1)

# source: https://matplotlib.org/devdocs/gallery/specialty_plots/hinton_demo.html
def hinton(matrix, null_positions, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log2(np.abs(matrix).max()))

    max_weight += 0.5 # to get better squares
    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(matrix):
        

        color = 'white' if w > 0 else 'black'
        if w == 0.5 and (x,y) in null_positions:
            color = 'blue'
        size = np.sqrt(abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()


if __name__ == '__main__':
    cwd = os.getcwd()
    path = cwd.split('code')[0]
    file_path = path + "data/analysis/a2_inner_vs_inter_transmission_over_total_transmission_in_cities_15_year_buckets_500_years_formatted.csv"
    
    vals = []
    with open(file_path) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        for row in csv_reader:
            vals.append(row)
    
    null_positions = [] # for years with no transmission
    for i in range(len(vals)):
        for j in range(len(vals[i])):
            if vals[i][j] != '':
                vals[i][j] = float(vals[i][j])
            else:
                null_positions.append((i,j))
                vals[i][j] = 0.5
            
    arr = np.array(vals)
    
    hinton(arr, null_positions)

    t11 = ['-11','','15','','45','','75','','105','','135','','165','','195','','225','','255','','285','','315','','345','','375','','405','','435','','465','','495']
    t12 = ['bsrh', 'bghdad', 'kwfh', 'almdynh', 'msr', 'dmshq', 'nysabwr', 'mkh', 'asbhan', 'wast', 'mrw', 'hms', 'sn\'ea\'', 'alry', 'qrtbh', 'hran', 'bghlan', 'hrah', 'alrqh', 'jrjan', 'almwsl', 'bkhara', 'hlb', 'tws']

    plt.xticks(range(len(t11)), t11, size='small')
    plt.yticks(range(len(t12)), t12, size='small')

    figure_path = path + "data/analysis/" + "a2_inner_vs_inter_transmission_over_total_transmission_in_cities_15_year_buckets_500_years_figure.pgf"
    plt.savefig(figure_path)

