# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

In [None]:
def metagenotype_entropy_error(
    world, metagenotype=None, discretized=False, fuzz_eps=1e-5, montecarlo_draws=1
):
    if metagenotype is None:
        metagenotype = world
    metagenotype = (
        metagenotype.metagenotype
    )  # In case metagenotype is a full World object.
    if discretized:
        g = world.genotype.discretized().fuzzed(fuzz_eps).data
    else:
        g = world.genotype.data
    p = world.community.data @ g
    m = metagenotype.total_counts().astype(int)
    mu = m.mean("position")

    obs_entropy = metagenotype.entropy()
    err_accum = 0
    for i in range(montecarlo_draws):
        sim = sp.stats.binom(m, p).rvs()
        sim_mgtp = sf.data.Metagenotype.from_counts_and_totals(
            sim,
            m,
            coords=dict(sample=metagenotype.sample, position=metagenotype.position),
        )
        sim_sample_entropy = sim_mgtp.entropy()
        err = obs_entropy - sim_sample_entropy
        err_accum += err

    err = err_accum / montecarlo_draws
    
    return ((err * mu).sum() / mu.sum()).values, err.to_series()

# Data

In [None]:
world = sf.data.World.load('data/group/een/species/sp-102506/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc')
print(world.sizes)

np.random.seed(0)

position_ss = world.random_sample(position=min(1000, world.sizes['position'])).position
sample_linkage = world.metagenotype.linkage()
mgtp_error = sf.evaluation.metagenotype_error2(world, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(world, discretized=False, p=1, montecarlo_draws=10)[1]
comm_entrp = world.community.entropy().to_series()

In [None]:
w = world.sel(position=position_ss).drop_low_abundance_strains(0.05)

w_genotype_linkage = w.genotype.linkage("strain")
w_position_linkage = w.genotype.linkage("position")

sf.plot.plot_metagenotype(
    w,
    col_linkage_func=lambda w: sample_linkage,
    row_linkage_func=lambda w: w_position_linkage,
    col_colors_func=lambda w: xr.Dataset(dict(entrp=entrp_error, mgtp=mgtp_error, cetrp=comm_entrp)),
)
sf.plot.plot_community(
    w,
    col_linkage_func=lambda w: sample_linkage,
    row_linkage_func=lambda w: w_genotype_linkage,
    col_colors_func=lambda w: xr.Dataset(dict(entrp=entrp_error, mgtp=mgtp_error, cetrp=comm_entrp)),
)

In [None]:
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

In [None]:
sf.plot.plot_metagenotype(
    w,
    col_linkage_func=lambda w: sample_linkage,
    row_linkage_func=lambda w: w_position_linkage,
    col_colors_func=lambda w: xr.Dataset(
        dict(entrp=high_entrp_error, mgtp=high_mgtp_error, centrp=high_comm_entrp)
    ),
)
sf.plot.plot_community(
    w,
    col_linkage_func=lambda w: sample_linkage,
    row_linkage_func=lambda w: w_genotype_linkage,
    col_colors_func=lambda w: xr.Dataset(
        dict(entrp=high_entrp_error, mgtp=high_mgtp_error, centrp=high_comm_entrp)
    ),
)

In [None]:
world_filt = sf.data.World.load('data/group/een/species/sp-101338/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.clean-m10-e20-c15.world.nc')
print(world_filt.sizes)

sample_filt_linkage = world_filt.metagenotype.linkage()

In [None]:
# mgtp_error2 = sf.evaluation.metagenotype_error2(world, discretized=False)[1]
entrp_error2 = sf.evaluation.metagenotype_entropy_error(world, discretized=False, p=3, montecarlo_draws=10)[1]

In [None]:
w = world_filt.sel(position=position_ss).drop_low_abundance_strains(0.05)

w_genotype_linkage = w.genotype.linkage("strain")
w_position_linkage = w.genotype.linkage('position')

sf.plot.plot_metagenotype(w, col_linkage_func=lambda w: sample_filt_linkage, row_linkage_func=lambda w: w_position_linkage, col_colors_func=lambda w: xr.Dataset(dict(entrp=entrp_error2, mgtp=mgtp_error)))
sf.plot.plot_community(w, col_linkage_func=lambda w: sample_filt_linkage, row_linkage_func=lambda w: w_genotype_linkage, col_colors_func=lambda w: xr.Dataset(dict(entrp=entrp_error2, mgtp=mgtp_error)))