In [None]:
from collections import defaultdict, Counter
import subprocess
from tqdm.notebook import tqdm
from pathlib import Path
import sklearn
from scipy import sparse
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score
from sklearn.ensemble import RandomForestClassifier
import time
import shutil
import json
import os
import csv


In [None]:
rep_dict = defaultdict(list)
with open("data/clusterRes_cluster.tsv") as clusters:
    for rep, member in (line.rstrip().split('\t') for line in clusters):
        rep_dict[rep].append(member)


In [None]:
# Skip
sim_dict = defaultdict(list)
for rep, members in tqdm(rep_dict.items(), total=len(rep_dict)):
    if len(members) == 1:
        sim_dict[rep].append(0)
        continue
    with open("temp_input.txt", 'w') as temp_query:
        temp_query.writelines((f"data/all_training_files/{mem}.fasta\n" for mem in members))
    #command = f"fastANI --ql temp_input.txt -r all_training_files/{rep}.fasta -o temp_out.txt -t 40"
    command = f"mash dist data/all_training_files/{rep}.fasta temp_input.txt -l -t -p 20"
    command_out = subprocess.check_output(command, shell=True).decode("utf-8")
    for line in [line.split('\t') for line in command_out.split('\n') if line != ""][1:] :
        sim_dict[rep].append(float(line[1]))
        if (float(line[1]) >= .75):
            print(rep, line)

In [None]:
# Skip
count = 0
singleton_count = 0
for rep in sim_dict:
    count += sum(1 for sim in sim_dict[rep] if sim > 0.3)
    if len(sim_dict[rep]) == 1:
        singleton_count += 1

In [None]:
# Skip
print(len(rep_dict))
print(count)
print(singleton_count)
import numpy as np
print(np.mean([len(sim_dict[rep]) for rep in sim_dict if len(sim_dict[rep]) > 1]))
#print(max(sim_dict.items(), key=lambda x: len(x[1])))
print(len(sim_dict["J7OEM"]))
print(len(sim_dict["PUKJQ"]))

In [None]:
# Skip
sequence_to_label = {}
with open("training_labels.tsv") as labels:
    for line in (line.rstrip().split('\t') for line in labels):
        for member in line[1:]:
            sequence_to_label[member] = line[0]
J7OEM_labels = set(sequence_to_label[member] for member in rep_dict["J7OEM"])
print(len(J7OEM_labels))
print(J7OEM_labels)
print(list((x[0], len(x[1])) for x in sorted(rep_dict.items(), reverse=True, key=lambda x: len(x[1]))[:10]))

In [None]:
def get_X_y(plasmid_file, sequence_to_label, frags, labels):
    y_str = []
    data = []
    cols = len(frags)
    row_ind = []
    col_ind = []
    x_le = preprocessing.LabelEncoder()
    x_le.fit(list(frags))
    # x_le.fit(list(frags_seen))
    curr_id = ""
    frags_hit = []
    hit_percent = []
    row_num = 0
    with open(plasmid_file) as training_handle:
        reader = csv.DictReader(training_handle, delimiter='\t')
        for r, line in enumerate(reader):
            curr_id = line["Query seq"]
            break
    with open(plasmid_file) as training_handle:
        reader = csv.DictReader(training_handle, delimiter='\t')
        for r, line in enumerate(reader):
            seq_id = line["Query seq"]
            if (seq_id) != curr_id:
                if len(frags_hit) == 0:
                    print(f"{curr_id} doesn't hit anything!")
                columns = x_le.transform(frags_hit)
                row_ind.extend([row_num for _ in range(len(columns))])
                col_ind.extend(columns)
                data.extend(1 for _ in range(len(frags_hit)))
                y_str.append(sequence_to_label[curr_id])
                frags_hit = []
                hit_percent = []
                curr_id = seq_id
                row_num += 1
            #if float(line["%IDY"]) > .95:
            frags_hit.append(line["Frag seq"])
            hit_percent.append(float(line["%IDY"]))
        columns = x_le.transform(frags_hit)
        # columns = x_le.transform(line[2:])
        row_ind.extend([row_num for _ in range(len(columns))])
        col_ind.extend(columns)
        data.extend(1 for _ in range(len(frags_hit)))
        y_str.append(sequence_to_label[curr_id])   
    rows = row_num + 1
    X_train_fragged = sparse.csr_matrix((data, (row_ind, col_ind)), shape=(rows, cols))
    y_le = preprocessing.LabelEncoder()
    y_le.fit(labels)
    y_train_fragged = y_le.transform(y_str)
    return X_train_fragged, y_train_fragged

In [None]:
def top_n_accuracy(preds, truths, n):
    best_n = np.argsort(preds, axis=1)[:,-n:]
    successes = 0
    for i in range(len(truths)):
      if truths[i] in best_n[i,:]:
        successes += 1
    return float(successes)/len(truths)

