In [None]:
import numpy as np

from autumn.core.inputs.social_mixing.queries import get_prem_mixing_matrices
from autumn.core.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.core.inputs.social_mixing.constants import LOCATIONS

import matplotlib.pyplot as plt
import matplotlib.cm as cm

In [None]:
model_iso3 = "LKA"
source_iso3 = "HKG"
age_adjust = True

AGEGROUP_STRATA = [str(5 * i) for i in range(16)]


prem_matrices = get_prem_mixing_matrices(model_iso3)
new_matrices = build_synthetic_matrices(model_iso3, source_iso3, AGEGROUP_STRATA, age_adjust)

# normalise new_matrices
lambdas_prem, _ = np.linalg.eig(prem_matrices["all_locations"])
sp_radius_prem = max([abs(l) for l in lambdas_prem])

lambdas_new, _ = np.linalg.eig(new_matrices["all_locations"])
sp_radius_new = max([abs(l) for l in lambdas_new])

for location in LOCATIONS:
    prem_matrices[location] = prem_matrices[location] * sp_radius_new / sp_radius_prem
    
matrices = {
    "prem": prem_matrices,
    "new": new_matrices
}


In [None]:
def compare_by_location(location, add_values=False):
    fig = plt.figure()
    fig.set_size_inches(12, 8)
    
    ticks = [2*i for i in range(8)]
    labels = [str(int(2*i*5)) for i in range(8)] 
    max_value = max(matrices["prem"][location].max(), matrices["new"][location].max())
    
    i = 0
    for key, mats in matrices.items():
        i += 1
        ax = fig.add_subplot(1, 2, i)
        
        #ax1 = fig.add_subplot(121)
        m=ax.imshow(np.transpose(mats[location]), cmap=cm.hot, vmin = 0, vmax=max_value, origin='lower')
        
        
        if add_values:
            n = mats[location].shape[0]
            for _i in range(n):
                for _j in range(n):
                    plt.text(_i - .5, _j, round(mats[location][_i, _j],1), color="green", fontsize=12)

        #ax.set_title(key)
        plt.xlabel("age of individual",fontsize=20)
        plt.ylabel("age of contacts",fontsize=20)
        plt.xticks(ticks, labels, fontsize=18)
        plt.yticks(ticks, labels, fontsize=18)
        cb =plt.colorbar(m,shrink=0.55)
        cb.ax.set_yticklabels(ticks, fontsize=16)

    plt.show()

    
    
def compare_location_contributions():
    i_ages_illustrate = [2, 6, 14]
    
    colors=["coral","cornflowerblue","palegreen","bisque"]
    
    locations_notall = [l for l in LOCATIONS if l != "all_locations"]
    
    for i_age in i_ages_illustrate:
        print(f"Age group starting {i_age * 5} years old")
        fig = plt.figure()
        fig.set_size_inches(12, 12)
        i = 0
        for key, mat in matrices.items():
            i += 1
            ax = fig.add_subplot(1, 2, i) 
            n_contacts = []
            for loc in locations_notall:
                n_contacts.append(sum(mat[loc][i_age,]))            
                
            ax.pie(n_contacts, labels=locations_notall, autopct='%1.1f%%', shadow=True, startangle=90, colors=colors)
            ax.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.

        plt.show()
    

In [None]:
for location in LOCATIONS:
    print(location)
    compare_by_location(location, False)

In [None]:
for location in LOCATIONS:
    print(location)
    compare_by_location(location, True)

In [None]:
compare_location_contributions()

In [None]:
def plot_total_contacts():
    fig = plt.figure(figsize=(9, 5))
    ax = fig.add_axes([0,0,1,1])
    barWidth = .25
    
    prem_s, new_s = [], []
    for i_age_group in range(16):
        prem_s.append(sum(prem_matrices["all_locations"][i_age_group,]))
        new_s.append(sum(new_matrices["all_locations"][i_age_group,]))
        
        
    # Set position of bar on X axis
    br1 = np.arange(len(prem_s))
    br2 = [x + barWidth for x in br1]

    plt.bar(br1, prem_s, color ='coral', width = barWidth,
        edgecolor ='grey', label ='prem')
    plt.bar(br2, new_s, color ='cornflowerblue', width = barWidth,
            edgecolor ='grey', label ='new')
    
    # Adding Xticks
    plt.xlabel('Index age', fontsize = 12)
    plt.ylabel('Total contacts', fontsize = 12)
    plt.xticks([r + barWidth for r in range(len(prem_s))],
            [str(int(5*i)) for i in range(16)])
    
    plt.legend()
    plt.show()

plot_total_contacts()

In [None]:
print("Prem")
prem_im_young = prem_matrices["other_locations"][2,15]
prem_im_old = prem_matrices["other_locations"][15, 2]
print(f"I am young, N contacts with old people: {prem_im_young}")
print(f"I am old, N contacts with young people: {prem_im_old}")

print()
print("New")
new_im_young = new_matrices["other_locations"][2,15]
new_im_old = new_matrices["other_locations"][15, 2]
print(f"I am young, N contacts with old people: {new_im_young}")
print(f"I am old, N contacts with young people: {new_im_old}")