In [None]:
import json
from urllib.request import urlopen
from pathlib import Path

import pandas as pd
import geopandas as gpd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

sns.set_style('whitegrid')
sns.set_palette('colorblind')

with urlopen('https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson') as response:
    countries = json.load(response)
geo_loc_countries = set([countries['features'][x]['properties']['ADMIN'] for x in range(len(countries['features']))])

# National Bacterial AMR burden relative to sequencing

## Collated Bacterial AMR burden

In [None]:
burden_df = pd.concat([pd.read_csv(taxa_fp) for taxa_fp in Path('../data/ihme_microbe').glob("*.csv")]).drop_duplicates()
bacterial_burden_per_country = burden_df.groupby('Location')['Value'].sum()
bacterial_burden_per_country = bacterial_burden_per_country\
                                .reset_index(name='Bacterial AMR DALYs per 100k')\
                                .rename(columns={'Location': 'Country'})
        
    
bacterial_burden_per_country['Bacterial AMR DALYs per 100k (log)'] = np.log(bacterial_burden_per_country['Bacterial AMR DALYs per 100k'])
ihme_to_geoloc_countries = {'Bahamas': 'The Bahamas',
                            'Bolivia (Plurinational State of)': 'Bolivia',
                            'Brunei Darussalam': 'Brunei',
                            'Cabo Verde': 'Cape Verde',
                            'Congo': 'Republic of Congo',
                            'Czechia': 'Czech Republic',
                            "Côte d'Ivoire": "Ivory Coast",
                            "Democratic People's Republic of Korea": 'North Korea',
                            "Eswatini": "Swaziland",
                            'Guinea-Bissau': 'Guinea Bissau',
                            'Iran (Islamic Republic of)': "Iran",
                            "Lao People's Democratic Republic": "Laos",
                            'Micronesia (Federated States of)': "Federated States of Micronesia",
                            'North Macedonia': "Macedonia",
                            'Republic of Korea': "South Korea",
                            'Republic of Moldova': "Moldova",
                            'Russian Federation': "Russia",
                            'Serbia': 'Republic of Serbia',
                            'Syrian Arab Republic': "Syria",
                            'Taiwan (Province of China)': "Taiwan",
                            'Timor-Leste': 'East Timor',
                            'Tokelau': "drop",
                            'Venezuela (Bolivarian Republic of)': "Venezuela",
                            'Viet Nam': "Vietnam"}

bacterial_burden_per_country['Map Locations'] = bacterial_burden_per_country['Country']\
                                                    .apply(lambda x: ihme_to_geoloc_countries[x] if x in ihme_to_geoloc_countries else x)
# dropping tokelau as we don't have map locations for it
bacterial_burden_per_country = bacterial_burden_per_country[bacterial_burden_per_country['Map Locations'] != 'drop']

In [None]:
fig = px.choropleth(bacterial_burden_per_country, geojson=countries, featureidkey='properties.ADMIN',
                    locations='Map Locations', color='Bacterial AMR DALYs per 100k (log)',
                    color_continuous_scale="Viridis",
                    projection='eckert4')
fig.show()

## SRA Metadata

In [None]:
sra_metadata = pd.read_csv('../data/sra_metadata/all_non_human_sra_metadata.csv.xz', low_memory=False)

def link_ihme_pathogens_to_sra_taxa(sra_taxa):
    """
    Function to link IHME pathogens to actual SRA taxa
    """
    spp = ['Aeromonas spp.', 'Campylobacter spp.', 'Chlamydia spp.', 'Citrobacter spp.', 'Enterobacter spp.',
           'Legionella spp.', 'Morganella spp.', 'Mycoplasma spp.', 'Proteus spp.', 'Providencia spp.', 'Serratia spp.', 
           'Shigella spp.']

    species = ['Acinetobacter baumannii', 'Clostridioides difficile', 'Enterococcus faecalis', 'Enterococcus faecium', 
               'Escherichia coli', 'Haemophilus influenzae', 'Klebsiella pneumoniae', 'Listeria monocytogenes', 'Mycobacterium tuberculosis',
               'Neisseria gonorrhoeae', 'Neisseria meningitidis', 'Pseudomonas aeruginosa', 
               'Staphylococcus aureus', 'Streptococcus pneumoniae', 'Vibrio cholerae']
    
    for spp_taxa in spp:
        if sra_taxa.startswith(spp_taxa.split()[0]):
            return spp_taxa
    
    for spec_taxa in species:
        if sra_taxa.startswith(spec_taxa):
            return spec_taxa
    
    special = ['Non-typhoidal Salmonella', 'Other Klebsiella species', 'Other enterococci', 
               'Salmonella enterica serovar Paratyphi', 'Salmonella enterica serovar Typhi',
              'Group A Streptococcus', 'Group B Streptococcus']
    
    if sra_taxa.startswith("Klebsiella") and not sra_taxa.startswith("Klebsiella pneumoniae"):
        return 'Other Klebsiella species'
    elif sra_taxa.startswith('Enterococcus') and not (sra_taxa.startswith("Enterococcous faecium") or sra_taxa.startswith("Enterococcus faecalis")):
        return "Other enterococci"
    elif sra_taxa.startswith('Salmonella enterica subsp. enterica serovar Typhi'):
        return 'Salmonella enterica serovar Typhi'
    elif sra_taxa.startswith('Salmonella enterica subsp. enterica serovar Paratyphi'):
        return 'Salmonella enterica serovar Paratyphi'
    elif sra_taxa.startswith('Salmonella'):
        return 'Non-typhoidal Salmonella'
    # note this will slightly undercount GAS and GBS as there is other Strep with A antigen and B antigens
    elif sra_taxa.startswith('Streptococcus pyogenes'):
        return 'Group A Streptococcus'
    elif sra_taxa.startswith('Streptococcus agalactiae'):
        return "Group B Streptococcus"
    
    return "Non-IHME Taxa"

