## Functions

This notebook includes all the functiones used for visualisation and analysis. Here you can modify more specific parameters. If you need help with any of these, use help(name_of_function) on any of the analysis notebooks.



In [None]:
# determine how many pixels (the resolution) of the plots
plt.rcParams['figure.dpi'] = 200

### Create a map showing the sample distribution



In [1]:
def map_samples(pf6plus, label, stock_image=False, individual_sites=False, save_fig=False):
    
    ''' Plot a world map showcasing the countries where the  Pf6+ samples come from.
    '''
    #temp supressing warning for slicing subsets (mapping with exact number on both dataframes)
    pd.options.mode.chained_assignment = None

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    ax.set_extent([-80, 153, -26, 29], crs=ccrs.PlateCarree())

    # dictionary storing the colours for each population
    population_colours = collections.OrderedDict()
    population_colours['SAM'] = "#4daf4a"
    population_colours['WAF'] = "#e31a1c"
    population_colours['CAF'] = "#fd8d3c" 
    population_colours['EAF'] = "#fecc5c" 
    population_colours['SAS'] = "#984ea3"
    population_colours['WSEA'] = "#9ecae1" 
    population_colours['ESEA'] = "#3182bd" 
    population_colours['OCE'] = "#f781bf"

    pf6plus['population_colour'] = pf6plus['Population'].map(population_colours)

    # get country borders
    resolution = '10m'
    category = 'cultural'
    name = 'admin_0_countries'
    shpfilename = shapereader.natural_earth(resolution, category, name)

    # read the shapefile using geopandas
    df = geopandas.read_file(shpfilename)
    
    countries = np.unique(pf6plus['Country'], return_counts=True)[0]
    samples = np.unique(pf6plus['Country'], return_counts=True)[1]


    # normalising using the (# of samples) corresponding to the country with the highest number of samples `max(np.unique(pf6plus['Country'], return_counts=True)[1])`.
    samples_norm = samples/2302
    # choose colourmap palette for map
    cmap = plt.cm.get_cmap('YlOrRd')

    #choose colourmap palette for country category scatterplot
    #country_cat = np.unique(pf6plus['Country']
    col = np.linspace(0, 1, len(np.unique(pf6plus['Country'])))
    colordict = dict(zip(np.unique(pf6plus['Country']), col))    
    pf6plus['Country_color'] = pf6plus['Country'].apply(lambda x: colordict[x])

    adm1_cmap = plt.cm.get_cmap('YlOrRd')
    for country, samp_norm in zip(countries, samples_norm):
        # read the borders for each country
        borders = df.loc[df['ADMIN'] == country,'geometry'].values[0]
        
        if individual_sites:
          #admin sites by colour of country
          ax.scatter(pf6plus['Longitude_adm1'], pf6plus['Latitude_adm1'], c=pf6plus['Country_color'],s=3.5, marker='o', linewidth=0.3)

        # plot the coloured country on a map
        else:
          ax.add_geometries([borders], crs=ccrs.PlateCarree(), facecolor=cmap(samp_norm), edgecolor='none', zorder=1)
    #add colour bar on the map
    if individual_sites == False:
          dummy_scat = ax.scatter([None] * len(samples), [None] * len(samples), c=samples_norm, cmap=cmap, vmin=0, vmax=2302, zorder=0)
          fig.colorbar(mappable=dummy_scat, label='Number of samples', orientation='vertical', shrink=.3) #

    #choose whether to have a stock image on the map
    if stock_image:
      ax.stock_img()
    else:
      ax.add_feature(cartopy.feature.LAND, color = 'lightgray')
      ax.add_feature(cartopy.feature.OCEAN)
      ax.add_feature(cartopy.feature.COASTLINE, edgecolor = 'grey')
    plt.title(label + " Sampled Countries", fontsize=6)
    plt.show()

    #save figure as .png with label & data as name
    if save_fig:
      fig.savefig(f'{label}_{strftime("%Y-%m-%d")}.png', bbox_inches='tight', dpi=200)


