In [1]:
# imports

# python packages
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
import matplotlib.font_manager
import rpy2
from functools import reduce
import xarray as xr

# rpy2 imports
from rpy2 import robjects as ro
from rpy2.robjects.packages import importr
from rpy2.ipython.ggplot import image_png
from rpy2.robjects import pandas2ri

# load rpy2 extension for ipython
pandas2ri.activate()
%load_ext rpy2.ipython

# stop showing SettingWithCopyWarning
pd.options.mode.chained_assignment = None




In [2]:
# install & import r package sctransform

# check if sctransform is installed
if not ro.packages.isinstalled('sctransform'):
    # select CRAN mirror
    utils = importr('utils')
    utils.chooseCRANmirror(ind=1)
    # install sctransform
    utils.install_packages(ro.vectors.StrVector(['sctransform']))
    
# check if glmGamPoi is installed
if not ro.packages.isinstalled('glmGamPoi'):
    print('Please install glmGamPoi: https://github.com/const-ae/glmGamPoi')

# import sctransform
sctransform = importr('sctransform')
rmatrix = importr('Matrix')

# should be version 0.3.5 or higher                           
print(sctransform.__version__)


0.4.1


In [3]:
# helper functions

# function to calculate 0-sensitive geometric mean
def geometric_mean(vector, pseudocount=1):
    return np.exp(np.mean(np.log(vector + pseudocount))) - pseudocount

# function to convert pandas dataframe to r matrix
def pandas_dataframe_to_r_matrix(df, dtype=float):
    """
    Function to convert pandas DataFrame objects to R matrix objects.
    """
    if dtype is float:
        vector = ro.vectors.FloatVector(df.values.flatten().tolist())
    elif dtype is str:
        vector = ro.vectors.StrVector(df.values.flatten().tolist())
    elif dtype is int:
        vector = ro.vectors.FloatVector(df.values.flatten().tolist())
    else:
        raise ValueError('The dtype {} is not recognized'.format(dtype))
    matrix = rmatrix.Matrix(
        data=vector, 
        nrow=df.shape[0], 
        ncol=df.shape[1], 
        byrow=True, 
        dimnames=[df.index.to_list(), df.columns.to_list()], 
        sparse=True
    )
    return matrix


# Read in collated mapping counts and metadata

In [4]:
# read in metadata

# filepaths
filepath_genome_metadata = '../../data/metadata/genome-metadata.csv'
filepath_ortholog_metadata = '../../data/metadata/ortholog-metadata.csv'
filepath_sample_metadata = '../../data/metadata/sample-metadata.csv'

# genome metadata
genome_df = pd.read_csv(filepath_genome_metadata)

#  ortholog metadata
ortholog_df = pd.read_csv(filepath_ortholog_metadata)

#  sample metadata
sample_df = pd.read_csv(filepath_sample_metadata)
sample_df['Cruise'] = sample_df['Cruise'].map({'kok1606': 'G1', 
                                               'mgl1704': 'G2', 
                                               'km1906': 'G3'})
sample_df['SampleName'] = [name[:-2] for name in sample_df['SampleID']]

# create annotations_df to map annotations back to orthologs
annotations_df = ortholog_df[['CyCOGID', 'Annotation']].drop_duplicates()

sample_df


