# Model plots for shark and ray meat landings and trade applied to 2014-2019 data

In [1]:
import os

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pyt
import seaborn as sns
import pdb
from matplotlib.gridspec import GridSpec
import xarray as xr
import xarray_einstats
import rdata as rd
import networkx as nx
from mpl_toolkits.basemap import Basemap
import geopandas as gpd

from matplotlib.path import Path
import matplotlib.patches as patches

In [2]:
# Set figure style.
az.style.use("arviz-darkgrid")
bd = os.getcwd() + "/../Data/"
bf = os.getcwd() + "/../Figures/"

In [3]:
# Helper functions
def indexall(L):
    poo = []
    for p in L:
        if not p in poo:
            poo.append(p)
    Ix = np.array([poo.index(p) for p in L])
    return poo, Ix


# Helper functions
match = lambda a, b: np.array([b.index(x) if x in b else None for x in a])


def unique(series: pd.Series):
    "Helper function to sort and isolate unique values of a Pandas Series"
    return series.sort_values().unique()


def pairplot_divergence(trace, basevar, targetvar, ax=None, divergence=True, color='C3', divergence_color='C2'):
    #theta = trace.get_values(varname=basevar, combine=True)[:, 0]
    theta = trace.posterior[basevar].values.flatten()
    logtau = trace.posterior[targetvar].values.flatten()
    if not ax:
        _, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(theta, logtau, 'o', color=color, alpha=.5)
    if divergence:
        divergent = trace.sample_stats.diverging.values.flatten()
        ax.plot(theta[divergent], logtau[divergent], 'o', color=divergence_color)
    ax.set_xlabel(basevar)
    ax.set_ylabel(targetvar)
    ax.set_title('scatter plot between log('+targetvar+') and '+basevar);
    return ax

# A small wrapper function for displaying the MCMC sampler diagnostics as above
def report_trace(trace,basevar,targetvar,logscale=False):
    # plot the trace of log(tau)
    pm.plot_trace({targetvar: trace.posterior[targetvar].values.flatten()});

    # plot the estimate for the mean of log(τ) cumulating mean
    if logscale:
        logtau = np.log(trace.posterior[targetvar].values.flatten())
    else:
        logtau = trace.posterior[targetvar].values.flatten()
    mlogtau = [np.mean(logtau[:i]) for i in np.arange(1, len(logtau))]
    plt.figure(figsize=(15, 4))
    #plt.axhline(0.7657852, lw=2.5, color='gray')
    plt.plot(mlogtau, lw=2.5)
    plt.ylim(0, 2)
    plt.xlabel('Iteration')
    plt.ylabel('MCMC mean of log('+targetvar+')')
    plt.title('MCMC estimation of log('+targetvar+')')
    plt.show()

    # display the total number and percentage of divergent
    divergent = trace.sample_stats.diverging.values.flatten()
    print('Number of Divergent %d' % divergent.nonzero()[0].size)
    divperc = divergent.nonzero()[0].size / len(trace) * 100
    print('Percentage of Divergent %.1f' % divperc)

    # scatter plot between log(tau) and theta[0]
    # for the identifcation of the problematic neighborhoods in parameter space
    pairplot_divergence(trace,basevar,targetvar);

## Load data

In [4]:
# Load data
exec(open("Joint_Trade_Landings_Data_Perth.py").read())

FileNotFoundError: [Errno 2] No such file or directory: 'Joint_Trade_Landings_Data_Lonsdale.py'

# Import landings inference object

In [None]:
# Import MultiTrace objects
idata_landings_x = az.from_netcdf("idata-landings-model_Lonsdale.nc")
idata_trade_x = az.from_netcdf("idata-trade-model_Lonsdale.nc")

In [None]:
# ArviZ doesn't handle MultiIndex yet
# Making it aware of the real data labeling at the obs level
more_coords = {
    "Exporter": ("exporter", biggest_countries_long),
    "Importer": ("importer", biggest_countries_long)
}

