In [None]:
import sevenbridges as sbg
from sevenbridges.errors import SbgError
from sevenbridges.http.error_handlers import rate_limit_sleeper, maintenance_sleeper
import sys
import re
import concurrent.futures
import pdb
config = sbg.Config(profile='turbo') # assumes sbg credentials file, use your profile name here
api = sbg.Api(config=config, error_handlers=[rate_limit_sleeper, maintenance_sleeper])

## Patch input bams

In [None]:
def patch_metadata(info):
    try:
        field = info.rstrip('\n').split('\t')
        cur_file = api.files.get(field[0])
        metadata = {}
        for i in range(start, len(field)):
            metadata[header[i]] = field[i]
        cur_file.metadata = metadata
        cur_file.save()
        return 0
    except Exception as e:
        sys.stderr.write(str(e))
        sys.stderr.write('Failed processing ' + info)
        return 1

In [None]:
manifest = '/Users/brownm28/Documents/2021-Dec-01_OT_PDX/bam_metadata_to_update.tsv'
file_data = open(manifest)
head = next(file_data)
header = head.rstrip('\n').split('\t')
start = 5 # which array position to start reading data from

i = 1
n = 100

with concurrent.futures.ThreadPoolExecutor(16) as executor:
    results = {executor.submit(patch_metadata, line): line for line in file_data}
    for result in concurrent.futures.as_completed(results):
        if result.result() == 1:
            exit(1)
        if i % n == 0:
            sys.stderr.write(str(i) + ' files updated\n')
        i += 1



## Patch batch task outputs

In [None]:
def patch_bam_fastq(task):
    """
    Reads rg.txt output, add read group info to output fastqs and back into source bam.
    Also patches over some data from bam to fastq
    """
    try:
        api_rg_dict = {"ID": "sample_id", "PL": "platform", "PU": "platform_unit_id", "LB": "library_id", "CN": "center"}
        rg_str = task.outputs['rg_string'].content()
        rg_array = rg_str.rstrip('\n').split('\t')
        rg_dict = {}
        for pairs in rg_array:
            item = pairs.split(':')
            if item[0] in api_rg_dict:
                rg_dict[api_rg_dict[item[0]]] = item[1]
        # add rg info to input bam
        bam_meta = {}

        for key in task.inputs['input_bam'].metadata:
            bam_meta[key] = task.inputs['input_bam'].metadata[key]
        for key in rg_dict:
            bam_meta[key] = rg_dict[key]
        
        # need to get actual file obj to mod
        input_bam = api.files.get(task.inputs['input_bam'].id)
        input_bam.metadata = bam_meta
        try:
            input_bam.save()
        except Exception as e:
            print(e)
        # add same info to output fastq except genome
        bam_meta.pop('reference_genome')
        for i in range(0, len(task.outputs['output'])):
            fastq = api.files.get(task.outputs['output'][i].id)
            fastq.metadata = bam_meta
            if re.search("R1", fastq.name):
                fastq.metadata['paired_end'] = '1'
            else:
                fastq.metadata['paired_end'] = '2'
            try:
                fastq.save()
            except Exception as e:
                print(e)
        return 0
    except Exception as e:
        return(e)
    
    

In [None]:
batch_obj = api.tasks.get('affaaaea-7c37-4703-af2d-8a2e54c59353')
child_tasks = batch_obj.get_batch_children(status="COMPLETED").all()
n = 0
for task in child_tasks:
    n +=1
print(n)
exit(0)
#pdb.set_trace()
i = 1
n = 100
with concurrent.futures.ThreadPoolExecutor(16) as executor:
    results = {executor.submit(patch_bam_fastq, task): task for task in child_tasks}
    for result in concurrent.futures.as_completed(results):
        if result.result() != 0:
            print(result.result())
            # exit(1)
        if i % n == 0:
            sys.stderr.write(str(i) + ' files updated\n')
        i += 1


## Set up PIVOT/JAX RNAseq WF Tasks

In [None]:
def get_fastqs(project):
    """
    Get all RNAseq fastqs, organize into a dict by sample and read1/2
    Used for both workflows!
    """
    fastqs = api.files.query(project=project, metadata = {'experimental_strategy': 'RNA-Seq', 'reference_genome': None}).all()
    fq_dict = {}
    for fastq in fastqs:
        if fastq.metadata['experimental_strategy'] == 'RNA-Seq' and  fastq.metadata['reference_genome'] == None and 'fastq.gz' in fastq.name:
            if fastq.metadata['sample_id'] not in fq_dict:
                fq_dict[fastq.metadata['sample_id']] = {}
            fq_dict[fastq.metadata['sample_id']][fastq.metadata['paired_end']] = api.files.get(fastq.id)
    return fq_dict