Unnamed: 0,SampleID,Cruise,Dataset,Experiment,Station,Cast,Datetime,Latitude,Longitude,Depth,Replicate,SmallFraction,LargeFraction,Unfractionated,Batch,SampleName
0,G1.SURF.NS.S02C1.15m.A,G1,Gradients 1 surface,,2.0,1.0,4/20/16 5:51,23.495833,-157.994333,15,A,G1.SURF.NS.S02C1.15m.0_2um.A,G1.SURF.NS.S02C1.15m.3um.A,False,G1_SURF,G1.SURF.NS.S02C1.15m
1,G1.SURF.NS.S02C1.15m.B,G1,Gradients 1 surface,,2.0,1.0,4/20/16 5:51,23.495833,-157.994333,15,B,G1.SURF.NS.S02C1.15m.0_2um.B,G1.SURF.NS.S02C1.15m.3um.B,False,G1_SURF,G1.SURF.NS.S02C1.15m
2,G1.SURF.NS.S02C1.15m.C,G1,Gradients 1 surface,,2.0,1.0,4/20/16 5:51,23.495833,-157.994333,15,C,G1.SURF.NS.S02C1.15m.0_2um.C,G1.SURF.NS.S02C1.15m.3um.C,False,G1_SURF,G1.SURF.NS.S02C1.15m
3,G1.SURF.NS.S04C1.15m.A,G1,Gradients 1 surface,,4.0,1.0,4/22/16 5:40,28.143167,-158.000667,15,A,G1.SURF.NS.S04C1.15m.0_2um.A,G1.SURF.NS.S04C1.15m.3um.A,False,G1_SURF,G1.SURF.NS.S04C1.15m
4,G1.SURF.NS.S04C1.15m.B,G1,Gradients 1 surface,,4.0,1.0,4/22/16 5:40,28.143167,-158.000667,15,B,G1.SURF.NS.S04C1.15m.0_2um.B,G1.SURF.NS.S04C1.15m.3um.B,False,G1_SURF,G1.SURF.NS.S04C1.15m
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
217,G3.UW.NS.UW40_1.7m.B,G3,Gradients 3 underway,,,,4/24/19 5:57,40.880000,-158.000000,7,B,G3.UW.NS.UW40_1.7m.0_2um.B,G3.UW.NS.UW40_1.7m.3um.B,False,G3_SURF,G3.UW.NS.UW40_1.7m
218,G3.UW.NS.UW40_1.7m.C,G3,Gradients 3 underway,,,,4/24/19 5:57,40.880000,-158.000000,7,C,G3.UW.NS.UW40_1.7m.0_2um.C,G3.UW.NS.UW40_1.7m.3um.C,False,G3_SURF,G3.UW.NS.UW40_1.7m
219,G3.UW.NS.UW40_2.7m.A,G3,Gradients 3 underway,,,,4/25/19 6:01,40.090000,-158.000000,7,A,G3.UW.NS.UW40_2.7m.0_2um.A,G3.UW.NS.UW40_2.7m.3um.A,False,G3_SURF,G3.UW.NS.UW40_2.7m
220,G3.UW.NS.UW40_2.7m.B,G3,Gradients 3 underway,,,,4/25/19 6:01,40.090000,-158.000000,7,B,G3.UW.NS.UW40_2.7m.0_2um.B,G3.UW.NS.UW40_2.7m.3um.B,False,G3_SURF,G3.UW.NS.UW40_2.7m


In [5]:
# read in mapped transcript abundance data

filepath_data = '../../data/2-mapping/collated_salmon_data.csv.gz'
mappings_df = pd.read_csv(filepath_data)
mappings_df


Unnamed: 0,MappingName,GeneLength,TPM,NumReads,SampleID,GenomeName,GeneID
0,AG-316-L16_2717627218,1116,0.0,0.0,G3.DIEL.NS.S4C21.15m.A,AG-316-L16,2717627218
1,AG-316-L16_2717627219,1290,0.0,0.0,G3.DIEL.NS.S4C21.15m.A,AG-316-L16,2717627219
2,AG-316-L16_2717627220,1578,0.0,0.0,G3.DIEL.NS.S4C21.15m.A,AG-316-L16,2717627220
3,AG-316-L16_2717627221,99,0.0,0.0,G3.DIEL.NS.S4C21.15m.A,AG-316-L16,2717627221
4,AG-316-L16_2717627222,1050,0.0,0.0,G3.DIEL.NS.S4C21.15m.A,AG-316-L16,2717627222
...,...,...,...,...,...,...,...
219859471,Syn9_638323192,228,0.0,0.0,G3.DIEL.NS.S4C13.15m.B,Syn9,638323192
219859472,Syn9_638323193,711,0.0,0.0,G3.DIEL.NS.S4C13.15m.B,Syn9,638323193
219859473,Syn9_638323194,180,0.0,0.0,G3.DIEL.NS.S4C13.15m.B,Syn9,638323194
219859474,Syn9_638323195,231,0.0,0.0,G3.DIEL.NS.S4C13.15m.B,Syn9,638323195


# Process mapped reads

- Drop reference genes without ortholog mapping (hypothesized to mostly be rRNA & other highly conserved RNA)
- Within each clade in each sample, aggregate reads mapping to the same ortholog


In [6]:
# directory to store output

output_dir = '../../data/3-normalization'


In [7]:
# look at read mappings per group and clade

pangenome_df = mappings_df[['MappingName', 'SampleID', 'GenomeName', 'NumReads']]

