diff --git a/workers/gene_annotator.py b/workers/gene_annotator.py index fe014eb..3bc3c11 100644 --- a/workers/gene_annotator.py +++ b/workers/gene_annotator.py @@ -3,6 +3,8 @@ Given genotype data that exists in the DB, joins that data with Ensembl to arrive at gene name(s) for each locus. """ +from contextlib import contextmanager + import csv from pyensembl import EnsemblRelease from sqlalchemy import select, update, Table, Column @@ -17,7 +19,9 @@ @worker.task def annotate(vcf_id): engine, connection, metadata = initialize_database(DATABASE_URI) - try: + + with closing(connection, pre_finally_fun=discard_temp, + connection=connection): genotypes = metadata.tables.get('genotypes') stmt = select([genotypes.c.contig, genotypes.c.position]).where( genotypes.c.vcf_id == vcf_id) @@ -36,18 +40,18 @@ def annotate(vcf_id): gene_name_str = ','.join(gene_names_for_locus) gene_names.append([contig, position, gene_name_str]) - tmp_gene_annotations = Table('gene_annotations', + tmp_table = Table('gene_annotations', metadata, Column('contig', Text, nullable=False), Column('position', Integer, nullable=False), Column('gene_names', Text, nullable=True), prefixes=['TEMPORARY']) - tmp_gene_annotations.create() + tmp_table.create() - # Open file for both writing (the gene annotations) and reading that out - # to Postgres + # Open file for both writing (the gene annotations) and reading that + # out to Postgres csv_file = temp_csv(mode='r+', tmp_dir=TEMPORARY_DIR) - try: + with closing(csv_file): # Don't use commas as a delim, as commas are part of gene_names csv.writer(csv_file, delimiter='\t').writerows(gene_names) @@ -55,22 +59,34 @@ def annotate(vcf_id): csv_file.seek(0, 0) cursor = connection.connection.cursor() - cursor.copy_from(csv_file, sep='\t', null='', - table=tmp_gene_annotations.name) - connection.connection.commit() + with closing(cursor): + cursor.copy_from(csv_file, sep='\t', null='', + table=tmp_table.name) + connection.connection.commit() - # Note: auto-commits - genotypes.update().where( - genotypes.c.contig == tmp_gene_annotations.c.contig).where( - genotypes.c.position == tmp_gene_annotations.c.position).values( - {'annotations:gene_names': tmp_gene_annotations.c.gene_names}).execute() + # Note: auto-commits + genotypes.update().where( + genotypes.c.contig == tmp_table.c.contig).where( + genotypes.c.position == tmp_table.c.position).values( + {'annotations:gene_names': tmp_table.c.gene_names} + ).execute() - # Update the list of extant columns for the UI - update_extant_columns(metadata, connection, vcf_id) - finally: - csv_file.close() - cursor.close() - finally: + # Update the list of extant columns for the UI + update_extant_columns(metadata, connection, vcf_id) + +def discard_temp(**kwargs): + connection = kwargs.get('connection', None) + if connection: connection.execute("DISCARD TEMP") connection.connection.commit() - connection.close() + +@contextmanager +def closing(thing, pre_finally_fun=None, **kwargs): + try: + yield thing + finally: + try: + if pre_finally_fun: + pre_finally_fun(**kwargs) + finally: + thing.close()