In [31]:
from Bio import SeqIO
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt
import time
import seaborn as sns

%matplotlib inline

In [32]:
def greedy_partition(df, threshold):

    ignore_index = []
    tally_dict = {}

    debug=False

    for row_id in df.index:

        if debug:
            print(row_id)
            print(ignore_index)

        # if we've already assigned this index as a child, don't need to consider it's row
        if not(row_id in ignore_index):

            # get the full dataframe row for this sample
            this_row_init = df.loc[row_id]  

            # remove columns that are already children of other nodes
            this_row = this_row_init.drop(labels = ignore_index)  
            #print(this_row)
            if debug:
                print(this_row_init.shape)
                print(this_row.shape)

            # children of this node are those with values greather than the threshold...
            children = list(this_row[this_row > threshold].index)
            children_filtered = [i for i in children if i != row_id ] #...not inlucding the node name itself

            if debug:
                print("len of children_filtered: " + str(len(children_filtered)) + '\n\n')

            # add the node + children to the dictionary, 
            # then add node + children to the set to ignore in future iterations
            tally_dict[row_id] = children_filtered
            ignore_index += children_filtered
            ignore_index += [row_id]

    return(tally_dict)

In [36]:
# taxon-specific data dumps are downloaded from GenBank using the corresponding query, written to .fasta, 
# ... and put in a location available to this script

#this_dump = "moraxellacatarhalis_29497_dump"
#this_dump = "candidaauris_498019_dump"
this_dump = "mpox_10244_dump"
#this_dump = "rhinovirusc_463676_dump"
#this_dump = "chkv_txid37124_dump"

In [37]:
# Run Sourmash on the input fasta containing all accessions for that species

K = 31
ST = 1000
    
start = time.time()

! mkdir {this_dump}
! rm cmp*;
! rm *sig;
! sourmash sketch dna -p k={K},scaled={ST} --singleton {this_dump}.fasta;
! sourmash compare *.sig --containment -o {this_dump}/cmp.dist --csv {this_dump}/cmp.csv;
! mv {this_dump}.fasta.sig {this_dump}

end = time.time()
elapsed = end - start


mkdir: mpox_10244_dump: File exists
zsh:1: no matches found: cmp*
zsh:1: no matches found: *sig
[K
== This is sourmash version 4.6.1. ==
[K== Please cite Brown and Irber (2016), doi:10.21105/joss.00027. ==

[Kcomputing signatures for files: mpox_10244_dump.fasta
[KComputing a total of 1 signature(s) for each input.
[Kcalculated 5452 signatures for 5452 sequences in mpox_10244_dump.fasta
[Ksaved 5452 signature(s) to 'mpox_10244_dump.fasta.sig'. Note: signature license is CC0.
[K
== This is sourmash version 4.6.1. ==
[K== Please cite Brown and Irber (2016), doi:10.21105/joss.00027. ==

[Kloaded 5452 signatures total.                                                  
[K
min similarity in matrix: 0.000
[Ksaving labels to: mpox_10244_dump/cmp.dist.labels.txt
[Ksaving comparison matrix to: mpox_10244_dump/cmp.dist


In [38]:
# Read in the sourmash results and apply greedy clustering to generate the clustered result
result_df = pd.read_csv(this_dump + "/cmp.csv")
result_df.index = result_df.columns
result_df.shape

result_dict = greedy_partition(result_df, .6) # do greedy clustering

In [39]:
# Write out the clustered result for inspection (via BLAST, etc)
len_of_key_pair = []
for key in result_dict:
    print(key)
    print(result_dict[key])
    print('\n')
    len_of_key_pair.append(len(result_dict[key]))

OQ503835.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00439/2022, partial genome
['OQ503834.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00438/2022, partial genome', 'OQ503833.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00445/2022, partial genome', 'OQ503811.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00437/2022, partial genome', 'OQ503810.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00436/2022, partial genome', 'OQ503809.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00435/2022, partial genome', 'OQ503808.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00434/2022, partial genome', 'OQ503807.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00433/2022, partial genome', 'OQ503806.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00432/2022, partial genome', 'OQ503805.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00431/2022, partial genome', 'OQ503804.1 Monkeypox virus isolate MPXV/Human/USA/CA-LACPHL-MA00430/2022, partial

In [40]:
# Compute the compression ratio = initial_size / current_size of DB

# sanity checking the calculations around compression ratio
len_of_key_pair.sort()
#print(sum(len_of_key_pair))
#print(len(result_df.index) - sum(len_of_key_pair))
#len_of_key_pair[-10:]

# compression ratio
print("initial count of accession: " + str(len(result_df.index)))
print("final count after clustering: " + str(len(result_dict.keys())))
print("compression ratio: " + str(len(result_df.index) / len(result_dict.keys())))

initial count of accession: 5452
final count after clustering: 240
compression ratio: 22.716666666666665


### Take the clustered data and generate a new .fasta file of accessions

In [41]:
records = list(SeqIO.parse(this_dump + ".fasta", "fasta"))
new_dump_fasta = this_dump + '.new.fasta'
with open(new_dump_fasta, 'w') as f: 
    for r in records:
        seqid = r.description
        if seqid in result_dict.keys():
            f.write('>' + seqid + '\n' + str(r.seq) + '\n')
f.close()