In [None]:
def temporal_samples(samples = 'Pf6+'):
    
    ''' Plot a stacked bar plot showcasing the temporal distribution of the samples in Pf6+.
    '''

    fig, ax = plt.subplots()

    if (samples == 'Pf6+'):

      ax.bar(np.unique(pf6plus.loc[pf6plus['Process'] == 'WGS']['Year']), np.unique(pf6plus.loc[(pf6plus['Process'] == 'WGS'), 'Year'], return_counts=True)[1], label='Pf6')
      ax.bar(np.unique(pf6plus.loc[pf6plus['Process'] != 'WGS']['Year']), np.unique(pf6plus.loc[(pf6plus['Process'] != 'WGS'), 'Year'], return_counts=True)[1], label='GenRe-Mekong')
      
    min_year = (pf6plus['Year']).min()
    max_year = (pf6plus['Year']).max()

    ax.set_xlim(min_year, max_year)

    ax.set_ylabel('# of samples')
    ax.set_title('Year of Collection')
    ax.legend()

    ##add versions for individual version (TO DO)
    # else: 
    #   ax.bar(np.unique(pf6plus.loc[pf6plus['Process'] == 'WGS']['Year']), np.unique(pf6plus.loc[(pf6plus['Process'] == 'WGS'), 'Year'], return_counts=True)[1], label='Pf6')


    plt.show()


###Tabulate drug resistant variants



In [None]:
def tabulate_drug_resistant(drug, country = None, population = None, year = None, bin=False):

    ''' Tabulate the frequency of drug resistant samples per country/year 
    
    Parameters:
      - drug: Any of the drugs in the Pf6+ dataframe ['Artemisinin', 'Chloroquine', 'DHA-PPQ', 'Piperaquine', 'Pyrimethamine', 'S-P', 'S-P-IPTp', 'Sulfadoxine']
      - country: Any of the countries in the Pf6+ dataframe (if specified, population value is not used) ['Bangladesh', 'Benin', 'Burkina Faso', 'Cambodia', 'Cameroon', 'Colombia', 'Congo DR', 'Ethiopia', 'Gambia', 'Ghana', 'Guinea', 'India', 'Indonesia', 'Ivory Coast', 'Kenya', 'Laos', 'Madagascar', 'Malawi', 'Mali', 'Mauritania', 'Mozambique', 'Myanmar', 'Nigeria', 'Papua New Guinea', 'Peru', 'Senegal', 'Tanzania', 'Thailand', 'Uganda', 'Viet Nam']
      - population: Any of the populations in the Pf6+ dataframe ['CAF', 'EAF', 'ESEA', 'OCE', 'SAM', 'SAS', 'WAF', 'WSEA']
      - year: An array with the year(s) in the Pf6+ dataframe [2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]
      - bin: If True, all the years between the specified values will be used. If False, individual years are used.

    Returns:
      A dataframe showing the number of Resistant, Sensitive & Undetermined 
      samples using the drug/country/year (or) drug/country/year combination provided. . The total number
      of samples and drug resistant frequency is also provided. 
    '''


    if all([country,year]):

      if bin:
        samples = pf6plus.loc[(pf6plus.Country.isin([country])) & (pf6plus.Year.isin([b for b in range(min(year),max(year))]))]
        print(drug + ' resistant samples ' + 'in ' + str(country) + ' from ' + str(min(year)) + ' to ' + str(max(year)) )
      else:
        samples = pf6plus.loc[(pf6plus.Country.isin([country])) & (pf6plus.Year.isin(year))]
        print(drug + ' resistant samples in ' + str(country) + ' on ' + str(year))

      phenotypes = samples.groupby([drug]).size().fillna(0).astype(int).to_frame('Samples').transpose()
      
    elif all([population,year]):

      if bin:

        samples = pf6plus.loc[(pf6plus.Population.isin([population])) & (pf6plus.Year.isin([b for b in range(min(year),max(year))]))]
        print(drug + ' resistant samples ' + 'in ' + str(population)  + ' from ' + str(min(year)) + ' to ' + str(max(year)) )
      else:
        samples = pf6plus.loc[(pf6plus.Population.isin([population])) & (pf6plus.Year.isin(year))]
        print(drug + ' resistant samples in ' + str(population) + ' on ' + str(year))

      phenotypes = samples.groupby([drug]).size().fillna(0).astype(int).to_frame('Samples').transpose()

    # if no year is specified, return all years
    elif country:
      samples = pf6plus.loc[pf6plus.Country.isin([country])]
      phenotypes = samples.groupby(['Year', drug]).size().unstack().fillna(0).astype(int)
      print(drug + ' resistant samples in ' + str(country))
    
    # if no country is specified, return all years
    else:
      if bin:
        samples = pf6plus.loc[pf6plus.Year.isin([b for b in range(min(year),max(year))])]
        print(drug + ' resistant samples in all countries from ' + str(min(year)) + ' to ' + str(max(year)) )
      
      elif population:
        samples = pf6plus.loc[(pf6plus.Population.isin([population]))]
        phenotypes = samples.groupby(['Population', drug]).size().unstack().fillna(0).astype(int)
        print(drug + ' resistant samples on all years')

      else:
        print(drug + ' resistant samples on all years')
        phenotypes = pf6plus.groupby(['Country', drug]).size().unstack().fillna(0).astype(int)
      
    
    phenotypes = phenotypes.assign(Total=phenotypes.sum(1))

    # calculating the frequency using all samples, no matter how many they are! (note that some combinations will have a small number of samples, so the frequency will not be an adequeate estimate))
    phenotypes['Resistant Frequency'] = [round(row[0]/(row[0]+row[1]), 2) for row in phenotypes[['Resistant','Sensitive']].to_numpy()]

    # fully troubleshoot numbers (TO DO)
    # add exception for multiple countries & multiple years (TO DO)
    # add exception for repeated countries/years (TO DO)
    # try:
    #   return phenotypes['Resistant']
    # except KeyError:
    #   raise KeyError('The specified Country/Year combination is not in the dataset')

    return(phenotypes)