In [None]:
def run_plasmidhawk_on_cluster(seq_ids, output_dir="phawk_run"):
    print("Running linear pangenome alignment on", output_dir)
    with open("training_labels.tsv") as labels:
        for line in (line.rstrip().split('\t') for line in labels):
            for member in line[1:]:
                sequence_to_label[member] = line[0]
    y_orig = [sequence_to_label[mem[0]] for mem in seq_ids]
    seq_counter = Counter(y_orig)
    seq_ids = [seq_id for seq_id in seq_ids if seq_counter[sequence_to_label[seq_id[0]]] > 1]
    y_orig = [sequence_to_label[mem[0]] for mem in seq_ids]
    if len(Counter(y_orig)) >= 9:
        print(len(Counter(y_orig)), "labs for", len(y_orig), "sequences")
    else: 
        print("Not enough labs\n")
        return
    X_train, X_test, y_train, y_test = train_test_split(seq_ids, y_orig, test_size=0.2, stratify=y_orig)
    
    t0 = time.time()
    # Create training file:
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    with open(os.path.join(output_dir, "training_sequences.txt"), 'w') as training_seqs_file:
        training_seqs_file.writelines(f"data/all_training_files/{seq_id[0]}.fasta\n" for seq_id in X_train)
    with open(os.path.join(output_dir, "testing_sequences.txt"), 'w') as testing_seqs_file:
        testing_seqs_file.writelines(f"data/all_training_files/{seq_id[0]}.fasta\n" for seq_id in X_test)
    with open(os.path.join(output_dir, "plaster.out"), 'w') as plaster_out, open(os.path.join(output_dir, "plaster.err"), 'w') as plaster_err:
        subprocess.check_call("plaster {} --realign --output {} --work-dir {} -p 40".format(
                os.path.join(output_dir, "training_sequences.txt"),
                os.path.join(output_dir, "plaster_train_results"),
                os.path.join(output_dir, "plaster_train_work")).split(' '),
            stdout=plaster_out, stderr=plaster_err)
    train_plasmid_file = os.path.join(output_dir, "plaster_train_results.tsv")
    frags_seen = set()
    with open(train_plasmid_file) as training_handle:
        reader = csv.DictReader(training_handle, delimiter='\t')
        for line in reader:
            frags_seen.add(line["Frag seq"])
    max_frag = int(max(frags_seen, key=lambda x: int(x.split("_")[1])).split("_")[1])
    frags_seen = [f"frag_{idx}" for idx in range(max_frag + 1)]  
    with open(os.path.join(output_dir, "plaster_test.out"), 'w') as plaster_out, open(os.path.join(output_dir, "plaster_test.err"), 'w') as plaster_err:
        subprocess.check_call("plaster {} --output {} --work-dir {} -p 40 --align-only --template {}".format(
                os.path.join(output_dir, "testing_sequences.txt"),
                os.path.join(output_dir, "plaster_test_results"),
                os.path.join(output_dir, "plaster_test_work"),
                os.path.join(output_dir, "plaster_train_results.fasta")).split(' '),
            stdout=plaster_out, stderr=plaster_err)
    test_plasmid_file = os.path.join(output_dir, "plaster_test_results.tsv")
    t1 = time.time()
    print("Linear method pipeline took", t1 - t0, "seconds.")
    
    X_train_fragged, y_train_fragged = get_X_y(train_plasmid_file, sequence_to_label, frags_seen, list(set(y_orig)))
    clf = RandomForestClassifier(n_estimators=1000, n_jobs=80, min_samples_split=2, max_depth=20)
    clf.fit(X_train_fragged, y_train_fragged)
    y_pred = clf.predict(X_train_fragged)
    print("Top 1 train accuracy", accuracy_score(y_train_fragged, y_pred))
    X_test_fragged, y_test_fragged = get_X_y(test_plasmid_file, sequence_to_label, frags_seen, list(set(y_orig)))
    y_pred = clf.predict(X_test_fragged)
    print("Top 1 test accuracy", accuracy_score(y_test_fragged, y_pred))
    y_pred = clf.predict_proba(X_test_fragged)
    print("Top 5 test accuracy", top_n_accuracy(y_pred, y_test_fragged, 5))
    t2 = time.time()
    print("Linear method machine learning took", t2 - t1, "seconds.\n")



In [None]:
top_n_clusters = list(x[0] for x in sorted(rep_dict.items(), reverse=True, key=lambda x: len(x[1]))[:20])