# Update landings coords
idata_landings_x.prior = idata_landings_x.prior.assign_coords(more_coords)
idata_landings_x.prior_predictive = idata_landings_x.prior_predictive.assign_coords(more_coords)
idata_landings_x.posterior = idata_landings_x.posterior.assign_coords(more_coords)
idata_landings_x.posterior_predictive = idata_landings_x.posterior_predictive.assign_coords(
    more_coords
)
idata_landings_x.observed_data = idata_landings_x.observed_data.assign_coords(more_coords)
idata_landings_x.constant_data = idata_landings_x.constant_data.assign_coords(more_coords)

# Update trade coords
idata_trade_x.prior = idata_trade_x.prior.assign_coords(more_coords)
idata_trade_x.prior_predictive = idata_trade_x.prior_predictive.assign_coords(more_coords)
idata_trade_x.posterior = idata_trade_x.posterior.assign_coords(more_coords)
idata_trade_x.posterior_predictive = idata_trade_x.posterior_predictive.assign_coords(
    more_coords
)
idata_trade_x.observed_data = idata_trade_x.observed_data.assign_coords(more_coords)
idata_trade_x.constant_data = idata_trade_x.constant_data.assign_coords(more_coords)

#"""

In [None]:
#"""
mu_pp = idata_landings_x.posterior_predictive["obs_spp"]
_, ax = plt.subplots()
ax.scatter(idata_landings_x.observed_data["obs_spp"], mu_pp.mean(("chain", "draw")))
az.plot_hdi(idata_landings_x.observed_data["obs_spp"], mu_pp)
ax.plot((0,12.5),(0,12.5),linestyle=":")
ax.set_xlabel("Observed")
ax.set_ylabel("Expected")
ax.set_title("Observed vs Expected log Species landings")
plt.savefig(bf+'/Diagnostics/'+'Observed_vs_Expected_log_landings.jpg',dpi=300);
#"""

In [None]:
#"""
mu_pp = idata_landings_x.posterior_predictive["obs_taxon"]
_, ax = plt.subplots()
ax.scatter(idata_landings_x.observed_data["obs_taxon"], mu_pp.mean(("chain", "draw")))
az.plot_hdi(idata_landings_x.observed_data["obs_taxon"], mu_pp)
ax.plot((0,12.5),(0,12.5),linestyle=":")
ax.set_xlabel("Observed")
ax.set_ylabel("Expected")
ax.set_title("Observed vs Expected log Taxon landings")
plt.savefig(bf+'/Diagnostics/'+'Observed_vs_Expected_log_landings.jpg',dpi=300);
#"""

In [None]:
#"""
mu_pp = idata_trade_x.posterior_predictive["log_shark_trade"]
_, ax = plt.subplots()
ax.scatter(idata_trade_x.observed_data["log_shark_trade"], mu_pp.mean(("chain", "draw")))
az.plot_hdi(idata_trade_x.observed_data["log_shark_trade"], mu_pp)
ax.plot((0,12.5),(0,12.5),linestyle=":")
ax.set_xlabel("Observed")
ax.set_ylabel("Expected")
ax.set_title("Observed vs Expected log Shark Trade")
plt.savefig(bf+'/Diagnostics/'+'Observed_vs_Expected_log_shark_trade.jpg',dpi=300);
#"""

In [None]:
#"""
mu_pp = idata_trade_x.posterior_predictive["log_ray_trade"]
_, ax = plt.subplots()
ax.scatter(idata_trade_x.observed_data["log_ray_trade"], mu_pp.mean(("chain", "draw")))
az.plot_hdi(idata_trade_x.observed_data["log_ray_trade"], mu_pp)
ax.plot((0,12.5),(0,12.5),linestyle=":")
ax.set_xlabel("Observed")
ax.set_ylabel("Expected")
ax.set_title("Observed vs Expected log Ray Trade")
plt.savefig(bf+'/Diagnostics/'+'Observed_vs_Expected_log_ray_trade.jpg',dpi=300);
#"""