# add genus, clade, and ortholog info to dataframe
pangenome_df['Group'] = pangenome_df['GenomeName'].map(genome_df.set_index('GenomeName')['Group'])
pangenome_df['Clade'] = pangenome_df['GenomeName'].map(genome_df.set_index('GenomeName')['Clade'])
pangenome_df['CyCOGID'] = pangenome_df['MappingName'].map(ortholog_df.set_index('MappingName')['CyCOGID'])

# look at what proportion of reads mapped to each group
total_reads = pangenome_df.NumReads.sum()
print(pangenome_df.groupby('Group').NumReads.sum() / total_reads)
print(pangenome_df.groupby('Clade').NumReads.sum() / total_reads)

# make dataframe of all reads aggregated by clade within each sample
clade_abun_df = pangenome_df.groupby(['SampleID', 'Group', 'Clade']).NumReads.sum().reset_index().rename(
    columns={'NumReads': 'CladeReads'})
clade_abun_df['SampleReads'] = clade_abun_df['SampleID'].map(clade_abun_df.groupby('SampleID').CladeReads.sum())
clade_abun_df['RelativeAbundance'] = clade_abun_df['CladeReads'] / clade_abun_df['SampleReads']


Group
Prochlorococcus            0.745148
Synechococcus              0.252534
Uncultured-marine-virus    0.000285
Virus                      0.002032
Name: NumReads, dtype: float64
Clade
5.1A-CRD2            0.024086
5.1A-II              0.007827
5.1A-III             0.006876
5.1A-IV              0.065266
5.1A-UC-A-EnvC       0.005206
5.1A-WPC1            0.000570
5.1A-unclassified    0.014533
5.1B-CRD1            0.025221
5.1B-I               0.036755
5.1B-IX              0.004742
5.1B-V               0.000115
5.1B-VI              0.000110
5.1B-VIII            0.000217
5.2                  0.022553
5.3                  0.038389
AMZ-II               0.000671
HLI                  0.293027
HLII                 0.181326
HLII.HLVI            0.001085
HLIII                0.015565
HLIII.HLIV.HLV       0.000524
HLIV                 0.011601
HLVI                 0.002703
LLI                  0.127157
LLI.LLVIII           0.002676
LLII.LLIII           0.024737
LLIV                 0.002942
LLV

In [8]:
# collect reads mapping to known orthologs only and aggregate within orthologs

# down-select to only mappings that don't belong to a virus or virocell reference
drop_references = genome_df[(genome_df['Group'] == 'Virus') | (genome_df['Virocell'])].GenomeName.to_list()
pangenome_df = pangenome_df[~pangenome_df.GenomeName.isin(drop_references)]

# down-select to only the genes with an ortholog mapping
pangenome_df = pangenome_df[~pangenome_df['CyCOGID'].isna()]

# aggregate all orthologs in same genus
pangenome_df = pangenome_df.groupby(['SampleID', 'Group', 'Clade', 'CyCOGID']).agg(Reads=('NumReads', 'sum')).reset_index()

# identify nonzero mappings (at least one read mapped to ortholog-sample-clade)
pangenome_df['Nonzero'] = pangenome_df['Reads'].gt(0.0)

pangenome_df


Unnamed: 0,SampleID,Group,Clade,CyCOGID,Reads,Nonzero
0,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,AMZ-II,60000002.0,0.0,False
1,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,AMZ-II,60000003.0,0.0,False
2,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,AMZ-II,60000004.0,0.0,False
3,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,AMZ-II,60000005.0,0.0,False
4,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,AMZ-II,60000006.0,0.0,False
...,...,...,...,...,...,...
21899629,G3.UW.NS.UW40_2.7m.C,Synechococcus,Unclassified,60039505.0,0.0,False
21899630,G3.UW.NS.UW40_2.7m.C,Synechococcus,Unclassified,60039506.0,1.0,True
21899631,G3.UW.NS.UW40_2.7m.C,Synechococcus,Unclassified,60039507.0,0.0,False
21899632,G3.UW.NS.UW40_2.7m.C,Synechococcus,Unclassified,60039508.0,0.0,False


In [9]:
# calculate genome coverage in each sample

clade_abun_df['PangenomeSize'] = clade_abun_df['Clade'].map(pangenome_df.groupby(['Clade']).CyCOGID.nunique())
clade_abun_df['DetectedPangenomeSize'] = clade_abun_df['Clade'].map(
    pangenome_df[pangenome_df.Nonzero].groupby(['Clade']).CyCOGID.nunique())
