# Geo

1. Plot global sample distribution (colored by study).

## Import Modules

In [None]:
import requests
import json
from io import StringIO
import os
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import zipfile

# Install these!
import descartes
import geopandas
import contextily as ctx
import fiona
import shapely

In [None]:
tree_df_path = "../../docs/results/latest/mugration/mugration.tsv"
outdir = "../../docs/results/latest/geo"

# Create output directory if it doesn't exist
if not os.path.exists(outdir):
    os.mkdir(outdir)

### Variables

In [None]:
NO_DATA_CHAR = "NA"
GEO_ATTR = "Province"
CRS="epsg:4326"
WEB_MERCATOR_CRS = "epsg:3857"

### Geo Config

In [None]:
# Order: bottom-left, bottom-right, top-right, top-left (sw, se, ne, nw)

region_poly = {
    "Caucasus" : {"wsen": [35.0000, 30.0009, 60.0000, 50.0000]},
    #Caucasus" : {"wsen": [40.058841, 40.202162, -75.042164, -74.924594]},
    "Europe" : {"wsen" : [-10, 30, 35, 70]},
    "Asia" : {"wsen" : [60, 0, 140, 70]},
}

df_columns = ["Region","Lon", "Lat"]
region_df = pd.DataFrame(columns=df_columns)

for region in region_poly:
    wsen = region_poly[region]["wsen"]
    
    # add to dataframe
    df = pd.DataFrame([
         [region, wsen[0], wsen[1]],
         [region, wsen[2], wsen[1]],
         [region, wsen[2], wsen[3]],
         [region, wsen[0], wsen[3]],
         ],
         columns=df_columns
    )
    region_df = region_df.append(df, ignore_index = True)
    
    # Polygon should be: sw, se, ne, nw
    region_poly[region]["poly"] = shapely.geometry.Polygon([
        (wsen[0], wsen[1]), 
        (wsen[2], wsen[1]), 
        (wsen[2], wsen[3]), 
        (wsen[0], wsen[3]),
    ])
    region_poly[region]["geoseries"] = geopandas.GeoSeries(region_poly[region]["poly"])
    region_poly[region]["geoseries"].crs = CRS
    region_poly[region]["xlim"] = (wsen[0], wsen[2])
    region_poly[region]["ylim"] = (wsen[1], wsen[3])

region_gdf = geopandas.GeoDataFrame(
    region_df, 
    geometry=geopandas.points_from_xy(region_df.Lon, region_df.Lat))
region_gdf.set_crs(CRS, inplace=True)

In [None]:
region_gdf.to_crs(epsg=3857)

### Global Plot Config

In [None]:
dpi=400

DEF_COL_LIST = ["#1f77b4", "#ff7f0e", "#2ca02c"]

# Font
SM_FONT = 5
MED_FONT = 8
LG_FONT = 10

plt.rc('font', size=SM_FONT)          # controls default text sizes
plt.rc('figure', titlesize=LG_FONT)  # fontsize of the figure title
#plt.rc('axes', labelsize=MED_FONT)    # fontsize of the x and y labels

CMAP="tab10"

cmap = matplotlib.cm.get_cmap(CMAP)
cmaplist = [cmap(i) for i in range(cmap.N)]
print(cmaplist)

### Contextily Basemaps

In [None]:
# Change the geoseries crs
region_gdf.to_crs(epsg=3857, inplace=True)

In [None]:
w = 3896182.178
s = 3503665.531
n = 6679169.448
e = 6446275.841

print(region_df[region_df["Region"] == "Caucasus"].geometry)

ctx_img, ctx_ext = ctx.bounds2img(
     w,
     s,
     e,
     n,
     ll=True,
     source=ctx.providers.Stamen.Toner,
     zoom=1,
     #crs=CRS,
    )
print(ctx_img)
plt.imshow(ctx_img, extent=ctx_ext)
# Reset the geoseries crs

In [None]:
"""w, s, e, n = (-3.0816650390625,
              53.268087670237485,
             -2.7582550048828125,
              53.486002749115556)"""

w, s, e, n = (35,
              30,
             60,
              50)
ctx_img, ctx_ext = ctx.bounds2img(w,
                                     s,
                                     e,
                                     n,
                                     ll=True,
                                     source=ctx.providers.Stamen.Toner,
                                     zoom=1,
                                     #crs=CRS,
                                    )

plt.imshow(ctx_img, extent=ctx_ext)
"""out_path = os.path.join(outdir, "liverpool.tif")  
_ = ctx.bounds2raster(w, s, e, n,
                     ll=True,
                     path=out_path,
                     source=ctx.providers.CartoDB.Positron
                    )"""

---
## Import Dataframe

Read in as pandas dataframe

In [None]:
tree_df = pd.read_csv(tree_df_path, sep='\t')
# Fix the problem with multiple forms of NA in the table
# Consolidate missing data to the NO_DATA_CHAR
tree_df.fillna(NO_DATA_CHAR, inplace=True)
tree_df.set_index("Name", inplace=True)