# Global summaries

## Observed bits for plotting


In [None]:
# Grab species level estimates for observed data
Obs_spp_data = np.exp(sdata.drop(columns=['year','year_spp_id','country_spp_id','species_spp_id']
          ).groupby(['country','species']).mean()).rename(columns={"logReported_species_landings": "Reported_landings"})

# Create xarray data of totals by species
Obs_spp_sums = xr.Dataset(Obs_spp_data.groupby(level='species').sum().sort_values(by='Reported_landings', ascending=False))

# Grab taxon level estimates for observed data
Obs_tax_data = np.exp(txdata.drop(columns=['year','year_tax_id','country_tax_id','taxon_tax_id']
          ).groupby(['country','taxon']).mean()).rename(columns={"logReported_taxon_landings": "Reported_landings"})

# Net estimated unreported species-level landings
Net_spp_landings = idata_landings_x.posterior['Latent_landings']-idata_landings_x.posterior['CountrySPP_landings']

In [None]:
# Look at top 10 reported landings by species
Obs_spp_data.groupby(level='species').sum().sort_values(by='Reported_landings', ascending=False).head(10)

## Plot top reported and unreported species

In [None]:
# = = = = = Set up plot
_, ax = plt.subplots(1,2,figsize=(10, 10))

# Select top 10% of landed species
tmp = idata_landings_x.posterior['CountrySPP_landings'].sum(('exporter')).median(('chain','draw'))
tmp = tmp.sortby(tmp, ascending=False)
sppx_landed = tmp[tmp>np.quantile(tmp,0.90)].species

# Select top 10% of unreported species
tmp = Net_spp_landings.sum(('exporter')).median(('chain','draw'))
tmp = tmp.sortby(tmp, ascending=False)
sppx_unrep = tmp[tmp>np.quantile(tmp,0.90)].species

# = = = = = Plot reported species landings
az.plot_forest(idata_landings_x.posterior['CountrySPP_landings'].sum(('exporter')).sel(species=sppx_landed).rename(''),
                    hdi_prob=0.9,
                    #transform=np.exp,
                    #figsize=(10, 20)
               ax=ax[0]
                   )
ax[0].set(
    title="Top reported landings",
    xlabel="tonnes",
    ylabel="",
);
#ax[0].set_xlim(0,200000)
for tick in ax[0].get_xticklabels():
    tick.set_rotation(45)

# = = = = = Plot unreported species landings
az.plot_forest(Net_spp_landings.sum(('exporter')).sel(species=sppx_unrep).rename(''),
                    #transform=np.exp,
                    #figsize=(10, 20)
                    hdi_prob=0.9,
                   ax=ax[1]
                   )
ax[1].set(
    title="Top unreported landings",
    xlabel="tonnes",
    ylabel="",
)
#ax[1].set_xlim(0,120000)
for tick in ax[1].get_xticklabels():
    tick.set_rotation(45)

plt.savefig(bf+'/Global/'+'Reported_Unreported_species.jpg',dpi=300);

## Plot top taxa traded

In [None]:
# Plot top taxa traded
# = = = = = Plot latent species landings
tax_landed = np.exp(idata_landings_x.posterior['CountryTaxon_log_landings']).sum(('exporter')).rename('')
tax_landed = tax_landed.sortby(tax_landed.mean(('chain','draw')), ascending=False)
ax = az.plot_forest(
        tax_landed,
        hdi_prob=0.9,
        )
ax[0].set(
    title="Top latent landings by taxon",
    xlabel="tonnes",
    ylabel="",
);
#ax[0].set_xlim(0,.5)
for tick in ax[0].get_xticklabels():
    tick.set_rotation(45)