### Plot Drug Resistant Prevalence

In [None]:
def plot_dr_prevalence(drugs, country = None, population = None, year = None, bin=False):
    
    ''' Plot the prevalence of resistant samples per country/year 
    
    Parameters:
      - drug: Any/list of the drugs in the Pf6+ dataframe ['Artemisinin', 'Chloroquine', 'DHA-PPQ', 'Piperaquine', 'Pyrimethamine', 'S-P', 'S-P-IPTp', 'Sulfadoxine']
      - country: Any of the countries in the Pf6+ dataframe (if specified, population value is not used) ['Bangladesh', 'Benin', 'Burkina Faso', 'Cambodia', 'Cameroon', 'Colombia', 'Congo DR', 'Ethiopia', 'Gambia', 'Ghana', 'Guinea', 'India', 'Indonesia', 'Ivory Coast', 'Kenya', 'Laos', 'Madagascar', 'Malawi', 'Mali', 'Mauritania', 'Mozambique', 'Myanmar', 'Nigeria', 'Papua New Guinea', 'Peru', 'Senegal', 'Tanzania', 'Thailand', 'Uganda', 'Viet Nam']
      - population: Any of the populations in the Pf6+ dataframe ['CAF', 'EAF', 'ESEA', 'OCE', 'SAM', 'SAS', 'WAF', 'WSEA']
      - year: Any/list of the years in the Pf6+ dataframe [2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]
      - bin: If True, all the years between the specified values will be used. If False, individual years are used.

    Returns:
      A series of plots (one per drug) showing the prevalence of resistant variants the drug/country/year (or) drug/country/year combination provided.
      Note that to increase confidence on disperse data, we only use country (or) population/year combinations with at least 25 samples.
    '''
    
    
    locus_year_group = pd.DataFrame(pf6plus.loc[(pf6plus['Country'] == country) & (pf6plus['Population'] == population)])
    pf_tmp = pd.DataFrame(locus_year_group.groupby(['Year', 'Country']).size())
    pf_tmp = pf_tmp.reset_index()
    pf_country = pf6plus.loc[(pf6plus['Year'].isin(pf_tmp[pf_tmp[0]>25]['Year'])) & (pf6plus['Country'].isin(pf_tmp[pf_tmp[0]>25]['Country']))]
 
    locus_year_group = pd.DataFrame(pf6plus.loc[(pf6plus['Population'] == population)])
    pf_tmp = pd.DataFrame(locus_year_group.groupby(['Year', 'Country']).size())
    pf_tmp = pf_tmp.reset_index()
    samples_subset = pf6plus.loc[(pf6plus['Year'].isin(pf_tmp[pf_tmp[0]>25]['Year'])) & (pf6plus['Country'].isin(pf_tmp[pf_tmp[0]>25]['Country']))]
    pf = samples_subset.loc[(samples_subset['Country'] != country) & (samples_subset['Population'] == population)]
    
    #define axs in fig based on # of drugs
    fig_cols = len(drugs)

    if fig_cols <= 6:
      fig_rows = 2
    else:
      fig_rows = 3
    

    if fig_cols == 1:
      fig, axes = plt.subplots(fig_cols, fig_rows, figsize=(2*7,fig_cols*4))
    else:
      fig, axes = plt.subplots(fig_cols, fig_rows, figsize=(fig_cols*7,fig_cols*4))

    for i, ax in enumerate(axes.flatten()):

        if (i<len(drugs)):
            phenotypes = pf.groupby(['Year', drugs[i]]).size().unstack().fillna(0).astype(int)
            phenotypes = phenotypes.assign(Total=phenotypes.sum(1))
            
            p = pf_country.groupby(['Year', drugs[i]]).size().unstack().fillna(0).astype(int)
            p = p.assign(Total=p.sum(1))

            #check if it is in the data before giving the size 
            try:
              p['Resistant']
              p['Sensitive']
            except:
              #add diff exceptions messages (dependending on resistant, sensitive, undetermined)
              print('Not enough data on ' + drugs[i] + ' resistant variants in dataset')
              fig.delaxes(axes.flatten()[i])
              continue

            #add exceptions for spellings and countries not matching the population provided (allow but throw a warning)
            if ('Resistant' in phenotypes) & ('Sensitive' in phenotypes):
                phenotypes['Prevalence ' + population] = [round(row[0]/(row[0]+row[1]), 2) for row in phenotypes[['Resistant','Sensitive']].to_numpy()]
                p['Prevalence ' + country] = [round(row[0]/(row[0]+row[1]), 2) for row in p[['Resistant','Sensitive']].to_numpy()]

            elif ('Resistant' in phenotypes ):
                phenotypes['Prevalence ' + population] = [round(row[0]/(row[0]), 2) for row in phenotypes[['Resistant']].to_numpy()]
                p['Prevalence ' + country] = [round(row[0]/(row[0]), 2) for row in p[['Resistant']].to_numpy()]
                
            elif ('Sensitive' in phenotypes):
                phenotypes['Prevalence ' + population] = [round(row[0]/(row[0]), 2) for row in phenotypes[['Sensitive']].to_numpy()]
                p['Prevalence ' + country] = [round(row[0]/(row[0]), 2) for row in p[['Sensitive']].to_numpy()]

            ax.set_ylim(0,1)
            ax.set_xlim(2001,2019)
            
            ax.plot(phenotypes['Prevalence ' + population],'o-')
            ax.plot(p['Prevalence ' + country],'o-')

            ax.set_title(drugs[i] +' resistance prevalence in ' + country + ' from ' + str(min(phenotypes.index)) + ' to '+ str(max(phenotypes.index)), size= 15)
            ax.legend(['Prevalence ' + population, 'Prevalence ' + country])

            #delete unused axis (need to reorder drugs before plotting, avoid whitespace)
            axes = axes.flatten()
            for j in range(fig_cols, fig_cols*fig_rows):
                axes[j].axis("off")
            
            #implement evenly spaced yearly bins TO DO