In [None]:
def get_graph_X_y(json_obj, sequence_to_label, nodes_seen, labels):
    y_str = []
    data = []
    seqs_seen = {obj["name"] for obj in json_obj}
    cols = len(nodes_seen)
    rows = len(seqs_seen)
    row_ind = []
    col_ind = []
    x_le = preprocessing.LabelEncoder()
    x_le.fit(list(nodes_seen))

    seq_to_nodes_hit = defaultdict(list)
    for obj in json_obj:
        seq_to_nodes_hit[obj["name"]].extend(mapping["position"]["node_id"] for mapping in obj["path"]["mapping"] if "node_id" in mapping["position"] if mapping["position"]["node_id"] in nodes_seen)
    for row_num, seq_id in enumerate(seq_to_nodes_hit):
        columns = x_le.transform(seq_to_nodes_hit[seq_id])
        row_ind.extend([row_num for _ in range(len(columns))])
        col_ind.extend(columns)
        data.extend(1 for _ in range(len(columns)))
        y_str.append(sequence_to_label[seq_id])

    X_train_fragged = sparse.csr_matrix((data, (row_ind, col_ind)), shape=(rows, cols))
    y_le = preprocessing.LabelEncoder()
    y_le.fit(labels)
    y_train_fragged = y_le.transform(y_str)
    return X_train_fragged, y_train_fragged

In [None]:

def run_bcalm_GA_on_cluster(seq_ids, output_dir="bcalm_run"):
    print("Running graph alignment on", output_dir)
    # Prepare input for passing to bcalm + GraphAligner
    with open("training_labels.tsv") as labels:
        for line in (line.rstrip().split('\t') for line in labels):
            for member in line[1:]:
                sequence_to_label[member] = line[0]
    y_orig = [sequence_to_label[mem[0]] for mem in seq_ids]
    seq_counter = Counter(y_orig)
    seq_ids = [seq_id for seq_id in seq_ids if seq_counter[sequence_to_label[seq_id[0]]] > 1]
    y_orig = [sequence_to_label[mem[0]] for mem in seq_ids]
    if len(Counter(y_orig)) >= 9:
        print(len(Counter(y_orig)), "labs for", len(y_orig), "sequences")
    else: 
        print("Not enough labs\n")
    X_train, X_test, y_train, y_test = train_test_split(seq_ids, y_orig, test_size=0.2, stratify=y_orig)

    # Create training file:
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    with open(os.path.join(output_dir, "training_sequences.txt"), 'w') as training_seqs_file:
        training_seqs_file.writelines(f"data/all_training_files/{seq_id[0]}.fasta\n" for seq_id in X_train)
    with open(os.path.join(output_dir, "training_sequences.txt"), 'r') as training_seqs_file, open(os.path.join(output_dir, "training_sequences.fasta"), 'w') as training_fasta_file:
        for line in training_seqs_file:
            training_fasta_file.write(open(line.strip()).read())
    with open(os.path.join(output_dir, "testing_sequences.txt"), 'w') as testing_seqs_file:
        testing_seqs_file.writelines(f"data/all_training_files/{seq_id[0]}.fasta\n" for seq_id in X_test)
    with open(os.path.join(output_dir, "testing_sequences.txt"), 'r') as testing_seqs_file, open(os.path.join(output_dir, "testing_sequences.fasta"), 'w') as testing_fasta_file:
        for line in testing_seqs_file:
            testing_fasta_file.write(open(line.strip()).read())
    t0 = time.time()      
    # Run bcalm + GraphAligner pipeline      
    Path(os.path.join(output_dir, "bcalm_test")).mkdir(parents=True, exist_ok=True)
    command = f"bcalm -in {output_dir}/training_sequences.fasta -abundance-min 1 -out {output_dir}/bcalm_test/training -nb-cores 20"
    with open(f"{output_dir}/bcalm_test/bcalm.out", 'w') as bcalm_out, open(f"{output_dir}/bcalm_test/bcalm.err", 'w') as bcalm_err:
        subprocess.check_call(command.split(' '), stdout=bcalm_out, stderr=bcalm_err)
    command = f"convertToGFA.py {output_dir}/bcalm_test/training.unitigs.fa {output_dir}/bcalm_test/training.gfa 31"
    subprocess.check_call(command.split(' '))
    
    command = f"GraphAligner -g {output_dir}/bcalm_test/training.gfa -f {output_dir}/training_sequences.fasta -a {output_dir}/bcalm_test/training.gam -t 20 -x dbg"
    with open(f"{output_dir}/bcalm_test/galigner.out", 'w') as galigner_out, open(f"{output_dir}/bcalm_test/galigner.err", 'w') as galigner_err:
        subprocess.check_call(command.split(' '), stdout=galigner_out, stderr=galigner_err)
    with open(f"{output_dir}/bcalm_test/training_alignment.json", 'w') as training_json:
        command = f"vg view -a {output_dir}/bcalm_test/training.gam"
        subprocess.check_call(command.split(' '), stdout=training_json)
    
    command = f"GraphAligner -g {output_dir}/bcalm_test/training.gfa -f {output_dir}/testing_sequences.fasta -a {output_dir}/bcalm_test/testing.gam -t 20 -x dbg"
    with open(f"{output_dir}/bcalm_test/galigner.out", 'a') as galigner_out, open(f"{output_dir}/bcalm_test/galigner.err", 'a') as galigner_err:
        subprocess.check_call(command.split(' '), stdout=galigner_out, stderr=galigner_err)
    with open(f"{output_dir}/bcalm_test/testing_alignment.json", 'w') as testing_json:
        command = f"vg view -a {output_dir}/bcalm_test/testing.gam"
        subprocess.check_call(command.split(' '), stdout=testing_json)
    t1 = time.time()
    print("Graph method pipeline took", t1 - t0, "seconds.")   
    # Parse Output
    train_json_obj = []
    with open(f"{output_dir}/bcalm_test/training_alignment.json") as input_json:
        for line in input_json:
            train_json_obj.append(json.loads(line))
    test_json_obj = []
    with open(f"{output_dir}/bcalm_test/testing_alignment.json") as input_json:
        for line in input_json:
            test_json_obj.append(json.loads(line))
            
    nodes_seen = set()
    for obj in train_json_obj:
#         print(obj["name"], len(obj["path"]["mapping"]))
        for mapping in obj["path"]["mapping"]:
            if "node_id" not in mapping["position"]:
                continue
            nodes_seen.add(mapping["position"]["node_id"])
    seqs_seen = {obj["name"] for obj in train_json_obj}
    X_train_fragged, y_train_fragged = get_graph_X_y(train_json_obj, sequence_to_label, nodes_seen, list(set(y_orig)))
    clf = RandomForestClassifier(n_estimators=1000, n_jobs=80, min_samples_split=2, max_depth=20)
    clf.fit(X_train_fragged, y_train_fragged)
    y_pred = clf.predict(X_train_fragged)
    print("Top 1 train accuracy", accuracy_score(y_train_fragged, y_pred))
    X_test_fragged, y_test_fragged = get_graph_X_y(test_json_obj, sequence_to_label, nodes_seen, list(set(y_orig)))
    y_pred = clf.predict(X_test_fragged)
    print("Top 1 test accuracy", accuracy_score(y_test_fragged, y_pred))
    y_pred = clf.predict_proba(X_test_fragged)
    print("Top 5 test accuracy", top_n_accuracy(y_pred, y_test_fragged, 5))
    t2 = time.time()
    print("Graph method machine learning took", t2 - t1, "seconds.\n")   

In [None]:
top_n_clusters = list(x[0] for x in sorted(rep_dict.items(), reverse=True, key=lambda x: len(x[1]))[:20])
for cluster_rep in top_n_clusters[1:15]:
    run_bcalm_GA_on_cluster([[seq_id] for seq_id in rep_dict[cluster_rep]], output_dir = cluster_rep)  
    run_plasmidhawk_on_cluster([[seq_id] for seq_id in rep_dict[cluster_rep]], output_dir = cluster_rep)  

    

In [None]:
# Minigraph alone does not create any useful graph ): 
for cluster_rep in top_n_clusters[1:10]:
    ref_file = f"all_training_files/{cluster_rep}.fasta"
    for member in rep_dict[cluster_rep]:
        if member == cluster_rep:
            continue
        member_file = f"data/all_training_files/{member}.fasta"
        with open(f"{cluster_rep}/minigraph_out_tmp.gfa", 'w') as minigraph_out, open(f"{cluster_rep}/minigraph.err", 'w') as minigraph_err:
            command = f"minigraph -x ggs {ref_file} {member_file}"
            print(command, "> ")
            subprocess.check_call(command.split(' '), stdout=minigraph_out, stderr=minigraph_err)
        shutil.copyfile(f"{cluster_rep}/minigraph_out_tmp.gfa", f"{cluster_rep}/minigraph_out.gfa")
        ref_file = f"{cluster_rep}/minigraph_out.gfa"

In [None]:
# Try out nucdif for SVs. Gives .gff files and not sure if these are what we're looking for
for cluster_rep in top_n_clusters[1:10]:
    ref_file = f"all_training_files/{cluster_rep}.fasta"
    for member in rep_dict[cluster_rep]:
        if member == cluster_rep:
            continue
        member_file = f"data/all_training_files/{member}.fasta"
        with open(f"{cluster_rep}/minigraph_out_tmp.gfa", 'w') as minigraph_out, open(f"{cluster_rep}/minigraph.err", 'w') as minigraph_err:
            command = f"nucdiff --vcf yes {ref_file} {member_file} {cluster_rep}/nd_out/{member} nd_nuc"
            subprocess.check_call(command.split(' '), stdout=minigraph_out, stderr=minigraph_err)