In [None]:
def get_jax_refs():
    ref_dict = {}
    ref_dict['ref_flat'] = api.files.get('61ae801faad1f926aea3031d')
    ref_dict['rsem_prepare_reference_archive'] = api.files.get('61ae7f30aad1f926aea3030f')
    ref_dict['ribosomal_intervals'] = api.files.get('61ae7f2faad1f926aea30309')
    ref_dict['index_file'] = api.files.get('61ae7fbeaad1f926aea30317')
    ref_dict['Sites'] = api.files.get('61ae8141aad1f926aea30329')
    ref_dict['Reference'] = api.files.get('61a923b3aad1f926aea19eb7')
    ref_dict['forward_prob'] = 0
    ref_dict['strand_specificity'] = 'SECOND_READ_TRANSCRIPTION_STRAND'
    return ref_dict
    

In [None]:
def draft_jax_task(samp_id,fq_dict, ref_dict, app_name):
    try:
        input_dict = {}
        for key in ref_dict:
            input_dict[key] = ref_dict[key]
        input_dict['input_pair'] = [fq_dict[samp_id]['1'], fq_dict[samp_id]['2']]
        task_name = 'PDX RNA Expression Estimation Workflow: ' + samp_id
        task = api.tasks.create(name=task_name, project=project, app=app_name, inputs=input_dict, run=False)
        task.save()
    except Exception as e:
        pdb.set_trace()
        hold = 1
        


In [None]:
project = 'd3b-bixu/open-targets-pdx-workflow-dev'
fastq_dict = get_fastqs(project) # only run once depending on wf
jax_refs = get_jax_refs()


In [None]:
app_name = 'd3b-bixu/open-targets-pdx-workflow-dev/pdx-rna-tumor-only'
i = 0
n = 20
for samp_id in fastq_dict:
    if i % n == 0:
        print ("Drafted " + str(i) + " tasks")
    draft_jax_task(samp_id, fastq_dict, jax_refs, app_name)
    i+=1

## Set up KF RNAseq wf

In [None]:
def get_kf_refs():
    ref_dict = {}
    ref_dict['xenome_index'] = api.files.get('61ae7fbeaad1f926aea30317')
    ref_dict['input_type'] = 'FASTQ'
    ref_dict['runThread'] = 36
    ref_dict['wf_strand_param'] = 'rf-stranded'
    ref_dict['idx_prefix'] = 'hg38_broad_NOD_based_on_mm10_k25'
    return ref_dict
    

In [None]:
def draft_kf_ot_task(samp_id,fq_dict, ref_dict, app_name):
    try:
        input_dict = {}
        for key in ref_dict:
            input_dict[key] = ref_dict[key]
        input_dict['reads1'] = fq_dict[samp_id]['1']
        input_dict['reads2'] = fq_dict[samp_id]['2']
        input_dict['sample_name'] = samp_id
        # some vars with rg info for ease of reading
        rid = "ID:" + samp_id
        pl = "PL:" + fq_dict[samp_id]['1'].metadata['platform']
        pu = "PU:" + fq_dict[samp_id]['1'].metadata['platform_unit_id']
        lb = "LB:" + fq_dict[samp_id]['1'].metadata['library_id']
        sm = "SM:" + samp_id
        cn = "CN:" + fq_dict[samp_id]['1'].metadata['center']
        # ID:PPTC-AF06-XTP1-A-1-0-R PL:Illumina PU:7007001353_20171208_CBK13ACXX-6-ID11 LB:ICD_IND-PPCTC.PPTC-AF06-XTP1-A-1-0-R-1_1pA SM:PPTC-AF06-XTP1-A-1-0-R CN:BCM
        input_dict['STAR_outSAMattrRGline'] = "\t".join([rid, pl, pu, lb, sm, cn])
        task_name = 'Kids First OT PDX RNA-Seq Workflow: ' + samp_id
        task = api.tasks.create(name=task_name, project=project, app=app_name, inputs=input_dict, run=False)
        task.inputs['output_basename'] = task.id
        task.save()
    except Exception as e:
        pdb.set_trace()
        hold = 1


In [None]:
project = 'd3b-bixu/open-targets-pdx-workflow-dev'
# fastq_dict = get_fastqs(project) # only run once depending on wf
kf_refs = get_kf_refs()
app_name = 'd3b-bixu/open-targets-pdx-workflow-dev/kfdrc-pdx-rnaseq-wf'
i = 0
n = 20
for samp_id in fastq_dict:
    if i % n == 0:
        print ("Drafted " + str(i) + " tasks")
    draft_kf_ot_task(samp_id, fastq_dict, kf_refs, app_name)
    i+=1


## Re-draft failed tasks

