In [None]:
from glob import glob
import json
import pandas as pd
import random
from collections import Counter
from tqdm import tqdm

In [None]:
def calculate(fp, answers, index_list=None):
    with open(fp) as f:
        data=[json.loads(l) for l in f.readlines()]
    correct=[]
    res=[]
    model_name=None
    if len(data)>0:
        for d in data:
            if index_list is not None:
                if d['question_id'] not in index_list:
                    continue
            a=answers.loc[d['question_id']]
            try:
                pred=d['pred_response'].lower().strip()
            except:
                pred=d['pred_response'][0].lower().strip()
            gt=d['gt_response'].lower().strip()
            if pred==gt:
                d['correct']=True
            elif gt in pred:
                if gt=='no' and 'not possible' in pred:
                    d['correct']=True           
                elif 'no' in pred.split() or 'no,' in pred.split() or 'no.' in pred.split():
                    if gt=='yes':
                        d['correct']=False
                    else:
                        d['correct']=True
                elif 'yes' in pred.split() or 'yes,' in pred.split() or 'yes.' in pred.split():
                    if gt=='yes':
                        d['correct']=True
                    else:
                        d['correct']=False
                else:
                    d['correct']=False     
            elif 'consistent with' in pred or 'may have' in pred or 'indicative of' in pred or 'evidence of' in pred:
                if gt=='yes':
                    disease=d['prompt'].split('this patient have')[-1].rstrip('?').strip()
                    if '/' in disease:
                        for di in disease.split('/'):
                            if f'consistent with {di}' in pred or f'may have {di}' in pred or f'indicative of {di}' in pred or f'evidence of {di}' in pred:
                                if 'not' in pred.split():
                                    d['correct']=False
                                else:
                                    d['correct']=True
                                break
                    else:
                        if f'consistent with {disease}' in pred or f'may have {disease}' in pred or f'indicative of {disease}' in pred or f'evidence of {disease}' in pred:
                            if 'not' in pred.split():
                                d['correct']=False
                            else:
                                d['correct']=True
                else:
                    disease=d['prompt'].split('this patient have')[-1].rstrip('?').strip()
                    if '/' in disease:
                        for di in disease.split('/'):
                            if f'consistent with {di}' in pred or f'may have {di}' in pred or f'indicative of {di}' in pred or f'evidence of {di}' in pred:
                                if 'not' in pred.split():
                                    d['correct']=True
                                else:
                                    d['correct']=False
                                break
                    else:
                        if f'consistent with {disease}' in pred or f'may have {disease}' in pred or f'indicative of {disease}' in pred or f'evidence of {disease}' in pred:
                            if 'not' in pred.split():
                                d['correct']=True
                            else:
                                d['correct']=False
                if 'correct' not in d:
                    d['correct']=False
            else:
                d['correct']=False
            if 'LLaVA-NeXT'==d['model_id']:
                if '7b' in fp:
                    model_name='llava-onevision-qwen2-7b'
                else:
                    model_name='llava-onevision-qwen2-72b'
            else:
                model_name=d['model_id']
            correct.append(d['correct'])
            res.append(pred)
    return model_name, correct, res

# Balanced sampled evaluation

In [None]:
random_state=989
with open('CXR-Reason-Golden.jsonl') as f:
    answers=[json.loads(l) for l in f.readlines()][0]
answers=pd.DataFrame(answers)
answers['disease']=answers.conversations.apply(lambda x: x[0]['value'].split('this patient have')[-1].rstrip('?').strip())
no_findings=answers[answers.question_type=='no_findings']
all_findings=answers[answers.question_type=='all_findings']

no_findings_sampled=no_findings.groupby(['image','disease']).sample(1, random_state=random_state)
all_findings_sampled=all_findings.groupby(['image','disease']).sample(1, random_state=random_state)
findings=pd.DataFrame([a for _, a in answers.iterrows() if a['question_type']=='findings'])
findings_anatomical=pd.DataFrame([a for _, a in answers.iterrows() if a['question_type']=='findings+anatomy'])
findings_sampled=findings.groupby(['image','disease']).sample(1, random_state=random_state)
findings_anatomical_sampled=findings_anatomical.groupby(['image','disease']).sample(1, random_state=random_state)

sampled_df=pd.concat([no_findings_sampled,all_findings_sampled,findings_sampled,findings_anatomical_sampled])

index_list=sampled_df.question_id.tolist()
answers.set_index('question_id',inplace=True)

In [None]:
result={}
response={}
for i in glob('GOLDEN_RESULT/*.jsonl'):
    model_name, correct, res = calculate(fp, answers, index_list)
    if model_name is not None:
        result[model_name]=correct
        response[model_name]=res

for k, v in result.items():
    answers[f'{k}_correct']=v
for k, v in response.items():
    answers[f'{k}_response']=v

for k in result.keys():
    print(k, len(answers[answers[f'{k}_correct']==True])/len(answers))
