From a51833272a11b2663eac821b69dd6375599457e2 Mon Sep 17 00:00:00 2001 From: Rohan Date: Fri, 5 Aug 2022 16:31:17 -0700 Subject: [PATCH 1/3] feat: more salmon outputs + better multiqc report --- wf/__init__.py | 210 +++++++++++++++++++++---------------------------- 1 file changed, 88 insertions(+), 122 deletions(-) diff --git a/wf/__init__.py b/wf/__init__.py index 88b7eb3..8c160ba 100644 --- a/wf/__init__.py +++ b/wf/__init__.py @@ -11,17 +11,20 @@ from enum import Enum from pathlib import Path from typing import Iterable, List, Optional, Tuple, Union -from urllib.parse import urlparse import lgenome from dataclasses_json import dataclass_json -from flytekit import LaunchPlan, task +from flytekit import task from flytekitplugins.pod import Pod -from kubernetes.client.models import (V1Container, V1PodSpec, - V1ResourceRequirements, V1Toleration) +from kubernetes.client.models import ( + V1Container, + V1PodSpec, + V1ResourceRequirements, + V1Toleration, +) from latch import map_task, message, small_task, workflow from latch.resources.launch_plan import LaunchPlan -from latch.types import LatchDir, LatchFile +from latch.types import LatchDir, LatchFile, file_glob def _capture_output(command: List[str]) -> Tuple[int, str]: @@ -56,61 +59,6 @@ def ___repr__(self): LatchFile.__repr__ = types.MethodType(___repr__, LatchFile) -def file_glob( - pattern: str, remote_directory: str, target_dir: Optional[Path] = None -) -> List[LatchFile]: - """Constructs a list of LatchFiles from a glob pattern. - Convenient utility for passing collections of files between tasks. See - [nextflow's channels](https://www.nextflow.io/docs/latest/channel.html) or - [snakemake's wildcards](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#wildcards). - for similar functionality in other orchestration tools. - The remote location of each constructed LatchFile will be consructed by - appending the file name returned by the pattern to the directory - represented by the `remote_directory`. - Args: - pattern: A glob pattern to match a set of files, eg. '*.py'. Will - resolve paths with respect to the working directory of the caller. - remote_directory: A valid latch URL pointing to a directory, eg. - latch:///foo. This _must_ be a directory and not a file. - target_dir: An optional Path object to define an alternate working - directory for path resolution - Returns: - A list of instantiated LatchFile objects. - Intended Use: :: - @small_task - def task(): - ... - return file_glob("*.fastq.gz", "latch:///fastqc_outputs") - """ - - if not _is_valid_url(remote_directory): - return [] - - if target_dir is None: - wd = Path.cwd() - else: - wd = target_dir - matched = sorted(wd.glob(pattern)) - - return [LatchFile(str(file), remote_directory + file.name) for file in matched] - - -def _is_valid_url(raw_url: str) -> bool: - """A valid URL (as a source or destination of a LatchFile) must: - * contain a latch or s3 scheme - * contain an absolute path - """ - try: - parsed = urlparse(raw_url) - except ValueError: - return False - if parsed.scheme not in ("latch", "s3"): - return False - if not parsed.path.startswith("/"): - return False - return True - - def _get_96_spot_pod() -> Pod: """[ "c6i.24xlarge", "c5.24xlarge", "c5.metal", "c5d.24xlarge", "c5d.metal" ]""" @@ -240,12 +188,16 @@ class TrimgaloreSalmonOutput: passed_salmon: bool passed_tximport: bool sample_name: str - sf_files: List[LatchFile] - auxiliary_directory: List[LatchDir] - # trimmed_reads: List[Replicate] + salmon_output: LatchDir + # quant: LatchFile + # genome_abundance: LatchFile trimgalore_reports: List[LatchFile] +def slugify(value: str) -> str: + return value.replace(" ", "_") + + @small_task def prepare_trimgalore_salmon_inputs( samples: List[Sample], @@ -301,7 +253,7 @@ def prepare_trimgalore_salmon_inputs( def _merge_replicates( replicates: List[Replicate], sample_name: str ) -> Union[Tuple[Path], Tuple[Path, Path]]: - local_r1_path = f"{sample_name}_r1_merged.fq" + local_r1_path = f"{slugify(sample_name)}_r1_merged.fq" r1 = _concatenate_files((str(x.r1.path) for x in replicates), local_r1_path) if isinstance(replicates[0], SingleEndReads): @@ -310,7 +262,7 @@ def _merge_replicates( if not all(isinstance(x, PairedEndReads) for x in replicates): raise RuntimeError("Not all technical replicates were paired end") - local_r2_path = f"{sample_name}_r2_merged.fq" + local_r2_path = f"{slugify(sample_name)}_r2_merged.fq" r2 = _concatenate_files((str(x.r2.path) for x in replicates), local_r2_path) return (r1, r2) @@ -354,7 +306,7 @@ def _flag(name: str) -> List[str]: flags += ["--paired", *_flag("clip_r2"), *_flag("three_prime_clip_r2")] read_paths.append(str(reads.r2.local_path)) - local_output = f"{ts_input.sample_name}_replicate_{replicate_index}" + local_output = f"{slugify(ts_input.sample_name)}_replicate_{replicate_index}" returncode, stdout = _capture_output( [ "trim_galore", @@ -381,12 +333,12 @@ def _flag(name: str) -> List[str]: ) raise TrimgaloreError(stdout) - def _output_path(middle: str) -> str: + def remote(middle: str) -> str: base = f"{ts_input.base_remote_output_dir}{ts_input.run_name}" tail = f"{ts_input.sample_name}/replicate_{replicate_index}/" return f"{base}/Quality Control Data/Trimming {middle} (TrimGalore)/{tail}" - reads_directory = _output_path("Reads") + reads_directory = remote("Reads") if isinstance(reads, SingleEndReads): (r1,) = file_glob(f"{local_output}/*trimmed.fq*", reads_directory) trimmed_replicate = SingleEndReads(r1=r1) @@ -400,7 +352,7 @@ def _output_path(middle: str) -> str: if isinstance(reads, PairedEndReads): os.remove(reads.r2.local_path) - reports_directory = _output_path("Reports") + reports_directory = remote("Reports") reports = file_glob("*trimming_report.txt", reports_directory) return reports, trimmed_replicate @@ -480,6 +432,19 @@ def trimgalore_salmon(input: TrimgaloreSalmonInput) -> TrimgaloreSalmonOutput: elif name == "gtf": custom_gtf = file + def local(suffix: Optional[str] = None) -> str: + salmon_path = "/root/salmon_quant/" + if suffix is not None: + salmon_path += suffix + return salmon_path + + def remote(suffix: Optional[str] = None) -> str: + i = input + path = f"{i.base_remote_output_dir}{i.run_name}/Quantification (salmon)/{i.sample_name}/" + if suffix is not None: + path += suffix + return path + gtf_path = None if custom_gtf is not None: gtf_path = _unzip_if_needed(custom_gtf.local_path) @@ -545,7 +510,7 @@ def trimgalore_salmon(input: TrimgaloreSalmonInput) -> TrimgaloreSalmonOutput: str(96), "--validateMappings", "-o", - "salmon_quant", + local(), ] ) @@ -553,27 +518,36 @@ def trimgalore_salmon(input: TrimgaloreSalmonInput) -> TrimgaloreSalmonOutput: for path in merged: os.remove(path) - # todo(rohankan): Add info messages too? Decide which Salmon logs - # are interesting or important enough to show on console + def parse_salmon_warning(alert_message: str, input: TrimgaloreSalmonInput) -> str: + if "of fragments were shorter than the k" in alert_message: + # percent = float(alert_message.split("%")[0]) + # min_read_size = int(alert_message.split(".")[-2].split(" ")[-1]) + return alert_message + elif "Detected a *potential* strand bias" in alert_message: + default = "salmon_quant/lib_format_counts.json" + return alert_message.replace(default, remote("lib_format_counts.json")) + return alert_message + + # Add info messages too? identifier = f"sample {input.sample_name}" errors = [] for alert_type, alert_message in re.findall(_SALMON_ALERT_PATTERN, stdout): - data = {"title": f"Salmon {alert_type} for {identifier}", "body": alert_message} + title = f"Salmon {alert_type} for {identifier}" if alert_type == "warning": - message("warning", data) + message( + "warning", + {"title": title, "body": parse_salmon_warning(alert_message, input)}, + ) else: - message("error", data) + message("error", {"title": title, "body": alert_message}) errors.append(alert_message) if returncode != 0: deets = "\n".join(["Error(s) occurred while running Salmon", *errors]) raise RuntimeError(deets) - salmon_quant_file = LatchFile( - "/root/salmon_quant/quant.sf", - _salmon_output_path(input, f"{input.sample_name}_quant.sf"), - ) - sf_files = [salmon_quant_file] + quant_name = local(f"{slugify(input.sample_name)}_quant.sf") + salmon_quant = Path(local("quant.sf")).rename(quant_name) try: gtf_path = ( @@ -582,24 +556,20 @@ def trimgalore_salmon(input: TrimgaloreSalmonInput) -> TrimgaloreSalmonOutput: else gm.download_gtf(show_progress=False) ) - tximport_output_path = "/root/salmon_quant/genome_abundance.sf" + tximport_output_path = local("genome_abundance.sf") subprocess.run( [ "/root/wf/run_tximport.R", "--args", - str(salmon_quant_file.path), + str(salmon_quant), gtf_path, tximport_output_path, ], capture_output=True, ) - sf_files.append( - LatchFile( - tximport_output_path, - _salmon_output_path(input, f"{input.sample_name}_genome_abundance.sf"), - ) - ) + ga_name = local(f"{slugify(input.sample_name)}_genome_abundance.sf") + Path(tximport_output_path).rename(ga_name) except subprocess.CalledProcessError as e: message("error", {"title": f"tximport error for {identifier}", "body": str(e)}) print( @@ -607,18 +577,18 @@ def trimgalore_salmon(input: TrimgaloreSalmonInput) -> TrimgaloreSalmonOutput: ) raise RuntimeError(f"Tximport error: {e}") - auxiliary_directory = LatchDir( - "/root/salmon_quant/aux_info", - _salmon_output_path(input, "Auxiliary Info"), - ) + for unwanted in ("logs", "cmd_info.json"): + path = Path(local(unwanted)) + if "." in unwanted: + path.unlink() + else: + shutil.rmtree(path) return TrimgaloreSalmonOutput( passed_salmon=True, passed_tximport=True, sample_name=input.sample_name, - sf_files=sf_files, - auxiliary_directory=[auxiliary_directory], - # trimmed_reads=trimmed_replicates, + salmon_output=LatchDir(local(), remote()), trimgalore_reports=trimgalore_reports, ) @@ -640,47 +610,49 @@ class SalmonError(Exception): _SALMON_ALERT_PATTERN = re.compile(r"\[(warning|error)\] (.+?)(?:\[\d{4}|$)", re.DOTALL) -def _salmon_output_path(inp: TrimgaloreSalmonInput, suffix: str) -> str: - return f"{inp.base_remote_output_dir}{inp.run_name}/Quantification (salmon)/{inp.sample_name}/{suffix}" - - _COUNT_TABLE_GENE_ID_COLUMN = "gene_id" @small_task def count_matrix_and_multiqc( run_name: str, - outputs: List[TrimgaloreSalmonOutput], + ts_outputs: List[TrimgaloreSalmonOutput], output_directory: Optional[LatchDir], latch_genome: LatchGenome, custom_gtf: Optional[LatchFile] = None, ) -> List[LatchFile]: output_files = [] - # Beginning creation of combined count matrix - tximport_outputs = [x for x in outputs if x.passed_tximport] - if len(tximport_outputs) == len(outputs): + Path("/root/inputs").mkdir(parents=True) + paths = [ + Path(x.salmon_output.local_path).rename(f"/root/inputs/{x.sample_name}") + for x in ts_outputs + ] + + def remote(suffix: str): + return _remote_output_dir(output_directory) + run_name + "/" + suffix + + # Create combined count matrix + if all(x.passed_tximport for x in ts_outputs): message( "info", { "title": "Generating count matrix from all samples", - "body": "\n".join(f"- {x.sample_name}" for x in tximport_outputs), + "body": "\n".join(f"- {x.sample_name}" for x in ts_outputs), }, ) combined_counts = defaultdict(dict) - for output in tximport_outputs: - genome_abundance_file = next( - x for x in output.sf_files if x.remote_path.endswith("dance.sf") - ) - with open(genome_abundance_file.local_path, "r") as f: + for output, path in zip(ts_outputs, paths): + genome_abundance_file = next(path.glob("*genome_abundance.sf")) + with genome_abundance_file.open("r") as f: for row in csv.DictReader(f, dialect=csv.excel_tab): gene_name = row["Name"] combined_counts[gene_name][output.sample_name] = row["NumReads"] raw_count_table_path = Path("./counts.tsv").resolve() with raw_count_table_path.open("w") as file: - sample_names = (x.sample_name for x in tximport_outputs) + sample_names = (x.sample_name for x in ts_outputs) writer = csv.DictWriter( file, [_COUNT_TABLE_GENE_ID_COLUMN, *sample_names], @@ -691,9 +663,10 @@ def count_matrix_and_multiqc( data[_COUNT_TABLE_GENE_ID_COLUMN] = gene_id writer.writerow(data) - base = _remote_output_dir(output_directory) + run_name - output_path = f"{base}/Quantification (salmon)/counts.tsv" - count_matrix_file = LatchFile(str(raw_count_table_path), output_path) + count_matrix_file = LatchFile( + str(raw_count_table_path), + remote("Quantification (salmon)/counts.tsv"), + ) output_files.append(count_matrix_file) else: message( @@ -704,18 +677,11 @@ def count_matrix_and_multiqc( }, ) - # Beginning creation of MultiQC report - salmon_outputs = [x for x in outputs if x.passed_salmon] - paths = [Path(x.auxiliary_directory[0].local_path) for x in salmon_outputs] - for output in salmon_outputs: - paths += [Path(x.local_path) for x in output.sf_files] - try: subprocess.run(["multiqc", *paths], check=True) - base = _remote_output_dir(output_directory) + run_name multiqc_report_file = LatchFile( "/root/multiqc_report.html", - f"{base}/multiqc_report.html", + remote("multiqc_report.html"), ) output_files.append(multiqc_report_file) except subprocess.CalledProcessError as e: @@ -1079,7 +1045,7 @@ def rnaseq( outputs = map_task(trimgalore_salmon)(input=inputs) return count_matrix_and_multiqc( run_name=run_name, - outputs=outputs, + ts_outputs=outputs, output_directory=custom_output_dir, latch_genome=latch_genome, custom_gtf=custom_gtf, From 91b43b43a4c9c4d7d807d2a9128105f55fb735e7 Mon Sep 17 00:00:00 2001 From: Rohan Date: Fri, 5 Aug 2022 16:33:51 -0700 Subject: [PATCH 2/3] fix: proper names for test data samples --- wf/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/wf/__init__.py b/wf/__init__.py index 8c160ba..3cfb297 100644 --- a/wf/__init__.py +++ b/wf/__init__.py @@ -1058,7 +1058,7 @@ def rnaseq( { "samples": [ Sample( - name="test_sample", + name="Control rep 1", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1069,7 +1069,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="Control rep 2", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1080,7 +1080,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="Control rep 3", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1091,7 +1091,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="CoCl2 rep 1", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1102,7 +1102,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="CoCl2 rep 2", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1113,7 +1113,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="Oxy rep 1", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1124,7 +1124,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="Oxy rep 2", strandedness=Strandedness.auto, replicates=[ SingleEndReads( @@ -1135,7 +1135,7 @@ def rnaseq( ], ), Sample( - name="test_sample", + name="Oxy rep 3", strandedness=Strandedness.auto, replicates=[ SingleEndReads( From 0300c8bcd56e2fabea9edb2719de682ccee2e283 Mon Sep 17 00:00:00 2001 From: Rohan Date: Fri, 5 Aug 2022 16:34:02 -0700 Subject: [PATCH 3/3] chore: bump version --- version | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version b/version index 7eff86e..7daae12 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.0.316 +0.0.317