In [1]:
import re
import pandas as pd
import csv
import os

UNIPROT_LINEAGES_PATH="/scicore/home/schwede/pudziu0000/projects/gLM/data/PINDER/uniprot_lineages_annotations_1024_512.tsv"
FALSE_PRED_PREFIX="/scicore/home/schwede/pudziu0000/projects/gLM/outputs/predictions/PINDER-eubacteria/mean_pooled/catseq_"
SEED=1024

In [2]:
# Determine, which UniProt IDs are belonging to eubacteria with labelled phylum
lineages = pd.read_csv(UNIPROT_LINEAGES_PATH, sep='\t')

tax_id_info = {}

for index, row in lineages.iterrows():
    if isinstance(row['Taxonomic lineage (Ids)'], str) and "2 (superkingdom)" in row['Taxonomic lineage (Ids)']: 
        class_ = re.search(r'(\d+) \(class\)', row['Taxonomic lineage (Ids)'])
        order = re.search(r'(\d+) \(order\)', row['Taxonomic lineage (Ids)'])
        family = re.search(r'(\d+) \(family\)', row['Taxonomic lineage (Ids)'])
        genus = re.search(r'(\d+) \(genus\)', row['Taxonomic lineage (Ids)'])
        species = re.search(r'(\d+) \(species\)', row['Taxonomic lineage (Ids)'])
        class_ = class_.group(1) if(class_) else None
        order = order.group(1) if(order) else None
        family = family.group(1) if(family) else None
        genus = genus.group(1) if(genus) else None
        species = species.group(1) if(species) else None
        tax_id_info_dict = {"class": class_, "order": order, "family": family, "genus": genus, "species": species}
        tax_id_info[row['From']] = tax_id_info_dict

In [3]:
# Determine, which PINDER IDs contain both bacterial proteins

def check_tax_match(tax_info, uniprot_ids, level):
    if(tax_info[uniprot_ids[0]][level] and tax_id_info[uniprot_ids[1]][level]): # check for Nones
        if(tax_info[uniprot_ids[0]][level] == tax_id_info[uniprot_ids[1]][level]): # check for match
            return 1
        else:
            return 0
    else:
        return None

def get_prediction_type(record):
    if(record["True label"] == 1 and record["Predicted label"] == 0):
        return "FN"
    elif(record["True label"] == 0 and record["Predicted label"] == 1):
        return "FP"
    elif(record["True label"] == 1 and record["Predicted label"] == 1):
        return "TP"
    elif(record["True label"] == 0 and record["Predicted label"] == 0):
        return "TN"

tax_matches_arr = []
tax_levels = ["class", "order", "family", "genus", "species"]

for biolm in ["gLM2", "ESM2", "MINT"]:
    tax_matches = {"complex_id": [], "class": [], "order": [], "family": [], "genus": [], "species": [], "prediction_type": [], "truth": []}
    predictions = pd.read_csv(f"{FALSE_PRED_PREFIX}{biolm}.tsv", sep='\t')
    for i, row in predictions.iterrows():
        tax_matches["complex_id"].append(predictions["Complex ID"][i])
        tax_matches["truth"].append(predictions["True label"][i])
        tax_matches["prediction_type"].append(get_prediction_type(row))
        uniprot_r = predictions["Complex ID"][i].split("_")[3]
        uniprot_l = predictions["Complex ID"][i].split("_")[7]
        for level in tax_levels:
            tax_matches[level].append(check_tax_match(tax_id_info, [uniprot_r, uniprot_l], level))
    tax_matches_arr.append(tax_matches)
        

In [12]:
tax_df = pd.DataFrame.from_dict(tax_matches_arr[0])

In [13]:
tax_df

Unnamed: 0,complex_id,class,order,family,genus,species,prediction_type,truth
0,8ap0__A1_A0A5S9CYM0_5cyy__B1_P9WG63,1,0.0,0.0,0.0,,TN,0
1,1wat__B1_P02941_1wat__A1_P02941,1,1.0,1.0,1.0,1.0,TP,1
2,6xrb__A1_A0A0H3NHS6_6xrb__B1_A0A0H3NHS6,1,1.0,1.0,1.0,1.0,TP,1
3,4u66__A1_A8MI53_4wkz__B1_W5IDC3,1,0.0,0.0,0.0,,TN,0
4,1dv1__A1_P24182_1dv1__B1_P24182,1,1.0,1.0,1.0,1.0,TP,1
...,...,...,...,...,...,...,...,...
1575,1d0v__A1_Q05603_1d0v__A2_Q05603,1,1.0,1.0,1.0,1.0,TP,1
1576,4h5b__A1_Q9RUY5_4h5b__B1_Q9RUY5,1,1.0,1.0,1.0,1.0,TP,1
1577,8bfr__A1_P30177_8bfr__A2_P30177,1,1.0,1.0,1.0,1.0,TP,1
1578,8jh0__B1_S2D3K4_8jh0__A1_S2D3K4,1,1.0,1.0,1.0,1.0,TP,1


