# Plot per region 'slop'

In [None]:
import math
import os
import pickle
import re
import logging
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import torch
import pyro.distributions as dist
from pyrocov import mutrans, pangolin, stats
from pyrocov.stats import normal_log10bf
from pyrocov.util import pretty_print, pearson_correlation
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import seaborn as sns
from bokeh.io import output_notebook, show, output_file
from bokeh.plotting import figure
from bokeh.models import GeoJSONDataSource, LinearColorMapper, ColorBar
from bokeh.palettes import brewer
import geopandas as gpd
import json

In [None]:
# result directory
results_dir = 'results/'

In [None]:
# get slop values
fit = torch.load(results_dir + 'mutrans.svi.3000.1.50.coef_scale=0.1.reparam.full.10001.0.05.0.1.10.0.200.6.None..pt', 
                 map_location = 'cpu')
rate = fit['median']['rate']
rate_loc = fit['median']['rate_loc']
slop = rate - rate_loc

In [None]:
# regions by clades
slop.shape

In [None]:
# Load the entire constant dataset
max_num_clades = 3000
min_num_mutations = 1
min_region_size = 50
ambiguous = False
columns_filename=f"results/columns.{max_num_clades}.pkl"
features_filename=f"results/features.{max_num_clades}.{min_num_mutations}.pt"

In [None]:
input_dataset = mutrans.load_gisaid_data(
        device="cpu",
        columns_filename=columns_filename,
        features_filename=features_filename,
        min_region_size=min_region_size
)

In [None]:
# Get country labels for the place dimension
countries = list(( x.split(' / ')[1] for x in input_dataset['location_id'].keys() ))

In [None]:
# specify lineage to plot
lineage_to_plot = 'B.1.1.7'

In [None]:
# indexes of the clades in the lineage of interest
clade_idxs = list( input_dataset['clade_id'][k] for k in input_dataset['clade_to_lineage'] if  input_dataset['clade_to_lineage'][k] == lineage_to_plot ) 

In [None]:
# Put data in dataframe
coarse_data = pd.DataFrame({
    'country': countries,
    'slop': slop[:,clade_idxs].sum(-1).tolist(),
})

In [None]:
# Inspect
coarse_data.groupby('country').agg('sum').reset_index()

# Plotting on Map

In [None]:
shapefile = '~/disk1/geo_data/ne_110m_admin_0_countries.shp'

In [None]:
#Read shapefile using Geopandas
gdf = gpd.read_file(shapefile)[['ADMIN', 'ADM0_A3', 'geometry']]
#Rename columns.
gdf.columns = ['country', 'country_code', 'geometry']
gdf.head()

In [None]:
# Merge the data and the coordinates
merged = gdf.merge(coarse_data, left_on ="country", right_on = 'country', how = 'outer')
merged['slop'] = merged['slop'].fillna(0)
merged = merged[merged['country_code'].notna()]

In [None]:
#Read data to json.
merged_json = json.loads(merged.to_json())
#Convert to String like object.
json_data = json.dumps(merged_json)

In [None]:
fit = torch.load(results_dir + 'mutrans.svi.3000.1.50.coef_scale=0.1.reparam.full.10001.0.05.0.1.10.0.200.6.None..pt', 
                 map_location = 'cpu')

In [None]:
max_num_clades = 3000
min_num_mutations = 1
min_region_size = 50
ambiguous = False
columns_filename=f"results/columns.{max_num_clades}.pkl"
features_filename=f"results/features.{max_num_clades}.{min_num_mutations}.pt"

input_dataset = mutrans.load_gisaid_data(
    device="cpu",
    columns_filename=columns_filename,
    features_filename=features_filename,
    min_region_size=min_region_size
)

In [None]:
def plot_strain_slop( fit, input_dataset, results_dir = 'results/', lineage_to_plot = 'B.1.1.7'):
    shapefile = '~/disk1/geo_data/ne_110m_admin_0_countries.shp'                                                                                                
                                                                                        
    # get slop values
    rate = fit['median']['rate']
    rate_loc = fit['median']['rate_loc']
    slop = rate - rate_loc

    # Get country labels for the place dimension
    countries = list(( x.split(' / ')[1] for x in input_dataset['location_id'].keys() ))
    
    # indexes of the clades in the lineage of interest
    clade_idxs = list( input_dataset['clade_id'][k] for k in input_dataset['clade_to_lineage'] if  input_dataset['clade_to_lineage'][k] == lineage_to_plot ) 
    
    # Put data in dataframe
    coarse_data = pd.DataFrame({
        'country': countries,
        'slop': slop[:,clade_idxs].mean(-1).tolist(),
    })
    
    #Read shapefile using Geopandas
    gdf = gpd.read_file(shapefile)[['ADMIN', 'ADM0_A3', 'geometry']]
    #Rename columns.
    gdf.columns = ['country', 'country_code', 'geometry']

    # Merge the data and the coordinates
    merged = gdf.merge(coarse_data, left_on ="country", right_on = 'country', how = 'outer')
    merged['slop'] = merged['slop'].fillna(0)
    merged = merged[merged['country_code'].notna()]

    #Read data to json.
    merged_json = json.loads(merged.to_json())
    #Convert to String like object.
    json_data = json.dumps(merged_json)

    #Input GeoJSON source that contains features for plotting.
    geosource = GeoJSONDataSource(geojson = json_data)

    #Define a sequential multi-hue color palette.
    palette = brewer['YlGnBu'][8]

    #Reverse color order so that dark blue is highest obesity.
    palette = palette[::-1]

    #Instantiate LinearColorMapper that linearly maps numbers in a range, into a sequence of colors.
    color_mapper = LinearColorMapper(palette = palette, low = merged['slop'].min(), high = merged['slop'].max())

    #Create color bar. 
    color_bar = ColorBar(color_mapper=color_mapper, label_standoff=8,width = 500, height = 20,
        border_line_color=None,location = (0,0), orientation = 'horizontal')

    #Create figure object.
    p = figure(title = '', plot_height = 600 , plot_width = 950, toolbar_location = None)
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None

    #Add patch renderer to figure. 
    p.patches('xs','ys', source = geosource,fill_color = {'field' :'slop', 'transform' : color_mapper},
              line_color = 'black', line_width = 0.25, fill_alpha = 1)
    #Specify figure layout.
    p.add_layout(color_bar, 'below')

    #Display figure inline in Jupyter Notebook.
    output_notebook()

    #Display figure.
    show(p)
    
    return p


In [None]:
plot_strain_slop(fit, input_dataset, lineage_to_plot = 'B.1.351')

In [None]:
plot_strain_slop(fit, input_dataset, lineage_to_plot = 'B.1.1.7')

In [None]:
plot_strain_slop(fit, input_dataset, lineage_to_plot = 'B.1.617.2')

In [None]:
plot_strain_slop(fit, input_dataset, lineage_to_plot = 'P.1')

In [None]:
plot_strain_slop(fit, input_dataset, lineage_to_plot = 'BA.1')

In [None]:
plot_strain_slop(fit, input_dataset, lineage_to_plot = 'BA.2')