sra_metadata['IHME_taxa'] = sra_metadata['organism'].apply(link_ihme_pathogens_to_sra_taxa)
sra_metadata = sra_metadata[sra_metadata['IHME_taxa'] != 'Non-IHME Taxa']

sra_metadata.loc[sra_metadata['geo_loc_name_country_calc'].isna(), 'Sampling Country Metadata'] = "Not Provided"
sra_metadata.loc[sra_metadata['geo_loc_name_country_calc'] == 'uncalculated', 'Sampling Country Metadata'] = "Not Provided"
sra_metadata.loc[sra_metadata['Sampling Country Metadata'].isna(), 'Sampling Country Metadata'] = "Provided" 

In [None]:
sra_wgs_country_metadata_tally = sra_metadata.groupby('IHME_taxa')['Sampling Country Metadata'].value_counts() / sra_metadata.groupby('IHME_taxa').size() * 100
sra_wgs_country_metadata_tally = sra_wgs_country_metadata_tally.reset_index(name="% of SRA WGS Records")
sra_wgs_country_metadata_tally = sra_wgs_country_metadata_tally.query("`Sampling Country Metadata` == 'Not Provided'").sort_values('% of SRA WGS Records', ascending=False)
sra_wgs_country_metadata_tally = sra_wgs_country_metadata_tally.rename(columns={'% of SRA WGS Records': '% of SRA WGS Records\nLacking Country Metadata',
                                                                               'IHME_taxa': "Taxa"})

In [None]:
sns.catplot(data = sra_wgs_country_metadata_tally, x='% of SRA WGS Records\nLacking Country Metadata', y='Taxa', kind='bar')
plt.xlim(0,100)

## Burden Relative to Country

In [None]:
country_sra = sra_metadata[sra_metadata['Sampling Country Metadata'] == 'Provided']

sra_to_geo_loc = {'Bahamas': "The Bahamas",
                 'Christmas Island': 'Australia',
                 'French Guiana': "France",
                 'Gaza Strip': 'Palestine',
                 'Guadeloupe': "France",
                 'Guinea-Bissau': 'Guinea Bissau',
                 'Hong Kong': 'Hong Kong S.A.R.',
                 'Martinique': "France",
                 'Mayotte': "France",
                 'Reunion': "France",
                 'Serbia': "Republic of Serbia",
                 'Tanzania': 'United Republic of Tanzania',
                 'USA': 'United States of America',
                 'Viet Nam': 'Vietnam',
                 'West Bank': "Palestine"}

# no longer exists
country_sra = country_sra[country_sra['geo_loc_name_country_calc'] != 'Yugoslavia']

country_sra['Country'] = country_sra['geo_loc_name_country_calc'].apply(lambda x: sra_to_geo_loc[x] if x in sra_to_geo_loc else x)
genomes_per_country = country_sra.groupby(['Country', 'IHME_taxa']).size().reset_index(name='SRA Samples')

In [None]:
dalys_per_100k = burden_df.rename(columns={'Location': 'Country', 'Pathogen': 'IHME_taxa',
                                           "Value": "DALYs per 100k"})[['Country', 'IHME_taxa', 'DALYs per 100k']]
dalys_per_100k = dalys_per_100k.set_index(['Country', 'IHME_taxa'])
dalys_per_100k['SRA Samples'] = genomes_per_country.set_index(['Country', 'IHME_taxa'])['SRA Samples']
dalys_per_100k['SRA Samples'] = dalys_per_100k['SRA Samples'].fillna(0)
dalys_per_100k = dalys_per_100k.reset_index()
dalys_per_100k['DALYs per 100k per Genome'] = dalys_per_100k['DALYs per 100k'] / dalys_per_100k['SRA Samples']

In [None]:
dalys_per_100k[dalys_per_100k['SRA Samples'] > 0].groupby('IHME_taxa')['Country'].nunique().sort_values().plot(kind='barh')

In [None]:
dalys_per_100k['DALYs per 100k per Genome'] = (dalys_per_100k['DALYs per 100k'] ) / (dalys_per_100k['SRA Samples'] + 1)