In [1]:
import argparse
import json
import requests
from augur.utils import json_to_tree
from Bio import SeqIO
from Bio.Seq import MutableSeq
from Bio.SeqRecord import SeqRecord



In [4]:
# Extract fasta for each node

def apply_muts_to_root(root_seq, list_of_muts):
    """
    Apply a list of mutations to the root sequence
    to find the sequence at a given node. The list of mutations
    is ordered from root to node, so multiple mutations at the
    same site will correctly overwrite each other
    """

    # make the root sequence mutatable
    root_plus_muts = MutableSeq(root_seq)

    # apply all mutations to root sequence
    for mut in list_of_muts:
        # subtract 1 to deal with biological numbering vs python
        mut_site = int(mut[1:-1])-1
        # get the nuc that the site was mutated TO
        mutation = mut[-1]
        # apply mutation
        root_plus_muts[mut_site] = mutation


    return root_plus_muts


def getNodeSequences(gene, local_files, tree_file, root_file):
    """
    Get the sequence at each node in the given tree and
    save them as a FASTA file
    """
    # if we are fetching the JSONs from a URL
    if local_files == "False":
        # fetch the tree JSON from URL
        tree_json = requests.get(tree_file, headers={"accept":"application/json"}).json()
        # put tree in Bio.phylo format
        tree = json_to_tree(tree_json)
        # fetch the root JSON from URL
        root_json = requests.get(root_file, headers={"accept":"application/json"}).json()
        # get the nucleotide sequence of root
        root_seq_nuc = root_json[gene]

    # if we are using paths to local JSONs
    elif local_files == "True":
        # load tree
        with open(tree_file, 'r') as f:
            tree_json = json.load(f)
        # put tree in Bio.phylo format
        tree = json_to_tree(tree_json)
        # load root sequence file
        with open(root_file, 'r') as f:
            root_json = json.load(f)
        # get the nucleotide sequence of root
        root_seq_nuc = root_json[gene]

    ## Now find the node sequences

    # initialize list to store sequence records for each node
    sequence_records = []

    # find sequence at each node in the tree (includes internal nodes and terminal nodes)
    for node in tree.find_clades():

        # get path back to the root
        path = tree.get_path(node)

        # get all  mutations relative to root
        muts = [branch.branch_attrs['mutations'].get(gene, []) for branch in path]
        # flatten the list of nucleotide mutations
        muts = [item for sublist in muts for item in sublist]
        # get sequence at node
        node_seq = apply_muts_to_root(root_seq_nuc, muts)

        sequence_records.append(SeqRecord(node_seq, node.name, '', ''))

    SeqIO.write(sequence_records, f"nodeSeqs_{gene}.fasta", "fasta")




if __name__ == '__main__':

    """"
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--gene", default="nuc",
        help="Name of gene to return AA sequences for. 'nuc' will return full geneome nucleotide seq")
    parser.add_argument("--local-files", default="False",
        help="Toggle this on if you are supplying local JSON files for the tree and root sequence." +
             "Default is to fetch them from a URL")
    parser.add_argument("--tree", default="https://data.nextstrain.org/ncov_gisaid_global_all-time.json",
        help="URL for the tree.json file, or path to the local JSON file if --local-files=True")
    parser.add_argument("--root", default="https://data.nextstrain.org/ncov_gisaid_global_all-time_root-sequence.json",
        help="URL for the root-sequence.json file, or path to the local JSON file if --local-files=True")

    args = parser.parse_args()

    getNodeSequences(args.gene, args.local_files, args.tree, args.root)
    """

    args_gene = "PB1"
    #args_gene = "nuc"
    args_local_files = "True"
    args_tree = "/Users/Carlos/Desktop/Bedford/esm-selection/Flu_Snakemake_Pipeline/h3n2_Sequences/h3n2_60y_pb1.json"
    #args_root = "/Users/Carlos/Downloads/h3n2_60y_pb2_root-sequence_Amino.json"
    args_root = "/Users/Carlos/Desktop/Bedford/esm-selection/Flu_Snakemake_Pipeline/root_tree_translated/h3n2_60y_pb1_root-sequence.json"
    #args_root = "/Users/Carlos/Downloads/h3n2_60y_pb2_root-sequence.json"

    getNodeSequences(args_gene, args_local_files, args_tree, args_root)

In [13]:
# Get terminals nodes for each internal node in tree