Convert to geopandas dataframe

In [None]:
# Remove NO DATA Cells
# RETHINK THIS WHEN INTERNAL NODES HAVE DATA!!
df = tree_df[tree_df[GEO_ATTR + "Lon"] != NO_DATA_CHAR]
geometry = [shapely.geometry.Point(xy) for xy in zip(df[GEO_ATTR + "Lon"], df[GEO_ATTR + "Lat"])]
df = df.drop([GEO_ATTR + "Lon", GEO_ATTR + "Lat"], axis=1)
#WGS84 = "epsg:4326"
gdf_points = geopandas.GeoDataFrame(df, crs=CRS, geometry=geometry)

## Import reference datasets

Import the world dataset from geopandas

In [None]:
gdf_points_count = {}

for geo in gdf_points[GEO_ATTR]:
    if geo not in gdf_points_count:
        gdf_points_count[geo] = 0
    gdf_points_count[geo] += 1

gdf_points_size = []

for geo in gdf_points[GEO_ATTR]:
    gdf_points_size.append(gdf_points_count[geo])

## Global

In [None]:
world_polygons = geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))

fig, ax1 = plt.subplots(1, dpi=dpi)

# Basemap
world_polygons.plot(ax=ax1, 
                    zorder=1, 
                    alpha=0.75, 
                    color=DEF_COL_LIST[0],
                    edgecolor="white", 
                    linewidth=0.25)

# Points
gdf_points.plot(ax=ax1, 
                zorder=2, 
                markersize=gdf_points_size, 
                color=DEF_COL_LIST[1],
                edgecolor="black",
                linewidths=0.5,
                )

# Polygons
for region in region_poly:
    region_poly[region]["geoseries"].plot(color="none", edgecolor="black", ax=ax1)
#caucasus_geoseries.plot(color="none", edgecolor="black", ax=ax1)
#europe_geoseries.plot(color="none", edgecolor="black", ax=ax1)
#asia_geoseries.plot(color="none", edgecolor="black", ax=ax1)

ax1.set_ylim(-63,)
#ax1.set_xticklabels([])
#ax1.set_yticklabels([])
ax1.set_title("Global sampling distribution")

#plt.axis('off')

# Save
out_path = os.path.join(outdir, "world_sample_distribution.jpg") 
plt.savefig(out_path, 
            dpi=dpi, 
            bbox_inches = "tight")

## Caucasus

In [None]:
region_xlim = (35, 60)
region_ylim = [30, 50]

"""fig, ax1 = plt.subplots(1, 
                        dpi=dpi
                       )"""

ax1 = gdf_points.plot( 
                markersize=gdf_points_size, 
                color=DEF_COL_LIST[1],
                edgecolor="black",
                linewidths=0.5,
                )

ctx.add_basemap(
    ax1, 
    crs=gdf_points.crs.to_string(),
    #source=ctx.providers.Stamen.TonerLite,
    source=ctx.providers.CartoDB.Positron,
    zoom=3,
)

ax1.set_xlim(region_poly["Caucasus"]["xlim"])
ax1.set_ylim(region_poly["Caucasus"]["ylim"])
ax1.set_title("Caucasus sampling distribution")
plt.axis('off')

In [None]:
region_xlim = (35, 60)
region_ylim = [30, 50]

fig, ax1 = plt.subplots(1, 
                        dpi=dpi
                       )

# Size the points
gdf_points_count = {}
gdf_points[GEO_ATTR]

world_polygons.plot(ax=ax1, 
                    zorder=1, 
                    alpha=0.75, 
                    color=DEF_COL_LIST[0],
                    edgecolor="white", 
                    linewidth=0.25)
gdf_points.plot(ax=ax1, 
                zorder=2, 
                markersize=gdf_points_size, 
                color=DEF_COL_LIST[1],
                edgecolor="black",
                linewidths=0.5,
                )

ax1.set_xlim(region_xlim)
ax1.set_ylim(region_ylim)
ax1.set_title("Caucasus sampling distribution")
#plt.axis('off')


In [None]:
"""db = gdf_points[gdf_points["Country"] == "Armenia"]
ax = gdf_points.plot(color="red", figsize=(9, 9))
ctx.add_basemap(
    ax, 
    crs=db.crs.to_string(),
    source=ctx.providers.Stamen.TerrainBackground,
)"""

### Europe

---
## OLD

### GADM Country Datasets

In [None]:
country_codes = ["ARM", "AZE", "GEO"]

country_codes

fig, ax1 = plt.subplots(1, dpi=dpi)