print()
question_types=answers.question_type.value_counts().index.tolist()

for k in result.keys():
    for qt in question_types:
        temp_df=answers[answers.question_type==qt]    
        print(k, qt, len(temp_df[temp_df[f'{k}_correct']==True])/len(temp_df))
    print()

#### CheXagent Disease Stratified Result

In [None]:
answers[answers.CheXagent_correct==True].groupby('disease').CheXagent_correct.value_counts()/answers.value_counts('disease')

In [None]:
qt_disease_df=pd.DataFrame()
for qt in question_types:
    temp_df=answers[answers.question_type==qt]  
    qt_disease_df[qt]=temp_df[temp_df.CheXagent_correct==True].groupby('disease').CheXagent_correct.value_counts()/temp_df.value_counts('disease')
qt_disease_df[['no_findings', 'findings', 'findings+anatomy', 'all_findings']]

# Full Golden Data

In [None]:
with open('CXR-Reason-Golden.jsonl') as f:
    full_answers=[json.loads(l) for l in f.readlines()][0]

In [None]:
full_answers=pd.DataFrame(full_answers)
full_answers['disease']=full_answers.conversations.apply(lambda x: x[0]['value'].split('this patient have')[-1].rstrip('?').strip())
full_answers.set_index('question_id',inplace=True)

In [None]:
result={}
response={}
for i in glob('GOLDEN_RESULT/*.jsonl'):
    model_name, correct, res = calculate(fp, full_answers)
    if model_name is not None:
        result[model_name]=correct
        response[model_name]=res

for k, v in result.items():
    full_answers[f'{k}_correct']=v
for k, v in response.items():
    full_answers[f'{k}_response']=v

for k in result.keys():
    print(k, len(full_answers[full_answers[f'{k}_correct']==True])/len(full_answers))
print()
question_types=full_answers.question_type.value_counts().index.tolist()

for k in result.keys():
    for qt in question_types:
        temp_df=full_answers[full_answers.question_type==qt]    
        print(k, qt, len(temp_df[temp_df[f'{k}_correct']==True])/len(temp_df))
    print()

#### CheXagent Disease Stratified Result

In [None]:
full_answers[full_answers.CheXagent_correct==True].groupby('disease').CheXagent_correct.value_counts()/full_answers.value_counts('disease')

In [None]:
qt_disease_full_df=pd.DataFrame()
for qt in question_types:
    temp_df=full_answers[full_answers.question_type==qt]  
    qt_disease_full_df[qt]=temp_df[temp_df.CheXagent_correct==True].groupby('disease').CheXagent_correct.value_counts()/temp_df.value_counts('disease')
qt_disease_full_df[['no_findings', 'findings', 'findings+anatomy', 'all_findings']]

#### Specific Findings and Anatomical Structure Result

In [None]:
full_answers[full_answers.CheXagent_correct==True].groupby('question_type_specific').CheXagent_correct.value_counts()/full_answers.value_counts('question_type_specific')

#### Pneumonia result

In [None]:
pneumonia=full_answers[full_answers['disease'].str.contains('pneumonia')]
pneumonia[pneumonia['CheXagent_correct']==True].groupby('question_type_specific').CheXagent_correct.value_counts().sort_values(ascending=False)/pneumonia.value_counts('question_type_specific')

In [None]:
pneumonia_result=pneumonia[pneumonia['CheXagent_correct']==True].groupby('question_type_specific').CheXagent_correct.value_counts().sort_values(ascending=False)/pneumonia.value_counts('question_type_specific')
pneumonia_result.loc[pneumonia.value_counts('question_type_specific').head(10).index.tolist()].sort_values(ascending=False)

# Silver_Dataset

In [None]:
with open('CXR-Reason-Silver.jsonl') as f:
    silver_answers=[json.loads(l) for l in f.readlines()][0]

In [None]:
silver_answers=pd.DataFrame(silver_answers)
silver_answers['disease']=silver_answers.conversations.apply(lambda x: x[0]['value'].split('this patient have')[-1].rstrip('?').strip())
silver_answers.set_index('question_id',inplace=True)

In [None]:
result={}
response={}
for i in glob('SILVER_RESULT/*.jsonl'):
    model_name, correct, res = calculate(fp, silver_answers)
    if model_name is not None:
        result[model_name]=correct
        response[model_name]=res

for k, v in result.items():
    silver_answers[f'{k}_correct']=v
for k, v in response.items():
    silver_answers[f'{k}_response']=v

for k in result.keys():
    print(k, len(silver_answers[silver_answers[f'{k}_correct']==True])/len(silver_answers))
print()
question_types=silver_answers.question_type.value_counts().index.tolist()

for k in result.keys():
    for qt in question_types:
        temp_df=silver_answers[silver_answers.question_type==qt]    
        print(k, qt, len(temp_df[temp_df[f'{k}_correct']==True])/len(temp_df))
    print()