In [None]:
import os

genome_dir = '../../gtdb-contam-dna'
output_dir = '../../output.gtdb-contam-dna'
genbank_genomes = '../../genbank_genomes'
name = 'GCF_002154655.1_genomic.fna.gz'

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import importlib
import pprint
import yaml
import glob

import charcoal.alignplot
importlib.reload(charcoal.alignplot)

from charcoal import alignplot
from charcoal.alignplot import AlignmentContainer, StackedDotPlot, AlignmentSlopeDiagram
from charcoal import utils

In [None]:
# configure paths to files based on parameters
genomebase = name
queryfile = f'{genome_dir}/{genomebase}'
matches_info_file = f'{output_dir}/{genomebase}.hitlist.matches.yaml'

In [None]:
from IPython.display import Markdown as md
from IPython.display import display
md(f"# Charcoal alignment report for genome `{genomebase}`")

In [None]:
with open(matches_info_file, 'rt') as fp:
    matches_info = yaml.safe_load(fp)
matches_info

genome_lin = matches_info['query_info']['genome_lineage']
match_rank = matches_info['query_info']['match_rank']
scaled = matches_info['query_info']['scaled']

clean_accs = []
dirty_accs = []
for match_acc, acc_info in matches_info['matches'].items():
    match_counts = acc_info['counts']
    match_type = acc_info['match_type']
    match_lineage = acc_info['lineage']

    if match_type == 'clean':
        clean_accs.append((match_acc, match_lineage, match_counts))
    elif match_type == 'dirty':
        dirty_accs.append((match_acc, match_lineage, match_counts))
        
clean_accs.sort(key=lambda x: -x[2])
dirty_accs.sort(key=lambda x: -x[2])

output = []

output.append(f'loaded {len(clean_accs)} clean accs and {len(dirty_accs)} dirty accs')
output.append('')
output.append(f'query genome lineage: `{genome_lin}`\n')

output.append(f'genomes that match the lineage at {match_rank}:')
for (match_acc, match_lineage, match_counts) in clean_accs:
    output.append(f'* `{match_acc}` with est {match_counts*scaled} kb;\n`{match_lineage}`')

output.append('')
output.append('genomes that do NOT match the lineage:')
for (match_acc, match_lineage, match_counts) in dirty_accs:
    output.append(f'* `{match_acc}` with est {match_counts*scaled} kb;\n`{match_lineage}`')

display(md("\n".join(output)))

In [None]:
def load_target_pairs(match_list):
    pairs = []
    for acc, _, _ in match_list:
        filename = glob.glob(f'{genbank_genomes}/{acc}*.fna.gz')
        #assert len(filename) == 1, filename # @CTB
        filename = filename[0]
        pairs.append((acc, filename))
        
    return pairs

contaminant_pairs = load_target_pairs(dirty_accs)
clean_pairs = load_target_pairs(clean_accs)

In [None]:
dirty_alignment = AlignmentContainer(genomebase, queryfile, contaminant_pairs, f'{output_dir}/hitlist-accessions.info.csv')

results = {}
for t_acc, _ in contaminant_pairs:
    mashmap_file = f'{output_dir}/{genomebase}.x.{t_acc}.mashmap.out'
    results[t_acc] = dirty_alignment._read_mashmap(mashmap_file)
dirty_alignment.results = results

display(md('filtering dirty alignments to query size >= 500 and identity >= 95%'))
dirty_alignment.filter(query_size=0.5, pident=95)

sum_dirty_kb = sum(dirty_alignment.calc_shared().values())
display(md(f'**dirty bases: {sum_dirty_kb:.1f}kb of alignments to query genome, across all targets.**'))

In [None]:
clean_alignment = AlignmentContainer(genomebase, queryfile, clean_pairs, f'{output_dir}/hitlist-accessions.info.csv')

results = {}
for t_acc, _ in clean_pairs:
    mashmap_file = f'{output_dir}/{genomebase}.x.{t_acc}.mashmap.out'
    results[t_acc] = clean_alignment._read_mashmap(mashmap_file)
clean_alignment.results = results

display(md('filtering clean alignments to query size >= 500 and identity >= 95%'))
clean_alignment.filter(query_size=0.5, pident=95)