### Plot most common haplotypes per population/country

In [None]:
def plot_haplotype_frequency(gene, top_haplotypes=5, country = None, pop = None, year = None, bin=False):
    
    ''' Tabulate the frequency of top n haplotypes on a specife gene per country (or) population per year 
    
    Parameters:
      - gene: Any of the genes in the Pf6+ dataframe ['PfCRT', 'Kelch', 'PfDHFR', 'PfEXO', 'PGB', 'Plasmepsin2/3', 'PfDHPS', 'PfMDR1']
      - top_haplotypes: The (n) most common haplotypes, default is 5. These excludes missing haplotypes.
      - country: Any of the countries in the Pf6+ dataframe (if specified, population value is not used) ['Bangladesh', 'Benin', 'Burkina Faso', 'Cambodia', 'Cameroon', 'Colombia', 'Congo DR', 'Ethiopia', 'Gambia', 'Ghana', 'Guinea', 'India', 'Indonesia', 'Ivory Coast', 'Kenya', 'Laos', 'Madagascar', 'Malawi', 'Mali', 'Mauritania', 'Mozambique', 'Myanmar', 'Nigeria', 'Papua New Guinea', 'Peru', 'Senegal', 'Tanzania', 'Thailand', 'Uganda', 'Viet Nam']
      - population: Any of the populations in the Pf6+ dataframe ['CAF', 'EAF', 'ESEA', 'OCE', 'SAM', 'SAS', 'WAF', 'WSEA']
      - year: Any/list of the years in the Pf6+ dataframe [2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]
      - bin: If True, all the years between the specified values will be used. If False, individual years are used.

    Returns:
      A dataframe showing the number of Resistant, Sensitive & Undetermined 
      samples using the drug/country/year (or) drug/country/year combination provided. . The total number
      of samples and drug resistant frequency is also provided. 
    '''
   
    if pop:
      loc = 'Population'
      location = pop
    elif country:
      loc = 'Country'
      location = country
    else:
      print('No country or population provided')


    for population in location:
        #get samples size for country selected
        locus_year_group = pd.DataFrame(pf6plus.loc[(pf6plus[loc] == population)])

        pf = pd.DataFrame(locus_year_group.groupby(['Year', 'Country']).size())
        pf = pf.reset_index()

        #get year/country combinations for each point selected
        samples_subset = pf6plus.loc[(pf6plus['Year'].isin(pf[pf[0]>25]['Year'])) & (pf6plus['Country'].isin(pf[pf[0]>25]['Country']))]

        top_haps = list(samples_subset.groupby(gene).size().sort_values(ascending=False).index)

        # exclude missing values for all haplotypes
        if (top_haps):
            missing = '-'* len(top_haps[0])
            
            if (missing in top_haps):
                top_haps.remove(missing)

            # individual missing values for Kelch
            if ('-' in top_haps):
                top_haps.remove('-')
            
            for hap in top_haps[:]:
               
                if (',' in hap) | ('-' in hap): 
                    top_haps.remove(hap)

        phenotypes = samples_subset.groupby(['Year', gene]).size().unstack().fillna(0).astype(int)

        plt.xlim(2001,2019)
        plt.ylim(0,1)
        
        phenotypes_norm=[]

        for i in range(0,len(phenotypes)):
            phenotypes_n=[]
            for j in range(0,len(top_haps)):
              phenotypes_n.append(phenotypes[top_haps[j]].values[i])

            phenotypes_norm.append(sum(phenotypes_n))

        if (len(top_haps)>=5):
            ordered_haps = sorted([top_haps[0], top_haps[1], top_haps[2], top_haps[3], top_haps[4]])
            
            #check dynamic changes of colours (TO DO)
            all_proportions = phenotypes[ordered_haps[0]] /(phenotypes_norm) + phenotypes[ordered_haps[1]] /(phenotypes_norm) + phenotypes[ordered_haps[2]] /(phenotypes_norm) + phenotypes[ordered_haps[3]] /(phenotypes_norm)+ phenotypes[ordered_haps[4]] /(phenotypes_norm)

            plt.plot(phenotypes[ordered_haps[0]] /phenotypes_norm , label = ordered_haps[0], color= 'green',linestyle='--', marker='o')
            plt.plot(phenotypes[ordered_haps[1]] /phenotypes_norm , label = ordered_haps[1], color= 'blue',linestyle='--', marker='o')
            plt.plot(phenotypes[ordered_haps[2]] /phenotypes_norm , label = ordered_haps[2], color= 'magenta',linestyle='--', marker='o')
            plt.plot(phenotypes[ordered_haps[3]] /phenotypes_norm, label = ordered_haps[3], color= 'yellow',linestyle='--', marker='o')
            plt.plot(phenotypes[ordered_haps[4]] /phenotypes_norm , label = ordered_haps[4], color= 'orange',linestyle='--', marker='o')
            plt.plot(1-all_proportions, label = "Other Haplotypes", color= 'grey',linestyle='--', marker='o')

        elif (len(top_haps)==4):
            ordered_haps = sorted([top_haps[0], top_haps[1], top_haps[2], top_haps[3]])
            all_proportions = phenotypes[ordered_haps[0]] /(phenotypes_norm) + phenotypes[ordered_haps[1]] /(phenotypes_norm) + phenotypes[ordered_haps[2]] /(phenotypes_norm) + phenotypes[ordered_haps[3]] /(phenotypes_norm)

            plt.plot(phenotypes[ordered_haps[0]] /phenotypes_norm , label = ordered_haps[0], color= 'green')
            plt.plot(phenotypes[ordered_haps[1]] /phenotypes_norm , label = ordered_haps[1], color= 'blue')
            plt.plot(phenotypes[ordered_haps[2]] /phenotypes_norm , label = ordered_haps[2], color= 'green')
            plt.plot(phenotypes[ordered_haps[3]] /phenotypes_norm, label = ordered_haps[3], color= 'red')
            plt.plot(1-all_proportions, label = "Other Haplotypes", color= 'grey',linestyle='--', marker='o')
            plt.plot(1-all_proportions, label = "Other Haplotypes", color= 'grey',linestyle='--', marker='o')


        elif (len(top_haps)==3):
            ordered_haps = sorted([top_haps[0], top_haps[1], top_haps[2]])
            all_proportions = phenotypes[ordered_haps[0]] /(phenotypes_norm) + phenotypes[ordered_haps[1]] /(phenotypes_norm) + phenotypes[ordered_haps[2]] /(phenotypes_norm) 

            plt.plot(phenotypes[ordered_haps[0]] /phenotypes_norm , label = ordered_haps[0], color= 'blue', marker='o')
            plt.plot(phenotypes[ordered_haps[1]] /phenotypes_norm , label = ordered_haps[1], color= 'yellow', marker='o')
            plt.plot(phenotypes[ordered_haps[2]] /phenotypes_norm , label = ordered_haps[2], color= 'green', marker='o')
            plt.plot(1-all_proportions, label = "Other Haplotypes", color= 'grey',linestyle='--', marker='o')


        elif (len(top_haps)==2):
            ordered_haps = sorted([top_haps[0], top_haps[1]])
            all_proportions = phenotypes[ordered_haps[0]] /(phenotypes_norm) + phenotypes[ordered_haps[1]] /(phenotypes_norm) 

            plt.plot(phenotypes[ordered_haps[0]] /phenotypes_norm , label = ordered_haps[0], color= 'blue', marker='o')
            plt.plot(phenotypes[ordered_haps[1]] /phenotypes_norm , label = ordered_haps[1], color= 'purple', marker='o')
            plt.plot(1-all_proportions, label = "Other Haplotypes", color= 'grey',linestyle='--', marker='o')

        elif (len(top_haps)==1):
            ordered_haps = [top_haps[0]]
            all_proportions = phenotypes[ordered_haps[0]] /(phenotypes_norm) 
            
            plt.plot(phenotypes[ordered_haps[0]] /phenotypes_norm , label = ordered_haps[0], color= 'blue', marker='o')
            plt.plot(1-all_proportions, label = "Other Haplotypes", color= 'grey',linestyle='--', marker='o')

        plt.axvline(2012,color='grey', linestyle='--', linewidth=1)
        
        #add year exceptions (TO DO)
        #add multiple countries /pops exceptions (TO DO)
        #ensure inconsistencies between E/WSEA in Thailand are taken into account (TO DO)
        #if no arg for population, take the first ocurrence (note that this could be a problem with countries in multiple populations (eg. Thailand))
        # if pop is None:
        #     pop = pf6plus.loc[pf6plus['Country']==country]['Population'][0]
        
        if country:
          plt.title(gene + ' Predominant Variants in '+ (country[0]))
        else:
          plt.title(gene + ' Predominant Variants in '+ population)
          
        plt.legend()
        plt.show()
