In [92]:
import os
import warnings
import qiime2 as q2

from qiime2.plugins.emperor.actions import (plot, biplot)
from qiime2.plugins.diversity.actions import (beta_phylogenetic, pcoa)
from deicode import rpca #from qiime2.plugins.deicode.actions import rpca

import pandas as pd
from pandas import DataFrame
import random
import numpy as np
import h5py
import biom

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline

from gemelli.ctf import ctf #from gemelli.actions import ctf
from gemelli.ctf import ctf_helper
from gemelli.factorization import TensorFactorization
from gemelli.preprocessing import build, rclr
from gemelli._ctf_defaults import (DEFAULT_COMP, DEFAULT_MSC,
                                   DEFAULT_MFC, DEFAULT_MAXITER,
                                   DEFAULT_FMETA as DEFFM)
from qiime2.plugins.longitudinal.actions import (volatility, linear_mixed_effects)

# import sys
# sys.path.insert(1, '../scripts/')
# from load_environmental_variables import *
local_data_path = '/data2/hratch/immune_CCI_pseudotime/'

In [465]:
# load files

# read in t0 cci timepoint
colnames = open(local_data_path + 'interim/velocyto_analyses/column_names.txt').read().splitlines() # cell barcodes in order of CCI distance matrices
# cell_ids = pd.read_csv(local_data_path + 'processed/5k_pbmc_celltypes_velocytoformatted.csv', index_col = 0) # map cell barcodes to cell type
# cell_id_map = dict(zip(cell_ids.SampleID, cell_ids.Cell_Type)) # 

cci_dt = h5py.File(local_data_path + 'interim/CCI_dt.h5', "r")
# cci_t0 = pd.DataFrame(cci_dt[sorted(cci_dt.keys())[0]], columns = colnames, index = colnames)
cci_t0 = biom.Table(np.array(cci_dt[sorted(cci_dt.keys())[0]]), 
                    sample_ids = colnames, observation_ids = colnames)

# # FIX THIS
# # duplicate time poitns not causing error, drop for now
# metadata.drop_duplicates(subset = ['velocity_pseudotime'], inplace = True)
# cci_t0 = cci_t0.loc[metadata.index, metadata.index]
# cci_t0 = biom.Table(cci_t0.values, sample_ids = metadata.index, observation_ids = metadata.index)

# load me`tadata (including pseudotime) from velocyto analysis
metadata = pd.read_csv(local_data_path + 'interim/velocyto_analyses/velocyto_attributes.csv', index_col = 0)
metadata['cell_ids'] = metadata.index

n_bins = 100
metadata['time'] = pd.qcut(x = metadata.velocity_pseudotime, q = n_bins, 
                              labels = list(range(n_bins)), duplicates = 'drop')

In [None]:
print('get ctf')
ctf_results = ctf(table = cci_t0, sample_metadata = metadata, 
                                           individual_id_column = 'cell_ids', 
                                           state_column = 'time', 
                 feature_metadata = None)
subject_biplot = ctf_results[0]
state_biplot = ctf_results[1]
distance_matrix = ctf_results[2]
state_subject_ordination = ctf_results[3]
state_feature_ordination = ctf_results[4]

state_feature_ordination.index = state_feature_ordination.index.astype(int)
print('save ctf')

with open(local_data_path + 'interim/ctf_results.pickle', 'wb') as handle:
    pickle.dump(ctf_results)
    
print('complete')

In [None]:
fig, ax = plt.subplots(figsize=(16, 5))
sns.lineplot(y='PC1',x='time', hue='Cell_Type', ci=95,
             data=state_subject_ordination, ax=ax)
prop_explained = subject_biplot.proportion_explained['PC1'] * 100
ax.set_ylabel('PC1 (%.2f %%)' % (prop_explained),
              fontsize=25,
              color='black')
ax.set_xlabel('time (weeks)',
              fontsize=25,
              color='black')

# fix spine colors
ax.set_facecolor('white')
ax.set_axisbelow(True)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(False)
# ax.spines['bottom'].set_bounds(1, 25)
for child in ax.get_children():
    if isinstance(child, matplotlib.spines.Spine):
        child.set_color('grey')
ax.tick_params(axis='y', colors='black')
ax.tick_params(axis='x', colors='black')

# generate legend from last iterations     
handles, labels = ax.get_legend_handles_labels()
legend = ax.legend(handles[1:], labels[1:], loc=2, 
                         bbox_to_anchor=(0.0, 1.1),
                         fancybox=True, framealpha=0.5, 
                         ncol=2, markerscale=3,
                         facecolor="white")
# # increase the line width in the legend 
# for line in legend.get_lines()[:]:
#     line.set_linewidth(8.0)
# for line in legend.get_lines()[:]:
#     line.set_linewidth(4.0)

fig.patch.set_facecolor('white')
fig.patch.set_alpha(0.0)

plt.show()



In [None]:
fig, ax = plt.subplots(figsize=(16, 5))


