In [None]:
# Set up environment

#malariagen
import malariagen_data
# base
import numpy as np
import pandas as pd
import pkg_resources
from sklearn.impute import SimpleImputer
from pandas_plink import read_plink

# viz
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

# feems
from feems.utils import prepare_graph_inputs
from feems import SpatialGraph, Viz

# change matplotlib fonts
plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.sans-serif"] = "Arial"


sample_sets = ['1229-VO-GH-DADZIE-VMF00095',
'1230-VO-GA-CF-AYALA-VMF00045',
'1232-VO-KE-OCHOMO-VMF00044',
'1235-VO-MZ-PAAIJMANS-VMF00094',
'1236-VO-TZ-OKUMU-VMF00090',
'1236-VO-TZ-OKUMU-VMF00261',
'1236-VO-TZ-OKUMU-VMF00248',
'1236-VO-TZ-OKUMU-VMF00252',
'1236-VO-TZ-OKUMU-OKFR-TZ-2008',
'1240-VO-CD-KOEKEMOER-VMF00099',
'1240-VO-MZ-KOEKEMOER-VMF00101',
'small-2020-af',
'AG1000G-KE', 
'AG1000G-MW', 
'AG1000G-TZ']


In [None]:
# Load malariagen_data API and metadata

af1 = malariagen_data.Af1(pre=True)
results_dir = '/Users/dennistpw/Projects/funestus_tz/feems/'
df_samples = af1.sample_metadata(sample_query='longitude>23 & taxon == "funestus"', sample_sets=sample_sets)
df_samples.to_csv('/Users/dennistpw/Projects/funestus_tz/feems/df_samples.csv')

#save snps in plink format
af1.biallelic_snps_to_plink(region='3RL:6000000-9000000', n_snps=100_000, sample_query='taxon == "funestus" & longitude > 23', sample_sets=sample_sets,output_dir='.',   min_minor_ac=2, max_missing_an=1)å

In [None]:
#read data
(bim, fam, G) = read_plink("/Users/dennistpw/Projects/funestus_tz/feems/3RL-6000000-9000000.100000.2.1.0")
imp = SimpleImputer(missing_values=np.nan, strategy="mean")
genotypes = imp.fit_transform((np.array(G)).T)
print("n_samples={}, n_snps={}".format(genotypes.shape[0], genotypes.shape[1]))

In [None]:
coords = df_samples[['longitude', 'latitude']].to_numpy()
outer = pd.read_csv('/Users/dennistpw/Projects/funestus_tz/feems/riftregionoutline.csv').to_numpy()  # outer coordinates
grid_path = "/Users/dennistpw/Projects/funestus_tz/feems/riftregionoutline.TRI.75K.shp"  # path to discrete global grid
# graph input files
outer, edges, grid, _ = prepare_graph_inputs(coord=coords, 
                                             ggrid=grid_path,
                                             translated=False, 
                                             buffer=0)
                                             
                                             #uter=outer)

In [None]:
sp_graph = SpatialGraph(genotypes, coords, grid, edges, scale_snps=True)

In [None]:
from matplotlib.axes import Axes
from cartopy.mpl.geoaxes import GeoAxes

GeoAxes._pcolormesh_patched = Axes.pcolormesh

projection = ccrs.EquidistantConic(central_longitude=30.10921958480732, central_latitude=15.874172562000085)
fig = plt.figure(dpi=300)
ax = fig.add_subplot(1, 1, 1, projection=projection)  
v = Viz(ax, sp_graph, projection=projection, edge_width=.5, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=10, 
        obs_node_size=7.5, sample_pt_color="black", 
        cbar_font_size=10)
ax._autoscaleXon = False
ax._autoscaleYon = False

v.draw_map()
v.draw_samples()
v.draw_edges(use_weights=False)
v.draw_obs_nodes(use_ids=False)

In [None]:
sp_graph.fit(lamb=20.0)

In [None]:
from feems import SpatialGraph, Viz
#sp_graph.fit(lamb=20.0)å
fig = plt.figure(dpi=300)
ax = fig.add_subplot(1, 1, 1, projection=projection) 
ax._autoscaleXon = False
ax._autoscaleYon = False
v = Viz(ax, sp_graph, projection=projection, edge_width=.5, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=20, 
        obs_node_size=7.5, sample_pt_color="black", 
        cbar_font_size=10)
v.draw_map()
v.draw_samples()
v.draw_edges(use_weights=True)
v.draw_obs_nodes(use_ids=False) 
#v.draw_edge_colorbar()

In [None]:
from feems.cross_validation import run_cv

# define grids

# reverse the order of lambdas and alphas for warmstart
lamb_grid = np.geomspace(1e-6, 1e2, 20)[::-1]

# run cross-validation
cv_err = run_cv(sp_graph, lamb_grid, n_folds=sp_graph.n_observed_nodes, factr=1e10)

# average over folds
mean_cv_err = np.mean(cv_err, axis=0)

# argmin of cv error
lamb_cv = float(lamb_grid[np.argmin(mean_cv_err)])


In [None]:
fig, ax = plt.subplots(dpi=300)
ax.plot(np.log10(lamb_grid), mean_cv_err, ".");
ax.set_xlabel("log10(lambda)");
ax.set_ylabel("L2 CV Error");
ax.axvline(np.log10(lamb_cv), color = "orange")

In [None]:

from matplotlib import gridspec