for code in country_codes:
    zip_file_url = "https://biogeo.ucdavis.edu/data/gadm3.6/gpkg/gadm36_{}_gpkg.zip".format(code)    
    zip_filename = os.path.basename(zip_file_url)
    gpkg_filename = zip_filename.replace("_gpkg.zip", ".gpkg")
    gpkg_path = os.path.join(outdir, gpkg_filename)
    target_layer = os.path.splitext(gpkg_filename)[0] + "_1"
    chunk_size = 128
    result = requests.get(zip_file_url, stream=True)
    out_path = os.path.join(outdir, os.path.basename(zip_file_url)) 

    with open(out_path, 'wb') as fd:
        for chunk in result.iter_content(chunk_size=chunk_size):
            fd.write(chunk)


    with zipfile.ZipFile(out_path, 'r') as zipObj:
        zipObj.extract(gpkg_filename, path=outdir)

    for layername in fiona.listlayers(gpkg_path):
        if layername ==  target_layer:
            geopkg = geopandas.read_file(gpkg_path, layer=layername)
            geopkg.plot(ax=ax1, 
                        zorder=1, 
                        alpha=0.50, 
                        color=cmaplist[country_codes.index(code)],
                        edgecolor="black", 
                        linewidth=0.25,
                        )


In [None]:
rus_gdf = geopandas.read_file(os.path.join(outdir, "gadm36_RUS.gpkg"))
for name in rus_gdf.NAME_2: print(name)

In [None]:
print(country_codes.index('ARM'))

### Plot with the world basemap

In [None]:
world = geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))
cities = geopandas.read_file(geopandas.datasets.get_path('naturalearth_cities'))

fig, ax = plt.subplots()
#world.plot(ax=ax, color='white', edgecolor='black')
#cities.plot(ax=ax, marker='o', color='red', markersize=5)
cities.plot(color='k', ax=ax, zorder=2)
world.plot(ax=ax, zorder=1);

ax.set_xlim(40, 55)
ax.set_ylim(30, 50)

## TEST

# BACKUP

### How much phylogeographic signal is present?

In [None]:
"""# Store terminals in two lists for comparison
term_list = [t for t in tree_div.get_terminals()]
term_list.reverse()

dist_div_list = []
dist_geo_list = []

i = 0
progress_log_breaks = [num for num in range(0,101)]
i_progress = 0

for t1 in term_list:
    
    t1_name = t1.name
    t1_lat = tree_df[ATTRIBUTE_LAT][t1.name]
    t1_lon = tree_df[ATTRIBUTE_LON][t1.name]
    t1_latlon = (t1_lat, t1_lon)
    
    # Skip if natlon is nan
    if t1_lat == "NA" or t1_lon == "NA":
        continue
        
    # Iterate through terminals again for distance measures
    for t2 in term_list:           
        t2_name = t2.name     
        # Skip if it's a self comparison
        if t1_name == t2_name: continue
            
        t2_lat = tree_df[ATTRIBUTE_LAT][t2_name]
        t2_lon = tree_df[ATTRIBUTE_LON][t2_name]
        
        # Skip if latlon is nan
        if t2_lat == "NA" or t2_lon == "NA":
            continue
        t2_latlon = (t2_lat, t2_lon)
       
        # Calculate branch distance (divergence)
        dist_div = tree_div.distance(t1_name, t2_name)  
        # Calculate geographic distance (great circle)
        dist_geo = getattr(great_circle(t1_latlon, t2_latlon), DIST_UNIT)
        dist_div_list.append(dist_div)
        dist_geo_list.append(dist_geo)
    
    # Counter for a progress log since this is slow
    i+=1
    progress = (i / len(term_list)) * 100
    if progress >= progress_log_breaks[i_progress]:
        print("{:0.0f}%".format(progress))
        i_progress += 1
    #if i ==5:
    #    break"""

In [None]:
"""
# Create a joint plot

g = sns.JointGrid()
g.fig.set_dpi(dpi)

# Main Scatter
#sns.regplot(
sns.scatterplot(
                x=dist_div_list,
                y=dist_geo_list,  
                s=1,
                #scatter_kws = {"s": 0.5}, 
                ax=g.ax_joint,
                #line_kws={"color": "red"}
               )

sns.kdeplot(x=dist_div_list, 
            linewidth=1,
            fill=True,
            ax=g.ax_marg_x,
            #bw_method=0.03,
           )

sns.kdeplot(y=dist_geo_list, 
            linewidth=1,
            fill=True,
            ax=g.ax_marg_y,
            bw_method=0.03,
           )

# Constrain view and limits
g.ax_marg_y.set_ylim(0-500,int(max(dist_geo_list)) + 500)

# Labels and Titles
g.ax_joint.set_xlabel("Genetic Distance (Branch Length)")
g.ax_joint.set_ylabel("Geographic Distance (km)")
g.fig.suptitle("Phylogeographic Signal as Isolation by Distance",
                x=0.5,
                y=1.05)

# Caption
caption_text = ("Pairwise genetic distance (branch length) compared with pariwise geographic distance (greater circle)." 
                + "\nGeographic Resolution: {}".format(ATTRIBUTE)
               )
g.fig.text(0.4, -0.05, caption_text, ha='center')

# Save
out_path = os.path.join(outdir, "treetime_{}_IBD.jpg".format(ATTRIBUTE.lower())) 
plt.savefig(out_path, 
            dpi=dpi, 
            bbox_inches = "tight")"""