In [None]:
import os
import json
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

In [None]:
def generate_file_dicts(dir):
    json_files = {}
    sorted_filenames = sorted(os.listdir(dir), key= lambda x: len(x))
    for filename in sorted_filenames:
        if filename.endswith(".json"):
            key_name = filename.replace('.json', '')
            file_path = os.path.join(dir, filename)
            json_files[key_name] = file_path
    
    return json_files       

In [None]:
def calculate_error(json_file, target):
    with open(json_file, 'r') as f:
        data = json.load(f)

    errors = []

    for _, value in data.items():
        pred = value.get(target, {}).get('pred')
        label = value.get(target, {}).get('label')
        if pred is None or label is None:
            continue
        error = round((pred - label), 3)
        errors.append(error)   
    return errors

def calculate_abs_error(json_file, target):
    with open(json_file, 'r') as f:
        data = json.load(f)

    abs_error = []

    for _, value in data.items():
        pred = value.get(target, {}).get('pred')
        label = value.get(target, {}).get('label')
        if pred is None or label is None:
            continue
        error = round(abs(pred - label), 3)
        abs_error.append(error)   
    return abs_error


In [None]:
def numerical_task_error(classifiers):
    purity_errors = {}
    fga_errors = {}

    for classifier_name, json_file in classifiers.items():
        if 'Purity' in classifier_name:
            purity_errors[classifier_name] = calculate_error(json_file, "purity")
        if 'FGA' in classifier_name:
            fga_errors[classifier_name] = calculate_error(json_file, "FRACTION_GENOME_ALTERED")

    return pd.DataFrame(purity_errors), pd.DataFrame(fga_errors)

In [73]:
def report_accuracy(json_file, target):
    with open(json_file, 'r') as f:
        data = json.load(f)

    labels = []
    preds = []
    unique_labels = set()

    for _, value in data.items():
        pred = value.get(target, {}).get('pred')
        label = value.get(target, {}).get('label')
        if pred is None or label is None:
            continue
        labels.append(label)
        preds.append(pred)   
        unique_labels.add(label)
    
    return labels, preds

In [118]:
def categorical_task_staging(classifiers):
    staging_labels = []
    staging_preds = []
    classification_reports = {}
    confusion_matrices = {}

    for classifier_name, json_file in classifiers.items():

        if 'Staging' in classifier_name:
            staging_labels, staging_preds = report_accuracy(json_file, "AJCC_PATHOLOGIC_TUMOR_STAGE_reduced")
            classification_reports[classifier_name] = classification_report(staging_labels, staging_preds, target_names=['Early Stage', 'Late Stage'])
            confusion_matrices[classifier_name] = confusion_matrix(staging_labels, staging_preds)
            print(classification_reports[classifier_name])
            print(confusion_matrices[classifier_name])
    return classification_reports, confusion_matrices

In [140]:
def categorical_task_subtyping(classifiers):
    subtyping_labels = []
    subtyping_preds = []
    classification_reports = {}
    confusion_matrices = {}

    for classifier_name, json_file in classifiers.items():
        if 'Subtyping' in classifier_name:
            print(classifier_name)
            subtyping_labels, subtyping_preds = report_accuracy(json_file, "lung-cancer-subtyping")

            subtyping_labels_counts = pd.Series(list(subtyping_labels)).value_counts()
            print(subtyping_labels_counts)

            classification_reports[classifier_name] = classification_report(subtyping_labels, subtyping_preds, target_names=['normal', 'luad', 'lusc'])
            confusion_matrices[classifier_name] = confusion_matrix(subtyping_labels, subtyping_preds, labels=['normal', 'luad', 'lusc'])
            
            print(classification_reports[classifier_name])
            print(confusion_matrices[classifier_name])
            print()
    
    return classification_reports, confusion_matrices

In [126]:
json_dir = 'data'
json_files = generate_file_dicts(json_dir)

In [None]:
purity_errors, fga_errors = numerical_task_error(json_files)
purity_errors.to_csv('error/purity.csv', index_label='Sample')
fga_errors.to_csv('error/fga.csv', index_label='Sample')

In [141]:
subtyping_reports, subtyping_matrices = categorical_task_subtyping(json_files)

Subtyping
lusc      166
luad      146
normal     92
Name: count, dtype: int64
              precision    recall  f1-score   support

      normal       0.72      0.82      0.77       146
        luad       0.85      0.72      0.78       166
        lusc       0.88      0.95      0.91        92

    accuracy                           0.80       404
   macro avg       0.82      0.83      0.82       404
weighted avg       0.81      0.80      0.80       404

[[ 87   4   1]
 [  7 119  20]
 [  5  42 119]]

Subtyping-FGA
lusc      166
luad      146
normal     92
Name: count, dtype: int64
              precision    recall  f1-score   support

      normal       0.71      0.82      0.76       146
        luad       0.86      0.71      0.78       166
        lusc       0.87      0.93      0.90        92

    accuracy                           0.80       404
   macro avg       0.81      0.82      0.81       404
weighted avg       0.81      0.80      0.80       404

[[ 86   5   1]
 [  8 120  18]
 

In [124]:
staging_reports, staging_matrices = categorical_task_subtyping(json_files)

              precision    recall  f1-score   support

      normal       0.72      0.82      0.77       146
        luad       0.85      0.72      0.78       166
        lusc       0.88      0.95      0.91        92

    accuracy                           0.80       404
   macro avg       0.82      0.83      0.82       404
weighted avg       0.81      0.80      0.80       404

[[119  20   7]
 [ 42 119   5]
 [  4   1  87]]
              precision    recall  f1-score   support

      normal       0.71      0.82      0.76       146
        luad       0.86      0.71      0.78       166
        lusc       0.87      0.93      0.90        92

    accuracy                           0.80       404
   macro avg       0.81      0.82      0.81       404
weighted avg       0.81      0.80      0.80       404

[[120  18   8]
 [ 43 118   5]
 [  5   1  86]]
              precision    recall  f1-score   support

      normal       0.77      0.82      0.79       146
        luad       0.87      0.77    