cell_types = state_subject_ordination.Cell_Type.unique()
colors = sns.color_palette("husl", len(cell_types))

counter = 0
for ct in cell_types:
    sns.regplot(y='PC1',x='time', 
                data=state_subject_ordination[state_subject_ordination.Cell_Type == ct],
                ci=95, scatter = False, lowess = True,
                color = colors[counter],
                ax=ax)
    counter += 1
prop_explained = subject_biplot.proportion_explained['PC1'] * 100
ax.set_ylabel('PC1 (%.2f %%)' % (prop_explained),
              fontsize=25,
              color='black')
ax.set_xlabel('time (weeks)',
              fontsize=25,
              color='black')

# fix spine colors
ax.set_facecolor('white')
ax.set_axisbelow(True)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(False)
# ax.spines['bottom'].set_bounds(1, 25)
for child in ax.get_children():
    if isinstance(child, matplotlib.spines.Spine):
        child.set_color('grey')
ax.tick_params(axis='y', colors='black')
ax.tick_params(axis='x', colors='black')

# # generate legend from last iterations     
# handles, labels = ax.get_legend_handles_labels()
# legend = ax.legend(handles[1:], labels[1:], loc=2, 
#                          bbox_to_anchor=(0.0, 1.1),
#                          fancybox=True, framealpha=0.5, 
#                          ncol=2, markerscale=3,
#                          facecolor="white")
# # increase the line width in the legend 
# for line in legend.get_lines()[:]:
#     line.set_linewidth(8.0)
# for line in legend.get_lines()[:]:
#     line.set_linewidth(4.0)

fig.patch.set_facecolor('white')
fig.patch.set_alpha(0.0)

plt.show()




In [None]:
from mpl_toolkits.mplot3d.axes3d import Axes3D
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, figsize=(10,10))

counter = 0
for ct in cell_types:
    index = state_subject_ordination[state_subject_ordination.Cell_Type == ct].index
    ax.plot(state_subject_ordination.loc[index, 'PC1'],
            state_feature_ordination.loc[index, 'PC1'], 
            state_subject_ordination.loc[index, 'time'], 
            color = colors[counter])
    counter += 1

plt.show()

# Source code

In [411]:
table = cci_t0
sample_metadata = metadata 
individual_id_column = 'cell_ids'
state_columns = ['velocity_pseudotime']

n_components = DEFAULT_COMP
min_sample_count = DEFAULT_MSC
min_feature_count = DEFAULT_MFC
max_iterations_als = DEFAULT_MAXITER
max_iterations_rptm = DEFAULT_MAXITER
n_initializations = DEFAULT_MAXITER

In [412]:
if sample_metadata is not None and not isinstance(sample_metadata,
                                                      DataFrame):
    sample_metadata = sample_metadata.to_dataframe()
keep_cols = state_columns + [individual_id_column]
all_sample_metadata = sample_metadata.drop(keep_cols, axis=1)
sample_metadata = sample_metadata[keep_cols]

# match the data (borrowed in part from gneiss.util.match)
subtablefids = table.ids('observation')
subtablesids = table.ids('sample')
if len(subtablesids) != len(set(subtablesids)):
    raise ValueError('Data-table contains duplicate sample IDs')
if len(subtablefids) != len(set(subtablefids)):
    raise ValueError('Data-table contains duplicate feature IDs')
submetadataids = set(sample_metadata.index)
subtablesids = set(subtablesids)
subtablefids = set(subtablefids)

sidx = subtablesids & submetadataids
if len(sidx) == 0:
    raise ValueError(("No more features left.  Check to make sure that "
                      "the sample names between `sample-metadata` and"
                      " `table` are consistent"))

table.filter(list(sidx), axis='sample', inplace=True)
sample_metadata = sample_metadata.reindex(sidx)

# filter and import table
for axis, min_sum in zip(['sample',
                          'observation'],
                         [min_sample_count,
                          min_feature_count]):
    table = table.filter(table.ids(axis)[table.sum(axis) >= min_sum],
                         axis=axis, inplace=True)

# table to dataframe
table = DataFrame(table.matrix_data.toarray(),
                  table.ids('observation'),
                  table.ids('sample'))

tensor = build()
tensor.construct(table, sample_metadata,
                 individual_id_column, state_columns)



TF = TensorFactorization(
        n_components=n_components,
        max_als_iterations=max_iterations_als,
        max_rtpm_iterations=max_iterations_rptm,
        n_initializations=n_initializations).fit(rclr(tensor.counts))

In [None]:
ctf_results = ctf(table = cci_t0, sample_metadata = metadata, 
                                           individual_id_column = 'cell_ids', 
                                           state_column = 'time', 
                 feature_metadata = None)
subject_biplot = ctf_results[0]
state_biplot = ctf_results[1]
distance_matrix = ctf_results[2]
state_subject_ordination = ctf_results[3]
state_feature_ordination = ctf_results[4]

state_feature_ordination.index = state_feature_ordination.index.astype(int)