In [None]:
import numpy as np

from autumn.tools.inputs.social_mixing.queries import get_prem_mixing_matrices
from autumn.tools.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.tools.inputs.social_mixing.constants import LOCATIONS

import matplotlib.pyplot as plt
import matplotlib.cm as cm

In [None]:
model_iso3 = "MYS"
source_iso_3 = "VNM"

AGEGROUP_STRATA = [str(5 * i) for i in range(16)]


prem_matrices = get_prem_mixing_matrices(model_iso3)
new_matrices = build_synthetic_matrices(model_iso3, source_iso_3, AGEGROUP_STRATA)

# normalise new_matrices
lambdas_prem, _ = np.linalg.eig(prem_matrices["all_locations"])
sp_radius_prem = max([abs(l) for l in lambdas_prem])

lambdas_new, _ = np.linalg.eig(new_matrices["all_locations"])
sp_radius_new = max([abs(l) for l in lambdas_new])

for location in LOCATIONS:
    prem_matrices[location] = prem_matrices[location] * sp_radius_new / sp_radius_prem
    
matrices = {
    "prem": prem_matrices,
    "new": new_matrices
}


In [None]:
def compare_by_location(location):
    fig = plt.figure()
    fig.set_size_inches(12, 8)
    
    ticks = [2*i - .5 for i in range(8)]
    labels = [str(int(2*i*5)) for i in range(8)] 
    max_value = max(matrices["prem"][location].max(), matrices["new"][location].max())
    
    i = 0
    for key, mats in matrices.items():
        i += 1
        ax = fig.add_subplot(1, 2, i)
        
        #ax1 = fig.add_subplot(121)
        ax.imshow(np.transpose(mats[location]), cmap=cm.hot, vmin = 0, vmax=max_value, origin='lower')
        
        ax.set_title(key)
        plt.xlabel("index")
        plt.ylabel("n_contacts")
        plt.xticks(ticks, labels)
        plt.yticks(ticks, labels)


    plt.show()

    
    
def compare_location_contributions():
    i_ages_illustrate = [2, 6, 14]
    
    locations_notall = [l for l in LOCATIONS if l != "all_locations"]
    
    for i_age in i_ages_illustrate:
        print(f"Age group age_{i_age * 5}")
        fig = plt.figure()
        fig.set_size_inches(12, 8)
        i = 0
        for key, mat in matrices.items():
            i += 1
            ax = fig.add_subplot(1, 2, i) 
            n_contacts = []
            for loc in locations_notall:
                n_contacts.append(sum(mat[loc][i_age,]))            
                
            ax.pie(n_contacts, labels=locations_notall, autopct='%1.1f%%', shadow=True, startangle=90)
            ax.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.

        plt.show()
    

In [None]:
for location in LOCATIONS:
    print(location)
    compare_by_location(location)

In [None]:
compare_location_contributions()