## Description

## Data & modules

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import os
from path import Path
from ete3 import Tree
from Bio import SeqIO
from ete3 import NCBITaxa
ncbi = NCBITaxa()
from hgt_algorithms import *

In [None]:
tree_dir = Path('second_round_raw_trees/')
input_fasta_file = 'out.mcl_twoway_filter.mci.I17.second_stage.fa'  
output_tree_dir = Path('second_round_processed_trees')
output_fasta_dir = Path('third_round_clusters')

In [None]:
target_accessions = open('basal_accessions.txt')
target_accessions = set(l.strip() for l in target_accessions if l.strip())

In [None]:
taxid_table = open('final_homolog_taxid_table')
taxid_table = {l.strip().split()[0]: l.strip().split()[1] for l in taxid_table if l.strip()}

In [None]:
fungi_taxid = 4751

## Selecting a neighbourhood of fungal proteins

For each tree, we'll retain only those leaves that are closer than a given distance threshold (in sps) from any target protein. 

In [None]:
sps_neihbourhood_threshold = 3.33

tree_files = [l for l in os.listdir(tree_dir) if 'treefile' in l]
raw_tree_sizes = []
raw_trees = []
rogue_pruned_tree_sizes = []
truncated_tree_sizes = []
truncated_trees = []
discarded_trees = []
raw_diameters = []
rogue_pruned_diameters = []
truncated_diameters = []
for fname in tree_files:
    T = Tree(tree_dir/fname, format=1)  # treating support as internal node names to handle missing support values
    T.set_outgroup(T.get_midpoint_outgroup())
    T.filename = fname
    raw_diameters.append(sum(c.get_farthest_leaf()[1] + c.dist for c in T.children))
    raw_tree_sizes.append(len(T))
    raw_trees.append(T.copy())
    all_leaf_names = set([l.name for l in T])
    non_rogue_names = all_leaf_names - consensus_rogues
    ## Disabled rogue pruning cos we lose nice HGTs 
    # T.prune(non_rogue_names, preserve_branch_length=True)
    target_in_T = set(l for l in T if l.name in target_accessions)
    if not target_in_T:
        discarded_trees.append(T)
        continue
    rogue_pruned_diameters.append(sum(c.get_farthest_leaf()[1] + c.dist for c in T.children))
    rogue_pruned_tree_sizes.append(len(T))
    min_distances_to_target = [min(T.get_distance(t, l) for t in target_in_T) for l in T]
    target_neighbourhood = [t for d,t in zip(min_distances_to_target, T) if d <= sps_neihbourhood_threshold]
    T.prune(target_neighbourhood, preserve_branch_length=True)
    if len(T) >= 4:
        truncated_tree_sizes.append(len(T))
        truncated_trees.append(T)
        truncated_diameters.append(sum(c.get_farthest_leaf()[1] + c.dist for c in T.children))
    
print('Processed %i trees, %i sequences, from' % (len(raw_tree_sizes), sum(raw_tree_sizes)), tree_dir)
print('Retained %i trees (>=4 leaves) with target proteins after rogue removal' % len(rogue_pruned_tree_sizes))
print('Retained %i trees, %i sequences, after truncating neighbourhood, smallest tree has %i leaves' % (len(truncated_tree_sizes), sum(truncated_tree_sizes), min(truncated_tree_sizes)))

Cells for a manual inspection of results:

In [None]:
# discarded_filenames = set(tree_files) - set(T.filename for T in truncated_trees)
# list(discarded_filenames)[:10]

In [None]:
# T = raw_trees[tree_files.index('663.fa.aln.treefile')]  
# target_in_T = [l.name for l in T if l.name in target_accessions]
# target_in_T

In [None]:
# min_distances_to_target = [min(T.get_distance(t, l) for t in target_in_T) for l in T]
# plt.figure()
# plt.hist(min_distances_to_target, bins=40)
# plt.show()

