In [None]:
import os
import argparse
import json
import matplotlib.pyplot as plt
import re
import numpy as np
def get_args_parser():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--type', default='TCGA_BRCA', type=str, help='name in TCGAs')
    parser.add_argument('--root', default='../GDC_DATA', type=str, help='path to TCGA')
    parser.add_argument('--savepath', default='./results/BRCA_dino_scratch/vis', type=str, help='path to wsi-text pairs')
    args, unparsed = parser.parse_known_args()

    for arg in vars(args):
        if vars(args)[arg] == 'True':
            vars(args)[arg] = True
        elif vars(args)[arg] == 'False':
            vars(args)[arg] = False

    return args

def clean_report_brca(report):
    report_cleaner = lambda t: (t.replace('\n', ' ').replace('  ', ' ') \
        .replace('  ', ' ').replace('  ', ' ')\
        .replace(' 10. ', ' ').replace(' 11. ', ' ').replace(' 12. ', ' ').replace(' 13. ', ' ').replace(' 14.', ' ')    \
        .replace(' 1. ', ' ').replace(' 2. ', ' ') \
        .replace(' 3. ', ' ').replace(' 4. ', ' ').replace(' 5. ', ' ').replace(' 6. ', ' ').replace(' 7. ', ' ').replace(' 8. ', ' ') .replace(' 9. ', ' ')   \
        .strip().lower() + ' ').split('. ')
    sent_cleaner = lambda t: re.sub('[#,?;*!^&_+():-\[\]{}]', '', t.replace('"', '').
                                replace('\\', '').replace("'", '').strip().lower())
    tokens = [sent_cleaner(sent) for sent in report_cleaner(report)]
    report = ' . '.join(tokens) 
    return report
print('ready')

def is_idc(text):
    if 'ductal carcinoma' in text:
        return 1
    else:
        return 0
def is_pr(text):
    if 'positive' in text:
        return 1
    else:
        return 0

In [None]:
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('maxent_ne_chunker')
nltk.download('words')

In [None]:
def entity_match(reports,gt):
    entities = []
    entities_gt = []
   
    sentence = nltk.sent_tokenize(reports)
    for sent in sentence:
        for c in nltk.pos_tag(nltk.word_tokenize(sent)):
           
            if c[1].startswith('NN'):
                if not re.sub('([^\u0061-\u007a])', '', c[0])=='':
                    entities.append(c[0])
 
    sentence = nltk.sent_tokenize(gt)
    for sent in sentence:
        for c in nltk.pos_tag(nltk.word_tokenize(sent)):
            if c[1].startswith('NN'):
                if not re.sub('([^\u0061-\u007a])', '', c[0])=='':
                    entities_gt.append(c[0])

    count = 0             
    for e in entities:
        if e in entities_gt:
            count+=1
    if len(entities)==0:
        return False
    pr = count/len(entities)
    
    count = 0
    for e in entities_gt:
        if e in entities:
            count+=1
    if len(entities_gt)==0:
        return False
    rc = count/len(entities_gt)
    
    f = 2*rc*pr/(rc+pr+0.00001)
    
    return f

In [None]:
from sklearn.metrics import recall_score,f1_score, roc_auc_score, precision_score
from sksurv.metrics import concordance_index_censored
import json
args = get_args_parser()
root = os.path.join(args.savepath)
subtype_pred = []
subtype_target = []
pr_pred = []
pr_target = []
all_event_times = []
all_estimate = []
result={}
fact_vol = 0
fact_count = 0
for file in os.listdir(root):
    if not file.startswith('TCGA'):
        continue
    file_name = os.path.join(root,file)
    #print(file_name)
    with open(file_name) as f:
        data = json.loads(f.read())
        for item in data:
            #brca subtyping
            tgt = item['gts']
            predict = item['res']
            #fact entity reward
            if True:
                fact = entity_match(predict,tgt)
                if fact:
                    fact_vol+=fact
                    fact_count+=1 
                    
            if 'logical' in item['Question'][0]:
                tgt = item['gts']
                if not ('ductal carcinoma' in tgt or 'lobular carcinoma' in tgt):
                    continue
                
                
                subtype_pred.append(is_idc(predict))
                subtype_target.append(is_idc(tgt))
            #pr prediction
            if 'receptor' in item['Question'][0]:
                tgt = item['gts']
                if not tgt in ('negative','positive'):

                    continue
                
                pr_pred.append(is_pr(predict))
                pr_target.append(is_pr(tgt))
                
            if 'survival time' in item['Question'][0]:
                res = item['res']
                gts = item['gts']

                if not res.isdecimal():
                    continue
                all_event_times.append(eval(gts))
                all_estimate.append(eval(res))
                


r = recall_score(subtype_pred, subtype_target)
f1 =f1_score(subtype_pred, subtype_target)
p = precision_score(subtype_pred, subtype_target)

pr_r = recall_score(pr_pred, pr_target)
pr_p = precision_score(pr_pred, pr_target)
pr_f1 = f1_score(pr_pred, pr_target)
print(subtype_pred)
print(subtype_target)
cindex = 1-concordance_index_censored([True]*len(all_estimate), all_event_times, all_estimate, tied_tol=1e-08)[0]
result.update({'subtype_r':r, 'subtype_p':p,'subtype_f1':f1, 'pr_r':pr_r,'pr_p':pr_p, 'pr_f1':pr_f1,'fact':fact_vol/fact_count})
print(result)
print(f'cindex:{cindex}')

