In [1]:
import pandas as pd
import numpy as np
from pathlib import Path

In [2]:
ls

 Volume in drive C is OS
 Volume Serial Number is 7CF0-0838

 Directory of C:\Users\guppa\Dropbox (Partners HealthCare)\research_bwh\dynamical_rules\RuleBasedDynamics\dynrules

07/20/2023  08:35 AM    <DIR>          .
06/25/2023  08:06 PM    <DIR>          ..
07/20/2023  08:35 AM    <DIR>          .ipynb_checkpoints
06/23/2023  04:44 PM                 0 __init__.py
07/18/2023  12:36 PM    <DIR>          __pycache__
06/29/2023  10:46 AM             7,486 analysis.py
07/18/2023  11:25 AM             7,862 analysis_glv.py
06/30/2023  02:25 PM             6,760 analysis2.py
07/20/2023  08:03 AM    <DIR>          CDIFF_DATA
07/20/2023  08:34 AM             2,791 data.py
06/25/2023  08:00 PM             2,696 data_old.py
07/20/2023  08:35 AM                72 dev_methods.ipynb
06/25/2023  07:37 PM    <DIR>          docs
07/18/2023  11:18 AM    <DIR>          experiments
07/18/2023  12:37 PM             7,688 glv_grad_match_model.py
06/27/2023  08:07 PM            12,684 gradient_matching_mo

In [4]:
cdiff_path = Path("./CDIFF_DATA/processed_data")

In [5]:
counts = pd.read_csv(cdiff_path / "reads.tsv", sep="\t", index_col=0)

In [6]:
taxonomy = pd.read_csv(cdiff_path / "taxonomy.tsv", sep="\t", index_col=0)

In [7]:
metadata = pd.read_csv(cdiff_path / "meta.tsv", sep="\t", index_col=0)

In [8]:
qpcr = pd.read_csv(cdiff_path / "qpcr.tsv", sep="\t", index_col=0)

In [9]:
taxmeta = pd.read_csv(cdiff_path / "taxa_meta.tsv", sep="\t", index_col=0)

In [13]:
# what about rescaling and filtering?

# todo: add consistency filter -- do tonight along with phage data
# add other datasets too; run tests to compare models**

# removes otus -- finds tags to keep and edits internal data accordingly
# needs to remove rows from counts, taxonomy, and taxmeta
def top_abundance_filter(notus):
    pass

In [10]:
# this will be a method of dataset class
def get_data_for_inference():
    pass

In [126]:
class Dataset:
    def __init__(self, counts, qpcr, metadata, taxa_metadata, taxonomy):
        self.counts = pd.read_csv(counts, sep="\t", index_col=0) 
        self.qpcr = pd.read_csv(qpcr, sep="\t", index_col=0)
        self.metadata = pd.read_csv(metadata, sep="\t")
        self.taxmeta = pd.read_csv(taxa_metadata, sep="\t", index_col=0)
        self.taxonomy = pd.read_csv(taxonomy, sep="\t", index_col=0)
    
    def get_sampleID(self, t, s):
        # SHOULD BE UNIQUE
        return self.metadata.loc[(self.metadata.loc[:,"time"] == t) & (self.metadata.loc[:,"subjectID"] == s), "sampleID"].values[0]

    def top_abundance_filter(self, notus):
        """
        filter by total relative abundance, keep 'notus' otus
        this is done 'in place' and modifies internal data
        """
        # take top N (averaged over all time points and subjects)
        ra = self.counts.sum(axis=1)/(self.counts.values.sum())
        # TODO: move this comment to c. diff loading function
        #* for c diff data top 14 agree with mdsine, can use top 13 if excluding hiranonis
        index = ra.sort_values(ascending=False)[:notus].index
        
        # use index to subset counts, taxonomy, and taxmeta
        self.counts = self.counts.loc[index,:]
        self.taxonomy = self.taxonomy.loc[index,:]
        self.taxmeta = self.taxmeta.loc[index,:]
         
    # TODO: add consistency filter
    # TODO: add option to remove times here too... currently doing outside of this class
    
    def get_count_mass_data(self):
        """
        get count and mass matrix data
        """
        
        subjs = self.metadata.loc[:,"subjectID"].unique()
        times = np.sort(self.metadata.loc[:,"time"].unique())
        ntime = len(times)
        nsubj = len(subjs)
        notus = self.taxonomy.shape[0]
        taxa = list(self.taxonomy.index)
        nrep = self.qpcr.shape[1]

        ycounts = np.zeros((ntime, nsubj, notus))
        wmass = np.zeros((nrep, ntime, nsubj))

        for i,t in enumerate(times):
            for j,s in enumerate(subjs):
                sample = self.get_sampleID(t,s)
                wmass[:,i,j] = self.qpcr.loc[sample,:].values
                for oidx, taxon in enumerate(taxa):
                    ycounts[i,j,oidx] = self.counts.loc[taxon,str(sample)]
                # TODO: make all strings, not sure why columns strings, but qpcr index ints?
        return times, ycounts, wmass #, taxa_type
                            
    def get_data_for_inference(self, rescale=1):
        """
        return data matrices to use for model inference        
        returns times, abundance data, time mask, and type info
        
        TODO: figure out... **take geomean of qpcr .... what does mdsine2 do, since it shows plots of data???
        """
        EPS = 1e-8
        times, ycounts, wmass = self.get_count_mass_data()
        
        wmass_geomean = np.exp(np.mean(np.log(wmass), axis=0))*rescale
        xdata = (ycounts/(ycounts.sum(axis=2, keepdims=True)))*wmass_geomean[:,:,None]
        ntime, nsubj, notu = xdata.shape

        xlogdata = np.log(xdata + EPS)
        
        time_mask = self.taxmeta.time.values
        type_names = self.taxmeta.type.values 
        # 0 == bacteria; 1 == phage
        typeinfo = -1*np.ones(len(type_names), dtype=int)
        typeinfo[type_names == 'bacteria'] = 0
        typeinfo[type_names == 'phage'] = 1
        
        if (typeinfo == -1).any():
            raise ValueError("invalid taxa type")
        
        return times, xlogdata, time_mask, typeinfo

In [127]:
dataset = Dataset(counts=cdiff_path / "reads.tsv",
                 qpcr=cdiff_path / "qpcr.tsv",
                 metadata=cdiff_path / "meta.tsv",
                 taxa_metadata=cdiff_path / "taxa_meta.tsv",
                 taxonomy=cdiff_path / "taxonomy.tsv")

In [128]:
dataset.top_abundance_filter(13)

In [130]:
t, x, tm, tp = dataset.get_data_for_inference(rescale=1e-9)

In [131]:
t.shape

(26,)

In [132]:
x.shape

(26, 5, 13)

In [133]:
tm

array([ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,
       28.9,  0. ])

In [134]:
tp

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [135]:
t

array([ 0.75,  1.  ,  2.  ,  3.  ,  4.  ,  6.  ,  8.  , 10.  , 14.  ,
       17.  , 21.  , 24.  , 28.  , 28.75, 29.  , 30.  , 31.  , 32.  ,
       34.  , 36.  , 38.  , 42.  , 45.  , 49.  , 52.  , 56.  ])