In [29]:
from scipy.stats import f_oneway
import numpy as np

for i, biolm in enumerate(["gLM2", "ESM2", "MINT"]):
    stats = {"group": [], "class": [], "order": [], "family": [], "genus": [], "species": []}
    tax_df = pd.DataFrame.from_dict(tax_matches_arr[i])
    for group in [["TP", "TN"], ["TP", "FN"], ["TN", "FP"], ["TP", "FP"]]:
        stats["group"].append(f"{group[0]}_{group[1]}")
        for level in tax_levels:
            if(len(set(tax_df[tax_df["prediction_type"] == group[0]][level].dropna())) == 1 and 
               len(set(tax_df[tax_df["prediction_type"] == group[1]][level].dropna())) == 1):
                stats[level].append(np.nan)
            else:
                stat, p = f_oneway(tax_df[tax_df["prediction_type"] == group[0]][level].dropna(), tax_df[tax_df["prediction_type"] == group[1]][level].dropna())
                stats[level].append(p < 0.05)
    stats_df = pd.DataFrame.from_dict(stats)
    print(biolm)
    print(stats_df)
            

gLM2
   group  class  order  family  genus  species
0  TP_TN   True   True    True   True     True
1  TP_FN   True   True    True   True     True
2  TN_FP  False  False   False  False    False
3  TP_FP   True   True    True   True     True
ESM2
   group  class  order  family  genus  species
0  TP_TN   True   True    True   True     True
1  TP_FN   True   True    True   True     True
2  TN_FP  False  False   False  False    False
3  TP_FP   True   True    True   True     True
MINT
   group  class  order family  genus species
0  TP_TN   True   True   True   True    True
1  TP_FN   True   True   True   True    True
2  TN_FP  False  False  False  False   False
3  TP_FP   True    NaN    NaN    NaN     NaN


In [None]:
for i, biolm in enumerate(["gLM2", "ESM2", "MINT"]):
    tax_df = pd.DataFrame.from_dict(tax_matches_arr[i])
    for level in tax_levels:
        print(f"{biolm}, {level}")
        print(f"TN:\t{len(tax_df[tax_df['prediction_type'] == 'TN'][level].dropna())}")
        print(f"FN:\t{len(tax_df[tax_df['prediction_type'] == 'FN'][level].dropna())}")
        print(f"FP:\t{len(tax_df[tax_df['prediction_type'] == 'FP'][level].dropna())}")
        print(f"TP:\t{len(tax_df[tax_df['prediction_type'] == 'TP'][level].dropna())}")
        

In [28]:
# Make a predictor based on the taxonomy level match
from sklearn.metrics import confusion_matrix
from sklearn.metrics import matthews_corrcoef

for i, biolm in enumerate(["gLM2", "ESM2", "MINT"]):
    print(biolm)
    tax_df = pd.DataFrame.from_dict(tax_matches_arr[i])
    for level in tax_levels:
        tax_df_cp = tax_df.copy()
        tax_df_cp = tax_df_cp.dropna(subset=[level])
        print(f"{level} MCC: {matthews_corrcoef(tax_df_cp['truth'], tax_df_cp[level])}")
        print(f"# predictions: {len(tax_df_cp['truth'])}")
        print(f"\tpred 0\tpred 1")
        tn, fp, fn, tp = confusion_matrix(tax_df_cp["truth"], tax_df_cp[level]).ravel()
        print(f"true 0\t{tn}\t{fp}")
        print(f"true 1\t{fn}\t{tp}")
        print()


gLM2
class MCC: 0.42281287371840015
# predictions: 1580
	pred 0	pred 1
true 0	242	548
true 1	1	789

order MCC: 0.699144667808426
# predictions: 1577
	pred 0	pred 1
true 0	519	269
true 1	1	788

family MCC: 0.7942674453491199
# predictions: 1577
	pred 0	pred 1
true 0	611	177
true 1	1	788

genus MCC: 0.8338451122629497
# predictions: 1577
	pred 0	pred 1
true 0	649	139
true 1	2	787

species MCC: 0.820334791672207
# predictions: 991
	pred 0	pred 1
true 0	324	90
true 1	1	576

ESM2
class MCC: 0.42281287371840015
# predictions: 1580
	pred 0	pred 1
true 0	242	548
true 1	1	789

order MCC: 0.699144667808426
# predictions: 1577
	pred 0	pred 1
true 0	519	269
true 1	1	788

family MCC: 0.7942674453491199
# predictions: 1577
	pred 0	pred 1
true 0	611	177
true 1	1	788

genus MCC: 0.8338451122629497
# predictions: 1577
	pred 0	pred 1
true 0	649	139
true 1	2	787

species MCC: 0.820334791672207
# predictions: 991
	pred 0	pred 1
true 0	324	90
true 1	1	576

MINT
class MCC: 0.42281287371840015
# predictions: