In [6]:
from utils_v10 import *

In [7]:
input_df = pd.read_csv('/home/tobamo/analize/model-tobamo/notebooks/ver10/results/training/test_input_df.csv')

In [8]:
def add_info_to_training_input_df(
    input_df, inv_nt_species2id_path: str, inv_type_acc_dict_path: str, orf_fasta_path: str
):

    if "contig_name" not in input_df.columns:
        input_df["contig_name"] = input_df["orf_name"].str.extract(r"(.*)(?:_ORF\.\d+|_aa_frame\d+)")
    else:
        input_df["nt_id"] = input_df["orf_name"].str.extract(r"(.*)_start")

    # add species info
    inv_nt_species2id = mpu.io.read(inv_nt_species2id_path)
    input_df["species"] = input_df["nt_id"].map({k: v[0] for k, v in inv_nt_species2id.items()})

    # add orf_type info
    inv_type_acc_dict = mpu.io.read(inv_type_acc_dict_path)
    input_df["orf_type"] = input_df["nt_id"].map({k: v[0] for k, v in inv_type_acc_dict.items()})

    # add strand orientation info
    with open(orf_fasta_path, "r") as file:
        orf_dict = {
            record.description.split()[0]: "REVERSE" if "(-)" in record.description else "FORWARD"
            for record in SeqIO.parse(file, "fasta")
        }
    input_df["strand"] = input_df["orf_name"].map(orf_dict)

    return input_df

In [9]:
inv_nt_species2id_path = "data/inv_nt_dict.json"
inv_type_acc_dict_path = "data/inv_type_acc_dict.json"
orf_fasta_path = "results/training/orfs/combined_orfs.fasta"

df = add_info_to_training_input_df(input_df,inv_nt_species2id_path, inv_type_acc_dict_path, orf_fasta_path)

In [10]:
# df.to_csv('results/training/test_input_df_with_info.csv', index=False)

In [11]:
dd = pd.read_csv('/home/tobamo/analize/model-tobamo/notebooks/ver10/results/training/model/model_results.csv')

In [12]:
import ast
def process_df(df, input_df):
    df['orf_name'] = df['orf_name'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['ground_truth'] = df['ground_truth'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['prediction'] = df['prediction'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['probability'] = df['probability'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    columns_to_explode = ['orf_name', 'ground_truth', 'prediction', 'probability']
    res_df = df.explode(columns_to_explode).reset_index(drop=True)
    res_df['true_positive'] = (res_df['ground_truth'] == 1) & (res_df['prediction'] == 1)
    res_df['false_positive'] = (res_df['ground_truth'] == 0) & (res_df['prediction'] == 1)
    res_df['true_negative'] = (res_df['ground_truth'] == 0) & (res_df['prediction'] == 0)
    res_df['false_negative'] = (res_df['ground_truth'] == 1) & (res_df['prediction'] == 0)

    res_df['prob_0'] = res_df['probability'].apply(lambda x: x[0])
    res_df['prob_1'] = res_df['probability'].apply(lambda x: x[1])
    res_df['max_prob'] = res_df[['prob_0', 'prob_1']].max(axis=1)

    res_df = res_df.copy()
    res_df.loc[:, 'match'] = res_df['ground_truth'] == res_df['prediction']

    if input_df.index.name != "orf_name":
        input_df = input_df.set_index("orf_name")
    orf_orientation_mapper = input_df['strand'].to_dict()
    res_df['strand'] = res_df['orf_name'].map(orf_orientation_mapper)
    
    contig_length_mapper = input_df['contig_length'].to_dict()
    res_df['contig_length'] = res_df['orf_name'].map(contig_length_mapper)

    confusion_matrix= res_df[
        ['true_positive', 'true_negative', 'false_positive', 'false_negative']
    ].sum()
    
    return res_df, confusion_matrix

def select_best_orf(res_df):
    res_df['best_orf'] = res_df.groupby('contig')['max_prob'].transform(lambda x: x == x.max()).astype(int)
    res_df = res_df[res_df['best_orf'] == 1]
    res_df = res_df.copy()
    res_df['rank'] = res_df.groupby('contig')['max_prob'].rank(method='first', ascending=False)
    res_df = res_df[res_df['rank'] == 1]
    return res_df

In [13]:
res_df, confusion_matrix = process_df(dd, df)
best = select_best_orf(res_df)

In [14]:
cm= best[
    ['true_positive', 'true_negative', 'false_positive', 'false_negative']
].sum()

cm

true_positive     1350
true_negative     1075
false_positive       0
false_negative       0
dtype: int64