Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GNN plots available as an SPT CLI command and API endpoint #320

Merged
merged 19 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions analysis_replication/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,3 @@ To run the figure generation script, alter the command below to reference your o
```bash
python retrieve_example_plot.py dataset_directory/ ~/.spt_db.config
```

# GNN importance fractions figure generation

Another figure is generated programmatically from extractions from Graph Neural Network models, provided by the API.

```bash
cd gnn_figure/
python graph_plugin_plots.py
```
1 change: 1 addition & 0 deletions build/apiserver/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ RUN python -m pip install scikit-learn==1.2.2
RUN python -m pip install Pillow==9.5.0
RUN python -m pip install pydantic==2.0.2
RUN python -m pip install secure==0.3.0
RUN python -m pip install matplotlib==3.7.1
ARG version
ARG service_name
ARG WHEEL_FILENAME
Expand Down
3 changes: 2 additions & 1 deletion build/build_scripts/import_test_dataset1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ rm -f .nextflow.log*; rm -rf .nextflow/; rm -f configure.sh; rm -f run.sh; rm -f

spt graphs upload-importances --config_path=build/build_scripts/.graph.config --importances_csv_path=test/test_data/gnn_importances/1.csv

spt db upload-sync-findings --database-config-file=build/db/.spt_db.config.local test/test_data/findings.json
spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local findings test/test_data/findings.json
spt db upload-sync-small --database-config-file=build/db/.spt_db.config.local gnn_plot_configurations test/test_data/gnn_plot.json