In [None]:
project = 'd3b-bixu/open-targets-pdx-workflow-dev'
failed = api.tasks.query(project=project, status="FAILED").all()
for fail in failed:
    if fail.name.startswith('PDX RNA Expression Estimation Workflow:') or fail.name.startswith('Kids First OT PDX RNA-Seq Workflow:'):
        print("Re-drafting " + fail.id + " " + fail.name)
        rerun = fail.clone(run=False)
        rerun.save()

## Patch metadata to outputs

In [None]:
def patch_jax_output(task):
    # just add workflow to files
    for output in task.outputs:
        if type(task.outputs[output]) is list:
            for obj in task.outputs[output]:
                file_obj = api.files.get(obj.id)
                file_obj.metadata['workflow'] = 'd3b-bixu/open-targets-pdx-workflow-dev/pdx-rna-tumor-only'
                file_obj.save()
        else:
            file_obj = api.files.get(task.outputs[output].id)
            file_obj.metadata['workflow'] = 'd3b-bixu/open-targets-pdx-workflow-dev/pdx-rna-tumor-only'
            file_obj.save()
    return 0


In [None]:
def patch_kf_ot_output(task):
    try:
        metadata = {}
        for key in task.inputs['reads1'].metadata:
            metadata[key] = task.inputs['reads1'].metadata[key]
        metadata.pop('paired_end')
        metadata['workflow'] = 'd3b-bixu/open-targets-pdx-workflow-dev/kfdrc-pdx-rnaseq-wf'
        metadata['reference_genome'] = 'GRCh38'
        metadata['annotation'] = 'GENCODE27'
        for output in task.outputs:
            if task.outputs[output] is not None:
                if type(task.outputs[output]) is list:
                    for obj in task.outputs[output]:
                        file_obj = api.files.get(obj.id)
                        file_obj.metadata = metadata
                        file_obj.save()
                else:
                    file_obj = api.files.get(task.outputs[output].id)
                    file_obj.metadata = metadata
                    file_obj.save()
    except Exception as e:
        print(e)
    return 0


In [None]:
def filter_task(task):
#     if task.name.startswith('PDX RNA Expression Estimation Workflow:'):
#         try:
#             patch_jax_output(task)
#         except Exception as e:
#             sys.stderr.write(str(e) + '\n')
#         return 1
        
    if task.name.startswith('Kids First OT PDX RNA-Seq Workflow:'):
        patch_kf_ot_output(task)
        return 1
    return 0


In [None]:
project = 'd3b-bixu/open-targets-pdx-workflow-dev'
completed = api.tasks.query(project=project, status="COMPLETED").all()
i = 1
n = 50
        
with concurrent.futures.ThreadPoolExecutor(16) as executor:
    results = {executor.submit(filter_task, task): task for task in completed}
    for result in concurrent.futures.as_completed(results):
        i += result.result()
        if i % n == 0:
            sys.stderr.write("Patched " + str(i) + " task outputs\n")

# for task in completed:
#     filter_task(task)

## Cleanup failed tasks and file names

In [None]:
project = 'd3b-bixu/open-targets-pdx-workflow-dev'
failed = api.tasks.query(project=project, status="FAILED").all()
for task in failed:
    if task.name.startswith('PDX RNA Expression Estimation Workflow:') or task.name.startswith('Kids First OT PDX RNA-Seq Workflow:'):
        for output in task.outputs:
            if task.outputs[output] is not None:
                if type(task.outputs[output]) is list:
                    for obj in task.outputs[output]:
                        file_obj = api.files.get(obj.id)
                        file_obj.delete()
                        print ("Deleted " + file_obj.name + " from task " + task.name)
                else:
                    file_obj = api.files.get(task.outputs[output].id)
                    file_obj.delete()
                    print ("Deleted " + file_obj.name + " from task " + task.name)


In [None]:
task_file = open('/Users/brownm28/Documents/2021-Dec-01_OT_PDX/cleanup_task_outputs.txt')
task_dict = {}
for line in task_file:
    task_dict[line.rstrip('\n')] = 0
project = 'd3b-bixu/open-targets-pdx-workflow-dev'
completed = api.tasks.query(project=project, status="COMPLETED").all()
for task in completed:
    if task.name in task_dict:
        sys.stderr.write("processing " + task.name + "\n")
        for output in task.outputs:
            if task.outputs[output] is not None:
                if type(task.outputs[output]) is list:
                    for obj in task.outputs[output]:
                        if obj.name.startswith('_1_'):
                            file_obj = api.files.get(obj.id)
                            file_obj.name = file_obj.name.replace('_1_','', 1)
                            file_obj.save()
                            print ("Renamed " + file_obj.name + " from task " + task.name)
                elif task.outputs[output].name.startswith('_1_'):
                    file_obj = api.files.get(task.outputs[output].id)
                    file_obj.name = file_obj.name.replace('_1_','', 1)
                    file_obj.save()
                    print ("Renamed " + file_obj.name + " from task " + task.name)