In [None]:
# retained_filenames = [T.filename for T in truncated_trees]
# '453.fa.aln.treefile' in retained_filenames

In [None]:
# T = truncated_trees[truncated_tree_sizes.index(min(truncated_tree_sizes))]
# print(T.filename)
# set(l for l in T if l.name in target_accessions)

In [None]:
# T = discarded_trees[0]
# print(T.filename)
# T = Tree(tree_dir/T.filename, format=1)
# set(l for l in T if l.name in target_accessions)

In [None]:
# T = truncated_trees[retained_filenames.index('453.fa.aln.treefile')]
# print(T.filename)
# set(l for l in T if l.name in target_accessions)
# print(T)

In [None]:
plt.figure()
plt.figure(figsize=(8,4))
plt.subplot(131)
plt.title('Leaf count before cutting')
plt.hist(raw_tree_sizes, bins=40)
plt.subplot(132)
plt.title('Leaf count after rogue removal')
plt.hist(rogue_pruned_tree_sizes, bins=40)
plt.subplot(133)
plt.title('Leaf count after nhbd selection')
plt.hist(truncated_tree_sizes, bins=40)
plt.tight_layout()
plt.show()
print('Largest tree size before cutting:', max(raw_tree_sizes))
print('Largest tree size after cutting:', max(truncated_tree_sizes))

In [None]:
plt.figure(figsize=(8,4))
plt.subplot(131)
plt.title('Diameter before cutting')
plt.hist(raw_diameters, bins=40)
plt.subplot(132)
plt.title('Diameter after rogue pruning')
plt.hist(rogue_pruned_diameters, bins=40)
plt.subplot(133)
plt.title('Diameter after neighbourhood selection')
plt.hist(truncated_diameters, bins=40)
plt.tight_layout()
plt.show()
print('Largest diameter of raw tree:', max(raw_diameters), 'cluster', tree_files[raw_diameters.index(max(raw_diameters))])
print('Largest diameter after rogue pruning:', max(rogue_pruned_diameters))
print('Largest diameter after nbhd selection:', max(truncated_diameters))
print('Smallest diameter of a raw tree:', min(raw_diameters))

## Long branch cutting

For each tree, we'll iteratively remove the longest branch (thus cutting the tree into two trees) until all branches are shorter than a given threshold (in sps).  

In [None]:
sps_threshold = 1.598

cut_trees = []
cut_tree_sizes = []
selected_cut_trees = []
cut_diameters = []
selected_diameters = []
selected_tree_sizes = []
for T in truncated_trees:
    for i,n in enumerate(T.traverse('postorder')):
        if not n.is_leaf():
            n.name = str(i)  # labelling internal nodes for processing reasons
    UT = UnrootedForest(T)
    UT.disintegrate(sps_threshold)
    F = UT.get_ete3()
    F = [FT for FT in F if len(FT) >= 4]
    for FT in F: FT.filename = T.filename
    cut_trees.extend(F)
    cut_diameters.extend([sum(c.get_farthest_leaf()[1] + c.dist for c in FT.children) for FT in F])
    cut_tree_sizes.extend([len(FT) for FT in F])
    for FT in F:
        remaining_accessions = set([l.name for l in FT])
        remaining_taxa = set([taxid_table[acc] for acc in remaining_accessions])
        lineages = {tx: ncbi.get_lineage(tx) for tx in remaining_taxa}  
        fungal_taxids = {tx for tx in remaining_taxa if fungi_taxid in lineages[tx]}
        if fungal_taxids and len(remaining_taxa) >= len(fungal_taxids) + 3:
            selected_cut_trees.append(FT)
            selected_diameters.append(sum(c.get_farthest_leaf()[1] + c.dist for c in FT.children))
            selected_tree_sizes.append(len(FT))
    
print('Processed %i trees' % len(truncated_trees))
print('Obtained %i trees (>=4 leaves) after cutting long branches' % len(cut_trees))
print('Obtained %i trees (>=4 leaves) with target proteins' % len(selected_cut_trees))

