# PostConterminator

**Version**: 0.8-pub (Jan 2025)

**Abstract**: Analyze Conterminator output

**Environment**: Jupyter/Jupyterhub 

## Initialization
### Dependencies

In [None]:
import csv
from itertools import combinations
from math import log10, ceil
import os
from pathlib import Path, PosixPath
import pwd
import sys
import time
from typing import Set, List, Union, Any, Dict, Tuple, Iterable, NewType, Optional, NamedTuple

In [None]:
from Bio.SeqIO.FastaIO import SimpleFastaParser
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm, trange 

Check the python version

In [None]:
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
  print('ERROR! This script requires Python 3.7 or higher.')
  print(f'You are using Python {sys.version_info.major}.{sys.version_info.minor}')
  sys.exit(1)

### Known locations

In [None]:
USERNAME: str = pwd.getpwuid(os.getuid()).pw_name
RELEASE: str = 'nt_wntr25'
NTDIR: PosixPath = Path('/dbs_base_path/dbs') / RELEASE
CONTOUT: PosixPath = NTDIR / Path('nt_umask.result_conterm_prediction')
NTFA: PosixPath = NTDIR / Path('nt_umask.fa')
NTFA_DECON: PosixPath = NTDIR / Path('nt_umask_decon.fa')

### Other constants

In [None]:
#NTNUMSEQ: int = 82792370
CONTCOLS: List[str] = [
    'NumId', 'CedId', 'CedKdom', 'CedSpp', 'AliIni', 'AliEnd',
    'CntgLen', 'CingTopId', 'CingKdom', 'CingSpp', 'CingTopLen', 'NumAli']
    
KINGDOMS: Dict[int, str] = {0:'BactArch', 1:'Fungi', 2:'Metazoa', 3:'Viridipl', 4:'OtherEuk'}

Conterminator output key:

1. Numeric identifier
2. Contaminated identifier
3. Kingdom (default: 0: Bacteria&Archaea, 1: Fungi, 2: Metazoa, 3: Viridiplantae, 4: Other Eukaryotes)
4. Species name
5. Alignment start
6. Alignment end
7. Corrected contig length (length between flanking Ns)
8. Identifier of the longest contaminating sequence
9. Kingdom of the longest contaminating sequence
10. Species name of the longest contaminating sequence
11. Length of the longest contaminating sequence
12. Count how often sequences from the contaminating kingdom align 

#### Aumatically get number of sequences —please check nt version is correct!

In [None]:
blastseqline = !/blast_base_path/blast+/c++/ReleaseMT/bin/blastdbcmd -db nt -info | grep "sequences"
blastdateline = !/blast_base_path/blast+/c++/ReleaseMT/bin/blastdbcmd -db nt -info | grep "Date:"

In [None]:
NTNUMSEQ: int = int(blastseqline[0].strip().split()[0].replace(',',''))
print(f'Num. seq in nt DB is {NTNUMSEQ} from nt version {blastdateline[0]}')

In [None]:
print(f'Cntrmntr output (LUSTRE): {CONTOUT}')

### Get data

#### From single output but latest version

In [None]:
%%time
cont = pd.read_csv(CONTOUT, sep='\t', names=CONTCOLS)

In [None]:
%%time
cont['CedKdom'].replace(KINGDOMS, inplace=True)
cont['CingKdom'].replace(KINGDOMS, inplace=True)

In [None]:
ced_id: Set[str] = set(cont['CedId'])
print(f"Conterminator total entries: {len(cont)}, non-redundant: {len(ced_id)}")
print(f"Conterminator redundancy rate: {(len(cont)-len(ced_id))/len(cont):.2%}")
print(f"NT database inter-kingdom contamination rate:  {len(ced_id)/NTNUMSEQ:.3%}")

In [None]:
plt.figure(figsize=(12,4))
sns.set_theme()
cont['CedKdom'].value_counts().plot(kind='bar', title='Kingdom of contaminatED sequences');

In [None]:
plt.figure(figsize=(12,4))
sns.set_theme(style="whitegrid", palette="pastel")
cont['CingKdom'].value_counts().plot(kind='bar', title='Kingdom of contaminatING sequences');

In [None]:
crosscont = cont[['CingKdom', 'CedKdom']].groupby(['CingKdom', 'CedKdom']).size()

In [None]:
mpl.rc_file_defaults()  # Restore mpl defaults after sns.set_theme() changes

fig, ax = plt.subplots(figsize=(14,8))

ARC_RAD: float = 0.08  # Set arc edges angle 
NODE_SIZE: int = 12000
ARROW_SIZE: int = 20
EDGE_NORM_FACTOR: float = 50.0

