In [1]:
import pandas as pd
import pickle

from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.PDB.MMCIF2Dict import MMCIF2Dict
from Bio.SeqUtils import seq1

#write a function to get mutations within a sample that are present in 85% or more of samples
#should take as input an annotation file, and a lineage to return values for.
# if using lineages then requires to be handed a regex
def get_representative_mutations(anno_file, lineage, cutoff = 0.85, overview_lineage = False, mutation_list_only = False):
    anno_file["nt_aa_compound"] = anno_file.REF + anno_file.POS.astype("str") + anno_file.ALT + "_" + anno_file.residues
    
    if overview_lineage:
        lineage_counts = anno_file.loc[anno_file.overview_lineage == lineage, "sample_id"].nunique()
        mutation_counts = anno_file.loc[anno_file.overview_lineage == lineage].groupby("description").nt_aa_compound.value_counts().rename("count").reset_index()
    else:
        lineage_counts = anno_file.loc[anno_file.lineage == lineage, "sample_id"].nunique()
        mutation_counts = anno_file.loc[anno_file.lineage == lineage].groupby("description").nt_aa_compound.value_counts().rename("count").reset_index()
    
    if mutation_list_only:
        representative_mutations = mutation_counts.loc[(mutation_counts["count"] >= (lineage_counts * cutoff)), "nt_aa_compound"]
    else:
        representative_mutations = mutation_counts.loc[(mutation_counts["count"] >= (lineage_counts * cutoff))].copy()
        representative_mutations["sample_count"] = lineage_counts
        representative_mutations["lineage"] = lineage
    return(representative_mutations)

annotation_file = pd.read_csv("gisaid_sequences_spear/spear_annotation_summary.tsv", sep = "\t")

metadata = pd.read_csv("lineages_subset_metadata.csv")
metadata["strain"] = metadata["strain"].str.replace(' ', '')
metadata["strain"] = metadata["strain"].str.replace('[^a-zA-Z0-9\.]', '_', regex = True)

annotation_file_merged = annotation_file.merge(metadata, left_on = "sample_id", right_on = "strain", how = "left")

lineages = annotation_file_merged.lineage.unique()

lineage_mutations = {}
for lineage in lineages:
    lineage_muts = get_representative_mutations(annotation_file_merged, lineage)
    lineage_mutations[lineage] = lineage_muts
    print("Done ", lineage)
pickle.dump(lineage_mutations, open("lineage_classifications/lineage_mutations_file.pkl", "wb"))

for lineage in lineage_mutations.keys():
    if len(lineage_mutations[lineage]) > 0:
        lineage_mutations[lineage].to_csv(f"lineage_classifications/{lineage}_consensus_mutations.csv")

lineage_mutations_df = pd.concat(lineage_mutations.values(), ignore_index=True)

lineage_mutations_df["aa"] = lineage_mutations_df["nt_aa_compound"].str.extract(r'.*_(.*)')


pattern = r'([A-Z])([\d\.]+)([A-z-\*]+)'
#
# Use str.extract() to create new columns for 'ref', 'pos', and 'alt'
lineage_mutations_df[["ref", "pos", "alt"]] = lineage_mutations_df['aa'].str.extract(pattern)

lineage_mutations_df["pos"] = lineage_mutations_df.pos.astype(float).astype(int)

lineage_mutations_df["aa_2"] = lineage_mutations_df["ref"] + lineage_mutations_df["pos"].astype("str") + lineage_mutations_df["alt"] 

#lineage_mutations_df = lineage_mutations_df.loc[lineage_mutations_df.lineage.isin(["ba_1", "ba_2", "ba_3", "ba_4","ba_5"]) == False]


spike_ab_list = pd.read_csv("spike_abs/ab_complexes_chains_updated.csv")
spike_ab_list["chain_id"] = spike_ab_list['chain_id'].apply(list)
spike_ab_list = spike_ab_list.explode("chain_id")
spike_ab_list["pdb_chain"] = spike_ab_list["pdb"].str.upper() + "_" + spike_ab_list["chain_id"]


# Define a function to safely retrieve values from the CIF data
def get_cif_value(key):
    try:
        return data[key]
    except KeyError:
        return [None]  # Return None if the key doesn't exist in the CIF data

    
try:
    all_cif = pd.read_pickle("spike_abs/all_cif.pkl")
except:
    all_cif = pd.DataFrame()

    for i, row in spike_ab_list.iterrows():
        file = f"spike_abs/cif/{row['pdb']}.cif.gz"
        with gzip.open(file, 'rt') as mmcif_file:
            data = MMCIF2Dict(mmcif_file)

            cif_data = {
                "align_id": get_cif_value("_struct_ref_seq_dif.align_id"),
                "pdb_id_code": get_cif_value("_struct_ref_seq_dif.pdbx_pdb_id_code"),
                "mon_id": get_cif_value("_struct_ref_seq_dif.mon_id"),
                "strand_id": get_cif_value("_struct_ref_seq_dif.pdbx_pdb_strand_id"),
                "seq_num": get_cif_value("_struct_ref_seq_dif.seq_num"),
                "pdb_ins_code": get_cif_value("_struct_ref_seq_dif.pdbx_pdb_ins_code"),
                "seq_db_name": get_cif_value("_struct_ref_seq_dif.pdbx_seq_db_name"),
                "seq_db_accession_code": get_cif_value("_struct_ref_seq_dif.pdbx_seq_db_accession_code"),
                "db_mon_id": get_cif_value("_struct_ref_seq_dif.db_mon_id"),
                "seq_db_seq_num": get_cif_value("_struct_ref_seq_dif.pdbx_seq_db_seq_num"),
                "details": get_cif_value("_struct_ref_seq_dif.details"),
                "auth_seq_num": get_cif_value("_struct_ref_seq_dif.pdbx_auth_seq_num"),
                "ordinal": get_cif_value("_struct_ref_seq_dif.pdbx_ordinal")
            }

        cif_df = pd.DataFrame(cif_data)
        all_cif = pd.concat([all_cif, cif_df])

    all_cif["pdb_chain"] = all_cif["pdb_id_code"] + "_" + all_cif["strand_id"]
    all_cif = all_cif.loc[all_cif.pdb_chain.isin(spike_ab_list.pdb_chain)]
    all_cif.to_pickle("spike_abs/all_cif.pkl")

