# TO DO
- common interface for each type of model, same trainer and streamline of result output/analysis

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import mcspace.visualization as vis
from mcspace.dataset import PerturbationDataSet
from mcspace.utils import pickle_load
from pathlib import Path
from mcspace.models import PerturbationModel
import torch
from mcspace.data_utils import get_normed_data_garb_clusters
from mcspace.trainer import train
# TODO: add model and vis of model reults; as well as post processing

# Load data

In [2]:
ls "../data/FMT_MAPSEQ_data/processed_data"

 Volume in drive C is OS
 Volume Serial Number is 7CF0-0838

 Directory of C:\Users\guppa\Dropbox (Partners HealthCare)\research_bwh\mapseq_topic_model_JULY_2022\MCSPACE_model\mcspace\data\FMT_MAPSEQ_data\processed_data

08/05/2023  06:46 PM    <DIR>          .
08/05/2023  02:55 PM    <DIR>          ..
08/05/2023  02:49 PM         7,652,385 env_data.tsv
08/05/2023  02:54 PM        14,639,349 env2jax_data.tsv
08/05/2023  06:46 PM        34,620,160 filtered_dataset.pkl
08/05/2023  02:47 PM         7,331,013 jax_data.tsv
               4 File(s)     64,242,907 bytes
               2 Dir(s)  47,192,596,480 bytes free


In [3]:
datapath = Path("../data/FMT_MAPSEQ_data/processed_data")

In [4]:
dataset = pickle_load(datapath / "filtered_dataset.pkl")

In [5]:
dataset.describe()

3 groups: pre-perturb, post-perturb, comparator
165 OTUs in study
4 subjects per group
stats for pre_perturb group:
	 Subject J1:
		 353 particles
		 min read depth: 2218
		 median read depth: 4245.0
		 max read depth: 122311
	 Subject J2:
		 353 particles
		 min read depth: 2396
		 median read depth: 4781.0
		 max read depth: 678320
	 Subject J3:
		 336 particles
		 min read depth: 1734
		 median read depth: 2980.0
		 max read depth: 242450
	 Subject J4:
		 229 particles
		 min read depth: 777
		 median read depth: 1250.0
		 max read depth: 245585
	 1271 particles for group pre_perturb


stats for post_perturb group:
	 Subject JE10:
		 417 particles
		 min read depth: 1416
		 median read depth: 2903.0
		 max read depth: 77217
	 Subject JE11:
		 384 particles
		 min read depth: 1246
		 median read depth: 2842.0
		 max read depth: 71299
	 Subject JE12:
		 344 particles
		 min read depth: 1434
		 median read depth: 3229.0
		 max read depth: 273209
	 Subject JE9:
		 321 particles
		 min r

# Load trainer and run inference

In [6]:
# load model and train

In [7]:
# get from dataset
num_otus = 165
num_subjects = 4

In [8]:
# TODO: put in utils, only have dataset as input
def estimate_group_variances(data, notus):
    groups = data.keys()
    xvar = {}

    for grp in groups:
        grpdata = data[grp]
        nsubj = len(list(grpdata.keys()))
        sdata = np.zeros((notus, nsubj))
        for i,s in enumerate(grpdata.keys()):
            counts = grpdata[s]
            ra = counts.sum(axis=0)/counts.sum()
            sdata[:,i] = np.log(ra + 1e-20)

        #* see outliers > 50; filter out and take median
        if nsubj < 3:
            svarmed = 0.1
        else:
            svar = np.var(sdata, axis=1)
            svarmed = np.median(svar)
        xvar[grp] = svarmed
    return xvar 

In [9]:
data = dataset.data

In [10]:
subject_variance = estimate_group_variances(data, num_otus)


In [11]:
# data

In [12]:
device = torch.device("cpu") #cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
datatrain, garbclusts = get_normed_data_garb_clusters(data, device)

In [14]:
model = PerturbationModel(num_communities=20,
                         num_otus=num_otus,
                         num_subjects=num_subjects,
                         subject_variance=subject_variance,
                         device=torch.device('cpu')
                         )

In [15]:
def get_data(reads_in, device):
    reads = torch.from_numpy(reads_in).to(torch.float)
    norm = torch.sum(reads, dim=1)
    rel_data = torch.div(reads, norm.unsqueeze(1))
    z_data = torch.log(rel_data+0.0001)
    z_std, z_mean = torch.std_mean(z_data, dim=1)
    z_data = z_data - z_mean.unsqueeze(1)
    z_data = torch.div(z_data, z_std.unsqueeze(1))

    if z_data.isnan().any():
        raise ValueError("nan in normed data")

    return {'count_data': reads.to(device), 'normed_data': z_data.to(device)}

In [16]:
def get_normed_reads_counts_combine_reps_group_garb(reads, device):
    # reads is a dict [t] of dicts [s]
    counts = {}
    garbage_clusters = {}
    normed_data = {}
    full_normed_data = [] #* output L* x O; for all particles concatenated together...

    for t in reads.keys():
        subjs = reads[t].keys()
        counts[t] = {}
        all_particles = None
        # garbage_clusters[t] = {}
        normed_data[t] = {}
        for s in subjs:
            reps = reads[t][s].keys()
            combined_reads = reads[t][s]
            counts[t][s] = torch.from_numpy(reads[t][s]).to(dtype=torch.float, device=device)
            data = get_data(combined_reads, device)
            normed_data[t][s] = data['normed_data']
            full_normed_data.append(data['normed_data'])

        bulk = all_particles.sum(axis=0)/all_particles.sum()
        garbage_clusters[t] = torch.from_numpy(bulk).to(dtype=torch.float, device=device)
    combined_data = torch.cat(full_normed_data, dim=0)
    return {'count_data': counts, 'normed_data': normed_data, 'full_normed_data': combined_data}, garbage_clusters


## train

In [17]:
dataset

<mcspace.dataset.PerturbationDataSet at 0x2a96774d460>

In [18]:
num_epochs = 100

In [19]:
losses = train(model, dataset, num_epochs, verbose=True)

TypeError: 'PerturbationDataSet' object is not subscriptable

# Post process runs, evaulate stability metric

In [None]:
# no stability metric; just need bayes factors!!

# Create figures

In [None]:
# would like tree methods to work here too...
# could use a 'preprocessed tree' -- we're not giving a tutorial on pplacer