# Create and populate graph
G = nx.MultiDiGraph()
edge_colors: List[int] = [] 
for k1 in KINGDOMS.values():
    for k2 in KINGDOMS.values():
        if k1 is k2:
            continue
        G.add_edge(str(k1), str(k2), crosscont[k1, k2])
        edge_colors.append(crosscont[k1, k2])
sum_edges: int = sum(edge_colors)
edge_width: List[float] = list(map(lambda e: e/sum_edges*EDGE_NORM_FACTOR, edge_colors))
        
values = [0, 1, 2, 3, 4]
pos = nx.spring_layout(G, seed=1)

M = G.number_of_edges()
edge_alphas = [(5 + i) / (M + 4) for i in range(M)]
cmap = plt.cm.plasma

nx.draw_networkx_nodes(G,
                       pos,
                       node_color = values,
                       node_size = NODE_SIZE)

label_options = {"ec": "k", "fc": "white", "alpha": 0.7}
nx.draw_networkx_labels(G, pos, font_size=14, bbox=label_options)

curved_edges = [edge for edge in G.edges() if tuple(reversed(edge)) in G.edges()]
straight_edges = list(set(G.edges()) - set(curved_edges))
edges_straight = nx.draw_networkx_edges(G,
                                        pos,
                                        ax=ax,
                                        node_size=NODE_SIZE,
                                        arrowstyle="-|>",
                                        arrowsize=ARROW_SIZE,
                                        edge_color=edge_colors,
                                        edge_cmap=cmap,
                                        width=edge_width,
                                        edgelist=straight_edges)
edges_curved = nx.draw_networkx_edges(G,
                                      pos,
                                      ax=ax,
                                      node_size=NODE_SIZE,                       
                                      arrowstyle="-|>",
                                      arrowsize=ARROW_SIZE,
                                      edge_color=edge_colors,
                                      edge_cmap=cmap,
                                      width=edge_width,
                                      edgelist=curved_edges,
                                      connectionstyle=f'arc3, rad = {ARC_RAD}')
#fig.savefig("2.png", bbox_inches='tight',pad_inches=0)

pc = mpl.collections.PatchCollection(edges_straight + edges_curved, cmap=cmap)
pc.set_array(edge_colors)
clb = plt.colorbar(pc)
#clb.ax.tick_params(labelsize=8) 
clb.ax.set_ylabel('Sequences')

# Title/legend
font = {"fontname": "DejaVu Sans", "color": "k", "fontweight": "bold", "fontsize": 14}
ax.set_title("NCBI BLAST nt DB inter-kingdom contamination network", font)

# Subtitles
ax = plt.gca()
font = {"fontname": "DejaVu Sans", "color": "b", "fontweight": "bold", "fontsize": 12}
ax.text(
    0.02, 0.94,
    f"DB version: {RELEASE} ({NTNUMSEQ/1e+6:.2f} Mseqs)",
    horizontalalignment="left",
    transform=ax.transAxes,
    fontdict=font,
)
font = {"fontname": "DejaVu Sans", "color": "purple", "fontweight": "light", "fontsize": 10}
ax.text(
    0.02, 0.91,
    f"{len(ced_id)} non-redundant contaminated seqs (rate: {len(ced_id)/NTNUMSEQ:.2%})",
    horizontalalignment="left",
    transform=ax.transAxes,
    fontdict=font,
)
ax = plt.gca()
font = {"fontname": "DejaVu Sans", "color": "b", "fontweight": "light", "fontsize": 10}
ax.text(
    0.02, 0.88,
    "edge width and color ~ number of seq",
    horizontalalignment="left",
    transform=ax.transAxes,
    fontdict=font,
)

# Resize figure for label readibility
ax.margins(0.1, 0.1)
fig.tight_layout()
plt.axis("off")
plt.savefig(f'{RELEASE}_contam_profile.pdf')
plt.show()

In [None]:
crosscont

In [None]:
%%time
MAXREADS: Optional[int] = None
cont_ids: Set[str] = set(cont['CedId'])
removed: int = 0
    
with open(NTFA, 'rt') as fa_cont, \
     open(NTFA_DECON, 'wt') as fa_desc:
    for i, (title, seq) in tqdm(enumerate(SimpleFastaParser(fa_cont)), total=NTNUMSEQ):
        if MAXREADS is not None and i >= MAXREADS:
            print('[stopping by maxreads limit!]')
            break
        cont_id: str = title.split()[0]
        try:
            cont_ids.remove(cont_id)
        except KeyError:
            fa_desc.write(f'>{title}\n')
            fa_desc.write(seq + '\n')
        else:
            removed += 1

print(f"\n{removed} contaminated sequences removed")
if cont_ids:
    print(f"Unable to remove {len(cont_ids)} seqs!")
else:
    print(f"Success! All the detected contaminants removed!")                        