plt.savefig(bf+'/Global/'+'Taxon_landings.jpg',dpi=300);

In [None]:
# Plot top unreported species traded in each taxon aggregation
for taxax in tax_landed.taxon.to_numpy()[:5]:
    # Select top 10% of landedspecies
    tmp = idata_landings_x.posterior['CountryTaxon_SPP_landings'].sum(('exporter')).sel(taxon=taxax).median(('chain','draw'))
    tmp = tmp.sortby(tmp, ascending=False)
    tmp = tmp[tmp>0.1]
    sppx_landed = tmp[tmp>np.quantile(tmp,0.90)].species
    
    # = = = = = Plot latent species landings
    plt_tmp = idata_landings_x.posterior['CountryTaxon_SPP_landings'].sum(('exporter')).sel(taxon=taxax,species=sppx_landed).rename('')
    ax = az.plot_forest(
        #tmp/tmp.median(('chain','draw')).sum(),
        plt_tmp,
        hdi_prob=0.9,
        #transform=np.log1p,
        #figsize=(10, 20)
   
       )
    
    ax[0].set(
        title='Top unreported '+taxax+' landings',
        xlabel="tonnes",
        ylabel="",
    );
    #ax[0].set_xlim(0,.5)
    for tick in ax[0].get_xticklabels():
        tick.set_rotation(45)
    plt.savefig(bf+'/Global/'+'Top_unreported_'+taxax+'_landings.jpg',dpi=300);

## Plot top species in trade

In [None]:
# Traded 
trade_spp = (
    idata_trade_x.copy().posterior['amount_exported']
    .rename('')
    .assign_coords({"Exporter": ("exporter", biggest_countries_long),"Importer": ("importer", biggest_countries_long)})
)
# Domestic consumption
domestic_spp = trade_spp*0

for c in country_:
    # Add domestic consumption to domestic xarray
    domestic_spp.loc[dict(exporter=c,importer=c)] = trade_spp.sel(exporter=c,importer=c)
    # Remove domestic consumption from trade xarray
    trade_spp.loc[dict(exporter=c,importer=c)] = 0

In [None]:
# = = = = = Set up plot
_, ax = plt.subplots(1,2,figsize=(10, 10))

# Plot top N species traded
N = 20
# = = = = = Plot latent species landings
spp_traded = trade_spp.sum(('exporter','importer')).rename('')
spp_traded = spp_traded.sortby(spp_traded.median(('chain','draw')), ascending=False)
sppx_trad = spp_traded.species[:N]

az.plot_forest(
        spp_traded.sel(species=sppx_trad),
        hdi_prob=0.9,
        ax=ax[0]
        )
ax[0].set(
    title="Top "+str(N)+" species in trade",
    xlabel="tonnes",
    ylabel="",
);

for tick in ax[0].get_xticklabels():
    tick.set_rotation(45)


# = = = = = Plot latent species landings
spp_dom = domestic_spp.sum(('exporter','importer')).rename('')
spp_dom = spp_dom.sortby(spp_dom.median(('chain','draw')), ascending=False)
sppx_dom = spp_dom.species[:N]

az.plot_forest(
        spp_dom.sel(species=sppx_dom),
        hdi_prob=0.9,
        ax=ax[1]
        )
ax[1].set(
    title="Top "+str(N)+" domestic spp",
    xlabel="tonnes",
    ylabel="",
);
for tick in ax[1].get_xticklabels():
    tick.set_rotation(45)

plt.savefig(bf+'/Global/'+'Top_marketed_spp.jpg',dpi=300);

In [None]:
# Calculate proportion of catch exported
dom = domestic_spp.sum(('species','importer'))
tot = domestic_spp.sum(('species','importer'))+trade_spp.sum(('species','importer'))
prop_ex = 1-(dom/tot)
prop_ex = prop_ex.sortby(prop_ex.median(('chain','draw')))

