From 78c5a094a4e6225289aa671c5149abfba9c6baa2 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Thu, 9 May 2024 14:10:53 -0400 Subject: [PATCH 01/19] move into spt/graphs --- .../graphs}/graph_plugin_plots.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {analysis_replication/gnn_figure => spatialprofilingtoolbox/graphs}/graph_plugin_plots.py (100%) diff --git a/analysis_replication/gnn_figure/graph_plugin_plots.py b/spatialprofilingtoolbox/graphs/graph_plugin_plots.py similarity index 100% rename from analysis_replication/gnn_figure/graph_plugin_plots.py rename to spatialprofilingtoolbox/graphs/graph_plugin_plots.py From dafa78bf28379e6d3ab64fdb00f6bc7a1c59f1bc Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Thu, 9 May 2024 15:54:39 -0400 Subject: [PATCH 02/19] sketch CLI --- analysis_replication/README.md | 9 -- pyproject.toml.unversioned | 1 + .../graphs/config_reader.py | 36 +++++++- .../graphs/graph_plugin_plots.py | 92 +++++++++---------- .../scripts/plot_importance_fractions.py | 64 +++++++++++++ .../graphs/template.config | 6 ++ 6 files changed, 148 insertions(+), 60 deletions(-) create mode 100644 spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py diff --git a/analysis_replication/README.md b/analysis_replication/README.md index 52b80843..52bd2e0c 100644 --- a/analysis_replication/README.md +++ b/analysis_replication/README.md @@ -38,12 +38,3 @@ To run the figure generation script, alter the command below to reference your o ```bash python retrieve_example_plot.py dataset_directory/ ~/.spt_db.config ``` - -# GNN importance fractions figure generation - -Another figure is generated programmatically from extractions from Graph Neural Network models, provided by the API. - -```bash -cd gnn_figure/ -python graph_plugin_plots.py -``` diff --git a/pyproject.toml.unversioned b/pyproject.toml.unversioned index ed5f802c..b71c512e 100644 --- a/pyproject.toml.unversioned +++ b/pyproject.toml.unversioned @@ -211,6 +211,7 @@ packages = [ "extract.py", "finalize_graphs.py", "generate_graphs.py", + "plot_importance_fractions.py", "plot_interactives.py", "prepare_graph_creation.py", "upload_importances.py", diff --git a/spatialprofilingtoolbox/graphs/config_reader.py b/spatialprofilingtoolbox/graphs/config_reader.py index cd72cce5..70849fa7 100644 --- a/spatialprofilingtoolbox/graphs/config_reader.py +++ b/spatialprofilingtoolbox/graphs/config_reader.py @@ -7,6 +7,7 @@ EXTRACT_SECTION_NAME = 'extract' GENERATION_SECTION_NAME = 'graph-generation' UPLOAD_SECTION_NAME = 'upload-importances' +PLOT_FRACTIONS_SECTION_NAME = 'plot-importance-fractions' def _read_config_file(config_file_path: str, section: str) -> dict[str, Any]: @@ -129,7 +130,7 @@ def read_upload_config(config_file_path: str) -> tuple[ f"""Read the TOML config file and return the '{UPLOAD_SECTION_NAME}' section. For a detailed explanation of the return values, refer to the docstring of - `spatialprofilingtoolbox.graphs.scripts.upload_importances.parse_arguments()`. + `spatialprofilingtoolbox.db.importance_score_transcriber.transcribe_importance()`. """ config = _read_config_file(config_file_path, UPLOAD_SECTION_NAME) db_config_file_path: str = config["db_config_file_path"] @@ -146,3 +147,36 @@ def read_upload_config(config_file_path: str) -> tuple[ plugin_version, cohort_stratifier, ) + + +def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ + str, + str, + list[str], + list[str], + tuple[int, int], + str | None, +]: + f"""Read the TOML config file and return the '{PLOT_FRACTIONS_SECTION_NAME}' section. + + For a detailed explanation of the return values, refer to the docstring of + `spatialprofilingtoolbox.graphs.graph_plugin_plots.PlotGenerator()`. + """ + config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME) + db_config_file_path: str = config["db_config_file_path"] + study_name: str = config["study_name"] + phenotypes: list[str] = config['phenotypes'] + plugins: list[str] = config['plugins'] + figure_size_raw = tuple(config['figure_size']) + if len(figure_size_raw) != 2 or not all(isinstance(x, int) for x in figure_size_raw): + raise ValueError("figure_size must be a two-tuple of integers.") + figure_size: tuple[int, int] = figure_size_raw + orientation: str | None = config.get("orientation", None) + return ( + db_config_file_path, + study_name, + phenotypes, + plugins, + figure_size, + orientation, + ) diff --git a/spatialprofilingtoolbox/graphs/graph_plugin_plots.py b/spatialprofilingtoolbox/graphs/graph_plugin_plots.py index d0890b02..8ff96310 100644 --- a/spatialprofilingtoolbox/graphs/graph_plugin_plots.py +++ b/spatialprofilingtoolbox/graphs/graph_plugin_plots.py @@ -1,14 +1,14 @@ -from os.path import join +"""GNN importance fractions figure generation.""" + from os.path import exists from pickle import load as pickle_load from pickle import dump as pickle_dump -from argparse import ArgumentParser from typing import Literal from typing import Iterable from typing import cast from typing import Any +from typing import TYPE_CHECKING from json import loads as json_loads -import sys from glob import glob import re from enum import Enum @@ -28,9 +28,11 @@ from cattrs import structure as cattrs_structure from tqdm import tqdm -sys.path.append('../') from accessors import DataAccessor # type: ignore +if TYPE_CHECKING: + from matplotlib.figure import Figure + GNNModel = Literal['cg-gnn', 'graph-transformer'] @@ -187,9 +189,9 @@ def _restrict_rows(self, cohorts: set[int], omittable: set[str]) -> None: ( phenotype, df[df['cohort'].isin(cohorts) & ~df.index.isin(omittable)] - .reset_index() - .sort_values(['cohort', 'sample']) - .set_index('sample'), + .reset_index() + .sort_values(['cohort', 'sample']) + .set_index('sample'), ) for phenotype, df in self.get_df_phenotypes() ) @@ -456,35 +458,43 @@ def label_indicators(spec: PlotSpecification) -> LabelIndicators: return LabelIndicators(len(spec.plugins), spec.orientation) -@define class PlotGenerator: - host: str - output_directory: str - show_also: bool - current_specification: PlotSpecification | None = None - - def get_specification(self) -> PlotSpecification: - return cast(PlotSpecification, self.current_specification) + """Generate a importance fractions plot.""" - def generate_plots(self) -> None: - for specification in get_plot_specifications(): - self.current_specification = specification - self._generate_plot() + def __init__( + self, + db_config_file_path: str, + study_name: str, + phenotypes: list[str], + plugins: list[str], + figure_size: tuple[int, int], + orientation: str | None, + ) -> None: + """Instantiate the importance fractions plot generator.""" + self.db_config_file_path = db_config_file_path + self.specification = PlotSpecification( + study_name, + phenotypes, + phenotypes, + None, # TODO: Get cohorts from database + plugins, + figure_size, + 'horizontal' if (orientation is None) else orientation, + ) - def _generate_plot(self) -> None: + def generate_plot(self) -> 'Figure': self._check_viability() dfs = self._retrieve_data() self._gather_subplot_cases(dfs) - self._generate_subplots(dfs) - self._export() + return self._generate_subplots(dfs) def _check_viability(self) -> None: - if len(self.get_specification().plugins) != 2: + if len(self.specification.plugins) != 2: raise ValueError('Currently plot generation requires 2 plugins worth of run data.') def _retrieve_data(self) -> tuple[DataFrame, ...]: - dfs = PlotDataRetriever(self.host).retrieve_data(self.get_specification()) - dfs = self._transfer_cohort_labels(dfs, self.get_specification()) + dfs = PlotDataRetriever(self.host).retrieve_data(self.specification) + dfs = self._transfer_cohort_labels(dfs, self.specification) return dfs def _transfer_cohort_labels( @@ -504,32 +514,25 @@ def _gather_subplot_cases( self, dfs: tuple[DataFrame, ...], ) -> tuple[SubplotSpecification, Indicators, Iterable[tuple[DataFrame, GNNModel]]]: - subplot_specification = derive_subplot_specification(self.get_specification()) - indicators = label_indicators(self.get_specification()).get_label_subplot_indicators() - return subplot_specification, indicators, zip(dfs, self.get_specification().plugins) + subplot_specification = derive_subplot_specification(self.specification) + indicators = label_indicators(self.specification).get_label_subplot_indicators() + return subplot_specification, indicators, zip(dfs, self.specification.plugins) - def _generate_subplots(self, dfs: tuple[DataFrame, ...]) -> None: + def _generate_subplots(self, dfs: tuple[DataFrame, ...]) -> 'Figure': plt.rcParams['font.size'] = 14 norm = self._generate_normalization(dfs) subplot_specification, indicators, cases = self._gather_subplot_cases(dfs) fig, axs = plt.subplots( *subplot_specification.grid_dimensions, - figsize=self.get_specification().figure_size, + figsize=self.specification.figure_size, ) title_location = subplot_specification.title_location subplot_generator = SubplotGenerator(title_location, norm) for i, ((df, plugin), ax) in enumerate(zip(cases, axs)): subplot_generator.plot(df, ax, plugin, indicators[0][i], indicators[1][i]) - fig.suptitle(self.get_specification().study) + fig.suptitle(self.specification.study) plt.tight_layout() - - def _export(self) -> None: - if self.output_directory is not None: - plt.savefig(join(self.output_directory, - f'{sanitized_study(self.get_specification().study)}.svg'), - ) - if self.show_also: - plt.show() + return fig @staticmethod def _generate_normalization(dfs: tuple[DataFrame, ...]) -> Normalize: @@ -537,14 +540,3 @@ def _generate_normalization(dfs: tuple[DataFrame, ...]) -> Normalize: vmin = min(df.min().min() for df in dfs_values_only) vmax = max(df.max().max() for df in dfs_values_only) return Normalize(vmin=vmin, vmax=vmax) - - -if __name__ == '__main__': - parser = ArgumentParser() - add = parser.add_argument - add('host', nargs='?', type=str, default='http://oncopathtk.org/api', help='SPT API host.') - add('output_directory', nargs='?', type=str, default='.', help='Directory in which to save SVGs.') - add('--show', action='store_true', help='If set, will display figures in addition to saving.') - args = parser.parse_args() - generator = PlotGenerator(args.host, args.output_directory, args.show) - generator.generate_plots() diff --git a/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py b/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py new file mode 100644 index 00000000..b0ed4401 --- /dev/null +++ b/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py @@ -0,0 +1,64 @@ +"""GNN importance fractions figure generation + +This is generated programmatically from extractions from Graph Neural Network models. +""" + +from argparse import ArgumentParser + +import matplotlib.pyplot as plt + +from spatialprofilingtoolbox.graphs.config_reader import read_plot_importance_fractions_config +from spatialprofilingtoolbox.graphs.graph_plugin_plots import PlotGenerator + + +def parse_arguments(): + """Process command line arguments.""" + parser = ArgumentParser( + prog='spt graphs plot-importance-fractions', + description="""Generate GNN-derived importance score fractions plot.""" + ) + parser.add_argument( + '--config_path', + type=str, + help='Path to the configuration TOML file.', + required=True, + ) + parser.add_argument( + '--output_filename', + type=str, + default=None, + help='Filename (including extension) to save the plot to.', + ) + parser.add_argument( + '--show', + action='store_true', + help='If set, will display figures in addition to saving.', + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + if not args.show and (args.output_filename is None): + raise ValueError('Nothing requested of the plot, skipping.') + ( + db_config_file_path, + study_name, + phenotypes, + plugins, + figure_size, + orientation, + ) = read_plot_importance_fractions_config(args.config_path) + generator = PlotGenerator( + db_config_file_path, + study_name, + phenotypes, + plugins, + figure_size, + orientation, + ) + fig = generator.generate_plot() + if args.output_filename is not None: + plt.savefig(args.output_filename) + if args.show: + plt.show() diff --git a/spatialprofilingtoolbox/graphs/template.config b/spatialprofilingtoolbox/graphs/template.config index 915a9c85..d0ea8fac 100644 --- a/spatialprofilingtoolbox/graphs/template.config +++ b/spatialprofilingtoolbox/graphs/template.config @@ -38,3 +38,9 @@ plugin_used = cg-gnn plugin_version = None datetime_of_run = 2024-01-01 12:00:00 cohort_stratifier = None + +[plot-importance-fractions] +phenotypes = ["Tumor", ...] +plugins = ["cg-gnn", "graph-transformer"] +figure_size = [11, 8] +orientation = "horizontal" From cb2df9c2a2f5a80f07091c24fe5960d261beb3bb Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Thu, 9 May 2024 15:57:10 -0400 Subject: [PATCH 03/19] rename fraction plotter --- spatialprofilingtoolbox/graphs/config_reader.py | 2 +- .../graphs/{graph_plugin_plots.py => importance_fractions.py} | 0 .../graphs/scripts/plot_importance_fractions.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename spatialprofilingtoolbox/graphs/{graph_plugin_plots.py => importance_fractions.py} (100%) diff --git a/spatialprofilingtoolbox/graphs/config_reader.py b/spatialprofilingtoolbox/graphs/config_reader.py index 70849fa7..ff9722d5 100644 --- a/spatialprofilingtoolbox/graphs/config_reader.py +++ b/spatialprofilingtoolbox/graphs/config_reader.py @@ -160,7 +160,7 @@ def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ f"""Read the TOML config file and return the '{PLOT_FRACTIONS_SECTION_NAME}' section. For a detailed explanation of the return values, refer to the docstring of - `spatialprofilingtoolbox.graphs.graph_plugin_plots.PlotGenerator()`. + `spatialprofilingtoolbox.graphs.importance_fractions.PlotGenerator()`. """ config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME) db_config_file_path: str = config["db_config_file_path"] diff --git a/spatialprofilingtoolbox/graphs/graph_plugin_plots.py b/spatialprofilingtoolbox/graphs/importance_fractions.py similarity index 100% rename from spatialprofilingtoolbox/graphs/graph_plugin_plots.py rename to spatialprofilingtoolbox/graphs/importance_fractions.py diff --git a/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py b/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py index b0ed4401..c24e8ad4 100644 --- a/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py +++ b/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt from spatialprofilingtoolbox.graphs.config_reader import read_plot_importance_fractions_config -from spatialprofilingtoolbox.graphs.graph_plugin_plots import PlotGenerator +from spatialprofilingtoolbox.graphs.importance_fractions import PlotGenerator def parse_arguments(): From dec34b7d4c25a02c41f1b1b9a21964b916e81145 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Mon, 13 May 2024 01:41:28 -0400 Subject: [PATCH 04/19] rely on api server implementation for now --- .../graphs/config_reader.py | 15 +- .../graphs/importance_fractions.py | 231 +++++++++++++++--- .../scripts/plot_importance_fractions.py | 8 +- .../graphs/template.config | 9 +- 4 files changed, 219 insertions(+), 44 deletions(-) diff --git a/spatialprofilingtoolbox/graphs/config_reader.py b/spatialprofilingtoolbox/graphs/config_reader.py index ff9722d5..27a59632 100644 --- a/spatialprofilingtoolbox/graphs/config_reader.py +++ b/spatialprofilingtoolbox/graphs/config_reader.py @@ -153,6 +153,7 @@ def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ str, str, list[str], + list[tuple[int, str]], list[str], tuple[int, int], str | None, @@ -163,7 +164,7 @@ def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ `spatialprofilingtoolbox.graphs.importance_fractions.PlotGenerator()`. """ config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME) - db_config_file_path: str = config["db_config_file_path"] + host_name: str = config.get("host_name", "http://oncopathtk.org/api") study_name: str = config["study_name"] phenotypes: list[str] = config['phenotypes'] plugins: list[str] = config['plugins'] @@ -172,10 +173,20 @@ def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ raise ValueError("figure_size must be a two-tuple of integers.") figure_size: tuple[int, int] = figure_size_raw orientation: str | None = config.get("orientation", None) + cohorts_raw: list[dict[str, str]] = config['cohorts'] + cohorts: list[tuple[int, str]] = [] + for cohort in cohorts_raw: + try: + cohorts.append((int(cohort['index_int']), cohort['label'])) + except KeyError: + 'Each cohort must have an index_int and a label.' + except ValueError: + 'Cohort index_int must be an integer.' return ( - db_config_file_path, + host_name, study_name, phenotypes, + cohorts, plugins, figure_size, orientation, diff --git a/spatialprofilingtoolbox/graphs/importance_fractions.py b/spatialprofilingtoolbox/graphs/importance_fractions.py index 8ff96310..5f066e76 100644 --- a/spatialprofilingtoolbox/graphs/importance_fractions.py +++ b/spatialprofilingtoolbox/graphs/importance_fractions.py @@ -8,9 +8,10 @@ from typing import cast from typing import Any from typing import TYPE_CHECKING -from json import loads as json_loads -from glob import glob import re +from itertools import chain +from urllib.parse import urlencode +from requests import get as get_request # type: ignore from enum import Enum import numpy as np @@ -25,11 +26,8 @@ from matplotlib.colors import Normalize from scipy.stats import fisher_exact # type: ignore from attr import define -from cattrs import structure as cattrs_structure from tqdm import tqdm -from accessors import DataAccessor # type: ignore - if TYPE_CHECKING: from matplotlib.figure import Figure @@ -51,23 +49,12 @@ class Orientation(Enum): class PlotSpecification: study: str phenotypes: tuple[str, ...] - attribute_order: tuple[str, ...] cohorts: tuple[Cohort, ...] plugins: tuple[GNNModel, ...] figure_size: tuple[float, float] orientation: Orientation -def get_plot_specifications() -> tuple[PlotSpecification, ...]: - filenames = glob('*.json') - specifications = [] - for filename in filenames: - with open(filename, 'rt', encoding='utf-8') as file: - contents = file.read() - specifications.append(cattrs_structure(json_loads(contents), PlotSpecification)) - return tuple(specifications) - - def sanitized_study(study: str) -> str: return re.sub(' ', '_', study).lower() @@ -75,20 +62,192 @@ def sanitized_study(study: str) -> str: PhenotypeDataFrames = tuple[tuple[str, DataFrame], ...] +class Colors: + bold_magenta = '\u001b[35;1m' + reset = '\u001b[0m' + + +class ImportanceCountsAccessor: + """Convenience caller of HTTP methods for access of phenotype counts and importance scores.""" + + def __init__(self, study, host=None): + if _host is None: + raise RuntimeError('Expected host name in api_host.txt .') + host = _host + use_http = False + if re.search('^http://', host): + use_http = True + host = re.sub(r'^http://', '', host) + self.host = host + self.study = study + self.use_http = use_http + print('\n' + Colors.bold_magenta + study + Colors.reset + '\n') + self.cohorts = self._retrieve_cohorts() + self.all_cells = self._retrieve_all_cells_counts() + + def counts(self, phenotype_names): + if isinstance(phenotype_names, str): + phenotype_names = [phenotype_names] + conjunction_criteria = self._conjunction_phenotype_criteria(phenotype_names) + all_name = self.name_for_all_phenotypes(phenotype_names) + conjunction_counts_series = self._get_counts_series(conjunction_criteria, all_name) + individual_counts_series = [ + self._get_counts_series(self._phenotype_criteria(name), self._name_phenotype(name)) + for name in phenotype_names + ] + df = concat( + [self.cohorts, self.all_cells, conjunction_counts_series, *individual_counts_series], + axis=1, + ) + df.replace([np.inf, -np.inf], np.nan, inplace=True) + return df + + def name_for_all_phenotypes(self, phenotype_names): + return ' and '.join([self._name_phenotype(p) for p in phenotype_names]) + + def counts_by_signature(self, positives: list[str], negatives: list[str]): + if (not positives) and (not negatives): + raise ValueError('At least one positive or negative marker is required.') + if not positives: + positives = [''] + elif not negatives: + negatives = [''] + parts = list(chain(*[ + [(f'{keyword}_marker', channel) for channel in argument] + for keyword, argument in zip(['positive', 'negative'], [positives, negatives]) + ])) + parts = sorted(list(set(parts))) + parts.append(('study', self.study)) + query = urlencode(parts) + endpoint = 'anonymous-phenotype-counts-fast' + return self._retrieve(endpoint, query)[0] + + def _get_counts_series(self, criteria, column_name): + criteria_tuple = ( + criteria['positive_markers'], + criteria['negative_markers'], + ) + counts = self.counts_by_signature(*criteria_tuple) + df = DataFrame(counts['counts']) + mapper = {'specimen': 'sample', 'count': column_name} + return df.rename(columns=mapper).set_index('sample')[column_name] + + def _retrieve_cohorts(self): + summary, _ = self._retrieve('study-summary', urlencode([('study', self.study)])) + return DataFrame(summary['cohorts']['assignments']).set_index('sample') + + def _retrieve_all_cells_counts(self): + counts = self.counts_by_signature([''], ['']) + df = DataFrame(counts['counts']) + all_name = 'all cells' + mapper = {'specimen': 'sample', 'count': all_name} + counts_series = df.rename(columns=mapper).set_index('sample')[all_name] + return counts_series + + def _get_base(self): + protocol = 'https' + if self.host == 'localhost' or re.search('127.0.0.1', self.host) or self.use_http: + protocol = 'http' + return '://'.join((protocol, self.host)) + + def _retrieve(self, endpoint, query): + base = f'{self._get_base()}' + url = '/'.join([base, endpoint, '?' + query]) + try: + content = get_request(url) + except Exception as exception: + print(url) + raise exception + return content.json(), url + + def _phenotype_criteria(self, name): + if isinstance(name, dict): + criteria = name + keys = ['positive_markers', 'negative_markers'] + for key in keys: + if criteria[key] == []: + criteria[key] = [''] + return criteria + query = urlencode([('study', self.study), ('phenotype_symbol', name)]) + criteria, _ = self._retrieve('phenotype-criteria', query) + return criteria + + def _conjunction_phenotype_criteria(self, names): + criteria_list = [] + for name in names: + criteria = self._phenotype_criteria(name) + criteria_list.append(criteria) + return self._merge_criteria(criteria_list) + + def _merge_criteria(self, criteria_list): + keys = ['positive_markers', 'negative_markers'] + merged = { + key: sorted(list(set(list(chain(*[criteria[key] for criteria in criteria_list]))))) + for key in keys + } + for key in keys: + if merged[key] == []: + merged[key] = [''] + return merged + + def _name_phenotype(self, phenotype): + if isinstance(phenotype, dict): + return ' '.join([ + ' '.join([f'{p}{sign}' for p in phenotype[f'{keyword}_markers'] if p != '']) + for keyword, sign in zip(['positive', 'negative'], ['+', '-']) + ]).rstrip() + return str(phenotype) + + def important( + self, + phenotype_names: str | list[str], + plugin: str = 'cg-gnn', + datetime_of_run: str | None = None, + plugin_version: str | None = None, + cohort_stratifier: str | None = None, + ) -> dict[str, float]: + if isinstance(phenotype_names, str): + phenotype_names = [phenotype_names] + conjunction_criteria = self._conjunction_phenotype_criteria(phenotype_names) + parts = list(chain(*[ + [(f'{keyword}_marker', channel) for channel in argument] + for keyword, argument in zip( + ['positive', 'negative'], [ + conjunction_criteria['positive_markers'], + conjunction_criteria['negative_markers'], + ]) + ])) + parts = sorted(list(set(parts))) + parts.append(('study', self.study)) + if plugin in {'cg-gnn', 'graph-transformer'}: + parts.append(('plugin', plugin)) + else: + raise ValueError(f'Unrecognized plugin name: {plugin}') + if datetime_of_run is not None: + parts.append(('datetime_of_run', datetime_of_run)) + if plugin_version is not None: + parts.append(('plugin_version', plugin_version)) + if cohort_stratifier is not None: + parts.append(('cohort_stratifier', cohort_stratifier)) + query = urlencode(parts) + phenotype_counts, _ = self._retrieve('importance-composition', query) + return {c['specimen']: c['percentage'] for c in phenotype_counts['counts']} + + @define class ImportanceFractionAndTestRetriever: host: str study: str - access: DataAccessor | None = None + access: ImportanceCountsAccessor | None = None count_important: int = 100 df_phenotypes: PhenotypeDataFrames | None = None df_phenotypes_original: PhenotypeDataFrames | None = None def initialize(self) -> None: - self.access = DataAccessor(self.study, host=self.host) + self.access = ImportanceCountsAccessor(self.study, host=self.host) - def get_access(self) -> DataAccessor: - return cast(DataAccessor, self.access) + def get_access(self) -> ImportanceCountsAccessor: + return cast(ImportanceCountsAccessor, self.access) def get_df_phenotypes(self) -> PhenotypeDataFrames: return cast(PhenotypeDataFrames, self.df_phenotypes) @@ -254,22 +413,12 @@ def retrieve_data(self, specification: PlotSpecification) -> tuple[DataFrame, .. cohorts = set(c.index_int for c in specification.cohorts) plugins = cast(tuple[GNNModel, GNNModel], specification.plugins) phenotypes = list(specification.phenotypes) - attribute_order = self._get_attribute_order(specification) retriever = ImportanceFractionAndTestRetriever(self.host, specification.study) retriever.initialize() return tuple( - retriever.retrieve(cohorts, phenotypes, plugin)[attribute_order] for plugin in plugins + retriever.retrieve(cohorts, phenotypes, plugin)[phenotypes] for plugin in plugins ) - @staticmethod - def _get_attribute_order(specification: PlotSpecification) -> list[str]: - attribute_order = list(specification.attribute_order) - if attribute_order is None: - attribute_order = specification.phenotypes.copy() - if 'cohort' not in attribute_order: - attribute_order.append('cohort') - return attribute_order - @define class SubplotGenerator: @@ -463,23 +612,29 @@ class PlotGenerator: def __init__( self, - db_config_file_path: str, + host_name: str, study_name: str, phenotypes: list[str], + cohorts_raw: list[tuple[int, str]], plugins: list[str], figure_size: tuple[int, int], orientation: str | None, ) -> None: """Instantiate the importance fractions plot generator.""" - self.db_config_file_path = db_config_file_path + self.host = host_name + cohorts: list[Cohort] = [] + for cohort in cohorts_raw: + cohorts.append(Cohort(*cohort)) + for model in plugins: + if model != 'cg-gnn' and model != 'graph-transformer': + raise ValueError(f'Unrecognized plugin name: {model}') self.specification = PlotSpecification( study_name, - phenotypes, - phenotypes, - None, # TODO: Get cohorts from database - plugins, + tuple(phenotypes), + tuple(cohorts), + cast(tuple[GNNModel], tuple(plugins)), figure_size, - 'horizontal' if (orientation is None) else orientation, + Orientation.HORIZONTAL if (orientation is None) else Orientation[orientation.upper()], ) def generate_plot(self) -> 'Figure': diff --git a/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py b/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py index c24e8ad4..e8145c9b 100644 --- a/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py +++ b/spatialprofilingtoolbox/graphs/scripts/plot_importance_fractions.py @@ -27,7 +27,7 @@ def parse_arguments(): '--output_filename', type=str, default=None, - help='Filename (including extension) to save the plot to.', + help='Filename to save the plot to. (Plot file type is chosen based on the extension.)', ) parser.add_argument( '--show', @@ -42,17 +42,19 @@ def parse_arguments(): if not args.show and (args.output_filename is None): raise ValueError('Nothing requested of the plot, skipping.') ( - db_config_file_path, + host_name, study_name, phenotypes, + cohorts, plugins, figure_size, orientation, ) = read_plot_importance_fractions_config(args.config_path) generator = PlotGenerator( - db_config_file_path, + host_name, study_name, phenotypes, + cohorts, plugins, figure_size, orientation, diff --git a/spatialprofilingtoolbox/graphs/template.config b/spatialprofilingtoolbox/graphs/template.config index d0ea8fac..090a08e3 100644 --- a/spatialprofilingtoolbox/graphs/template.config +++ b/spatialprofilingtoolbox/graphs/template.config @@ -40,7 +40,14 @@ datetime_of_run = 2024-01-01 12:00:00 cohort_stratifier = None [plot-importance-fractions] +host_name = "http://oncopathtk.org/api" phenotypes = ["Tumor", ...] plugins = ["cg-gnn", "graph-transformer"] figure_size = [11, 8] -orientation = "horizontal" +orientation = horizontal +[[plot-importance-fractions.cohorts]] +index_int = 1 +label = Non-responder +[[plot-importance-fractions.cohorts]] +index_int = 3 +label = Responder From 9bc71762a291696bb1db12d270305f5711497860 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Mon, 13 May 2024 17:49:10 -0400 Subject: [PATCH 05/19] accommodate configreader's stripped down toml support --- docs/maintenance.md | 4 ++- environment.yml | 9 ++++++ .../graphs/config_reader.py | 28 +++++++++++++------ .../graphs/importance_fractions.py | 8 ++---- .../graphs/template.config | 20 ++++++------- 5 files changed, 44 insertions(+), 25 deletions(-) create mode 100644 environment.yml diff --git a/docs/maintenance.md b/docs/maintenance.md index ff7a85c8..3850d8b7 100644 --- a/docs/maintenance.md +++ b/docs/maintenance.md @@ -19,7 +19,9 @@ The modules in this repository are built, tested, and deployed using `make` and | [Docker Engine](https://docs.docker.com/engine/install/) | 20.10.17 | | [Docker Compose](https://docs.docker.com/compose/install/) | 2.10.2 | | [bash](https://www.gnu.org/software/bash/) | >= 4 | -| [python](https://www.python.org/downloads/) | >=3.7 | +| [python](https://www.python.org/downloads/) | >=3.7 <3.12 | +| [postgresql](https://www.postgresql.org/download/) | 13.4 | +| [toml](https://pypi.org/project/toml/) | 0.10.2 | A typical development workflow looks like: diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..9c82fcb9 --- /dev/null +++ b/environment.yml @@ -0,0 +1,9 @@ +name: spt +channels: + - conda-forge +dependencies: + - python=3.11 + - toml + - make + - bash + - postgresql diff --git a/spatialprofilingtoolbox/graphs/config_reader.py b/spatialprofilingtoolbox/graphs/config_reader.py index 27a59632..b0a9d2af 100644 --- a/spatialprofilingtoolbox/graphs/config_reader.py +++ b/spatialprofilingtoolbox/graphs/config_reader.py @@ -17,8 +17,12 @@ def _read_config_file(config_file_path: str, section: str) -> dict[str, Any]: dict(config_file[GENERAL_SECTION_NAME]) if (GENERAL_SECTION_NAME in config_file) else {} if section in config_file: config.update(dict(config_file[section])) + for sec in config_file.sections(): + if sec.startswith(section + '.'): + sub_section = sec.split('.')[1] + config[sub_section] = dict(config_file[sec]) for key, value in config.items(): - if value.lower() in {'none', 'null', ''}: + if isinstance(value, str) and value.lower() in {'none', 'null', ''}: config[key] = None return config @@ -166,22 +170,28 @@ def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME) host_name: str = config.get("host_name", "http://oncopathtk.org/api") study_name: str = config["study_name"] - phenotypes: list[str] = config['phenotypes'] - plugins: list[str] = config['plugins'] - figure_size_raw = tuple(config['figure_size']) - if len(figure_size_raw) != 2 or not all(isinstance(x, int) for x in figure_size_raw): - raise ValueError("figure_size must be a two-tuple of integers.") - figure_size: tuple[int, int] = figure_size_raw + phenotypes: list[str] = config['phenotypes'].split(', ') + plugins: list[str] = config['plugins'].split(', ') + try: + figure_size: tuple[int, int] = tuple(map(int, config['figure_size'].split(', '))) + except ValueError as e: + raise ValueError("figure_size must be a two-tuple of integers.") from e + assert len(figure_size) == 2, "figure_size must be a two-tuple of integers." orientation: str | None = config.get("orientation", None) - cohorts_raw: list[dict[str, str]] = config['cohorts'] + cohorts: list[tuple[int, str]] = [] - for cohort in cohorts_raw: + i_cohort: int = 0 + cohort_section_name: str = f'cohort0' + while cohort_section_name in config: + cohort = config[cohort_section_name] try: cohorts.append((int(cohort['index_int']), cohort['label'])) except KeyError: 'Each cohort must have an index_int and a label.' except ValueError: 'Cohort index_int must be an integer.' + i_cohort += 1 + cohort_section_name = f'cohort{i_cohort}' return ( host_name, study_name, diff --git a/spatialprofilingtoolbox/graphs/importance_fractions.py b/spatialprofilingtoolbox/graphs/importance_fractions.py index 5f066e76..dca7803e 100644 --- a/spatialprofilingtoolbox/graphs/importance_fractions.py +++ b/spatialprofilingtoolbox/graphs/importance_fractions.py @@ -70,10 +70,7 @@ class Colors: class ImportanceCountsAccessor: """Convenience caller of HTTP methods for access of phenotype counts and importance scores.""" - def __init__(self, study, host=None): - if _host is None: - raise RuntimeError('Expected host name in api_host.txt .') - host = _host + def __init__(self, study, host): use_http = False if re.search('^http://', host): use_http = True @@ -413,10 +410,11 @@ def retrieve_data(self, specification: PlotSpecification) -> tuple[DataFrame, .. cohorts = set(c.index_int for c in specification.cohorts) plugins = cast(tuple[GNNModel, GNNModel], specification.plugins) phenotypes = list(specification.phenotypes) + attribute_order = phenotypes + ['cohort'] retriever = ImportanceFractionAndTestRetriever(self.host, specification.study) retriever.initialize() return tuple( - retriever.retrieve(cohorts, phenotypes, plugin)[phenotypes] for plugin in plugins + retriever.retrieve(cohorts, phenotypes, plugin)[attribute_order] for plugin in plugins ) diff --git a/spatialprofilingtoolbox/graphs/template.config b/spatialprofilingtoolbox/graphs/template.config index 090a08e3..bf1d56e0 100644 --- a/spatialprofilingtoolbox/graphs/template.config +++ b/spatialprofilingtoolbox/graphs/template.config @@ -40,14 +40,14 @@ datetime_of_run = 2024-01-01 12:00:00 cohort_stratifier = None [plot-importance-fractions] -host_name = "http://oncopathtk.org/api" -phenotypes = ["Tumor", ...] -plugins = ["cg-gnn", "graph-transformer"] -figure_size = [11, 8] +host_name = http://oncopathtk.org/api +phenotypes = Tumor, ... +plugins = cg-gnn, graph-transformer +figure_size = x_width, y_width orientation = horizontal -[[plot-importance-fractions.cohorts]] -index_int = 1 -label = Non-responder -[[plot-importance-fractions.cohorts]] -index_int = 3 -label = Responder +[plot-importance-fractions.cohort0] +index_int = index_in_database +label = how you want it to be named in the plot +[plot-importance-fractions.cohort1] +index_int = index_in_database +label = how you want it to be named in the plot From b6e446d5fe6226ad80aa258e3ecae720d0dc3875 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Mon, 13 May 2024 19:39:36 -0400 Subject: [PATCH 06/19] rename db upload_sync --- build/build_scripts/import_test_dataset1.sh | 2 +- pyproject.toml.unversioned | 2 +- .../scripts/{upload_sync_findings.py => upload_sync_small.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename spatialprofilingtoolbox/db/scripts/{upload_sync_findings.py => upload_sync_small.py} (100%) diff --git a/build/build_scripts/import_test_dataset1.sh b/build/build_scripts/import_test_dataset1.sh index f57bc51d..5a3eeb0f 100755 --- a/build/build_scripts/import_test_dataset1.sh +++ b/build/build_scripts/import_test_dataset1.sh @@ -11,7 +11,7 @@ rm -f .nextflow.log*; rm -rf .nextflow/; rm -f configure.sh; rm -f run.sh; rm -f spt graphs upload-importances --config_path=build/build_scripts/.graph.config --importances_csv_path=test/test_data/gnn_importances/1.csv -spt db upload-sync-findings --database-config-file=build/db/.spt_db.config.local test/test_data/findings.json +spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local test/test_data/findings.json spt db status --database-config-file build/db/.spt_db.config.local > table_counts.txt diff build/build_scripts/expected_table_counts.txt table_counts.txt diff --git a/pyproject.toml.unversioned b/pyproject.toml.unversioned index b71c512e..897a17b6 100644 --- a/pyproject.toml.unversioned +++ b/pyproject.toml.unversioned @@ -190,7 +190,7 @@ packages = [ "drop.py", "drop_ondemand_computations.py", "delete_feature.py", - "upload_sync_findings.py", + "upload_sync_small.py", "collection.py", ] "spatialprofilingtoolbox.db.data_model" = [ diff --git a/spatialprofilingtoolbox/db/scripts/upload_sync_findings.py b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py similarity index 100% rename from spatialprofilingtoolbox/db/scripts/upload_sync_findings.py rename to spatialprofilingtoolbox/db/scripts/upload_sync_small.py From 3e596d508839db297e4730a052024a80cd6ec83a Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Mon, 13 May 2024 20:19:04 -0400 Subject: [PATCH 07/19] add gnn plot api endpoint --- build/build_scripts/import_test_dataset1.sh | 2 +- spatialprofilingtoolbox/apiserver/app/main.py | 60 ++++++++++++++++ spatialprofilingtoolbox/db/accessors/study.py | 8 ++- .../db/database_connection.py | 1 + spatialprofilingtoolbox/db/querying.py | 4 ++ .../db/scripts/upload_sync_small.py | 68 +++++++++++-------- 6 files changed, 114 insertions(+), 29 deletions(-) diff --git a/build/build_scripts/import_test_dataset1.sh b/build/build_scripts/import_test_dataset1.sh index 5a3eeb0f..a0229714 100755 --- a/build/build_scripts/import_test_dataset1.sh +++ b/build/build_scripts/import_test_dataset1.sh @@ -11,7 +11,7 @@ rm -f .nextflow.log*; rm -rf .nextflow/; rm -f configure.sh; rm -f run.sh; rm -f spt graphs upload-importances --config_path=build/build_scripts/.graph.config --importances_csv_path=test/test_data/gnn_importances/1.csv -spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local test/test_data/findings.json +spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local findings test/test_data/findings.json spt db status --database-config-file build/db/.spt_db.config.local > table_counts.txt diff build/build_scripts/expected_table_counts.txt table_counts.txt diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index 6d7d7140..cfccde23 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -12,6 +12,7 @@ from fastapi.responses import StreamingResponse from fastapi import Query from fastapi import HTTPException +import matplotlib.pyplot as plt import secure @@ -42,6 +43,8 @@ ValidChannelListNegatives2, ValidFeatureClass, ) +from spatialprofilingtoolbox.graphs.importance_fractions import PlotGenerator + VERSION = '0.23.0' TITLE = 'Single cell studies data API' @@ -375,3 +378,60 @@ async def get_plot_high_resolution( def streaming_iteration(): yield from input_buffer return StreamingResponse(streaming_iteration(), media_type="image/png") + + +@app.get("/importance-fraction-plot/") +async def importance_fraction_plot( + study: ValidStudy, + img_format: str = 'svs', +) -> StreamingResponse: + """Return a plot of the fraction of important cells expressing a given phenotype.""" + APPROVED_FORMATS = {'png', 'svs'} + if img_format not in APPROVED_FORMATS: + raise ValueError(f'Image format "{img_format}" not supported.') + + settings: list[str] = cast(list[str], query().get_study_gnn_plot_configurations(study)) + ( + hostname, + phenotypes, + cohorts, + plugins, + figure_size, + orientation, + ) = parse_gnn_plot_settings(settings) + + plot = PlotGenerator( + hostname, + study, + phenotypes, + cohorts, + plugins, + figure_size, + orientation, + ).generate_plot() + plt.figure(plot.number) + buf = BytesIO() + plt.savefig(buf, format=img_format) + buf.seek(0) + return StreamingResponse(buf, media_type=f"image/{img_format}") + + +def parse_gnn_plot_settings(settings: list[str]) -> tuple[ + str, + list[str], + list[tuple[int, str]], + list[str], + tuple[int, int], + str, +]: + hostname = settings[0] + phenotypes = settings[1].split(', ') + plugins = settings[2].split(', ') + figure_size = tuple(map(int, settings[3].split(', '))) + assert len(figure_size) == 2 + orientation = settings[4] + cohorts: list[tuple[int, str]] = [] + for cohort in settings[5:]: + count, name = cohort.split(', ') + cohorts.append((int(count), name)) + return hostname, phenotypes, cohorts, plugins, figure_size, orientation diff --git a/spatialprofilingtoolbox/db/accessors/study.py b/spatialprofilingtoolbox/db/accessors/study.py index e530038d..ba7556a8 100644 --- a/spatialprofilingtoolbox/db/accessors/study.py +++ b/spatialprofilingtoolbox/db/accessors/study.py @@ -83,7 +83,13 @@ def get_available_gnn(self, study: str) -> AvailableGNN: return AvailableGNN(plugins=tuple(specifier for (specifier, ) in rows)) def get_study_findings(self) -> list[str]: - self.cursor.execute('SELECT txt FROM findings ORDER BY id;') + return self._get_study_small_artifacts('findings') + + def get_study_gnn_plot_configurations(self) -> list[str]: + return self._get_study_small_artifacts('gnn-plot-configurations') + + def _get_study_small_artifacts(self, name: str) -> list[str]: + self.cursor.execute(f'SELECT txt FROM {name} ORDER BY id;') return [row[0] for row in self.cursor.fetchall()] @staticmethod diff --git a/spatialprofilingtoolbox/db/database_connection.py b/spatialprofilingtoolbox/db/database_connection.py index 7e0a98b3..32bc274d 100644 --- a/spatialprofilingtoolbox/db/database_connection.py +++ b/spatialprofilingtoolbox/db/database_connection.py @@ -266,6 +266,7 @@ class (QueryCursor) newly provides on each invocation. get_sample_names: Callable get_available_gnn: Callable get_study_findings: Callable + get_study_gnn_plot_configurations: Callable is_public_collection: Callable def __init__(self, query_handler: Type): diff --git a/spatialprofilingtoolbox/db/querying.py b/spatialprofilingtoolbox/db/querying.py index cd886f2e..54f4f456 100644 --- a/spatialprofilingtoolbox/db/querying.py +++ b/spatialprofilingtoolbox/db/querying.py @@ -62,6 +62,10 @@ def get_available_gnn(cls, cursor, study: str) -> AvailableGNN: def get_study_findings(cls, cursor, study: str) -> list[str]: return StudyAccess(cursor).get_study_findings() + @classmethod + def get_study_gnn_plot_configurations(cls, cursor, study: str) -> list[str]: + return StudyAccess(cursor).get_study_gnn_plot_configurations() + @classmethod def get_composite_phenotype_identifiers(cls, cursor) -> tuple[str, ...]: return sort(PhenotypesAccess(cursor).get_composite_phenotype_identifiers()) diff --git a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py index 92909107..76c90ca9 100644 --- a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py +++ b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py @@ -1,4 +1,4 @@ -"""CLI utility to drop one feature (values, specification, specifiers) from the database""" +"""Synchronize a small data artifact with the database.""" import argparse from json import loads as json_loads @@ -8,64 +8,78 @@ from spatialprofilingtoolbox.workflow.common.cli_arguments import add_argument from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger -logger = colorized_logger('upload_sync_findings') +logger = colorized_logger('upload_sync_small') + +APPROVED_NAMES = ('findings', 'gnn-plot-configurations') def parse_args(): parser = argparse.ArgumentParser( - prog='spt db upload-sync-findings', - description='Synchronize (upload or modify) study "findings" with database.' + prog='spt db upload-sync-small', + description='Synchronize small lists of strings for each study with the database.' + ) + parser.add_argument( + 'name', + help='The name of the table of strings to be synchronized.', ) parser.add_argument( - 'findings_file', - help='The JSON file containing a list of findings for each study.', + 'file', + help='The JSON file containing a list of strings for each study.', ) add_argument(parser, 'database config') return parser.parse_args() -def _create_table() -> str: - return 'CREATE TABLE IF NOT EXISTS findings (id SERIAL PRIMARY KEY, txt TEXT);' +def _create_table_query(name: str) -> str: + return f'CREATE TABLE IF NOT EXISTS {name} (id SERIAL PRIMARY KEY, txt TEXT);' -def _sync_findings(cursor, findings: tuple[str, ...]) -> bool: - cursor.execute(_create_table()) - cursor.execute('SELECT id, txt FROM findings ORDER BY id;') +def _sync_data(cursor, name: str, data: tuple[str, ...]) -> bool: + cursor.execute(_create_table_query(name)) + cursor.execute(f'SELECT id, txt FROM {name} ORDER BY id;') rows = tuple(cursor.fetchall()) - if tuple(text for _, text in rows) == findings: + if tuple(text for _, text in rows) == data: return True - cursor.execute('DELETE FROM findings;') - for finding in findings: - cursor.execute('INSERT INTO findings(txt) VALUES (%s);', (finding,)) + cursor.execute(f'DELETE FROM {data};') + for datum in data: + cursor.execute(f'INSERT INTO {data}(txt) VALUES (%s);', (datum,)) return False -def _upload_sync_findings_study( +def _upload_sync_study( study: str, - findings: list[str], + name: str, + data: list[str], database_config_file: str, ) -> None: with DBCursor(database_config_file=database_config_file, study=study) as cursor: - already_synced = _sync_findings(cursor, tuple(findings)) + already_synced = _sync_data(cursor, name, tuple(data)) if already_synced: - logger.info(f'Findings for "{study}" are already up-to-date.') + logger.info(f'Data for "{study}" are already up-to-date.') else: - logger.info(f'Findings for "{study}" were synced.') + logger.info(f'Data for "{study}" were synced.') -def upload_sync_findings(findings: dict[str, list[str]], database_config_file: str) -> None: - for study, study_findings in findings.items(): - _upload_sync_findings_study(study, study_findings, database_config_file) +def upload_sync( + name: str, + data_per_study: dict[str, list[str]], + database_config_file: str, +) -> None: + for study, study_data in data_per_study.items(): + _upload_sync_study(study, name, study_data, database_config_file) def main(): args = parse_args() + if args.name not in APPROVED_NAMES: + logger.error(f'{args.name} is not an approved table name.') + return database_config_file = get_and_validate_database_config(args) - with open(args.findings_file, 'rt', encoding='utf-8') as file: + with open(args.file, 'rt', encoding='utf-8') as file: contents = file.read() - findings = json_loads(contents) - upload_sync_findings(findings, database_config_file) + to_sync = json_loads(contents) + upload_sync(args.name, to_sync, database_config_file) -if __name__=='__main__': +if __name__ == '__main__': main() From fbedc786a3e3caf77c0cbab294e4e40c1cb63c8f Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Tue, 14 May 2024 12:29:17 -0400 Subject: [PATCH 08/19] hotfix typos, missing import --- build/apiserver/Dockerfile | 1 + pyproject.toml.unversioned | 1 + spatialprofilingtoolbox/apiserver/app/main.py | 4 ++-- spatialprofilingtoolbox/db/scripts/upload_sync_small.py | 4 ++-- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/build/apiserver/Dockerfile b/build/apiserver/Dockerfile index d08ae904..f319db51 100644 --- a/build/apiserver/Dockerfile +++ b/build/apiserver/Dockerfile @@ -16,6 +16,7 @@ RUN python -m pip install scikit-learn==1.2.2 RUN python -m pip install Pillow==9.5.0 RUN python -m pip install pydantic==2.0.2 RUN python -m pip install secure==0.3.0 +RUN python -m pip install matplotlib==3.7.1 ARG version ARG service_name ARG WHEEL_FILENAME diff --git a/pyproject.toml.unversioned b/pyproject.toml.unversioned index 897a17b6..c1d2c138 100644 --- a/pyproject.toml.unversioned +++ b/pyproject.toml.unversioned @@ -31,6 +31,7 @@ repository = "https://github.com/nadeemlab/SPT" [project.optional-dependencies] apiserver = [ + "matplotlib==3.7.1", "fastapi==0.100.0", "uvicorn>=0.15.0,<0.16.0", "pandas==2.0.2", diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index cfccde23..b33d2174 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -383,10 +383,10 @@ def streaming_iteration(): @app.get("/importance-fraction-plot/") async def importance_fraction_plot( study: ValidStudy, - img_format: str = 'svs', + img_format: str = 'svg', ) -> StreamingResponse: """Return a plot of the fraction of important cells expressing a given phenotype.""" - APPROVED_FORMATS = {'png', 'svs'} + APPROVED_FORMATS = {'png', 'svg'} if img_format not in APPROVED_FORMATS: raise ValueError(f'Image format "{img_format}" not supported.') diff --git a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py index 76c90ca9..1b8a1ed5 100644 --- a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py +++ b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py @@ -40,9 +40,9 @@ def _sync_data(cursor, name: str, data: tuple[str, ...]) -> bool: rows = tuple(cursor.fetchall()) if tuple(text for _, text in rows) == data: return True - cursor.execute(f'DELETE FROM {data};') + cursor.execute(f'DELETE FROM {name};') for datum in data: - cursor.execute(f'INSERT INTO {data}(txt) VALUES (%s);', (datum,)) + cursor.execute(f'INSERT INTO {name}(txt) VALUES (%s);', (datum,)) return False From e91859d98721b4fe2a9414ffc720bfc29f49c983 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Tue, 14 May 2024 13:49:10 -0400 Subject: [PATCH 09/19] add gnn plot config table --- build/build_scripts/import_test_dataset1.sh | 1 + spatialprofilingtoolbox/db/accessors/study.py | 6 +++--- .../db/scripts/upload_sync_small.py | 2 +- test/test_data/gnn_plot.json | 10 ++++++++++ 4 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 test/test_data/gnn_plot.json diff --git a/build/build_scripts/import_test_dataset1.sh b/build/build_scripts/import_test_dataset1.sh index a0229714..bc49cfa1 100755 --- a/build/build_scripts/import_test_dataset1.sh +++ b/build/build_scripts/import_test_dataset1.sh @@ -12,6 +12,7 @@ rm -f .nextflow.log*; rm -rf .nextflow/; rm -f configure.sh; rm -f run.sh; rm -f spt graphs upload-importances --config_path=build/build_scripts/.graph.config --importances_csv_path=test/test_data/gnn_importances/1.csv spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local findings test/test_data/findings.json +spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local gnn_plot_configurations test/test_data/gnn_plot.json spt db status --database-config-file build/db/.spt_db.config.local > table_counts.txt diff build/build_scripts/expected_table_counts.txt table_counts.txt diff --git a/spatialprofilingtoolbox/db/accessors/study.py b/spatialprofilingtoolbox/db/accessors/study.py index ba7556a8..e76c5e43 100644 --- a/spatialprofilingtoolbox/db/accessors/study.py +++ b/spatialprofilingtoolbox/db/accessors/study.py @@ -84,10 +84,10 @@ def get_available_gnn(self, study: str) -> AvailableGNN: def get_study_findings(self) -> list[str]: return self._get_study_small_artifacts('findings') - + def get_study_gnn_plot_configurations(self) -> list[str]: - return self._get_study_small_artifacts('gnn-plot-configurations') - + return self._get_study_small_artifacts('gnn_plot_configurations') + def _get_study_small_artifacts(self, name: str) -> list[str]: self.cursor.execute(f'SELECT txt FROM {name} ORDER BY id;') return [row[0] for row in self.cursor.fetchall()] diff --git a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py index 1b8a1ed5..f01d66b0 100644 --- a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py +++ b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py @@ -10,7 +10,7 @@ from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger logger = colorized_logger('upload_sync_small') -APPROVED_NAMES = ('findings', 'gnn-plot-configurations') +APPROVED_NAMES = ('findings', 'gnn_plot_configurations') def parse_args(): diff --git a/test/test_data/gnn_plot.json b/test/test_data/gnn_plot.json new file mode 100644 index 00000000..c5e75eb5 --- /dev/null +++ b/test/test_data/gnn_plot.json @@ -0,0 +1,10 @@ +{ + "Melanoma intralesional IL2": [ + "Tumor, Adipocyte or Langerhans cell, Natural killer cell, CD4+ T cell, Nerve, B cell, CD4+/CD8+ T cell, CD4+ regulatory T cell, CD8+ natural killer T cell, CD8+ regulatory T cell, CD8+ T cell, Double negative regulatory T cell, T cell/null phenotype, Natural killer T cell, CD4+ natural killer T cell", + "cg-gnn, graph-transformer", + "11, 8", + "horizontal", + "1, Non-responder", + "3, Responder" + ] +} \ No newline at end of file From 6e426d3d10515d57686f7dc935c8638c56290919 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Wed, 15 May 2024 12:03:16 -0400 Subject: [PATCH 10/19] finish apiserver endpoint --- build/apiserver/Makefile | 2 +- build/build_scripts/.graph_transformer.config | 7 + build/build_scripts/expected_table_counts.txt | 6 +- build/build_scripts/import_test_dataset1.sh | 1 + spatialprofilingtoolbox/apiserver/app/main.py | 22 +- .../graphs/importance_fractions.py | 258 +++++++++++++----- .../test_retrieval_of_gnn_plot.sh | 15 + test/apiserver/unit_tests/record_counts1.txt | 6 +- test/test_data/gnn_plot.json | 2 +- 9 files changed, 226 insertions(+), 93 deletions(-) create mode 100644 build/build_scripts/.graph_transformer.config create mode 100644 test/apiserver/module_tests/test_retrieval_of_gnn_plot.sh diff --git a/build/apiserver/Makefile b/build/apiserver/Makefile index 46613135..e69d8623 100644 --- a/build/apiserver/Makefile +++ b/build/apiserver/Makefile @@ -81,4 +81,4 @@ ${TESTS}: setup-testing clean: >@rm -f ${WHEEL_FILENAME} >@rm -f status_code ->@for f in dlogs.db.txt dlogs.api.txt dlogs.od.txt ../../${TEST_LOCATION}\/${MODULE_NAME}/_proximity.json ../../${TEST_LOCATION}\/${MODULE_NAME}/_squidpy.json; do rm -f $$f; done; +>@for f in dlogs.db.txt dlogs.api.txt dlogs.od.txt ../../${TEST_LOCATION}\/${MODULE_NAME}/_proximity.json ../../${TEST_LOCATION}\/${MODULE_NAME}/_squidpy.json ../../${TEST_LOCATION}\/${MODULE_NAME}/_gnn.svg ; do rm -f $$f; done; diff --git a/build/build_scripts/.graph_transformer.config b/build/build_scripts/.graph_transformer.config new file mode 100644 index 00000000..a8732738 --- /dev/null +++ b/build/build_scripts/.graph_transformer.config @@ -0,0 +1,7 @@ +[general] +db_config_file_path = build/db/.spt_db.config.local +study_name = Melanoma intralesional IL2 + +[upload-importances] +plugin_used = graph-transformer +datetime_of_run = 2023-10-02 10:46 AM diff --git a/build/build_scripts/expected_table_counts.txt b/build/build_scripts/expected_table_counts.txt index b2d348d7..2ba8de2c 100644 --- a/build/build_scripts/expected_table_counts.txt +++ b/build/build_scripts/expected_table_counts.txt @@ -9,15 +9,15 @@ diagnosis 2 diagnostic_selection_criterion 4 expression_quantification 18200 - feature_specification 1 - feature_specifier 4 + feature_specification 2 + feature_specifier 8 histological_structure 700 histological_structure_identification 700 histology_assessment_process 7 intervention 2 plane_coordinates_reference_system 0 publication 2 - quantitative_feature_value 700 + quantitative_feature_value 1400 research_professional 32 shape_file 700 specimen_collection_process 7 diff --git a/build/build_scripts/import_test_dataset1.sh b/build/build_scripts/import_test_dataset1.sh index bc49cfa1..4b5fc9c6 100755 --- a/build/build_scripts/import_test_dataset1.sh +++ b/build/build_scripts/import_test_dataset1.sh @@ -10,6 +10,7 @@ nextflow run . rm -f .nextflow.log*; rm -rf .nextflow/; rm -f configure.sh; rm -f run.sh; rm -f main.nf; rm -f nextflow.config; rm -rf work/; rm -rf results/ spt graphs upload-importances --config_path=build/build_scripts/.graph.config --importances_csv_path=test/test_data/gnn_importances/1.csv +spt graphs upload-importances --config_path=build/build_scripts/.graph_transformer.config --importances_csv_path=test/test_data/gnn_importances/1.csv spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local findings test/test_data/findings.json spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local gnn_plot_configurations test/test_data/gnn_plot.json diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index b33d2174..1dc84b6a 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -392,7 +392,6 @@ async def importance_fraction_plot( settings: list[str] = cast(list[str], query().get_study_gnn_plot_configurations(study)) ( - hostname, phenotypes, cohorts, plugins, @@ -401,7 +400,12 @@ async def importance_fraction_plot( ) = parse_gnn_plot_settings(settings) plot = PlotGenerator( - hostname, + ( + get_anonymous_phenotype_counts_fast, + get_study_summary, + get_phenotype_criteria, + importance_composition, + ), study, phenotypes, cohorts, @@ -417,21 +421,19 @@ async def importance_fraction_plot( def parse_gnn_plot_settings(settings: list[str]) -> tuple[ - str, list[str], list[tuple[int, str]], list[str], tuple[int, int], str, ]: - hostname = settings[0] - phenotypes = settings[1].split(', ') - plugins = settings[2].split(', ') - figure_size = tuple(map(int, settings[3].split(', '))) + phenotypes = settings[0].split(', ') + plugins = settings[1].split(', ') + figure_size = tuple(map(int, settings[2].split(', '))) assert len(figure_size) == 2 - orientation = settings[4] + orientation = settings[3] cohorts: list[tuple[int, str]] = [] - for cohort in settings[5:]: + for cohort in settings[4:]: count, name = cohort.split(', ') cohorts.append((int(count), name)) - return hostname, phenotypes, cohorts, plugins, figure_size, orientation + return phenotypes, cohorts, plugins, figure_size, orientation diff --git a/spatialprofilingtoolbox/graphs/importance_fractions.py b/spatialprofilingtoolbox/graphs/importance_fractions.py index dca7803e..aa642075 100644 --- a/spatialprofilingtoolbox/graphs/importance_fractions.py +++ b/spatialprofilingtoolbox/graphs/importance_fractions.py @@ -7,12 +7,15 @@ from typing import Iterable from typing import cast from typing import Any +from typing import Callable +from typing import Coroutine from typing import TYPE_CHECKING import re from itertools import chain from urllib.parse import urlencode -from requests import get as get_request # type: ignore from enum import Enum +from concurrent.futures import ThreadPoolExecutor +from asyncio import new_event_loop import numpy as np from numpy.typing import NDArray @@ -26,7 +29,6 @@ from matplotlib.colors import Normalize from scipy.stats import fisher_exact # type: ignore from attr import define -from tqdm import tqdm if TYPE_CHECKING: from matplotlib.figure import Figure @@ -61,6 +63,19 @@ def sanitized_study(study: str) -> str: PhenotypeDataFrames = tuple[tuple[str, DataFrame], ...] +APIServerCallables = tuple[ + Callable[[list[str], list[str], str], Any], + Callable[[str], Any], + Callable[[str, str], Any], + Callable[[ + str, + list[str], + list[str], + str, + dict[str, Any], + ], Any], +] + class Colors: bold_magenta = '\u001b[35;1m' @@ -70,19 +85,42 @@ class Colors: class ImportanceCountsAccessor: """Convenience caller of HTTP methods for access of phenotype counts and importance scores.""" - def __init__(self, study, host): - use_http = False - if re.search('^http://', host): - use_http = True - host = re.sub(r'^http://', '', host) - self.host = host + def __init__( + self, + study: str, + what_to_query: str | APIServerCallables, + ) -> None: + self.use_http = False + self.host: str | None = None + self.query_anonymous_phenotype_counts_fast: Callable[[ + list[str], list[str], str], Any] | None = None + self.query_study_summary: Callable[[str], Any] | None = None + self.query_phenotype_criteria: Callable[[str, str], Any] | None = None + self.query_importance_composition: Callable[[ + str, + list[str], + list[str], + str, + dict[str, Any], + ], Any] | None = None + if isinstance(what_to_query, str): + if re.search('^http://', what_to_query): + self.use_http = True + what_to_query = re.sub(r'^http://', '', what_to_query) + self.host = what_to_query + else: + ( + self.query_anonymous_phenotype_counts_fast, + self.query_study_summary, + self.query_phenotype_criteria, + self.query_importance_composition, + ) = what_to_query self.study = study - self.use_http = use_http print('\n' + Colors.bold_magenta + study + Colors.reset + '\n') self.cohorts = self._retrieve_cohorts() self.all_cells = self._retrieve_all_cells_counts() - def counts(self, phenotype_names): + def counts(self, phenotype_names: str | list[str]) -> DataFrame: if isinstance(phenotype_names, str): phenotype_names = [phenotype_names] conjunction_criteria = self._conjunction_phenotype_criteria(phenotype_names) @@ -99,27 +137,32 @@ def counts(self, phenotype_names): df.replace([np.inf, -np.inf], np.nan, inplace=True) return df - def name_for_all_phenotypes(self, phenotype_names): + def name_for_all_phenotypes(self, phenotype_names: list[str]) -> str: return ' and '.join([self._name_phenotype(p) for p in phenotype_names]) - def counts_by_signature(self, positives: list[str], negatives: list[str]): + def counts_by_signature(self, positives: list[str], negatives: list[str]) -> dict[str, Any]: if (not positives) and (not negatives): raise ValueError('At least one positive or negative marker is required.') if not positives: positives = [''] elif not negatives: negatives = [''] - parts = list(chain(*[ - [(f'{keyword}_marker', channel) for channel in argument] - for keyword, argument in zip(['positive', 'negative'], [positives, negatives]) - ])) - parts = sorted(list(set(parts))) - parts.append(('study', self.study)) - query = urlencode(parts) - endpoint = 'anonymous-phenotype-counts-fast' - return self._retrieve(endpoint, query)[0] - - def _get_counts_series(self, criteria, column_name): + if self.host is not None: + parts = list(chain(*[ + [(f'{keyword}_marker', channel) for channel in argument] + for keyword, argument in zip(['positive', 'negative'], [positives, negatives]) + ])) + parts = sorted(list(set(parts))) + parts.append(('study', self.study)) + query = urlencode(parts) + endpoint = 'anonymous-phenotype-counts-fast' + return self._retrieve(endpoint, query)[0] + else: + assert self.query_anonymous_phenotype_counts_fast is not None + query = self.query_anonymous_phenotype_counts_fast(positives, negatives, self.study) + return _finish_retrieving_from_api_server(query) + + def _get_counts_series(self, criteria: dict[str, list[str]], column_name: str) -> Series: criteria_tuple = ( criteria['positive_markers'], criteria['negative_markers'], @@ -129,11 +172,16 @@ def _get_counts_series(self, criteria, column_name): mapper = {'specimen': 'sample', 'count': column_name} return df.rename(columns=mapper).set_index('sample')[column_name] - def _retrieve_cohorts(self): - summary, _ = self._retrieve('study-summary', urlencode([('study', self.study)])) + def _retrieve_cohorts(self) -> DataFrame: + if self.host is not None: + summary, _ = self._retrieve('study-summary', urlencode([('study', self.study)])) + else: + assert self.query_study_summary is not None + query = self.query_study_summary(self.study) + summary = _finish_retrieving_from_api_server(query) return DataFrame(summary['cohorts']['assignments']).set_index('sample') - def _retrieve_all_cells_counts(self): + def _retrieve_all_cells_counts(self) -> Series: counts = self.counts_by_signature([''], ['']) df = DataFrame(counts['counts']) all_name = 'all cells' @@ -141,13 +189,14 @@ def _retrieve_all_cells_counts(self): counts_series = df.rename(columns=mapper).set_index('sample')[all_name] return counts_series - def _get_base(self): + def _get_base(self) -> str: protocol = 'https' if self.host == 'localhost' or re.search('127.0.0.1', self.host) or self.use_http: protocol = 'http' return '://'.join((protocol, self.host)) - def _retrieve(self, endpoint, query): + def _retrieve(self, endpoint: str, query: str) -> tuple[dict[str, Any], str]: + from requests import get as get_request # type: ignore base = f'{self._get_base()}' url = '/'.join([base, endpoint, '?' + query]) try: @@ -157,7 +206,7 @@ def _retrieve(self, endpoint, query): raise exception return content.json(), url - def _phenotype_criteria(self, name): + def _phenotype_criteria(self, name: str | dict[str, list[str]]) -> dict[str, list[str]]: if isinstance(name, dict): criteria = name keys = ['positive_markers', 'negative_markers'] @@ -165,18 +214,23 @@ def _phenotype_criteria(self, name): if criteria[key] == []: criteria[key] = [''] return criteria - query = urlencode([('study', self.study), ('phenotype_symbol', name)]) - criteria, _ = self._retrieve('phenotype-criteria', query) + if self.host is not None: + query = urlencode([('study', self.study), ('phenotype_symbol', name)]) + criteria, _ = self._retrieve('phenotype-criteria', query) + else: + assert self.query_phenotype_criteria is not None + query = self.query_phenotype_criteria(self.study, name) + criteria = _finish_retrieving_from_api_server(query) return criteria - def _conjunction_phenotype_criteria(self, names): - criteria_list = [] + def _conjunction_phenotype_criteria(self, names: str) -> dict[str, list[str]]: + criteria_list: list[dict[str, list[str]]] = [] for name in names: criteria = self._phenotype_criteria(name) criteria_list.append(criteria) return self._merge_criteria(criteria_list) - def _merge_criteria(self, criteria_list): + def _merge_criteria(self, criteria_list: list[dict[str, list[str]]]) -> dict[str, list[str]]: keys = ['positive_markers', 'negative_markers'] merged = { key: sorted(list(set(list(chain(*[criteria[key] for criteria in criteria_list]))))) @@ -187,7 +241,7 @@ def _merge_criteria(self, criteria_list): merged[key] = [''] return merged - def _name_phenotype(self, phenotype): + def _name_phenotype(self, phenotype: str) -> str: if isinstance(phenotype, dict): return ' '.join([ ' '.join([f'{p}{sign}' for p in phenotype[f'{keyword}_markers'] if p != '']) @@ -206,47 +260,84 @@ def important( if isinstance(phenotype_names, str): phenotype_names = [phenotype_names] conjunction_criteria = self._conjunction_phenotype_criteria(phenotype_names) - parts = list(chain(*[ - [(f'{keyword}_marker', channel) for channel in argument] - for keyword, argument in zip( - ['positive', 'negative'], [ - conjunction_criteria['positive_markers'], - conjunction_criteria['negative_markers'], - ]) - ])) - parts = sorted(list(set(parts))) - parts.append(('study', self.study)) - if plugin in {'cg-gnn', 'graph-transformer'}: - parts.append(('plugin', plugin)) + if self.host is not None: + parts = list(chain(*[ + [(f'{keyword}_marker', channel) for channel in argument] + for keyword, argument in zip( + ['positive', 'negative'], [ + conjunction_criteria['positive_markers'], + conjunction_criteria['negative_markers'], + ]) + ])) + parts = sorted(list(set(parts))) + parts.append(('study', self.study)) + if plugin in {'cg-gnn', 'graph-transformer'}: + parts.append(('plugin', plugin)) + else: + raise ValueError(f'Unrecognized plugin name: {plugin}') + if datetime_of_run is not None: + parts.append(('datetime_of_run', datetime_of_run)) + if plugin_version is not None: + parts.append(('plugin_version', plugin_version)) + if cohort_stratifier is not None: + parts.append(('cohort_stratifier', cohort_stratifier)) + query = urlencode(parts) + phenotype_counts, _ = self._retrieve('importance-composition', query) else: - raise ValueError(f'Unrecognized plugin name: {plugin}') - if datetime_of_run is not None: - parts.append(('datetime_of_run', datetime_of_run)) - if plugin_version is not None: - parts.append(('plugin_version', plugin_version)) - if cohort_stratifier is not None: - parts.append(('cohort_stratifier', cohort_stratifier)) - query = urlencode(parts) - phenotype_counts, _ = self._retrieve('importance-composition', query) + assert self.query_importance_composition is not None + optional_args = { + 'datetime_of_run': datetime_of_run, + 'plugin_version': plugin_version, + 'cohort_stratifier': cohort_stratifier, + } + optional_args = {k: v for k, v in optional_args.items() if v is not None} + + query = self.query_importance_composition( + self.study, + conjunction_criteria['positive_markers'], + conjunction_criteria['negative_markers'], + plugin, + **optional_args, + ) + phenotype_counts = _finish_retrieving_from_api_server(query) return {c['specimen']: c['percentage'] for c in phenotype_counts['counts']} -@define +def _finish_retrieving_from_api_server(query: Coroutine[Any, Any, Any]) -> dict[str, Any]: + """Wait for request to complete and return as dictionary.""" + with ThreadPoolExecutor() as executor: + loop = new_event_loop() + try: + result = executor.submit(loop.run_until_complete, query).result() + finally: + loop.close() + return result.dict() + + class ImportanceFractionAndTestRetriever: - host: str - study: str - access: ImportanceCountsAccessor | None = None - count_important: int = 100 - df_phenotypes: PhenotypeDataFrames | None = None - df_phenotypes_original: PhenotypeDataFrames | None = None - def initialize(self) -> None: - self.access = ImportanceCountsAccessor(self.study, host=self.host) + def __init__( + self, + host: str | APIServerCallables, + study: str, + count_important: int = 100, + use_tqdm: bool = False, + ) -> None: + self.host = host + self.study = study + self.count_important = count_important + self.use_tqdm = use_tqdm + + self.df_phenotypes = None + self.df_phenotypes_original = None + self.access = ImportanceCountsAccessor(self.study, self.host) def get_access(self) -> ImportanceCountsAccessor: - return cast(ImportanceCountsAccessor, self.access) + return self.access def get_df_phenotypes(self) -> PhenotypeDataFrames: + if self.df_phenotypes is None: + raise RuntimeError('Phenotype dataframes have not been initialized.') return cast(PhenotypeDataFrames, self.df_phenotypes) def get_sanitized_study(self) -> str: @@ -259,7 +350,7 @@ def get_pickle_file(self, data: Literal['counts', 'importance'], plugin: GNNMode return f'{self.get_sanitized_study()}.{plugin}.pickle' @staticmethod - def get_progress_bar_format(): + def get_progress_bar_format() -> str: return '{l_bar}{bar:30}{r_bar}{bar:-30b}' def reset_phenotype_counts(self, df: DataFrame) -> None: @@ -281,9 +372,15 @@ def _retrieve_phenotype_counts(self, df: DataFrame) -> None: N = len(levels) print('Retrieving count data to support plot.') f = self.get_progress_bar_format() + + if self.use_tqdm: + from tqdm import tqdm + iterable = tqdm(levels, total=N, bar_format=f) + else: + iterable = levels self.df_phenotypes_original = tuple( (str(phenotype), self.get_access().counts(phenotype).astype(int)) - for phenotype in tqdm(levels, total=N, bar_format=f) + for phenotype in iterable ) with open(pickle_file, 'wb') as file: pickle_dump(self.df_phenotypes_original, file) @@ -301,9 +398,14 @@ def retrieve(self, cohorts: set[int], phenotypes: list[str], plugin: GNNModel) - important_proportions = pickle_load(file) print(f'Loaded from cache: {pickle_file}') else: + if self.use_tqdm: + from tqdm import tqdm + iterable = tqdm(self.get_df_phenotypes(), total=N, bar_format=f) + else: + iterable = self.get_df_phenotypes() important_proportions = { phenotype: self.get_access().important(phenotype, plugin=plugin) - for phenotype, _ in tqdm(self.get_df_phenotypes(), total=N, bar_format=f) + for phenotype, _ in iterable } with open(pickle_file, 'wb') as file: pickle_dump(important_proportions, file) @@ -404,15 +506,19 @@ def _get_cell_count(self) -> Series: @define class PlotDataRetriever: - host: str + host: str | APIServerCallables + use_tqdm: bool def retrieve_data(self, specification: PlotSpecification) -> tuple[DataFrame, ...]: cohorts = set(c.index_int for c in specification.cohorts) plugins = cast(tuple[GNNModel, GNNModel], specification.plugins) phenotypes = list(specification.phenotypes) attribute_order = phenotypes + ['cohort'] - retriever = ImportanceFractionAndTestRetriever(self.host, specification.study) - retriever.initialize() + retriever = ImportanceFractionAndTestRetriever( + self.host, + specification.study, + use_tqdm=self.use_tqdm, + ) return tuple( retriever.retrieve(cohorts, phenotypes, plugin)[attribute_order] for plugin in plugins ) @@ -610,16 +716,17 @@ class PlotGenerator: def __init__( self, - host_name: str, + what_to_query: str | APIServerCallables, study_name: str, phenotypes: list[str], cohorts_raw: list[tuple[int, str]], plugins: list[str], figure_size: tuple[int, int], orientation: str | None, + use_tqdm: bool = False, ) -> None: """Instantiate the importance fractions plot generator.""" - self.host = host_name + self.host = what_to_query cohorts: list[Cohort] = [] for cohort in cohorts_raw: cohorts.append(Cohort(*cohort)) @@ -634,6 +741,7 @@ def __init__( figure_size, Orientation.HORIZONTAL if (orientation is None) else Orientation[orientation.upper()], ) + self.use_tqdm = use_tqdm def generate_plot(self) -> 'Figure': self._check_viability() @@ -646,7 +754,7 @@ def _check_viability(self) -> None: raise ValueError('Currently plot generation requires 2 plugins worth of run data.') def _retrieve_data(self) -> tuple[DataFrame, ...]: - dfs = PlotDataRetriever(self.host).retrieve_data(self.specification) + dfs = PlotDataRetriever(self.host, self.use_tqdm).retrieve_data(self.specification) dfs = self._transfer_cohort_labels(dfs, self.specification) return dfs diff --git a/test/apiserver/module_tests/test_retrieval_of_gnn_plot.sh b/test/apiserver/module_tests/test_retrieval_of_gnn_plot.sh new file mode 100644 index 00000000..84112a0c --- /dev/null +++ b/test/apiserver/module_tests/test_retrieval_of_gnn_plot.sh @@ -0,0 +1,15 @@ + +query="http://spt-apiserver-testing:8080/importance-fraction-plot/?study=Melanoma%20intralesional%20IL2" + +curl -sf "$query" > _gnn.svg ; +if [ "$?" -gt 0 ]; +then + echo "Error with apiserver query for GNN plot." + echo "$query" + exit 1 +fi + +if [ ! -s _gnn.svg ] || ! grep -q " Date: Wed, 15 May 2024 16:07:39 -0400 Subject: [PATCH 11/19] change upload-sync-small to upload a config file string --- spatialprofilingtoolbox/apiserver/app/main.py | 7 +++++-- .../db/scripts/upload_sync_small.py | 6 ++---- .../graphs/config_reader.py | 20 +++++++++++++++---- test/test_data/gnn_plot.json | 7 +------ 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index 1dc84b6a..85cd019e 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -43,6 +43,7 @@ ValidChannelListNegatives2, ValidFeatureClass, ) +from spatialprofilingtoolbox.graphs.config_reader import read_plot_importance_fractions_config from spatialprofilingtoolbox.graphs.importance_fractions import PlotGenerator VERSION = '0.23.0' @@ -390,14 +391,16 @@ async def importance_fraction_plot( if img_format not in APPROVED_FORMATS: raise ValueError(f'Image format "{img_format}" not supported.') - settings: list[str] = cast(list[str], query().get_study_gnn_plot_configurations(study)) + settings: str = cast(list[str], query().get_study_gnn_plot_configurations(study))[0] ( + _, + _, phenotypes, cohorts, plugins, figure_size, orientation, - ) = parse_gnn_plot_settings(settings) + ) = read_plot_importance_fractions_config(None, settings) plot = PlotGenerator( ( diff --git a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py index f01d66b0..79990098 100644 --- a/spatialprofilingtoolbox/db/scripts/upload_sync_small.py +++ b/spatialprofilingtoolbox/db/scripts/upload_sync_small.py @@ -20,11 +20,12 @@ def parse_args(): ) parser.add_argument( 'name', + choices=APPROVED_NAMES, help='The name of the table of strings to be synchronized.', ) parser.add_argument( 'file', - help='The JSON file containing a list of strings for each study.', + help='The JSON file containing the list of strings to be synced for each study.', ) add_argument(parser, 'database config') return parser.parse_args() @@ -71,9 +72,6 @@ def upload_sync( def main(): args = parse_args() - if args.name not in APPROVED_NAMES: - logger.error(f'{args.name} is not an approved table name.') - return database_config_file = get_and_validate_database_config(args) with open(args.file, 'rt', encoding='utf-8') as file: contents = file.read() diff --git a/spatialprofilingtoolbox/graphs/config_reader.py b/spatialprofilingtoolbox/graphs/config_reader.py index b0a9d2af..57157ffe 100644 --- a/spatialprofilingtoolbox/graphs/config_reader.py +++ b/spatialprofilingtoolbox/graphs/config_reader.py @@ -10,9 +10,18 @@ PLOT_FRACTIONS_SECTION_NAME = 'plot-importance-fractions' -def _read_config_file(config_file_path: str, section: str) -> dict[str, Any]: +def _read_config_file( + config_file_path: str | None, + section: str, + config_file_string: str | None = None, +) -> dict[str, Any]: config_file = ConfigParser() - config_file.read(config_file_path) + if config_file_path is not None: + config_file.read(config_file_path) + elif config_file_string is not None: + config_file.read_string(config_file_string) + else: + raise ValueError("Either config_file_path or config_file_string must be provided.") config: dict[str, Any] = \ dict(config_file[GENERAL_SECTION_NAME]) if (GENERAL_SECTION_NAME in config_file) else {} if section in config_file: @@ -153,7 +162,10 @@ def read_upload_config(config_file_path: str) -> tuple[ ) -def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ +def read_plot_importance_fractions_config( + config_file_path: str | None, + config_file_string: str | None = None, +) -> tuple[ str, str, list[str], @@ -167,7 +179,7 @@ def read_plot_importance_fractions_config(config_file_path: str) -> tuple[ For a detailed explanation of the return values, refer to the docstring of `spatialprofilingtoolbox.graphs.importance_fractions.PlotGenerator()`. """ - config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME) + config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME, config_file_string) host_name: str = config.get("host_name", "http://oncopathtk.org/api") study_name: str = config["study_name"] phenotypes: list[str] = config['phenotypes'].split(', ') diff --git a/test/test_data/gnn_plot.json b/test/test_data/gnn_plot.json index 1da2d0d3..09843de2 100644 --- a/test/test_data/gnn_plot.json +++ b/test/test_data/gnn_plot.json @@ -1,10 +1,5 @@ { "Melanoma intralesional IL2": [ - "Tumor, CD4+ T cell, CD8+ T cell, T cell/null phenotype", - "cg-gnn, graph-transformer", - "11, 8", - "horizontal", - "1, Non-responder", - "3, Responder" + "[plot-importance-fractions]\nphenotypes = Tumor, CD4+ T cell, CD8+ T cell, T cell/null phenotype\nplugins = cg-gnn, graph-transformer\nfigure_size = 11, 8\norientation = horizontal\n[plot-importance-fractions.cohort0]\nindex_int = 1\nlabel = Non-responder\n[plot-importance-fractions.cohort1]\nindex_int = 3\nlabel = Responder" ] } \ No newline at end of file From 499d1ab9e166e2e3bef731fc6e80c5262a492c9c Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Wed, 15 May 2024 16:18:41 -0400 Subject: [PATCH 12/19] have fastapi handle img format validation --- spatialprofilingtoolbox/apiserver/app/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index 85cd019e..2e135843 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -2,6 +2,7 @@ from typing import cast from typing import Annotated +from typing import Literal import json from io import BytesIO from base64 import b64decode @@ -72,6 +73,7 @@ CELL_DATA_CELL_LIMIT = 100001 + def custom_openapi(): if app.openapi_schema: return app.openapi_schema @@ -117,6 +119,7 @@ async def get_study_names( """The names of studies/datasets, with display names.""" specifiers = query().retrieve_study_specifiers() handles = [query().retrieve_study_handle(study) for study in specifiers] + def is_public(study_handle: StudyHandle) -> bool: if StudyCollectionNaming.is_untagged(study_handle): return True @@ -132,6 +135,7 @@ def is_public(study_handle: StudyHandle) -> bool: status_code=404, detail=f'Collection "{collection}" is not a valid collection string.', ) + def tagged(study_handle: StudyHandle) -> bool: return StudyCollectionNaming.tagged_with(study_handle, collection) handles = list(filter(tagged, map(query().retrieve_study_handle, specifiers))) @@ -345,6 +349,7 @@ async def get_cell_data( if not sample in query().get_sample_names(study): raise HTTPException(status_code=404, detail=f'Sample "{sample}" does not exist.') number_cells = cast(int, query().get_number_cells(study)) + def match(c: PhenotypeCount) -> bool: return c.specimen == sample count = tuple(filter(match, get_phenotype_counts([], [], study, number_cells).counts))[0].count @@ -384,13 +389,9 @@ def streaming_iteration(): @app.get("/importance-fraction-plot/") async def importance_fraction_plot( study: ValidStudy, - img_format: str = 'svg', + img_format: Literal['svg', 'png'] = 'svg', ) -> StreamingResponse: """Return a plot of the fraction of important cells expressing a given phenotype.""" - APPROVED_FORMATS = {'png', 'svg'} - if img_format not in APPROVED_FORMATS: - raise ValueError(f'Image format "{img_format}" not supported.') - settings: str = cast(list[str], query().get_study_gnn_plot_configurations(study))[0] ( _, From 71a05be2abdf9d3fdc2bb391dd500119017a8811 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Wed, 15 May 2024 16:41:44 -0400 Subject: [PATCH 13/19] remove unused file --- spatialprofilingtoolbox/apiserver/app/main.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index 2e135843..0e58e398 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -422,22 +422,3 @@ async def importance_fraction_plot( plt.savefig(buf, format=img_format) buf.seek(0) return StreamingResponse(buf, media_type=f"image/{img_format}") - - -def parse_gnn_plot_settings(settings: list[str]) -> tuple[ - list[str], - list[tuple[int, str]], - list[str], - tuple[int, int], - str, -]: - phenotypes = settings[0].split(', ') - plugins = settings[1].split(', ') - figure_size = tuple(map(int, settings[2].split(', '))) - assert len(figure_size) == 2 - orientation = settings[3] - cohorts: list[tuple[int, str]] = [] - for cohort in settings[4:]: - count, name = cohort.split(', ') - cohorts.append((int(count), name)) - return phenotypes, cohorts, plugins, figure_size, orientation From b5892cab08f85bded1ef97729c33fda5b0aee117 Mon Sep 17 00:00:00 2001 From: James Mathews Date: Wed, 15 May 2024 17:09:23 -0400 Subject: [PATCH 14/19] Update test artifact --- test/apiserver/unit_tests/test_available_gnn.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/apiserver/unit_tests/test_available_gnn.sh b/test/apiserver/unit_tests/test_available_gnn.sh index a96a191b..1db22c60 100644 --- a/test/apiserver/unit_tests/test_available_gnn.sh +++ b/test/apiserver/unit_tests/test_available_gnn.sh @@ -9,7 +9,7 @@ then exit 1 fi response=$(curl -s "$query") -if [[ "$response" != '{"plugins":["cg-gnn"]}' ]]; +if [[ "$response" != '{"plugins":["cg-gnn","graph-transformer"]}' ]]; then echo "API query for available GNN metrics failed." exit 1 From 013132a8c58ae77ff30c9d12021689b399c7c69d Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Wed, 15 May 2024 17:11:34 -0400 Subject: [PATCH 15/19] remove weird thread pool stuff --- spatialprofilingtoolbox/apiserver/app/main.py | 40 ++++++++++++++++--- .../graphs/importance_fractions.py | 37 +++++------------ 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index 0e58e398..e8415ee7 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -194,10 +194,16 @@ async def get_anonymous_phenotype_counts_fast( """Computes the number of cells satisfying the given positive and negative criteria, in the context of a given study. """ + return _get_anonymous_phenotype_counts_fast(positive_marker, negative_marker, study) + +def _get_anonymous_phenotype_counts_fast( + positive_marker: ValidChannelListPositives, + negative_marker: ValidChannelListNegatives, + study: ValidStudy, +) -> PhenotypeCounts: number_cells = cast(int, query().get_number_cells(study)) return get_phenotype_counts(positive_marker, negative_marker, study, number_cells) - @app.get("/request-spatial-metrics-computation/") async def request_spatial_metrics_computation( study: ValidStudy, @@ -261,7 +267,29 @@ async def available_gnn_metrics( @app.get("/importance-composition/") -async def importance_composition( +async def get_importance_composition( + study: ValidStudy, + positive_marker: ValidChannelListPositives, + negative_marker: ValidChannelListNegatives, + plugin: str = 'cg-gnn', + datetime_of_run: str = 'latest', + plugin_version: str | None = None, + cohort_stratifier: str | None = None, + cell_limit: int = 100, +) -> PhenotypeCounts: + """For each specimen, return the fraction of important cells expressing a given phenotype.""" + return _get_importance_composition( + study, + positive_marker, + negative_marker, + plugin, + datetime_of_run, + plugin_version, + cohort_stratifier, + cell_limit, + ) + +def _get_importance_composition( study: ValidStudy, positive_marker: ValidChannelListPositives, negative_marker: ValidChannelListNegatives, @@ -405,10 +433,10 @@ async def importance_fraction_plot( plot = PlotGenerator( ( - get_anonymous_phenotype_counts_fast, - get_study_summary, - get_phenotype_criteria, - importance_composition, + _get_anonymous_phenotype_counts_fast, + query().get_study_summary, + query().get_phenotype_criteria, + _get_importance_composition, ), study, phenotypes, diff --git a/spatialprofilingtoolbox/graphs/importance_fractions.py b/spatialprofilingtoolbox/graphs/importance_fractions.py index aa642075..3cf701fd 100644 --- a/spatialprofilingtoolbox/graphs/importance_fractions.py +++ b/spatialprofilingtoolbox/graphs/importance_fractions.py @@ -8,14 +8,11 @@ from typing import cast from typing import Any from typing import Callable -from typing import Coroutine from typing import TYPE_CHECKING import re from itertools import chain from urllib.parse import urlencode from enum import Enum -from concurrent.futures import ThreadPoolExecutor -from asyncio import new_event_loop import numpy as np from numpy.typing import NDArray @@ -29,6 +26,7 @@ from matplotlib.colors import Normalize from scipy.stats import fisher_exact # type: ignore from attr import define +from pydantic import BaseModel if TYPE_CHECKING: from matplotlib.figure import Figure @@ -64,16 +62,16 @@ def sanitized_study(study: str) -> str: PhenotypeDataFrames = tuple[tuple[str, DataFrame], ...] APIServerCallables = tuple[ - Callable[[list[str], list[str], str], Any], - Callable[[str], Any], - Callable[[str, str], Any], + Callable[[list[str], list[str], str], BaseModel], + Callable[[str], BaseModel], + Callable[[str, str], BaseModel], Callable[[ str, list[str], list[str], str, dict[str, Any], - ], Any], + ], BaseModel], ] @@ -159,8 +157,7 @@ def counts_by_signature(self, positives: list[str], negatives: list[str]) -> dic return self._retrieve(endpoint, query)[0] else: assert self.query_anonymous_phenotype_counts_fast is not None - query = self.query_anonymous_phenotype_counts_fast(positives, negatives, self.study) - return _finish_retrieving_from_api_server(query) + return self.query_anonymous_phenotype_counts_fast(positives, negatives, self.study).dict() def _get_counts_series(self, criteria: dict[str, list[str]], column_name: str) -> Series: criteria_tuple = ( @@ -177,8 +174,7 @@ def _retrieve_cohorts(self) -> DataFrame: summary, _ = self._retrieve('study-summary', urlencode([('study', self.study)])) else: assert self.query_study_summary is not None - query = self.query_study_summary(self.study) - summary = _finish_retrieving_from_api_server(query) + summary = self.query_study_summary(self.study).dict() return DataFrame(summary['cohorts']['assignments']).set_index('sample') def _retrieve_all_cells_counts(self) -> Series: @@ -219,8 +215,7 @@ def _phenotype_criteria(self, name: str | dict[str, list[str]]) -> dict[str, lis criteria, _ = self._retrieve('phenotype-criteria', query) else: assert self.query_phenotype_criteria is not None - query = self.query_phenotype_criteria(self.study, name) - criteria = _finish_retrieving_from_api_server(query) + criteria = self.query_phenotype_criteria(self.study, name).dict() return criteria def _conjunction_phenotype_criteria(self, names: str) -> dict[str, list[str]]: @@ -292,28 +287,16 @@ def important( } optional_args = {k: v for k, v in optional_args.items() if v is not None} - query = self.query_importance_composition( + phenotype_counts = self.query_importance_composition( self.study, conjunction_criteria['positive_markers'], conjunction_criteria['negative_markers'], plugin, **optional_args, - ) - phenotype_counts = _finish_retrieving_from_api_server(query) + ).dict() return {c['specimen']: c['percentage'] for c in phenotype_counts['counts']} -def _finish_retrieving_from_api_server(query: Coroutine[Any, Any, Any]) -> dict[str, Any]: - """Wait for request to complete and return as dictionary.""" - with ThreadPoolExecutor() as executor: - loop = new_event_loop() - try: - result = executor.submit(loop.run_until_complete, query).result() - finally: - loop.close() - return result.dict() - - class ImportanceFractionAndTestRetriever: def __init__( From 6f59e2bae23cd7a0057ff62f07cc0e8a05399eb5 Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Wed, 15 May 2024 17:25:44 -0400 Subject: [PATCH 16/19] hotfix config reading when calling from apiserver --- spatialprofilingtoolbox/apiserver/app/main.py | 2 +- spatialprofilingtoolbox/graphs/config_reader.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index e8415ee7..763d21f0 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -429,7 +429,7 @@ async def importance_fraction_plot( plugins, figure_size, orientation, - ) = read_plot_importance_fractions_config(None, settings) + ) = read_plot_importance_fractions_config(None, settings, True) plot = PlotGenerator( ( diff --git a/spatialprofilingtoolbox/graphs/config_reader.py b/spatialprofilingtoolbox/graphs/config_reader.py index 57157ffe..7230ba6a 100644 --- a/spatialprofilingtoolbox/graphs/config_reader.py +++ b/spatialprofilingtoolbox/graphs/config_reader.py @@ -165,6 +165,7 @@ def read_upload_config(config_file_path: str) -> tuple[ def read_plot_importance_fractions_config( config_file_path: str | None, config_file_string: str | None = None, + calling_by_api: bool = False, ) -> tuple[ str, str, @@ -181,7 +182,7 @@ def read_plot_importance_fractions_config( """ config = _read_config_file(config_file_path, PLOT_FRACTIONS_SECTION_NAME, config_file_string) host_name: str = config.get("host_name", "http://oncopathtk.org/api") - study_name: str = config["study_name"] + study_name: str = config["study_name"] if not calling_by_api else '' phenotypes: list[str] = config['phenotypes'].split(', ') plugins: list[str] = config['plugins'].split(', ') try: From 81f44b7563d280929ddc0fcf1131d577b5b2829c Mon Sep 17 00:00:00 2001 From: Carlin Liao Date: Wed, 15 May 2024 17:40:28 -0400 Subject: [PATCH 17/19] update analysis_replication for gnn plots --- analysis_replication/README.md | 11 +++- .../melanoma_intralesional_il2.config | 15 +++++ .../melanoma_intralesional_il2.json | 63 ------------------- .../gnn_figure/urothelial_ici.config | 15 +++++ .../gnn_figure/urothelial_ici.json | 41 ------------ 5 files changed, 40 insertions(+), 105 deletions(-) create mode 100644 analysis_replication/gnn_figure/melanoma_intralesional_il2.config delete mode 100644 analysis_replication/gnn_figure/melanoma_intralesional_il2.json create mode 100644 analysis_replication/gnn_figure/urothelial_ici.config delete mode 100644 analysis_replication/gnn_figure/urothelial_ici.json diff --git a/analysis_replication/README.md b/analysis_replication/README.md index 52bd2e0c..14158e78 100644 --- a/analysis_replication/README.md +++ b/analysis_replication/README.md @@ -30,7 +30,7 @@ substituting the argument with the address of your local API server. (See *Setti - You can alternatively store the API host in `api_host.txt` and omit the command-line argument above. - The run result is here in [results.txt](results.txt). -# Cell arrangment figure generation +## Cell arrangement figure generation One figure is generated programmatically from published source TIFF files. To run the figure generation script, alter the command below to reference your own database configuration file and path to unzipped Moldoveanu et al dataset. @@ -38,3 +38,12 @@ To run the figure generation script, alter the command below to reference your o ```bash python retrieve_example_plot.py dataset_directory/ ~/.spt_db.config ``` + +## GNN importance fractions figure generation + +This plot replication requires the installation of SPT, as it's not a script in this directory. Instead, it's a command in the `spt graphs` CLI that uses the configuration files stored in `gnn_figure/` to reproduce the plots seen in our publication. + +```bash +spt graphs plot-importance-fractions --config_path gnn_figure/melanoma_intralesional_il2.config --output_filename gnn_figure/melanoma_intralesional_il2.png +spt graphs plot-importance-fractions --config_path gnn_figure/urothelial_ici.config --output_filename gnn_figure/urothelial_ici.png +``` diff --git a/analysis_replication/gnn_figure/melanoma_intralesional_il2.config b/analysis_replication/gnn_figure/melanoma_intralesional_il2.config new file mode 100644 index 00000000..74c91087 --- /dev/null +++ b/analysis_replication/gnn_figure/melanoma_intralesional_il2.config @@ -0,0 +1,15 @@ +[general] +study_name = Melanoma intralesional IL2 + +[plot-importance-fractions] +host_name = http://oncopathtk.org/api +phenotypes = Tumor, Adipocyte or Langerhans cell, Natural killer cell, CD4+ T cell, Nerve, B cell, CD4+/CD8+ T cell, CD4+ regulatory T cell, CD8+ natural killer T cell, CD8+ regulatory T cell, CD8+ T cell, Double negative regulatory T cell, T cell/null phenotype, Natural killer T cell, CD4+ natural killer T cell +plugins = cg-gnn, graph-transformer +figure_size = 11, 8 +orientation = horizontal +[plot-importance-fractions.cohort0] +index_int = 1 +label = Non-responder +[plot-importance-fractions.cohort1] +index_int = 3 +label = Responder diff --git a/analysis_replication/gnn_figure/melanoma_intralesional_il2.json b/analysis_replication/gnn_figure/melanoma_intralesional_il2.json deleted file mode 100644 index 1cdc1e34..00000000 --- a/analysis_replication/gnn_figure/melanoma_intralesional_il2.json +++ /dev/null @@ -1,63 +0,0 @@ -{ - "study": "Melanoma intralesional IL2", - "phenotypes": [ - "Tumor", - "Adipocyte or Langerhans cell", - "Nerve", - "B cell", - "Natural killer cell", - "Natural killer T cell", - "CD4+/CD8+ T cell", - "CD4+ natural killer T cell", - "CD4+ regulatory T cell", - "CD4+ T cell", - "CD8+ natural killer T cell", - "CD8+ regulatory T cell", - "CD8+ T cell", - "Double negative regulatory T cell", - "T cell/null phenotype", - "CD163+MHCII- macrophage", - "CD163+MHCII+ macrophage", - "CD68+MHCII- macrophage", - "CD68+MHCII+ macrophage", - "Other macrophage/monocyte CD14+", - "Other macrophage/monocyte CD4+" - ], - "attribute_order": [ - "Tumor", - "Adipocyte or Langerhans cell", - "Natural killer cell", - "CD4+ T cell", - "Nerve", - "B cell", - "CD4+/CD8+ T cell", - "CD4+ regulatory T cell", - "CD8+ natural killer T cell", - "CD8+ regulatory T cell", - "CD8+ T cell", - "Double negative regulatory T cell", - "T cell/null phenotype", - "Natural killer T cell", - "CD4+ natural killer T cell", - "cohort" - ], - "cohorts": [ - { - "index_int": 1, - "label": "Non-responder" - }, - { - "index_int": 3, - "label": "Responder" - } - ], - "plugins": [ - "cg-gnn", - "graph-transformer" - ], - "figure_size": [ - 11, - 8 - ], - "orientation": "horizontal" -} diff --git a/analysis_replication/gnn_figure/urothelial_ici.config b/analysis_replication/gnn_figure/urothelial_ici.config new file mode 100644 index 00000000..c8c26bbc --- /dev/null +++ b/analysis_replication/gnn_figure/urothelial_ici.config @@ -0,0 +1,15 @@ +[general] +study_name = Urothelial ICI + +[plot-importance-fractions] +host_name = http://oncopathtk.org/api +phenotypes = Tumor, CD4- CD8- T cell, T cytotoxic cell, T helper cell, Macrophage, intratumoral CD3+ LAG3+, Regulatory T cell +plugins = cg-gnn, graph-transformer +figure_size = 14, 5 +orientation = vertical +[plot-importance-fractions.cohort0] +index_int = 1 +label = Responder +[plot-importance-fractions.cohort1] +index_int = 2 +label = Non-responder diff --git a/analysis_replication/gnn_figure/urothelial_ici.json b/analysis_replication/gnn_figure/urothelial_ici.json deleted file mode 100644 index 4d32834b..00000000 --- a/analysis_replication/gnn_figure/urothelial_ici.json +++ /dev/null @@ -1,41 +0,0 @@ -{ - "study": "Urothelial ICI", - "phenotypes": [ - "Tumor", - "CD4- CD8- T cell", - "T cytotoxic cell", - "T helper cell", - "Macrophage", - "intratumoral CD3+ LAG3+", - "Regulatory T cell" - ], - "attribute_order": [ - "Tumor", - "CD4- CD8- T cell", - "T cytotoxic cell", - "T helper cell", - "Macrophage", - "intratumoral CD3+ LAG3+", - "Regulatory T cell", - "cohort" - ], - "cohorts": [ - { - "index_int": 1, - "label": "Responder" - }, - { - "index_int": 2, - "label": "Non-responder" - } - ], - "plugins": [ - "cg-gnn", - "graph-transformer" - ], - "figure_size": [ - 14, - 5 - ], - "orientation": "vertical" -} From b4b3efbdc6d54d1397a423d2b5267ac1e8a80886 Mon Sep 17 00:00:00 2001 From: James Mathews Date: Wed, 15 May 2024 17:44:37 -0400 Subject: [PATCH 18/19] Version bump --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index c86a09df..444d4960 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.23.2 \ No newline at end of file +0.23.3 \ No newline at end of file From b2ae1b06adbf58490c0f75a932b9e8f84d9b8c2b Mon Sep 17 00:00:00 2001 From: James Mathews Date: Wed, 15 May 2024 18:13:01 -0400 Subject: [PATCH 19/19] Remove some build dependencies --- environment.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/environment.yml b/environment.yml index 9c82fcb9..c8cf57bd 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,4 @@ channels: - conda-forge dependencies: - python=3.11 - - toml - - make - - bash - - postgresql + - libpq