#for these purposes we do not consider indels
all_cif_filtered = all_cif.loc[(all_cif.pdb_id_code.isna() == False) & (all_cif.details.isin(["expression tag", "cloning artifact", "conflict", "linker", "initiating methionine", "amidation", "deletion", "insertion"]) == False)].copy()

all_cif_filtered["mon_id_single"] = all_cif_filtered.mon_id.apply(seq1)
all_cif_filtered.loc[all_cif_filtered.details == "deletion", "mon_id_single"] = "del"
all_cif_filtered["db_mon_id_single"] = all_cif_filtered.db_mon_id.apply(seq1)
all_cif_filtered.loc[all_cif_filtered.details == "insertion", "db_mon_id_single"] = "ins"

all_cif_filtered["aa_2"] = all_cif_filtered["db_mon_id_single"] + all_cif_filtered["seq_db_seq_num"] + all_cif_filtered["mon_id_single"]

all_sarscov2_cif = all_cif_filtered.loc[(all_cif_filtered.seq_db_accession_code == "P0DTC2")].copy()

all_sarscov2_cif_filtered = all_sarscov2_cif.loc[(all_sarscov2_cif.aa_2.isin(["S383C", "D985C", "K955P", "V956P","V705C", "T883C", "A846Y", "I844M", "K835M", "F817P", "A892P", "A899P", "A942P", "V367F", "R682G", "R682S", "R683G", "R683S", "R685G", "R685S", "K986P", "V987P"]) == False)]
all_sarscov2_cif_filtered = all_sarscov2_cif_filtered.loc[(all_sarscov2_cif_filtered.seq_db_seq_num.isin(["682", "683", "684", "685"]) == False)]
all_sarscov2_cif_filtered = all_sarscov2_cif_filtered.drop(all_sarscov2_cif_filtered.loc[all_sarscov2_cif_filtered.pdb_id_code == "7CAC"].index)

wt = spike_ab_list.loc[spike_ab_list.pdb.isin(all_sarscov2_cif_filtered.pdb_id_code) == False].pdb.unique().tolist()
wt.append("7CAC") #manually adding this pdb as it is wt but has unusual mutation profile

Done  gamma
Done  beta
Done  alpha
Done  delta
Done  ba_1
Done  omicron
Done  ba_2
Done  ba_5
Done  ba_4
Done  ba_3


In [2]:
def count_lineage_matches(mutations, lineage_file):
    if mutations.values.tolist() == ["N501Y"]:
        max_lineage = "alpha"
        lineage_stats = {"method" : "assigned by single mut"}
        return max_lineage, lineage_stats
    elif mutations.values.tolist() == ["D614G"]:
        max_lineage = "delta"
        lineage_stats = {"method" : "assigned by single mut"}
        return max_lineage, lineage_stats
    lineage_stats = {}
    max_perc = 0
    max_lineage = None
    for lineage, group in lineage_file.loc[(lineage_file.description == "surface glycoprotein") & (lineage_file.lineage != "omicron")].groupby("lineage"):
        lineage_total_muts = len(group)
        lineage_match_muts = group.loc[group.aa_2.isin(mutations)]
        lineage_match_count = len(lineage_match_muts)
        perc_matches = lineage_match_count/lineage_total_muts
        additional_mutations = len(mutations) - lineage_match_count
        lineage_stats[lineage] = {"lineage_match_count" : lineage_match_count, "percentage_matches" : perc_matches, "num_additional_mutations" : additional_mutations, "method": "assigned by perc"}
        if perc_matches > max_perc:
            max_perc = perc_matches
            max_lineage = lineage
    if (max_lineage == None):
        max_lineage = "to investigate"
    return max_lineage, lineage_stats

results = {}
for i, group in all_sarscov2_cif_filtered.groupby(["pdb_id_code", "strand_id"]):
    group = group.reset_index()
    structure = f"{group.pdb_id_code.values[0]}_{group.strand_id.values[0]}"
    result = count_lineage_matches(group.aa_2, lineage_mutations_df)
    results[structure] = [result[0], result[1]]

classification_results = pd.DataFrame(results).T.reset_index()
classification_results = classification_results.rename(columns = {"index": "pdb_chain_id", 0: "classification", 1 : "classification_stats"})

wt_spike = spike_ab_list.loc[spike_ab_list.pdb.isin(wt), ["pdb_chain"]].copy().rename(columns = {"pdb_chain": "pdb_chain_id"})

wt_spike["classification"] = "wt"
wt_spike["classification_stats"] = "no mutations after filtering"

all_classification_results = pd.concat([wt_spike, classification_results])
all_classification_results["pdb"] = all_classification_results["pdb_chain_id"].str[:4]

all_classification_results["pdb"].nunique()

all_classification_results.drop_duplicates("pdb").classification.value_counts(dropna = False)

all_classification_results = all_classification_results.drop_duplicates("pdb")
all_classification_results.to_pickle("all_classification_results.pkl")

wt       644
ba_1      31
beta      28
delta     22
alpha      7
ba_2       6
Name: classification, dtype: int64