# figure params
projection = ccrs.EquidistantConic(central_longitude=-108.842926, central_latitude=66.037547)
title_loc = "left"
title_pad = "-10"
title_fontsize = 12
edge_width = .2
edge_alpha = 1
edge_zorder = 3
obs_node_size = 3
obs_node_linewidth = .4
cbar_font_size = 8
cbar_ticklabelsize = 8
cbar_orientation = "horizontal"

# figure setup
fig = plt.figure(dpi=300)
spec = gridspec.GridSpec(ncols=2, nrows=2, figure=fig, wspace=0.0, hspace=0.0)

# axis 00 
ax_00 = fig.add_subplot(spec[0, 0], projection=projection)
ax_00._autoscaleXon = False
ax_00._autoscaleYon = False
ax_00.set_title("A", loc=title_loc, pad=title_pad, fontdict={"fontsize": title_fontsize})
sp_graph.fit(float(lamb_grid[0]))
v = Viz(ax_00, sp_graph, projection=projection, edge_width=edge_width, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=20, 
        obs_node_size=obs_node_size, sample_pt_color="black", 
        cbar_font_size=10)
v.draw_map()
v.draw_edges(use_weights=True)
v.draw_obs_nodes(use_ids=False) 
ax_00.text(.2, .85, "lambda={:.5f}\ncv l2 error={:.5f}".format(lamb_grid[0], mean_cv_err[0, 0]), 
           fontdict={"fontsize": 4}, transform = ax_00.transAxes)

# axis 10
ax_10 = fig.add_subplot(spec[1, 0], projection=projection)
ax_10._autoscaleXon = False
ax_10._autoscaleYon = False
ax_10.set_title("B", loc=title_loc, pad=title_pad, fontdict={"fontsize": title_fontsize})
sp_graph.fit(float(lamb_grid[3]))
v = Viz(ax_10, sp_graph, projection=projection, edge_width=edge_width, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=20,
        obs_node_size=obs_node_size, sample_pt_color="black", 
        cbar_font_size=10)
v.draw_map()
v.draw_edges(use_weights=True)
v.draw_obs_nodes(use_ids=False) 
ax_10.text(.2, .85, "lambda={:.5f}\ncv l2 error={:.5f}".format(lamb_grid[3], mean_cv_err[3, 0]), 
           fontdict={"fontsize": 4}, transform = ax_10.transAxes)

# axis 01
ax_01 = fig.add_subplot(spec[0, 1], projection=projection)
ax_01._autoscaleXon = False
ax_01._autoscaleYon = False
ax_01.set_title("C", loc=title_loc, pad=title_pad, fontdict={"fontsize": title_fontsize})
sp_graph.fit(float(lamb_cv))
v = Viz(ax_01, sp_graph, projection=projection, edge_width=edge_width, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=20, 
        obs_node_size=obs_node_size, sample_pt_color="black", 
        cbar_font_size=10)
v.draw_map()
v.draw_edges(use_weights=True)
v.draw_obs_nodes(use_ids=False) 
ax_01.text(.2, .85, "lambda={:.5f}\ncv l2 error={:.5f}".format(lamb_cv, mean_cv_err[np.argmin(mean_cv_err), 0]), 
           fontdict={"fontsize": 4}, transform = ax_01.transAxes)

# axis 11
ax_11 = fig.add_subplot(spec[1, 1], projection=projection)
ax_11._autoscaleXon = False
ax_11._autoscaleYon = False
ax_11.set_title("D", loc=title_loc, pad=title_pad, fontdict={"fontsize": title_fontsize})
sp_graph.fit(float(lamb_grid[10]))
v = Viz(ax_11, sp_graph, projection=projection, edge_width=edge_width, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=20, 
        obs_node_size=obs_node_size, sample_pt_color="black", 
        cbar_font_size=10)
v.draw_map()
v.draw_edges(use_weights=True)
v.draw_obs_nodes(use_ids=False)
v.cbar_font_size = cbar_font_size
v.cbar_orientation = cbar_orientation
v.cbar_ticklabelsize = cbar_ticklabelsize
v.draw_edge_colorbar()
ax_11.text(.2, .85, "lambda={:.5f}\ncv l2 error={:.5f}".format(lamb_grid[10], mean_cv_err[10, 0]), 
           fontdict={"fontsize": 4}, transform = ax_11.transAxes)

In [None]:
# re-fit
sp_graph.fit(lamb_cv)
projection = ccrs.EquidistantConic(central_longitude=30.10921958480732, central_latitude=15.874172562000085)

# plot
fig = plt.figure(dpi=300)
ax = fig.add_subplot(1, 1, 1, projection=projection)  
ax._autoscaleXon = False
ax._autoscaleYon = False
v = Viz(ax, sp_graph, projection=projection, edge_width=.5, 
        edge_alpha=1, edge_zorder=100, sample_pt_size=20, 
        obs_node_size=7.5, sample_pt_color="black", 
        cbar_font_size=10)
v.draw_map()
v.draw_edges(use_weights=True)
v.draw_obs_nodes(use_ids=False) 
v.draw_edge_colorbar()

In [None]:
# write the deme coordinates + sample size (node attributes) out into a csv file
np.savetxt('rift_incksanodepos.csv', np.vstack((sp_graph.node_pos.T, [sp_graph.nodes[n]['n_samples'] for n in range(len(sp_graph.nodes))])).T, delimiter=',') 
# write the relevant edge weights out into a csv file
np.savetxt('rift_incksaedgew.csv', np.vstack((np.array(sp_graph.edges).T, sp_graph.w)).T, delimiter=',')