clade_abun_df = pd.merge(
    clade_abun_df, 
    pangenome_df.groupby(['SampleID', 'Clade']).Nonzero.sum().reset_index().rename(columns={'Nonzero':'DetectedCyCOGs'}), 
    on=['SampleID', 'Clade'], 
    how='left'
)
clade_abun_df['PangenomeCoverage'] = clade_abun_df['DetectedCyCOGs'] / clade_abun_df['PangenomeSize']
clade_abun_df['DetectedPangenomeCoverage'] = clade_abun_df['DetectedCyCOGs'] / clade_abun_df['DetectedPangenomeSize']

# save as csv
clade_abun_df.to_csv('../../data/2-mapping/aggregate-clade-reads.csv', index=False)
clade_abun_df


Unnamed: 0,SampleID,Group,Clade,CladeReads,SampleReads,RelativeAbundance,PangenomeSize,DetectedPangenomeSize,DetectedCyCOGs,PangenomeCoverage,DetectedPangenomeCoverage
0,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,AMZ-II,507.205,5433304.056,0.000093,1762.0,127.0,0.0,0.000000,0.000000
1,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,HLI,272443.296,5433304.056,0.050143,4906.0,3906.0,575.0,0.117203,0.147209
2,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,HLII,4828459.042,5433304.056,0.888678,6997.0,4787.0,1740.0,0.248678,0.363484
3,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,HLII.HLVI,34148.879,5433304.056,0.006285,1822.0,1144.0,146.0,0.080132,0.127622
4,G1.SURF.NS.S02C1.15m.A,Prochlorococcus,HLIII,24885.537,5433304.056,0.004580,1652.0,221.0,11.0,0.006659,0.049774
...,...,...,...,...,...,...,...,...,...,...,...
6655,G3.UW.NS.UW40_2.7m.C,Synechococcus,5.1B-VI,379.143,6469143.806,0.000059,2564.0,226.0,8.0,0.003120,0.035398
6656,G3.UW.NS.UW40_2.7m.C,Synechococcus,5.1B-VIII,2009.257,6469143.806,0.000311,2515.0,239.0,3.0,0.001193,0.012552
6657,G3.UW.NS.UW40_2.7m.C,Synechococcus,5.2,267334.246,6469143.806,0.041325,5123.0,481.0,13.0,0.002538,0.027027
6658,G3.UW.NS.UW40_2.7m.C,Synechococcus,5.3,1016616.381,6469143.806,0.157149,2783.0,526.0,2.0,0.000719,0.003802


In [None]:
# maxes

clade_abun_df.groupby(['Group', 'Clade'])[['PangenomeSize', 'DetectedPangenomeSize', 
                                           'PangenomeCoverage', 'DetectedPangenomeCoverage']].max()

In [None]:
clade_abun_df.groupby(['Group', 'Clade']).DetectedPangenome.max() / clade_abun_df.groupby(['Group', 'Clade']).PangenomeSize.max()

In [None]:
genome_df['Clade'].value_counts()

# Separate out data by clade

In [None]:
# separate out data for most highly represented clades in dataset

# thresholds
detection_min = 3
coverage_threshold = 0.01
clades = {
    'pro': ['HLI', 'HLII', 'LLI', 'LLII.LLIII', 'LLIV', 'LLVII'], 
    'syn': ['5.1A-CRD2', '5.1A-II', '5.1A-III', '5.1A-IV', '5.1A-UC-A-EnvC', '5.1B-CRD1', '5.1B-I']
}

# storage variables
ortho_list = []
sample_list = []
core_df_dict = {}

for selected_clade in clades['pro'] + clades['syn']:
    print(selected_clade)
    clade_pangenome_df = pangenome_df[pangenome_df.Clade == selected_clade]
    
    # # make fresh directory
    # dir_path = '{}/{}'.format(output_dir, selected_clade)
    # if not os.path.isdir(dir_path):
    #     os.makedirs(dir_path)
    
    # down-select data to detected orthologs
    core_df = pd.DataFrame(clade_pangenome_df.groupby('CyCOGID').Nonzero.sum())
    core_df = core_df[core_df.Nonzero.gt(0)]