Inspect the results:

In [None]:
plt.figure(figsize=(8,4))
plt.subplot(131)
plt.title('Leaf count before cutting')
plt.hist(raw_tree_sizes, bins=40)
plt.subplot(132)
plt.title('Leaf count after cutting')
plt.hist(cut_tree_sizes, bins=40)
plt.subplot(133)
plt.title('Leaf count after selection')
plt.hist(selected_tree_sizes, bins=40)
plt.tight_layout()
plt.show()
print('Largest tree size before cutting:', max(raw_tree_sizes))
print('Largest tree size after cutting:', max(cut_tree_sizes))
print('Largest cut tree with target proteins:', max(selected_tree_sizes))

In [None]:
plt.figure(figsize=(8,4))
plt.subplot(131)
plt.title('Diameter before cutting')
plt.hist(truncated_diameters, bins=40)
plt.subplot(132)
plt.title('Diameter after cutting')
plt.hist(cut_diameters, bins=40)
plt.subplot(133)
plt.title('Diameter after selecting targets')
plt.hist(selected_diameters, bins=40)
plt.tight_layout()
plt.show()
print('Largest diameter before cutting:', max(truncated_diameters))
print('Largest diameter after cutting:', max(cut_diameters))
print('Largest diameter with target:', max(selected_diameters))
print('Smallest diameter with target:', min(selected_diameters))

Check the minimal distance to any target protein for all leaves in an example tree:

In [None]:
target_in_T = set(l for l in T if l.name in target_accessions)
min_distances_to_target = [min(T.get_distance(t, l) for t in target_in_T) for l in T]

In [None]:
plt.figure()
plt.hist(min_distances_to_target, bins=40)
plt.show()

In [None]:
sum(d < 2.48 for d in min_distances_to_target)

## Saving processed trees & cluster FASTAs

Create the output directory for processed trees; If it exists, erase its contents; Save the trees:

In [None]:
try:
    os.mkdir(output_tree_dir)
except FileExistsError:
    dir_contents = os.listdir(output_tree_dir)
    for f in dir_contents:
        os.remove(output_tree_dir + '/' + f)
        
for i, T in enumerate(selected_cut_trees):
    T.write(outfile = output_tree_dir / '%i.treefile' % i, format = 5)

Saving tree FASTAs for re-alignment:

In [None]:
sequences = list(SeqIO.parse(joint_fasta_file, 'fasta'))
sequences = {s.id: s for s in sequences}

Create the output directory for FASTAs of sequences in the processed trees; If it exists, erase its contents; Save the FASTAs:

In [None]:
try:
    os.mkdir(output_fasta_dir)
except FileExistsError:
    dir_contents = os.listdir(output_fasta_dir)
    for f in dir_contents:
        os.remove(output_fasta_dir + '/' + f)

for i, T in enumerate(selected_cut_trees):
    cluster = [sequences[l.name] for l in T]
    with open(output_fasta_dir / 'Cluster_%i.fa' % i, 'w') as h:
        SeqIO.write(cluster, h, 'fasta')

## Alternative version - long leaf removal without internal branch cutting

In [None]:
# sps_threshold = 1.328
# long_leaves = []
# trees = os.listdir(tree_dir)
# all_accessions_in_trees = []
# for fname in trees:
#     T = Tree(tree_dir/fname, format=1)  # treating support as internal node names to handle missing support values
#     T.set_outgroup(T.get_midpoint_outgroup())
#     all_accessions_in_trees.extend([l.name for l in T])
#     for i,n in enumerate(T.traverse('postorder')):
#         if not n.is_leaf():
#             n.name = str(i)
#     UT = UnrootedForest(T)
#     UT.disintegrate(sps_threshold)
#     F = UT.get_ete3()
#     long_leaves.extend([tf.name for tf in F if len(tf) == 1])
# assert len(long_leaves) == len(set(long_leaves))
# assert len(all_accessions_in_trees) == len(set(all_accessions_in_trees))
# long_leaves = set(long_leaves)
# all_accessions_in_trees = set(all_accessions_in_trees)
# print('Found', len(long_leaves), 'long leaves, including', len(target_accessions & long_leaves), 'target ones for sps threshold', sps_threshold)