def collect_terminal_nodes(node):
    """
    Recursively collect terminal nodes (nodes without children) for each node in the tree.
    """
    name = node.get("name", "(unnamed)")
    children = node.get("children", [])
    
    # If the node has no children, it's a terminal node
    if not children:
        return [name]
    
    # Otherwise, collect terminal nodes from all children
    terminal_nodes = []
    for child in children:
        terminal_nodes.extend(collect_terminal_nodes(child))
    
    return terminal_nodes

def map_terminal_nodes(node):
    """
    Create a mapping of each node to its terminal nodes.
    """
    name = node.get("name", "(unnamed)")
    children = node.get("children", [])
    
    # Collect terminal nodes for this node
    terminal_nodes = collect_terminal_nodes(node)
    node_terminal_map[name] = terminal_nodes
    
    # Recurse into each child
    for child in children:
        map_terminal_nodes(child)

# Load tree JSON
with open("/Users/Carlos/Downloads/h3n2_60y_pb2.json", "r") as f:
    data = json.load(f)

# Entry point (usually under 'tree' or 'nodes')
tree_root = data.get("tree", data)  # Adjust depending on JSON structure

# Dictionary to store the mapping of nodes to their terminal nodes
node_terminal_map = {}

# Map terminal nodes for each node
map_terminal_nodes(tree_root)

In [14]:
# import and clean up the frequency JSON

json_fh_frequency = open("/Users/Carlos/Downloads/h3n2_60y_pb2_tip-frequencies.json", "r")
json_dict_frequency = json.load(json_fh_frequency)

del json_dict_frequency["pivots"]
del json_dict_frequency["generated_by"]

In [15]:
# Get frequency sums for all nodes

from collections import defaultdict

# Output dictionary
summed_frequencies = {}

# Sum frequencies
for node, terminals in node_terminal_map.items():
    summed = None
    for terminal in terminals:
        if terminal in json_dict_frequency:
            freqs = json_dict_frequency[terminal]["frequencies"]
            if summed is None:
                summed = freqs.copy()
            else:
                summed = [x + y for x, y in zip(summed, freqs)]
    if summed is not None:
        summed_frequencies[node] = summed

In [17]:
# Filter out terminal nodes

filtered_dict = {key: value for key, value in summed_frequencies.items() if key.startswith("NODE_")}

In [18]:
# Get max frequency for each internal node

max_values = {key: max(values) for key, values in filtered_dict.items()}

In [39]:
# Path to your multi-FASTA file
fasta_path = "/Users/Carlos/Desktop/Bedford/esm-selection/nodeSeqs_PB2.fasta"

# Create a new dictionary to hold scores and sequences
node_data = {}

# Load FASTA sequences into a dictionary
fasta_dict = SeqIO.to_dict(SeqIO.parse(fasta_path, "fasta"))

# Add sequences to the node_data
for node, score in max_values.items():
    if node in fasta_dict:
        node_data[node] = {
            "max_frequency": score,
            "sequence": str(fasta_dict[node].seq)
        }
    else:
        node_data[node] = {
            "max_frequency": score,
            "sequence": None  # or handle missing sequences as needed
        }

# Example usage
#for node, data in node_data.items():
#    print(f"{node} - max_frequency: {data['max_frequency']}, Sequence: {data['sequence'][:30]}...")  # preview first 30 bases

#node_data

In [38]:
import pandas as pd

df = pd.DataFrame.from_dict(node_data, orient='index')

df = df.reset_index()
df = df.rename(columns={df.columns[0]: 'node'})

df_to_csv = df.to_csv("/Users/Carlos/Desktop/Bedford/esm-selection/Max_Freq_Fasta.csv", index=False)

In [45]:
max_freq_df = pd.read_csv("/Users/Carlos/downloads/Max_Freq_Fasta_LL.csv")
max_freq_df_unique = pd.read_csv("/Users/Carlos/downloads/Max_Freq_Fasta_LL_SMOL.csv")

max_freq_df_unique = max_freq_df_unique.drop(columns=['node', 'max_frequency'])

merged = max_freq_df.merge(max_freq_df_unique, on='sequence', how='left')

merged

Unnamed: 0,node,max_frequency,sequence,log_likelihood_x,log_likelihood_y
0,NODE_0000000,1.000007,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-236.884384
1,NODE_0000001,1.000007,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-236.884384
2,NODE_0000004,0.416665,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-236.884384
3,NODE_0000002,0.166666,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-247.045624
4,NODE_0000006,0.249999,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-236.884384
...,...,...,...,...,...
720,NODE_0000447,0.311489,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-274.582397
721,NODE_0001054,0.305340,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-274.582397
722,NODE_0001055,0.259144,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-274.582397
723,NODE_0001056,0.203259,MEIKELRNLMSQSRTREILTKTTVDHMAIIKKYTSGRQEKNPSLRM...,0,-274.582397