#     core_df = core_df[core_df.Nonzero.ge(detection_min)]
    clade_pangenome_df = clade_pangenome_df[clade_pangenome_df.CyCOGID.isin(core_df.index)]

    # down-select to only samples with coverage that is at least x% of the max coverage found in the dataset
    coverage = clade_pangenome_df.groupby('SampleID').Nonzero.sum()
    max_cov = coverage.max()
    total_orthos = len(core_df.index)
    print('Max coverage: {} / {} orthologs ({})'.format(max_cov, total_orthos, max_cov/total_orthos))
    samples = coverage[coverage.gt(coverage.max() * coverage_threshold)].index
    sample_list.append(samples)
    min_cov = coverage[samples].min()
    print('Min coverage: {} / {} orthologs ({})'.format(min_cov, total_orthos, min_cov/total_orthos))
    print('Total samples: {}\n'.format(len(samples)))
    clade_pangenome_df = clade_pangenome_df[clade_pangenome_df['SampleID'].isin(samples)]
    
    # down-select data to only orthologs that show up in a minimum number of the remaining samples
    core_df = pd.DataFrame(clade_pangenome_df.groupby('CyCOGID').Nonzero.sum())
    core_df = core_df[core_df.Nonzero.ge(detection_min)]
    clade_pangenome_df = clade_pangenome_df[clade_pangenome_df.CyCOGID.isin(core_df.index)]
    ortho_list.append(core_df.index)

    # pivot table
    clade_core_df = clade_pangenome_df.pivot(index='CyCOGID', columns=['SampleID'], values='Reads')
    core_df_dict[selected_clade] = clade_core_df
    
    # # save pre-normed data to csv
    # clade_core_df.to_csv('{}/read_counts_{}.csv'.format(dir_path, selected_clade))
    

In [None]:
# look at intersections of samples and genes

# universal
print('Total cyanobacterial orthologs: {}'.format(len(reduce(np.union1d, ortho_list))))
print('Total cyanobacterial samples: {}'.format(len(reduce(np.union1d, sample_list))))
print('Shared cyanobacterial orthologs: {}'.format(len(reduce(np.intersect1d, ortho_list))))
print('Shared cyanobacterial samples: {}\n'.format(len(reduce(np.intersect1d, sample_list))))

# pro
print('Total Pro orthologs: {}'.format(len(reduce(np.union1d, ortho_list[:3]))))
print('Total Pro samples: {}'.format(len(reduce(np.union1d, sample_list[:3]))))
print('Shared Pro orthologs: {}'.format(len(reduce(np.intersect1d, ortho_list[:3]))))
print('Shared Pro samples: {}\n'.format(len(reduce(np.intersect1d, sample_list[:3]))))

# syn
print('Total Syn orthologs: {}'.format(len(reduce(np.union1d, ortho_list[3:]))))
print('Total Syn samples: {}'.format(len(reduce(np.union1d, sample_list[3:]))))
print('Shared Syn orthologs: {}'.format(len(reduce(np.intersect1d, ortho_list[3:]))))
print('Shared Syn samples: {}\n'.format(len(reduce(np.intersect1d, sample_list[3:]))))


# Normalize data using sctransform

* Save diagnostic plots & csv files, as well as tensorized data in NetCDF4 format

In [None]:
# run the model on each clade, saving outputs and plots

# store normalized data as slabs to arrange in xarray tensor
read_slabs = {}
residual_slabs = {}