sum_clean_kb = sum(clean_alignment.calc_shared().values())
display(md(f'**clean bases: {sum_clean_kb:.1f}kb of alignments to query genome, across all targets.**'))

In [None]:
x = []
y = []
for i in range (1, 21):
    f = i / 20
    a = dirty_alignment.filter_by_query_coverage(f)

    sum_dirty_kb = sum(a.calc_shared().values())
    x.append(f)
    y.append(sum_dirty_kb)
    
plt.plot(x, y, '.-')
plt.xlabel('fraction of contig that must be covered')
_ = plt.ylabel('kb covered in dirty contigs')


In [None]:
dirty_alignment = dirty_alignment.filter_by_query_coverage(0.5)
clean_alignment = clean_alignment.filter_by_query_coverage(0.5)

## Stacked DotPlot view

(this is an n-ary dotplot)

In [None]:
clean_dotplot = StackedDotPlot(clean_alignment)
_ = clean_dotplot.plot()
_ = plt.title('Alignments to clean genomes')

In [None]:
dirty_dotplot = StackedDotPlot(dirty_alignment)
_ = dirty_dotplot.plot()
_ = plt.title('Alignments to dirty genomes')

In [None]:
clean_slope = AlignmentSlopeDiagram(clean_alignment)
_ = clean_slope.plot()

_ = plt.title('Alignments to clean genomes')

In [None]:
dirty_slope = AlignmentSlopeDiagram(dirty_alignment)
_ = dirty_slope.plot()

_ = plt.title('Alignments to dirty genomes')

## region response curve

underlying logic:

* for our primary use case here (contamination/legitimate "shared" nucleotides), we can consider nt alignments of >= 95% to be contamination to be removed
* we want to remove as many bp of contamination as possible with as little "legit" non-shared nt as possible.
* what's the response curve for that, and can we use it to figure out which genome(s) have the likely contaminants?

In [None]:
for t_acc, _ in contaminant_pairs:
    x, y, sat1 = dirty_dotplot.target_response_curve(t_acc)
    if max(y):
        plt.plot(x, y / max(y)) #, label=f'target loss ({t_acc})')

x3, y3, sat3 = dirty_dotplot.query_response_curve()
plt.plot(x3, y3 / max(y3), label='query loss (non-contaminated genome)')

plt.xlabel('kb in genome contigs removed')
plt.ylabel('fraction of alignments removed')
plt.legend(loc='lower right')

In [None]:
for t_acc, _ in clean_pairs:
    x, y, sat1 = clean_dotplot.target_response_curve(t_acc)
    if max(y):
        plt.plot(x, y / max(y)) #, label=f'target loss ({t_acc})')

x3, y3, sat3 = clean_dotplot.query_response_curve()
plt.plot(x3, y3 / max(y3), label='query loss (non-contaminated genome)')

plt.xlabel('kb in genome contigs removed')
plt.ylabel('fraction of alignments removed')
plt.legend(loc='lower right')
print(f'{sat1:.1f}kb, {sat3:.1f}kb')

## Reporting alignments

In [None]:
regions = []
for k, v in dirty_alignment.results.items():
    regions.extend(v)

queryfile = dirty_alignment.queryfile

# calculate and sort region summed kb in alignments over 95%            
regions_by_query = alignplot.group_regions_by(regions, "query")
regions_aligned_kb = alignplot.calc_regions_aligned_bp(
    regions_by_query, "query", filter_by=lambda r: r.pident >= 95
)
region_items = list(regions_aligned_kb.items())
region_items.sort(key=lambda x: -x[1])

In [None]:
print('Top 5 dirty contigs w/all alignments b/t query and matches --')
for n, (name, aligned_kb) in enumerate(region_items[:5]):
    print(f'contig #{n+1} in query - contig name {name}')
    for i, a in enumerate(regions_by_query[name]):
        print(f"* alignment {i}: {a.pident:.1f}% identity across {abs(a.qend - a.qstart):.0f} kb\n  {a.query}[{int(a.qstart*1000)}:{int(a.qend*1000)}] aligns to {a.target}[{int(a.tstart*1000)}:{int(a.tend*1000)}]")
    print('')