In [None]:
# Plot proportion of trade exported
ax = az.plot_forest(
        prop_ex,
        hdi_prob=0.9,
        combined=True,
        figsize=(8, 12)
        )
ax[0].set(
    title="Export/landings ratio",
    xlabel="Proportion exports",
    ylabel="",
);
for tick in ax[0].get_xticklabels():
    tick.set_rotation(45)
ax[0].set_yticklabels(prop_ex.Exporter.to_numpy()[::-1])

plt.savefig(bf+'/Global/'+'Proportion_exports.jpg',dpi=300);

# National plots

In [None]:
taxon_colours = {'Alopias':'#6baed6', 'Bathyraja':'#fdae6b', 'Carcharhinidae':'#3182bd', 'Dasyatidae':'#fdd0a2',
       'Elasmobranchii':'black', 'Etmopterus':'#636363', 'Isurus':'#9ecae1', 'Lamnidae':'#c6dbef', 'Mobulidae':'#756bb1',
       'Mustelus':'#5254a3', 'Myliobatidae':'#a1d99b', 'Potamotrygon':'#b5cf6b', 'Pristidae':'#de9ed6',
       'Pristiophorus':'#ce6dbd', 'Rajidae':'#e6550d', 'Rajiformes':'#fd8d3c', 'Rhinobatidae':'#8c6d31',
       'Scyliorhinidae':'#843c39', 'Scyliorhinus':'#ad494a', 'Selachimorpha':'#979797', 'Sphyrnidae':'#d9d9d9',
       'Squalidae':'#969696', 'Squatinidae':'#bdbdbd'}

In [None]:
# Plot estimated species landings
for counx in country_:
    couny = biggest_countries_long[list(biggest_countries).index(counx)]
    try:
        # Select only species estimated to have some landings
        tmp = idata_landings_x.posterior['Latent_landings'].mean(("chain", "draw")).sel(exporter=counx)
        tmp = tmp.sortby(tmp, ascending=False)
        sppxxx = tmp[tmp>50].species.to_numpy()
        
        ax = az.plot_forest(idata_landings_x.posterior['Latent_landings'].sel(exporter=counx, species=sppxxx).rename(''),
                            #transform=np.log1p,
                            #figsize=(12,10)
                            combined=True
                           )
        ax[0].set(
            title="Average latent species landings "+couny,
            xlabel="tonnes",
            ylabel="",
        )
        #ax[0].set_xlim(0,6000)
        plt.tight_layout()
        plt.savefig(bf+'/Exporters/'+couny+'_Latent_landings_.jpg',dpi=300)
        plt.close()
    except ValueError:
        pass
    except AttributeError:
        pass

    

In [None]:
# Plot estimated species landings
for counx in country_:
    couny = biggest_countries_long[list(biggest_countries).index(counx)]
    try:
        # Select only species estimated to have some landings
        tmp = idata_landings_x.posterior['Latent_landings'].median(("chain", "draw")).sel(exporter=counx)
        tmp = tmp.sortby(tmp, ascending=False)
        sppxxx = tmp[tmp>10].species.to_numpy()
        poo = list(sppxxx)
        for s in Obs_spp_data.loc[counx].index.values:
            if s not in sppxxx:
                poo+=[s]
        sppxxx = np.array(poo)
        ax = az.plot_forest([idata_landings_x.prior['CountrySPP_landings'].sel(exporter=counx, species=sppxxx),
                            idata_landings_x.posterior['CountrySPP_landings'].sel(exporter=counx, species=sppxxx)],
                            model_names=["Prior", "Posterior"],
                            transform=np.log1p,
                            figsize=(10, 10)
                           )
        
        
        # Add observed average landings reported to species
        rowlist = list(idata_landings_x.posterior['CountrySPP_landings'].sel(exporter=counx, species=sppxxx).mean(("chain", "draw")).species.to_numpy())
        xtmp = Obs_spp_data.loc[counx].reset_index()
        ytmp = ax[0].yaxis.get_majorticklocs()[::-1][np.array([rowlist.index(x) if x in rowlist else -1 for x in xtmp.species])]
        plt.scatter(np.log(xtmp.Reported_landings),ytmp,c='black',zorder=10)
        
        ax[0].set(
            title="Average reported species landings "+couny,
            xlabel="log(tonnes)",
            ylabel="",
        )
        plt.savefig(bf+'/Exporters/'+couny+'_Reported_species_landings.jpg',dpi=300)
        plt.close()
    except ValueError:
        pass
    except KeyError:
        pass