# iterate through pro & syn
for genus in ['pro', 'syn']:
    print(genus)
    read_slabs[genus] = []
    residual_slabs[genus] = []
        
    # iterate through clades
    for clade in clades[genus]:
        print(clade)

        # make fresh directory
        dir_path = '{}/{}'.format(output_dir, clade)
        if not os.path.isdir(dir_path):
            os.makedirs(dir_path)

        # convert clade_core_df to r matrix
        clade_core_df = core_df_dict[clade]
        r_clade_core_df = pandas_dataframe_to_r_matrix(clade_core_df)

        # fit vst normalization model
        result = sctransform.vst(
            r_clade_core_df, 
            n_genes='null', 
            min_cells=detection_min,
            return_gene_attr=True, 
            return_cell_attr=True, 
            vst_flavor='v2', 
            verbosity=1
        )

        # save residuals as csv
        residual_df = pd.DataFrame(
            np.asarray(result[0]), 
            index=clade_core_df.index, 
            columns=clade_core_df.columns
        )
        residual_df.to_csv('{}/residuals_{}.csv'.format(dir_path, clade))
        
        # save residuals as slab 
        slab = (
            residual_df
            .reset_index()
            .rename(columns={'OrthologID': 'ortholog'})
            .melt(id_vars='ortholog', var_name='sample', value_name='residual')
        )
        slab['ortholog'] = slab['ortholog'].astype('Int64')
        slab['clade'] = clade
        residual_slabs[genus].append(slab.set_index(['ortholog', 'clade', 'sample']))
        
        # save read counts as slab 
        slab = (
            clade_core_df
            .reset_index()
            .rename(columns={'OrthologID': 'ortholog'})
            .melt(id_vars='ortholog', var_name='sample', value_name='readcount')
        )
        slab['ortholog'] = slab['ortholog'].astype('Int64')
        slab['clade'] = clade
        read_slabs[genus].append(slab.set_index(['ortholog', 'clade', 'sample']))

        # save plot of model parameters
        plots = sctransform.plot_model_pars(result, show_theta=True)
        img = image_png(plots)
        with open('{}/parameters_{}.png'.format(dir_path, clade), 'wb') as png:
            png.write(img.data)

        # plot high variance genes
        residual_var = residual_df.var(axis=1)
        three_sigma = 1 + residual_var.var() * 3
        expression = clade_core_df.apply(geometric_mean, axis=1)
        plt.figure(figsize=(10, 4))
        sns.scatterplot(x=expression, y=residual_var, alpha=0.1);
        plt.hlines(three_sigma, expression.min(), expression.max(), 
                   colors=['orange'], linestyles=[':'], label='3*sigma');
        plt.xlabel('mean gene abundance (reads/sample)')
        plt.xscale('log')
        plt.ylabel('residual variance')
        plt.title('normalized residual abundance of {} genes'.format(clade))
        plt.savefig('{}/residual_variance_{}.png'.format(dir_path, clade))
        plt.show()

        # save csv of residual variances with annotations
        res_var_df = residual_var.reset_index().rename(columns={0:'ResidualVariance'})
        res_var_df['Annotation'] = res_var_df.OrthologID.map(annotations_df.set_index('ortholog')['annotation'])
        res_var_df = res_var_df.sort_values('ResidualVariance', ascending=False).reset_index()
        res_var_df.to_csv('{}/high_variance_orthologs_{}.csv'.format(dir_path, clade))


In [None]:
# make slabs into xarray tensors and save

# make DataArrays for metadata
replicate_da = xr.DataArray.from_series(sample_df.set_index('sample')['replicate'])
samplename_da = xr.DataArray.from_series(sample_df.set_index('sample')['samplename'])
annotation_da = xr.DataArray.from_series(annotations_df.set_index('ortholog')['annotation'])

for genus in ['pro', 'syn']:
    # make read count DataArray
    read_data = read_slabs[genus]
    read_all_slabs = pd.concat(read_data)
    read_da = xr.DataArray.from_series(read_all_slabs['readcount']).fillna(0)
    # make residual DataArray
    residual_data = residual_slabs[genus]
    residual_all_slabs = pd.concat(residual_data)
    residual_da = xr.DataArray.from_series(residual_all_slabs['residual']).fillna(0)
    # combine residuals with metadata into xarray Dataset
    ds = xr.Dataset(
        dict(
            readcount=read_da, 
            residual=residual_da, 
            replicate=replicate_da.loc[residual_da.sample],
            samplename=samplename_da.loc[residual_da.sample], 
            annotation=annotation_da.loc[residual_da.ortholog]
        )
    )
    # save as netCDF4 file in data directory
    ds.to_netcdf('{}/{}-tensor-dataset.nc'.format(output_dir, genus))

ds
    

## Batch corrected normalization