Found 3914 long leaves, including 776 target ones for sps threshold 1.1158, full disintegration    
Found 2623 long leaves, including 578 target ones for sps threshold 1.328, full disintegration

In [None]:
# remaining_accessions = all_accessions_in_trees - (ml_rogues | long_leaves)

In [None]:
# print('Long leaf and ML rogues:', len(ml_rogues & long_leaves))
# print('Removed accessions:', len(ml_rogues | long_leaves), 'out of', len(all_accessions_in_trees))
# print('Removed target accessions:', len((ml_rogues|long_leaves)&target_accessions), 'out of', len(all_accessions_in_trees & target_accessions))
# print('Remaining accessions:', len(remaining_accessions))

Saving remaining sequences for reclustering:

In [None]:
# raw_cluster_dir = 'first_round_clusters/' + CLUSTER_DIR
# raw_cluster_files = os.listdir(raw_cluster_dir)
# all_sequences = SeqIO.parse(main_sequence_file, 'fasta')
# all_remaining_sequences = [] 
# for s in all_sequences:
#     if s.id in remaining_accessions:
#         all_remaining_sequences.append(s)

In [None]:
# print('Retrieved', len(all_remaining_sequences), 'sequences')

In [None]:
# new_cluster_dir = 'second_round_clusters/'
# new_cluster_joint_fasta_file = new_cluster_dir + CLUSTER_DIR + '.fa'
# with open(new_cluster_joint_fasta_file, 'w') as h:
#     SeqIO.write(all_remaining_sequences, h, 'fasta')

## Optional: rogue taxon data parsing using RogueNaRok's results

In [None]:
# roguenarok_files = [f for f in os.listdir(tree_dir) if '_droppedRogues' in f]
# ml_rogues = []
# consensus_rogues = []
# empty_result_files = []
# for fname in roguenarok_files: 
#     if 'droppedRogues' in fname:
#         rogues = open(tree_dir / fname).readlines()
#         if not rogues:
#             empty_result_files.append(fname)
#             continue
#         rogue_taxa = set()
#         for l in rogues:
#             l = l.strip().split('\t')
#             taxa =  l[2]
#             if taxa not in {'NA', 'taxon'}:
#                 taxa = taxa.split(',')
#                 rogue_taxa.update(taxa)
#         if 'MLtree' in fname:
#             assert not rogue_taxa & set(ml_rogues)
#             ml_rogues.extend(rogue_taxa)
#         if 'ufboot_consensus' in fname:
#             assert not rogue_taxa & set(consensus_rogues)
#             consensus_rogues.extend(rogue_taxa)
# ml_rogues = set(ml_rogues)
# consensus_rogues = set(consensus_rogues)
# print('Found', len(ml_rogues), 'ML rogues, including', len(ml_rogues & target_accessions), 'target, and', len(consensus_rogues), 'consensus rogues, including', len(consensus_rogues & target_accessions), 'targets')


## SPS value testing

In [None]:
sps_to_identity(2.11)

In [None]:
identity_to_sps(0.3)

In [None]:
identity_to_sps(0.2)

In [None]:
idv_std(3, 300)

In [None]:
sps = 1.1158
sps_to_identity(sps) - idv_std(sps, 300)

In [None]:
identity_to_sps(sps_to_identity(0.8))

In [None]:
sps_to_identity(identity_to_sps(0.5))

In [None]:
sps_values = np.linspace(0, 5)
id_values = sps_to_identity(sps_values)
id_stds = idv_std(sps_values)
plt.figure()
plt.plot(sps_values, id_values)
plt.plot(sps_values, id_values + id_stds)
plt.plot(sps_values, id_values - id_stds)