In [None]:
# Select only taxa estimated to have some landings
for counx in country_:
    couny = biggest_countries_long[list(biggest_countries).index(counx)]
    try:
        taxxx = taxon_shortlist[idata_landings_x.posterior['CountryTaxon_log_landings'].median(("chain", "draw")).sel(exporter=counx)>2]
        
        ax = az.plot_forest([idata_landings_x.prior['CountryTaxon_log_landings'].sel(exporter=counx, taxon=taxxx).rename(''),
                            idata_landings_x.posterior['CountryTaxon_log_landings'].sel(exporter=counx, taxon=taxxx).rename('')],
                            model_names=["Prior", "Posterior"]
                           )
        
        # Add observed average landings reported to species
        xtmp = np.log(Obs_tax_data.loc[counx].Reported_landings.to_numpy())
        obslist = Obs_tax_data.loc[counx].index
        rowlist = list(idata_landings_x.posterior['CountryTaxon_log_landings'].mean(("chain", "draw")).sel(exporter=counx, taxon=taxxx).taxon.to_numpy())
        setlist = list(set(rowlist) & set(obslist))
        xindx = np.array([list(obslist).index(x) for x in setlist])
        ytmp = ax[0].yaxis.get_majorticklocs()[::-1][np.array([rowlist.index(x) for x in setlist])]
        plt.scatter(xtmp[xindx],ytmp,c='black',zorder=10)
        
        ax[0].set(
            title="Average reported taxon landings "+biggest_countries_long[list(biggest_countries).index(counx)],
            xlabel="log(tonnes)",
            ylabel="",
        )
        plt.savefig(bf+'/Exporters/'+couny+'_Reported_taxon_landings.jpg',dpi=300)
        plt.close()
    except ValueError:
        #pdb.set_trace()
        pass
    except KeyError:
        #pdb.set_trace()
        pass

In [None]:
TaxonMASK_Sx[list(country_).index('IND')].sum(0)[list(taxon_shortlist).index('Elasmobranchii')]

In [None]:
Obs_tax_data.loc['AGO']

In [None]:
counx = 'AGO'
couny = biggest_countries_long[list(biggest_countries).index(counx)]

taxxx = taxon_shortlist[idata_landings_x.posterior['CountryTaxon_log_landings'].median(("chain", "draw")).sel(exporter=counx)>2]

ax = az.plot_forest([idata_landings_x.prior['CountryTaxon_log_landings'].sel(exporter=counx, taxon=taxxx).rename(''),
                    idata_landings_x.posterior['CountryTaxon_log_landings'].sel(exporter=counx, taxon=taxxx).rename('')],
                    model_names=["Prior", "Posterior"]
                   )

# Add observed average landings reported to species
xtmp = np.log(Obs_tax_data.loc[counx].Reported_landings.to_numpy())
obslist = Obs_tax_data.loc[counx].index
rowlist = list(idata_landings_x.posterior['CountryTaxon_log_landings'].mean(("chain", "draw")).sel(exporter=counx, taxon=taxxx).taxon.to_numpy())
setlist = list(set(rowlist) & set(obslist))
xindx = np.array([list(obslist).index(x) for x in setlist])
ytmp = ax[0].yaxis.get_majorticklocs()[::-1][np.array([rowlist.index(x) for x in setlist])]
plt.scatter(xtmp[xindx],ytmp,c='black',zorder=10)