As of February, 2023 this is not possible with the sctransform.vst function when using the "v2" version of the model (see [this issue](https://github.com/satijalab/sctransform/issues/126) for details). As the v2 version is the most appropriate for this data, batch correction is not possible at this time. Additionally, 
PCA and UMAP analysis of the data suggests that it may be difficult to disentangle batch effects from signal, since the batches are correlated with different sampling conditions.

In [None]:
# # run the model on each clade, saving outputs and plots

# # store normalized data as slabs to arrange in xarray tensor
# slabs = {}
# # directory to store output
# output_dir = '../../data/3-normalization/batch-corrected'

# # iterate through pro & syn
# for genus in ['pro', 'syn']:
#     print(genus)
#     slabs[genus] = []
        
#     # iterate through clades
#     for clade in clades[genus]:
#         print(clade)

#         # make fresh directory
#         dir_path = '{}/{}'.format(output_dir, clade)
#         if not os.path.isdir(dir_path):
#             os.makedirs(dir_path)

#         # convert clade_core_df to r matrix
#         clade_core_df = core_df_dict[clade]
#         r_clade_core_df = pandas_dataframe_to_r_matrix(clade_core_df)
        
#         # pull out sample attributes that match samples in clade_core_df
#         sample_attr_df = sample_df.set_index('sample').loc[clade_core_df.columns, ['Dataset', 'Cruise']]
        
#         # convert pandas df to r dataframe
#         r_sample_attr_df = pandas2ri.py2rpy(sample_attr_df)

#         # fit vst normalization model
#         result = sctransform.vst(
#             r_clade_core_df, 
#             cell_attr=r_sample_attr_df, 
#             batch_var=ro.vectors.StrVector(["Dataset"]), 
#             n_genes='null', 
#             min_cells=detection_min,
#             return_gene_attr=True, 
#             return_cell_attr=True, 
# #             vst_flavor='v2', 
#             verbosity=10
#         )

#         # save residuals as csv
#         residual_df = pd.DataFrame(
#             np.asarray(result[0]), 
#             index=clade_core_df.index, 
#             columns=clade_core_df.columns
#         )
#         residual_df.to_csv('{}/residuals_{}.csv'.format(dir_path, clade))
        
#         # save residuals as slab 
#         slab = (
#             residual_df
#             .reset_index()
#             .rename(columns={'OrthologID': 'ortholog'})
#             .melt(id_vars='ortholog', var_name='sample', value_name='residual')
#         )
#         slab['ortholog'] = slab['ortholog'].astype('Int64')
#         slab['clade'] = clade
#         slabs[genus].append(slab.set_index(['ortholog', 'clade', 'sample']))

#         # save plot of model parameters
#         plots = sctransform.plot_model_pars(result, show_theta=True)
#         img = image_png(plots)
#         with open('{}/parameters_{}.png'.format(dir_path, clade), 'wb') as png:
#             png.write(img.data)

#         # plot high variance genes
#         residual_var = residual_df.var(axis=1)
#         three_sigma = 1 + residual_var.var() * 3
#         expression = clade_core_df.apply(geometric_mean, axis=1)
#         plt.figure(figsize=(10, 4))
#         sns.scatterplot(x=expression, y=residual_var, alpha=0.1);
#         plt.hlines(three_sigma, expression.min(), expression.max(), 
#                    colors=['orange'], linestyles=[':'], label='3*sigma');
#         plt.xlabel('mean gene abundance (reads/sample)')
#         plt.xscale('log')
#         plt.ylabel('residual variance')
#         plt.title('normalized residual abundance of {} genes'.format(clade))
#         plt.savefig('{}/residual_variance_{}.png'.format(dir_path, clade))
#         plt.show()

#         # save csv of residual variances with annotations
#         res_var_df = residual_var.reset_index().rename(columns={0:'ResidualVariance'})
#         res_var_df['Annotation'] = res_var_df.OrthologID.map(annotations_df.set_index('ortholog')['annotation'])
#         res_var_df = res_var_df.sort_values('ResidualVariance', ascending=False).reset_index()
#         res_var_df.to_csv('{}/high_variance_orthologs_{}.csv'.format(dir_path, clade))
        

In [None]:
# # make slabs into xarray tensors and save

# # make DataArrays for metadata
# replicate_da = xr.DataArray.from_series(sample_df.set_index('sample')['replicate'])
# samplename_da = xr.DataArray.from_series(sample_df.set_index('sample')['samplename'])
# annotation_da = xr.DataArray.from_series(annotations_df.set_index('ortholog')['annotation'])

# for genus in ['pro', 'syn']:
#     # make residual DataArray
#     data = slabs[genus]
#     all_slabs = pd.concat(data)
#     residual_da = xr.DataArray.from_series(all_slabs['residual']).fillna(0)
#     # combine residuals with metadata into xarray Dataset
#     ds = xr.Dataset(
#         dict(
#             residual=residual_da, 
#             replicate=replicate_da,
#             samplename=samplename_da, 
#             annotation=annotation_da.loc[residual_da.ortholog]
#         )
#     )
#     # save as netCDF4 file in data directory
#     ds.to_netcdf('{}/{}-res-abun.nc'.format(output_dir, genus))

# ds