spt db status --database-config-file build/db/.spt_db.config.local > table_counts.txt
diff build/build_scripts/expected_table_counts.txt table_counts.txt
Expand Down
4 changes: 3 additions & 1 deletion docs/maintenance.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ The modules in this repository are built, tested, and deployed using `make` and
| [Docker Engine](https://docs.docker.com/engine/install/) | 20.10.17 |
| [Docker Compose](https://docs.docker.com/compose/install/) | 2.10.2 |
| [bash](https://www.gnu.org/software/bash/) | >= 4 |
| [python](https://www.python.org/downloads/) | >=3.7 |
| [python](https://www.python.org/downloads/) | >=3.7 <3.12 |
| [postgresql](https://www.postgresql.org/download/) | 13.4 |
| [toml](https://pypi.org/project/toml/) | 0.10.2 |

A typical development workflow looks like:

Expand Down
9 changes: 9 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: spt
channels:
- conda-forge
dependencies:
- python=3.11
- toml
- make
- bash
- postgresql
4 changes: 3 additions & 1 deletion pyproject.toml.unversioned
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ repository = "https://github.com/nadeemlab/SPT"

[project.optional-dependencies]
apiserver = [
"matplotlib==3.7.1",
"fastapi==0.100.0",
"uvicorn>=0.15.0,<0.16.0",
"pandas==2.0.2",
Expand Down Expand Up @@ -190,7 +191,7 @@ packages = [
"drop.py",
"drop_ondemand_computations.py",
"delete_feature.py",
"upload_sync_findings.py",
"upload_sync_small.py",
"collection.py",
]
"spatialprofilingtoolbox.db.data_model" = [
Expand All @@ -211,6 +212,7 @@ packages = [
"extract.py",
"finalize_graphs.py",
"generate_graphs.py",
"plot_importance_fractions.py",
"plot_interactives.py",
"prepare_graph_creation.py",
"upload_importances.py",
Expand Down
60 changes: 60 additions & 0 deletions spatialprofilingtoolbox/apiserver/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi.responses import StreamingResponse
from fastapi import Query
from fastapi import HTTPException
import matplotlib.pyplot as plt

import secure

Expand Down Expand Up @@ -42,6 +43,8 @@
ValidChannelListNegatives2,
ValidFeatureClass,
)
from spatialprofilingtoolbox.graphs.importance_fractions import PlotGenerator

VERSION = '0.23.0'

TITLE = 'Single cell studies data API'
Expand Down Expand Up @@ -375,3 +378,60 @@ async def get_plot_high_resolution(
def streaming_iteration():
yield from input_buffer
return StreamingResponse(streaming_iteration(), media_type="image/png")


@app.get("/importance-fraction-plot/")
async def importance_fraction_plot(
study: ValidStudy,
img_format: str = 'svg',
) -> StreamingResponse:
"""Return a plot of the fraction of important cells expressing a given phenotype."""
APPROVED_FORMATS = {'png', 'svg'}
if img_format not in APPROVED_FORMATS:
raise ValueError(f'Image format "{img_format}" not supported.')
CarlinLiao marked this conversation as resolved.
Show resolved Hide resolved

settings: list[str] = cast(list[str], query().get_study_gnn_plot_configurations(study))
(
hostname,
phenotypes,
cohorts,
plugins,
figure_size,
orientation,
) = parse_gnn_plot_settings(settings)

plot = PlotGenerator(
hostname,
study,
phenotypes,
cohorts,
plugins,
figure_size,
orientation,
).generate_plot()
plt.figure(plot.number)
buf = BytesIO()
plt.savefig(buf, format=img_format)
buf.seek(0)
return StreamingResponse(buf, media_type=f"image/{img_format}")


def parse_gnn_plot_settings(settings: list[str]) -> tuple[
str,
list[str],
list[tuple[int, str]],
list[str],
tuple[int, int],
str,
]:
hostname = settings[0]
phenotypes = settings[1].split(', ')
plugins = settings[2].split(', ')
figure_size = tuple(map(int, settings[3].split(', ')))
assert len(figure_size) == 2
orientation = settings[4]
cohorts: list[tuple[int, str]] = []
for cohort in settings[5:]:
count, name = cohort.split(', ')
cohorts.append((int(count), name))
return hostname, phenotypes, cohorts, plugins, figure_size, orientation
8 changes: 7 additions & 1 deletion spatialprofilingtoolbox/db/accessors/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ def get_available_gnn(self, study: str) -> AvailableGNN:
return AvailableGNN(plugins=tuple(specifier for (specifier, ) in rows))

def get_study_findings(self) -> list[str]:
self.cursor.execute('SELECT txt FROM findings ORDER BY id;')
return self._get_study_small_artifacts('findings')

def get_study_gnn_plot_configurations(self) -> list[str]:
return self._get_study_small_artifacts('gnn_plot_configurations')

def _get_study_small_artifacts(self, name: str) -> list[str]:
self.cursor.execute(f'SELECT txt FROM {name} ORDER BY id;')
return [row[0] for row in self.cursor.fetchall()]

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions spatialprofilingtoolbox/db/database_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class (QueryCursor) newly provides on each invocation.
get_sample_names: Callable
get_available_gnn: Callable
get_study_findings: Callable
get_study_gnn_plot_configurations: Callable
is_public_collection: Callable

def __init__(self, query_handler: Type):
Expand Down
4 changes: 4 additions & 0 deletions spatialprofilingtoolbox/db/querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def get_available_gnn(cls, cursor, study: str) -> AvailableGNN:
def get_study_findings(cls, cursor, study: str) -> list[str]:
return StudyAccess(cursor).get_study_findings()

@classmethod
def get_study_gnn_plot_configurations(cls, cursor, study: str) -> list[str]:
return StudyAccess(cursor).get_study_gnn_plot_configurations()

@classmethod
def get_composite_phenotype_identifiers(cls, cursor) -> tuple[str, ...]:
return sort(PhenotypesAccess(cursor).get_composite_phenotype_identifiers())
Expand Down
71 changes: 0 additions & 71 deletions spatialprofilingtoolbox/db/scripts/upload_sync_findings.py

This file was deleted.

85 changes: 85 additions & 0 deletions spatialprofilingtoolbox/db/scripts/upload_sync_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Synchronize a small data artifact with the database."""
CarlinLiao marked this conversation as resolved.
Show resolved Hide resolved

import argparse
from json import loads as json_loads

from spatialprofilingtoolbox.db.database_connection import get_and_validate_database_config
from spatialprofilingtoolbox.db.database_connection import DBCursor
from spatialprofilingtoolbox.workflow.common.cli_arguments import add_argument

from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger
logger = colorized_logger('upload_sync_small')

APPROVED_NAMES = ('findings', 'gnn_plot_configurations')


def parse_args():
parser = argparse.ArgumentParser(
prog='spt db upload-sync-small',
description='Synchronize small lists of strings for each study with the database.'
)
parser.add_argument(
'name',
help='The name of the table of strings to be synchronized.',
)
CarlinLiao marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
'file',
help='The JSON file containing a list of strings for each study.',
)
add_argument(parser, 'database config')
return parser.parse_args()


def _create_table_query(name: str) -> str:
return f'CREATE TABLE IF NOT EXISTS {name} (id SERIAL PRIMARY KEY, txt TEXT);'


def _sync_data(cursor, name: str, data: tuple[str, ...]) -> bool:
cursor.execute(_create_table_query(name))
cursor.execute(f'SELECT id, txt FROM {name} ORDER BY id;')
rows = tuple(cursor.fetchall())
if tuple(text for _, text in rows) == data:
return True
cursor.execute(f'DELETE FROM {name};')
for datum in data:
cursor.execute(f'INSERT INTO {name}(txt) VALUES (%s);', (datum,))
return False


def _upload_sync_study(
study: str,
name: str,
data: list[str],
database_config_file: str,
) -> None:
with DBCursor(database_config_file=database_config_file, study=study) as cursor:
already_synced = _sync_data(cursor, name, tuple(data))
if already_synced:
logger.info(f'Data for "{study}" are already up-to-date.')
else:
logger.info(f'Data for "{study}" were synced.')


def upload_sync(
name: str,
data_per_study: dict[str, list[str]],
database_config_file: str,
) -> None:
for study, study_data in data_per_study.items():
_upload_sync_study(study, name, study_data, database_config_file)


def main():
args = parse_args()
if args.name not in APPROVED_NAMES:
logger.error(f'{args.name} is not an approved table name.')
return
CarlinLiao marked this conversation as resolved.
Show resolved Hide resolved
database_config_file = get_and_validate_database_config(args)
with open(args.file, 'rt', encoding='utf-8') as file:
contents = file.read()
to_sync = json_loads(contents)
upload_sync(args.name, to_sync, database_config_file)


if __name__ == '__main__':
main()
Loading
Loading