ax[0].set(
    title="Average reported taxon landings "+biggest_countries_long[list(biggest_countries).index(counx)],
    xlabel="log(tonnes)",
    ylabel="",
)


In [None]:
taxxx

In [None]:
# Plot taxon decomposition bar charts
for counx in country_:
    try:
        # Posterior taxon landings by species
        cut_tmp = 50
        ps_len = 10
        
        while ps_len>9:
            # Increase cutoff by 15% to reduce number of species displayed
            cut_tmp = round(cut_tmp*1.15)
            PostTaxonSpp = (
                    idata_landings_x.posterior["CountryTaxon_SPP_landings"]
                    .mean(("chain", "draw"))
                    .assign_coords({"Exporter": ("exporter", biggest_countries_long)})
                    .sel(exporter=counx)
                    .drop_vars(['exporter'])
                ).to_dataframe()
            PostTaxonSpp = PostTaxonSpp[PostTaxonSpp.CountryTaxon_SPP_landings>cut_tmp]
            
            # Posterior species landings
            PostSpp = (
                    idata_landings_x.posterior["CountrySPP_landings"]
                    .mean(("chain", "draw"))
                    .assign_coords({"Exporter": ("exporter", biggest_countries_long)})
                    .sel(exporter=counx)
                    .drop_vars(['exporter'])-1
                ).to_dataframe()
            PostSpp = PostSpp[PostSpp.CountrySPP_landings>cut_tmp]
            ps_len = len(PostSpp)
        
        # One liner to create a stacked bar chart.
        ax = sns.histplot(PostTaxonSpp, x='species', hue='taxon', weights='CountryTaxon_SPP_landings',
                     multiple='stack', palette=taxon_colours, shrink=0.8)
        ax.set_ylabel('Landings (t)')
        ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
        ax.set_title(biggest_countries_long[list(country_).index(counx)]+' unreported landings \n by aggregated taxa')
        # Fix the legend so it's not on top of the bars.
        legend = ax.get_legend()
        legend.set_bbox_to_anchor((1, 1))
        plt.savefig(bf+'/Exporters/'+biggest_countries_long[list(country_).index(counx)]+'_unreported_by_taxa.jpg',dpi=300)
        plt.close()

    except TypeError:
        plt.close()
        pass
    except AttributeError:
        plt.close()
        pass

# Species plots

In [None]:
# Latent trade by species
for sppx in species_:
    try:
        post_latent_trade = (
            idata_trade_x.posterior["amount_exported"]
            .mean(("chain", "draw"))
            .assign_coords({"Exporter": ("exporter", biggest_countries_long),
                            "Importer": ("importer", biggest_countries_long)})
            .drop_vars(['exporter','importer'])
            .sel(species=sppx)
        )

        ds_masked = post_latent_trade.where(post_latent_trade!=0)

        _, ax = plt.subplots(figsize=(23, 18))
        sns.heatmap(
            ds_masked.to_dataframe().drop(["species"], axis="columns")
            .reset_index().drop(['exporter','importer'],axis=1)
            .set_index(['Exporter','Importer']).unstack()
            .droplevel(0, axis=1)
            .T,
            ax=ax,
            cmap=sns.color_palette("Blues", as_cmap=True),
        )
        ax.hlines(range(ncou),xmin=0,xmax=ncou,color='grey',alpha=0.1)
        ax.vlines(range(ncou),ymin=0,ymax=ncou,color='grey',alpha=0.1)

        ax.set_title("Average latent trade\n"+sppx, fontsize=30)
        ax.set_xlabel("Exporter", fontsize=30)
        ax.set_ylabel("Importer", fontsize=30)
        plt.savefig(bf+'/Species/'+sppx+'_latent_trade.jpg',dpi=300)
        plt.close()
    except:
        plt.close()
        pass
