diff --git a/README.md b/README.md index f53a787..463a2a3 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ you can install it directly from the github repository. **Note: Only run this if you wish to use the development version** ```bash -pip install https://github.com/biota/sourcetracker2/archive/master.zip +pip install https://github.com/caporaso-lab/sourcetracker2/archive/master.zip ``` To test that your installation was successful, try the following command: diff --git a/sourcetracker/__init__.py b/sourcetracker/__init__.py index 2b20aec..ecc21b6 100644 --- a/sourcetracker/__init__.py +++ b/sourcetracker/__init__.py @@ -9,10 +9,9 @@ from ._compare import compare_sinks, compare_sink_metrics from ._sourcetracker import gibbs -from ._plot import plot_heatmap __version__ = '2.0.1-dev' _readme_url = "https://github.com/biota/sourcetracker2/blob/master/README.md" -__all__ = ['compare_sinks', 'compare_sink_metrics', 'gibbs', 'plot_heatmap'] +__all__ = ['compare_sinks', 'compare_sink_metrics', 'gibbs'] diff --git a/sourcetracker/_cli/gibbs.py b/sourcetracker/_cli/gibbs.py index e747b1a..cdac9b3 100755 --- a/sourcetracker/_cli/gibbs.py +++ b/sourcetracker/_cli/gibbs.py @@ -19,7 +19,7 @@ from sourcetracker._cli.cli import cli from sourcetracker._gibbs import gibbs_helper -from sourcetracker._plot import plot_heatmap +from sourcetracker._plot import ST_graphs from sourcetracker._util import parse_sample_metadata, biom_to_df # import default descriptions @@ -29,7 +29,11 @@ DESC_RAF2, DESC_RST, DESC_DRW, DESC_BRN, DESC_DLY, DESC_PFA, DESC_RPL, DESC_SNK, DESC_SRS, - DESC_SRS2, DESC_CAT) + DESC_SRS2, DESC_CAT, DESC_DIA, + DESC_LIM, DESC_STBAR, DESC_HTM, + DESC_PHTM, DESC_TTL, DESC_HCOL, + DESC_UKN, DESC_TRA, DESC_BCOL, + DESC_FLBR) # import default values from sourcetracker._gibbs_defaults import (DEFAULT_ALPH1, DEFAULT_ALPH2, @@ -120,9 +124,28 @@ help=DESC_CAT) # Stats functions for diagnostics @click.option('--diagnostics', required=False, default=False, is_flag=True, - show_default=True) + show_default=True, help=DESC_DIA) @click.option('--limit', required=False, default=0.05, type=click.FLOAT, - show_default=True) + show_default=True, help=DESC_LIM) +# (added options for graphical ouput and varying stats functions) +@click.option('--stacked_bar', required=False, default=False, is_flag=True, + show_default=True, help=DESC_STBAR) +@click.option('--heatmap', required=False, default=True, is_flag=True, + show_default=True, help=DESC_HTM) +@click.option('--paired_heatmap', required=False, default=False, is_flag=True, + show_default=True, help=DESC_PHTM) +@click.option('--title', required=False, default='Mixing Proportions', + type=click.STRING, show_default=True, help=DESC_TTL) +@click.option('--heatmap_color', required=False, default='viridis', + type=click.STRING, show_default=True, help=DESC_HCOL) +@click.option('--keep_unknowns', required=False, default=True, is_flag=True, + show_default=True, help=DESC_UKN) +@click.option('--transpose', required=False, default=False, is_flag=True, + show_default=True, help=DESC_TRA) +@click.option('--bar_color', required=False, default="", type=click.STRING, + show_default=True, help=DESC_BCOL) +@click.option('--flip_bar', required=False, default=False, is_flag=True, + show_default=True, help=DESC_FLBR) def gibbs(table_fp: Table, mapping_fp: pd.DataFrame, output_dir: str, @@ -144,7 +167,16 @@ def gibbs(table_fp: Table, sink_column_value: str, source_category_column: str, diagnostics: bool, - limit: float): + limit: float, + stacked_bar: bool, + heatmap: bool, + paired_heatmap: bool, + title: str, + heatmap_color: str, + keep_unknowns: bool, + transpose: bool, + bar_color: str, + flip_bar: bool): '''Gibb's sampler for Bayesian estimation of microbial sample sources. For details, see the project README file. @@ -180,11 +212,19 @@ def gibbs(table_fp: Table, mpm.to_csv(os.path.join(output_dir, 'mixing_proportions.txt'), sep='\t') mps.to_csv(os.path.join(output_dir, 'mixing_proportions_stds.txt'), sep='\t') - + # need to count number of rows here to check for equality + # add notice if not equal + color_list = bar_color.split(",") # Plot contributions. - fig, ax = plot_heatmap(mpm.T) - fig.savefig(os.path.join(output_dir, 'mixing_proportions.pdf'), dpi=300) - + graphs = ST_graphs(mpm, output_dir, title=title, color=heatmap_color) + if heatmap: + graphs.ST_heatmap(keep_unknowns=keep_unknowns) + if paired_heatmap: + graphs.ST_paired_heatmap(keep_unknowns=keep_unknowns, + normalized=transpose, transpose=transpose) + if stacked_bar: + graphs.ST_Stacked_bar(keep_unknowns=keep_unknowns, coloring=color_list, + flipped=flip_bar) if diagnostics: os.mkdir(output_dir + 'diagnostics') data = np.load('envcounts.npy', allow_pickle=True) diff --git a/sourcetracker/_gibbs_defaults.py b/sourcetracker/_gibbs_defaults.py index ab929a0..5bb250c 100644 --- a/sourcetracker/_gibbs_defaults.py +++ b/sourcetracker/_gibbs_defaults.py @@ -93,3 +93,37 @@ 'sink (or source if `--loo is passed). ' 'This feature table contains the specific ' ' of each fractional contribution.') +DESC_DIA = ('Activate diagnostics function which visualizes the ' + 'deviation of each SourceTracker run, requires' + 'at least 2 restarts. Default is False.') +DESC_LIM = ('Minimum deviation limit for display. Default value' + 'of 0.05') +DESC_STBAR = ('Activates stacked bar plot visualization.' + 'Default is False.') +DESC_HTM = ('Deactivates Heatmap plot. Default is True.') +DESC_PHTM = ('Activates Paired heatmap visualization. Paired' + 'heatmap intends to visualize pairings for sourcetracker' + 'to identify these pairings. Non random pairings are ' + 'identified by the highest result in each column. ' + 'For each correct pairing, the resulting proportion ' + 'should be calculated against a binomial distribution ' + 'with a p value calculated against random distribution ' + 'with a correct proportion of 1/n at random.' + 'This can be used for groupings such as convergent ' + 'microbiomes or organ transfer similarities as a few' + ' examples.') +DESC_TTL = ('Title input. String format') +DESC_HCOL = ('Heatmap coloring. Coloring pattern for default and' + ' paired heatmap. Default is viridis, other options in' + ' plot.py.') +DESC_UKN = ('keep unknown sources in heatmap or stacked bar plots.' + 'IMPORTANT: Setting to False will normalize proportions back to' + '1.') +DESC_TRA = ('Transpose Heatmap plots.Flips x and y axis. Default is' + ' False.') +DESC_BCOL = ('Coloring for stacked bar plot. Default is matplotlib' + ' default. List format should be used.' + 'An example of this would be [red,green,blue] with' + ' each color in a string format.') +DESC_FLBR = ('Transpose bar plot. Flips x and y axis. Default is' + ' False.') diff --git a/sourcetracker/_plot.py b/sourcetracker/_plot.py index 90e2e11..6735d81 100644 --- a/sourcetracker/_plot.py +++ b/sourcetracker/_plot.py @@ -7,19 +7,259 @@ # # The full license is in the file LICENSE, distributed with this software. # ---------------------------------------------------------------------------- - +""" +Parameters +---------- +self.file string + output directory name given from input +self.mpm dataframe + mixing proportion result from gibbs +self.title string + title +self.color string + color scheme for heatmaps pulled from matplotlib +self.out_name string + output name spaces are replaced with _ +""" import seaborn as sns import matplotlib.pyplot as plt +import os + + +class ST_graphs: + def __init__(self, mpm, output_dir, + title='Mixing Proportions', color='viridis'): + self.file = output_dir + self.mpm = mpm.T + self.title = title + self.color = color + self.out_name = title.replace(" ", "_") + + def ST_heatmap(self, keep_unknowns=True, annot=True, + xlabel='Sources', ylabel='Sinks', vmax=1.0): + """ + Default Plot for Gibbs method. + Parameters + ---------- + unknowns bool + removes unknown column + annot bool + Defines visual proportions in plot + xlabel string + x label + ylabel string + y label + vmax float + determines the maximum color value of the plot + returns + -------- + none + outputs a heatmap visualization in a PNG + """ + prop = self.mpm + if keep_unknowns: + fp_suffix = "_heatmap.png" + else: + fp_suffix = "_heatmap_nounknown.png" + prop = prop.drop(['Unknown'], axis=1) + prop = prop.div(prop.sum(axis=1), axis=0) + + fig, ax = plt.subplots(figsize=((prop.shape[1] * 3 / 4)+4, + (prop.shape[0] * 3 / 4)+4)) + sns.heatmap(prop, vmin=0.0, vmax=vmax, cmap=self.color, + annot=annot, linewidths=.5, ax=ax) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(self.title) + plt.xticks(rotation=45, ha='right') + + plt.savefig(os.path.join(self.file, self.out_name + fp_suffix), + bbox_inches="tight") + + def ST_paired_heatmap(self, normalized=False, keep_unknowns=True, + transpose=False, annot=True, ylabel='Sinks', + heat_ratio=0.08): + """ + Parameters + ---------- + normalized bool + normalize each column to equal to 1 to represent the likelihood + unknowns bool + removes unknown column + transpose bool + transpose + annot bool + Defines visual proportions in plot + xlabel string + x label + ylabel string + y label + heat ratio float + ratio of example bar as compared to main columns. + Should be much thinner than main columns + returns + -------- + none + outputs a heatmap visualization in a PNG defined by each + individual column + + Any analysis should be done using a bin(n,x) + + Color examples: + "viridis" "icefire" "vlag" "Spectral" "mako" "magma" + "coolwarm" "rocket" "flare" "crest" + "_r" reverses all of these + + vmax and min will show the maximum and minimum for the + heat map settings. vmax=max is default and what we use here. + vmin=min is also what we use here but the standard version + will use vmin=0 in order to show the minimum possible + and vmax=1. + The reason I do not in this case is that these ranges + are not particularly helpful to distinguishing the + successful matches to each other. + + Paired heatmap should be used for assessing overlap in + paired individual microbiomes or metabolomic overlap. + Some examples of this may include organ donors and + recipients, matching animal fecal samples to track + movement patterns, overlap in mouse microbiomes + caged together, and individual identification + given a range of possible known samples. + """ + prop = self.mpm + if not keep_unknowns: + prop = prop.drop(['Unknown'], axis=1) + prop = prop.div(prop.sum(axis=1), axis=0) + if normalized: + prop = prop.div(prop.sum(axis=0), axis=1) + tra = "" + if transpose: + prop = prop.T + tra = "_Transposed" + midpoint = len(prop.columns)/2 + midpoint = round(midpoint) + ratios, g, axes = [], [], [] + for i in range(len(prop.columns)): + ratios.append(1) + axes.append("ax" + str(i)) + g.append("g" + str(i)) + ratios.append(heat_ratio) + axes.append("axcb") + fig, axes = plt.subplots(1, len(axes), + gridspec_kw={'width_ratios': ratios}, + figsize=((prop.shape[1] * 3 / 4)+4, + (prop.shape[0] * 3 / 4)+4)) + for i in range(len(prop.columns)): + if i == 0: + g[i] = sns.heatmap(prop.iloc[:, i:i + 1], + vmin=0, cmap=self.color, + cbar=False, annot=annot, ax=axes[i]) + g[i].set_xlabel("") + g[i].set_ylabel(ylabel) + elif i == midpoint: + g[i] = sns.heatmap(prop.iloc[:, i:i+1], vmin=0, + cmap=self.color, cbar=False, + annot=annot, ax=axes[i]) + g[i].set_xlabel("Sources") + g[i].set_ylabel("") + g[i].set_yticks([]) + g[i].set_title(self.title) + elif i == len(prop.columns) - 1: + g[i] = sns.heatmap(prop.iloc[:, i:i + 1], vmin=0, + cmap=self.color, annot=annot, + ax=axes[i], + cbar_ax=axes[i + 1]) + g[i].set_xlabel("") + g[i].set_ylabel("") + g[i].set_yticks([]) + else: + g[i] = sns.heatmap(prop.iloc[:, i:i + 1], + vmin=0, cmap=self.color, + cbar=False, annot=annot, + ax=axes[i]) + g[i].set_xlabel("") + g[i].set_ylabel("") + g[i].set_yticks([]) + for ax in g: + tl = ax.get_xticklabels() + ax.set_xticklabels(tl, rotation=0) + tly = ax.get_yticklabels() + ax.set_yticklabels(tly, rotation=0) + if normalized: + if keep_unknowns: + add_line = tra + "_pairedheatmap_normalized.png" + else: + add_line = tra + "_pairedheatmap_nounknown_normalized.png" + else: + if keep_unknowns: + add_line = tra + "_pairedheatmap.png" + else: + add_line = tra + "_pairedheatmap_nounknowns.png" + plt.savefig(os.path.join(self.file, self.out_name + add_line), + bbox_inches="tight") + + def ST_Stacked_bar(self, keep_unknowns=True, x_lab="Sink", + y_lab="Source Proportion", coloring=[], flipped=False): + """ + Creates a Stacked bar plot for the user with direct png save function + Parameters + ---------- + unknowns bool + removes unknown column + xlabel string + x label + ylabel string + y label + coloring string list + string list of colors to encode the bar plot. + must have and equal number of colors as sources + flipped bool + flips x and y axis + returns + -------- + none + outputs a stacked bar visualization in a PNG + + color example list + '#1f77b4'Blue, '#ff7f0e'Orange, '#2ca02c'Green, '#d62728'Red, + '#9467bd'Purple, '#8c564b'Brown, '#e377c2'Pink, '#7f7f7f'Grey, + '#bcbd22'Gold, '#17becf'Cyan + make sure to use contrasting colors in order better illuminate + your data above are some example codes to use + """ + prop = self.mpm + if not keep_unknowns: + prop = prop.drop(['Unknown'], axis=1) + prop = prop.div(prop.sum(axis=1), axis=0) + if flipped: + prop = prop.T + y_lab_flip = x_lab + x_lab_flip = y_lab + y_lab = y_lab_flip + x_lab = x_lab_flip + prop = prop.div(prop.sum(axis=1), axis=0) + prop = prop.reset_index() + if len(coloring) != (prop.shape[1]-1): + coloring = [] + if coloring == []: + coloring = None + prop.plot(kind='bar', x=prop.columns[0], stacked=True, + figsize=((prop.shape[1] * 3 / 4)+4, + (prop.shape[0] * 3 / 4)+4), + color=coloring) + plt.xlabel(x_lab) + plt.ylabel(y_lab) + plt.title(self.title) + plt.autoscale() + plt.xticks(rotation=45, ha='right') + fp_suffix = ".png" + if not keep_unknowns: + fp_suffix = "_no_unknowns" + fp_suffix + if flipped: + fp_suffix = "_flipped" + fp_suffix + fp_suffix = "_stacked_bar" + fp_suffix -def plot_heatmap(mpm, cm=plt.cm.viridis, xlabel='Sources', ylabel='Sinks', - title='Mixing Proportions (as Fraction)'): - '''Make a basic mixing proportion histogram.''' - fig = plt.figure() - ax = fig.add_subplot(1, 1, 1) - sns.heatmap(mpm, vmin=0, vmax=1.0, cmap=cm, annot=True, linewidths=.5, - ax=ax) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) - return fig, ax + plt.savefig(os.path.join(self.file, self.out_name + fp_suffix), + bbox_inches="tight")