# Imputing lineages for reconstructed internal nodes

In [None]:
import tskit
import tszip
import pandas as pd
import tqdm

import sys
sys.path.append("../")
import sc2ts.utils
import sc2ts.lineages

In [None]:
ts_long_path = "../../sc2ts_ts/upgma-mds-1000-md-30-mm-3-2022-06-30-recinfo"
ts_wide_path = "../../sc2ts_ts/upgma-full-md-30-mm-3-2021-06-30-recinfo"
ts_long = tszip.decompress(ts_long_path + "-il.ts.tsz")
ts_wide = tszip.decompress(ts_wide_path + "-il.ts.tsz")
ti_long = sc2ts.utils.TreeInfo(ts_long)
ti_wide = sc2ts.utils.TreeInfo(ts_wide)
mutations_json_filepath = "../../sc2ts_ts/consensus_mutations.json"
gisaid_metadata_filepath = "../../sc2ts_ts/metadata_tsv_2023_03_09/metadata.tsv"

# GISAID vs Nextclade lineage comparison

In [None]:
md = pd.read_table(gisaid_metadata_filepath)

In [None]:
gisaid_data = [(x,y) for x, y in zip(md['Accession ID'], md['Pango lineage'])]

In [None]:
linmuts_dict = sc2ts.lineages.read_in_mutations(mutations_json_filepath)

In [None]:
ts_long_gisaid = sc2ts.utils.check_lineages(
    ts_long,
    ti_long,
    gisaid_data,
    linmuts_dict,
    diff_filehandle='../../sc2ts_ts/lineage_disagreement_long',
)

In [None]:
ts_wide_gisaid = sc2ts.utils.check_lineages(
    ts_wide,
    ti_wide,
    gisaid_data,
    linmuts_dict,
    diff_filehandle='../../sc2ts_ts/lineage_disagreement_wide',
)

# ts lineage imputation

In [None]:
edited_ts_long = sc2ts.utils.lineage_imputation(
    mutations_json_filepath,
    ts_long_gisaid, 
    ti_long,
    internal_only=False,
    verbose=False
)

In [None]:
edited_ts_long.dump(ts_long_path + "-gisaid-il.ts")
tszip.compress(edited_ts_long, ts_long_path + "-gisaid-il.ts.tsz")

In [None]:
correct = total = 0
for node in edited_ts_long.nodes():
    if 'GISAID_lineage' not in node.metadata and 'Imputed_GISAID_lineage' in node.metadata and 'Nextclade_pango' not in node.metadata and 'Imputed_Nextclade_pango' in node.metadata:
        if node.metadata['Imputed_GISAID_lineage'] == node.metadata['Imputed_Nextclade_pango']:
            correct += 1
        total += 1
print(correct/total)

In [None]:
edited_ts_wide = sc2ts.utils.lineage_imputation(
    mutations_json_filepath,
    ts_wide_gisaid, 
    ti_wide,
    internal_only=False,
    verbose=False
)

In [None]:
edited_ts_wide.dump(ts_wide_path + "-gisaid-il.ts")
tszip.compress(edited_ts_wide, ts_wide_path + "-gisaid-il.ts.tsz")

In [None]:
correct = total = 0
for node in edited_ts_wide.nodes():
    if 'GISAID_lineage' not in node.metadata and 'Imputed_GISAID_lineage' in node.metadata and 'Nextclade_pango' not in node.metadata and 'Imputed_Nextclade_pango' in node.metadata:
        if node.metadata['Imputed_GISAID_lineage'] == node.metadata['Imputed_Nextclade_pango']:
            correct += 1
        total += 1
print(correct/total)