In [None]:
import torch # type: ignore
import esm # type: ignore
import argparse
from Bio import SeqIO # type: ignore
import pandas as pd # type: ignore
from tqdm import tqdm # type: ignore

max_freq_df = pd.read_csv('Max_Freq_Fasta.csv')
max_freq_df['log_likelihood'] = 0

max_freq_df_unique = max_freq_df.drop_duplicates(subset='sequence', keep='first')

max_freq_df_unique = max_freq_df_unique.reset_index(drop=True)

# 1. Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # Disable dropout for evaluation

for index, sequence in enumerate(max_freq_df_unique['sequence']):

    data = [(max_freq_df_unique['node'][index], sequence)]

    print(data)

    # 3. Tokenize
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # 4. Compute log-likelihoods
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[], return_contacts=False)
        log_probs = torch.log_softmax(results["logits"], dim=-1)
        log_likelihood = log_probs.gather(2, batch_tokens.unsqueeze(-1)).sum().item()

        print(f"Log-Likelihood: {log_likelihood:.2f}")
        max_freq_df_unique.at[index, 'log_likelihood'] = log_likelihood  

max_freq_df_unique = max_freq_df_unique.drop(columns=['node', 'max_frequency'])

merged = max_freq_df.merge(max_freq_df_unique, on='sequence', how='left')

merged.to_csv('Max_Freq_Fasta_LL.csv', index=False)

In [41]:
#translate nucleotide to protein 

import argparse
import json
import requests
from Bio.Seq import Seq
from Bio import SeqIO
from Bio.Seq import MutableSeq
from Bio.SeqRecord import SeqRecord
import os

def convert_to_prot(tree_file, root_file, segment):
    tree_file_annotations = open(tree_file, "r")
    tree_file_annotations = json.load(tree_file_annotations)

    root_file_sequence = open(root_file, "r")
    root_file_sequence = json.load(root_file_sequence)

    start_post = tree_file_annotations["meta"]["genome_annotations"][segment.upper()]["start"]
    end_post = tree_file_annotations["meta"]["genome_annotations"][segment.upper()]["end"]

    sequence = Seq(root_file_sequence["nuc"])

    sub_seq = sequence[(start_post-1):(end_post)]

    protein_seq = sub_seq.translate()  

    root_file_sequence[segment.upper()] = str(protein_seq)

    print(root_file_sequence)

    root_json_file_name = os.path.basename(root_file)

    with open(root_json_file_name, 'w') as f:
        json.dump(root_file_sequence, f, indent=2)

def combine_ha_segments(root_file):

    root_file_sequence = open(root_file, "r")
    root_file_sequence = json.load(root_file_sequence)

    root_file_sequence["HA"] = root_file_sequence["SigPep"] + root_file_sequence["HA1"] + root_file_sequence["HA2"]

    root_json_file_name = os.path.basename(root_file)
    
    with open(root_json_file_name, 'w') as f:
        json.dump(root_file_sequence, f, indent=2)

if __name__ == '__main__':

    """"
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--gene", default="nuc",
        help="Name of gene to return AA sequences for. 'nuc' will return full geneome nucleotide seq")
    parser.add_argument("--local-files", default="False",
        help="Toggle this on if you are supplying local JSON files for the tree and root sequence." +
             "Default is to fetch them from a URL")
    parser.add_argument("--tree", default="https://data.nextstrain.org/ncov_gisaid_global_all-time.json",
        help="URL for the tree.json file, or path to the local JSON file if --local-files=True")
    parser.add_argument("--root", default="https://data.nextstrain.org/ncov_gisaid_global_all-time_root-sequence.json",
        help="URL for the root-sequence.json file, or path to the local JSON file if --local-files=True")

    args = parser.parse_args()

    getNodeSequences(args.gene, args.local_files, args.tree, args.root)
    """

    args_tree = "/Users/Carlos/Desktop/Bedford/esm-selection/h3n2_Sequences/h3n2_60y_ha.json"
    args_root = "/Users/Carlos/Desktop/Bedford/esm-selection/h3n2_Sequences/h3n2_60y_ha_root-sequence.json"
    args_segment = "ha"

    if args_segment == "ha":
        combine_ha_segments(args_root)
    else:
        convert_to_prot(args_tree, args_root, args_segment)