In [51]:
import pandas as pd
import numpy as np
import torch
import json
from functools import partial
from torchmetrics import F1Score, Recall, Precision, AveragePrecision, MatthewsCorrCoef, AUROC, MetricCollection

In [19]:
combinations_st = ['S', 'T', 'ST', 'STY']
combinations_other = ['Y']
results = {}

for c in combinations_st:
    results[c] = pd.read_csv(f'data/musite_preds/musite_preds_{c}_general_phosphorylation_SorT.txt', sep='\t', 
                             names=['id', 'seq_idx', 'residue', 'prediction'])

for c in combinations_other:
    results[c] = pd.read_csv(f'data/musite_preds/musite_preds_{c}_general_phosphorylation_{c}.txt', sep='\t',
                             names=['id', 'seq_idx', 'residue', 'prediction'])

In [47]:
results_sty_y = pd.read_csv(f'data/musite_preds/musite_preds_STY_general_phosphorylation_Y.txt', sep='\t',
                             names=['id', 'seq_idx', 'residue', 'prediction'])

In [49]:
results['STY'] = pd.concat([results['STY'], results_sty_y])

# Check the number of predictions 

In [21]:
results['S'].groupby('id')[['seq_idx', 'prediction']].agg(list)

Unnamed: 0_level_0,seq_idx,prediction
id,Unnamed: 1_level_1,Unnamed: 2_level_1
>A0A024R4G9,"[16, 20, 31, 45, 54, 60, 82, 93]","[0.0892075836658477, 0.1957073420286178, 0.103..."
>A0A087WQP5,"[98, 102, 107]","[0.2832421027123928, 0.4769477695226669, 0.397..."
>A0A0A6YY25,"[3, 6, 7, 30, 53, 56, 96, 101, 110, 118, 153, ...","[0.6995263218879699, 0.2375560104846954, 0.192..."
>A0A0B4J1F3,"[2, 25, 37, 84, 86, 99, 106, 107, 110, 125, 12...","[0.0789844438433647, 0.2887605041265487, 0.128..."
>A0A0G2JTM7,"[10, 32, 38, 41, 42, 46, 47, 53, 123, 124, 153...","[0.4124750792980194, 0.3041444838047027, 0.297..."
...,...,...
>XP_983730,"[6, 11, 16, 23, 31, 33, 52, 56, 62, 68, 88, 93...","[0.0100663808174431, 0.0058047487866133, 0.005..."
>XP_984438,"[2, 3, 8, 11, 27, 29, 30, 31, 46, 61, 78, 96, ...","[0.4777823090553283, 0.496388179063797, 0.3547..."
>XP_987269,"[3, 10, 16, 25, 38, 64, 77, 90, 103, 116, 129,...","[0.5226779460906983, 0.4313943207263946, 0.700..."
>YP_009725305,"[5, 13, 46, 59, 105]","[0.8353796005249023, 0.4889589428901672, 0.163..."


In [5]:
splits = pd.read_json('data/splits_S.json')

In [8]:
len(splits.iloc[0]['test']) + len(splits.iloc[0]['train'])

11220

Number of proteins match

In [9]:
prot_info = pd.read_json('data/phosphosite_sequences/phosphosite_df.json')

In [57]:

metrics = {
    'f1' : F1Score('binary'),
    'precision' : Precision('binary'),
    'recall' : Recall('binary'),
    'auprc' : AveragePrecision('binary'),
    'auroc' : AUROC('binary'),
    'mcc' : MatthewsCorrCoef('binary')
}

metrics = MetricCollection(metrics)

residues = {'S' : {'S'}, 'T' : {'T'}, 'Y' : {'Y'}, 'ST': {'S', 'T'}, 'STY' : {'S', 'T', 'Y'}}

def prepare_labels(row, residues : set):
    res = []
    for i, s in enumerate(row['sequence']):
        if s in residues:
            if i in row['sites']:
                res.append(1)
            else:
                res.append(0)
    return res

for res, preds in results.items():
    fixed_ids = preds['id'].apply(lambda x: x[1:])
    copy = preds.copy()
    copy['id'] = fixed_ids
    grouped = copy.groupby('id').agg(list)
    merged = grouped.join(prot_info.set_index('id'), how='left')

    # Convert sites to lists of ints
    merged['sites'] = merged['sites'].apply(lambda x: [int(i) for i in x])
    merged['labels'] = merged.apply(partial(prepare_labels, residues=residues[res]), axis=1)
    splits = pd.read_json(f'data/splits_{res}.json')
    
    metric_results = { k : [] for k in metrics.keys()}
    for i in range(len(splits)):
        test_prots = splits.iloc[i]['test']
        ids = prot_info.loc[test_prots]['id']
        for id in ids:
            metrics.update(torch.as_tensor(merged.loc[id]['prediction']), torch.as_tensor(merged.loc[id]['labels']))

        fold_results = metrics.compute()
        for k,v in fold_results.items():
            metric_results[k].append(float(v.numpy()))

        metrics.reset()

    print(f'Residue: {res}')
    for k, vals in list(metric_results.items()):
        print(f'{k} mean: {np.mean(vals):0.4f}')
        print(f'{k} std: {np.std(vals):0.4f}')
        metric_results[f'{k}_mean'] = np.mean(vals)
        metric_results[f'{k}_std'] = np.std(vals)
    print('----------------')

    
    with open(f'data/musite_preds/metrics_{res}.json', 'w') as f:
        json.dump(metric_results, f, indent='\t')

Residue: S
auprc mean: 0.0322
auprc std: 0.0029
auroc mean: 0.6391
auroc std: 0.0125
f1 mean: 0.0532
f1 std: 0.0023
mcc mean: 0.0550
mcc std: 0.0052
precision mean: 0.0282
precision std: 0.0012
recall mean: 0.4643
recall std: 0.0169
----------------
Residue: T
auprc mean: 0.0311
auprc std: 0.0021
auroc mean: 0.6644
auroc std: 0.0122
f1 mean: 0.0673
f1 std: 0.0063
mcc mean: 0.0615
mcc std: 0.0086
precision mean: 0.0397
precision std: 0.0039
recall mean: 0.2216
recall std: 0.0153
----------------
Residue: ST
auprc mean: 0.0283
auprc std: 0.0030
auroc mean: 0.6560
auroc std: 0.0096
f1 mean: 0.0509
f1 std: 0.0037
mcc mean: 0.0570
mcc std: 0.0051
precision mean: 0.0272
precision std: 0.0021
recall mean: 0.3896
recall std: 0.0149
----------------
Residue: STY
auprc mean: 0.0175
auprc std: 0.0005
auroc mean: 0.5771
auroc std: 0.0059
f1 mean: 0.0353
f1 std: 0.0011
mcc mean: 0.0256
mcc std: 0.0011
precision mean: 0.0189
precision std: 0.0006
recall mean: 0.2759
recall std: 0.0032
--------------