In [None]:
from utils import f1_score, micro_f1_score

In [None]:
import re

def preprocess_text(input_text, text, delimiter=', ', is_pred=True):
    if is_pred:
        if "let me know" in text.lower() and (text.lower().endswith("! \n") or text.lower().endswith("!\n")):
            text = text[:text.lower().index("let me know")]

        if "## reason:" in text.lower():
            text = text[:text.lower().find("## reason:")].strip()
        
        if "**explanation:**" in text.lower():
            text = text[:text.lower().find("**explanation:**")].strip()
            
        while text.startswith('\n'):
            text = text[1:]
            
        while text.endswith('\n'):
            text = text[:-1]
            
        # if '\n' in text:
        #     text = text.split('\n')[-1].strip()
            
        lowered_text = text.lower()
        lowered_input_text = input_text.lower()
        
        keywords1 = ['answer:', 'categorize terms:', 'output:', 'are:', 'is:', 'terms:', 'domain:', 'terms**', 'output**', '**output:**']
        keywords2 = ['terms are', 'term is', 'should be:', 'would be:', 'answer is:', 'answer is ', 'output is:', 'the annotator should write']
        for keyword in keywords1+keywords2:
            if keyword in lowered_text:
                if keyword in keywords2 and lowered_text not in lowered_input_text:
                    text = text[lowered_text.find(keyword)+len(keyword):].strip()
                else:
                    text = text[lowered_text.rfind(keyword)+len(keyword):].strip()
                break
        
        # Eliminate any special characters at the prefix
        pattern1 = r'^[^\w]+'
        text = re.sub(pattern1, ' ', text)
    
    # Eliminate parantheses
    pattern2 = r'[\[\]\"\']'
    text = re.sub(pattern2, ' ', text)
        
    domain_words = []
    for domain_word in text.split(delimiter):
        while domain_word.strip().startswith('.') or domain_word.strip().endswith('.'):
            domain_word = domain_word.replace('.', ' ')
        domain_words.append(domain_word)
    
    if is_pred:
        splited_domain_words = []
        for domain_word in domain_words:
            lowered_domain_word = domain_word.lower()
            if lowered_domain_word.startswith('and ') or ' and ' in lowered_domain_word and lowered_domain_word.strip() not in lowered_input_text:
                domain_words.remove(domain_word)
                for splited_domain_word in domain_word.split('and '):
                    splited_domain_word = splited_domain_word.strip()
                    if splited_domain_word != '':
                        splited_domain_words.append(splited_domain_word)
        
        domain_words.extend(splited_domain_words)
        # check_for_colon = all(map(lambda x: ':' in x, domain_words))
        # if check_for_colon:
        #     domain_words = list(map(lambda x: x.split(':')[0].strip(), domain_words))
        # Check for repetitive answers
        cnt_threshold = 50
        cnt = 0
        for i in range(len(domain_words)):
            for j in range(i+1, len(domain_words)):
                if domain_words[i] == domain_words[j]:
                    cnt += 1
                if cnt > cnt_threshold:
                    domain_words = list(set(domain_words))
                    print('Repetitive answers are removed:', domain_words)
                    break
            if cnt > cnt_threshold:
                break
            
    domain_words = list(map(lambda x: x.strip(), domain_words))
    return domain_words   

In [None]:
import os
import pandas as pd
import json
from glob import glob

master_path = 'outputs/test'
reports = None

empty_paths = []
exclude_recalcuation_models = ['bart', 'roberta', 'mbart']
# Manually selected paths
unwanted_paths = []
really_calculate_f1_score = True
really_store_unwnated_paths = False

select_columns = ['model_name', 'dataset_name', 'num_shots', 'retrieval_method', 
                  'total_precision', 'total_recall', 
                  'micro_f1_score', 
                  'seed', 'individual_report_path',
                  'f1_score', 'precision', 'preprocessed_refs']

for dataset_name in os.listdir(master_path):
    dataset_path = os.path.join(master_path, dataset_name)
    for retrieval_style in os.listdir(dataset_path):
        retrieval_style_path = os.path.join(dataset_path, retrieval_style, "[0-9][0-9]_[0-9][0-9]_[0-9][0-9]_[0-9][0-9]_[0-9][0-9]")
        for output in glob(retrieval_style_path):
            report_path = os.path.join(output, 'report')
            individual_report_path = os.path.join(report_path, 'report_0.csv')
            overall_report_path = os.path.join(report_path, 'report_overall.csv')
            config_path = os.path.join(output, 'config.json')
            
            try:
                with open(overall_report_path, 'r') as f:
                    overall_report = json.load(f)
                
                with open(config_path, 'r') as f:
                    config = json.load(f)
            except Exception as e:
                print(f"Error: {e}, overall_report_path: {overall_report_path}, config_path: {config_path}")
                empty_paths.append(output)
                continue  
            ############## Manually selected paths ##############
            if really_store_unwnated_paths and 'retrieval_method' in config and config['prompt_style'] == 'default' and config['retrieval_method'] == 'default' and 'mbart' not in config['model_name'] and 'roberta' not in config['model_name']:
                unwanted_paths.append(output)
            ####################################################
                
            config['model_name'] = os.path.basename(config['model_name'])
            overall_report = pd.DataFrame([overall_report])
            config = pd.DataFrame([config])
            individual_report_path_df = pd.DataFrame({"individual_report_path": individual_report_path}, index=[0])
            report = pd.concat([overall_report, config, individual_report_path_df], axis=1)
            
            if reports is None:
                reports = report
                continue
            else:
                reports = pd.concat([reports, report], axis=0)
                
duplicate_columns = select_columns.copy().remove('individual_report_path')
reports = reports[select_columns]
reports = reports.drop_duplicates(subset=duplicate_columns).reset_index(drop=True)
reports.head(5)

In [None]:
from ast import literal_eval
from scipy.stats import spearmanr
import numpy as np

retrieval_method_to_model_name = {  
                                    'fastkassim': '',
                                    'default':'bge-large-en-v1.5', 
                                    'default_w_ins': 'bge-en-icl', 
                                    'bm25': '',
                                    'random': '',
                                    }
retrieval_method_dfs = []
model_names = ['gemma-2-9b-it', 'Meta-Llama-3.1-8B-Instruct', 'Mistral-Nemo-Instruct-2407']
dataset_names = ['ACTER', 'ACL-RD', 'BCGM']

# Select rows where model_name is in model_names
retrieval_method_df = reports[reports['model_name'].isin(model_names)]
retrieval_method_df = retrieval_method_df[retrieval_method_df['dataset_name'].isin(dataset_names)]
retrieval_method_df = retrieval_method_df[retrieval_method_df['num_shots']!=0.0]
retrieval_method_df = retrieval_method_df[retrieval_method_df['num_shots'].isin([5])]   
retrieval_method_df = retrieval_method_df[retrieval_method_df['seed'].isin([42, 1000, 2000, 3000])]   

in_domain_datasets = ['ACL-RD','BCGM']
cross_domain_datasets = ['ACTER']

for retrieval_method in retrieval_method_to_model_name:
    cross_domain = {'correlation': [], 'overlap_ratio': []}
    in_domain = {'correlation': [], 'overlap_ratio': []}
    for dataset_name in dataset_names:
        corr_list = []
        label_frequency_list = []
        ds_df = retrieval_method_df[retrieval_method_df['dataset_name'] == dataset_name]
        for model_name in model_names:
            df1 = ds_df[ds_df['retrieval_method'] == retrieval_method]
            df2 = df1[df1['model_name']==model_name]   
            label_frequency = []
            f1_scores = []
        
            for i in range(df2.shape[0]):
                row = df2.iloc[i]
                preprocessed_refs = row['preprocessed_refs']
                retrieved_result = row['retrieved_result']
                if not isinstance(retrieved_result, list) or not isinstance(preprocessed_refs, str):
                    continue
                
                preprocessed_refs = literal_eval(preprocessed_refs)
                f1_score = [float(f1_score) for f1_score in row['f1_score'].split(',')]
                total_tp = 0
                total_pr = 0
                for rr, pr, fs in zip(retrieved_result, preprocessed_refs, f1_score):
                    if pr == ['No term']:
                        continue
                    
                    label_set = set()
                    for r in rr:
                        label = r['label']
                        if len(label) == 0:
                            label = ['No term']

                        for l in label:
                            label_set.add(l.lower())

                    pr = set([p.lower() for p in pr])
                    tp = label_set & pr
                    total_tp += len(tp)
                    total_pr += len(pr)

                    precision = len(tp)/len(pr)
                    recall = len(tp)/len(pr)

                    label_frequency.append(precision)
                    f1_scores.append(fs)

            if np.all(np.array(label_frequency)==0.0) or np.all(np.array(f1_scores)==0.0):
                corr, p_value = 0, 0
            elif len(label_frequency) == 0 or len(f1_scores) == 0:
                corr, p_value = np.nan, np.nan
                print('Correlation is nan')
            else:

                corr, p_value = spearmanr(f1_scores, label_frequency)
                
                corr_list.append(corr)
                label_frequency_list.extend(label_frequency)
        
        if dataset_name in in_domain_datasets:
            in_domain['correlation'].extend(corr_list)
            in_domain['overlap_ratio'].extend(label_frequency_list)
        elif dataset_name in cross_domain_datasets:
            cross_domain['correlation'].extend(corr_list)
            cross_domain['overlap_ratio'].extend(label_frequency_list)
            
    print(f"Cross-domain method: {retrieval_method}, correlation: {np.mean(cross_domain['correlation'])*100}, overlap_ratio:{np.mean(cross_domain['overlap_ratio'])*100}")
    print(f"In-domain method: {retrieval_method}, correlation: {np.mean(in_domain['correlation'])*100}, overlap_ratio:{np.mean(in_domain['overlap_ratio'])*100}")
    print()

# Comparison between different retrieval methods

In [None]:
plm_model_names = ['roberta-large', 'bart-large']

In [None]:
import seaborn as sns
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt

figure_output_dir = 'figures'
colors =  {'zero_shot': 'yellow', 'default':'blue', 'default_w_ins':'purple', 'random':'green', 'fastkassim':'red', 'bm25':'orange'}
dataset_name_to_title = {'ACTER':'ACTER', 'ACL-RD':'ACL RD-TEC 2.0', 'BCGM':'BCGM'}
retrieval_method_to_legend = {'zero_shot': 'Zero-shot', 'default':'BGE-large-en', 'default_w_ins':'BGE-en-icl', 'random':'Random', 'fastkassim':'FastKASSIM', 'bm25': 'BM25'}
default_criteria = [list(reports['temperature']==0.01), 
                    list(reports['do_sample']==False), 
                    list(reports['seed'].isin([42, 1000, 2000, 3000]))]
best_syntatic_score = {}
best_semantic_score = {}

correlations_per_method = {'default':[], 'random':[], 'fastkassim':[], 'bm25':[], 'default_w_ins':[]}
for model_name in ['gemma-2-9b-it', 'Meta-Llama-3.1-8B-Instruct', 'Mistral-Nemo-Instruct-2407']:
    for dataset_name in dataset_name_to_title.keys():
        for prompt_style in ['default']:    
            identical_criteria = default_criteria + [list(reports['dataset_name']==dataset_name), list(reports['prompt_style'] == prompt_style), list(reports['model_name']==model_name)]
            
            identical_criteria = identical_criteria + [list(reports['num_shots']!=0)]
            meets_criteria = []
            criteria = identical_criteria + [list(reports['retrieval_method']=='default'), list(reports['seed']==42), list(reports['default_prompt_style']==0)]  # Defulat prompt style이 0이 아닌 경우도 있어야 함
            for i in zip(*criteria):
                meets_criteria.append(all(i))
            default_reports = reports[meets_criteria]
            
            meets_criteria = []
            criteria = identical_criteria + [list(reports['retrieval_method']=='fastkassim'), list(reports['default_prompt_style']==0)]
            for i in zip(*criteria):
                meets_criteria.append(all(i))
            fastkassim_reports = reports[meets_criteria]

            meets_criteria = []
            criteria = identical_criteria + [list(reports['retrieval_method']=='bm25')]
            for i in zip(*criteria):
                meets_criteria.append(all(i))
            bm25_reports = reports[meets_criteria]

            for r in bm25_reports.iterrows():
                print(f"Dataset: {dataset_name}, Model: {model_name}, Retrieval method: bm25")
                print(f"Num Shots: {r[1]['num_shots']}, Precision: {r[1]['total_precision']} Recall: {r[1]['total_precision']} F1 score: {r[1]['micro_f1_score']}")
                
            meets_criteria = []
            criteria = identical_criteria + [list(reports['retrieval_method']=='random'), list(reports['seed'].isin([1000, 3000]))]
            for i in zip(*criteria):
                meets_criteria.append(all(i))

            random_reports = reports[meets_criteria]  
            
            meets_criteria = []
            criteria = identical_criteria + [list(reports['retrieval_method']=='default_w_ins')]
            for i in zip(*criteria):
                meets_criteria.append(all(i))

            dwi_reports = reports[meets_criteria]  

            meets_criteria = []
            criteria = identical_criteria + [list(reports['retrieval_method']=='bm25')]
            for i in zip(*criteria):
                meets_criteria.append(all(i))

            bm25_reports = reports[meets_criteria]  

            default_reports = default_reports.sort_values('num_shots')
            random_reports = random_reports.sort_values('num_shots')
            bm25_reports = bm25_reports.sort_values('num_shots')
            fastkassim_reports = fastkassim_reports.sort_values('num_shots')
            dwi_reports = dwi_reports.sort_values('num_shots')

            plt.title(f"{model_name} performance on {dataset_name_to_title[dataset_name]}", fontweight='bold')
            plt.xlabel('Number of Demonstrations')
            plt.ylabel('Micro F1 Score')
            plt.xticks(default_reports['num_shots'])
            for selected_reports in [default_reports, dwi_reports, random_reports, fastkassim_reports, bm25_reports]:
                try:
                    _model_name = selected_reports['model_name'].iloc[0]
                    _retrieval_method = selected_reports['retrieval_method'].iloc[0]
                    color = colors[selected_reports['retrieval_method'].values[0]]
                    if selected_reports is random_reports:
                        random_mean_scores = selected_reports.groupby('num_shots')['micro_f1_score'].mean()
                        ramdom_prec_mean_scores = selected_reports.groupby('num_shots')['total_precision'].mean()
                        random_recall_mean_scores = selected_reports.groupby('num_shots')['total_recall'].mean()
                        corr = np.corrcoef(selected_reports['num_shots'], selected_reports['micro_f1_score'])[0][1]
                        correlations_per_method[selected_reports['retrieval_method'].values[0]].append(corr)
                        
                        sns.lineplot(x='num_shots', y='micro_f1_score', data=selected_reports, ci=95, label='Random', color=color)
                        plt.scatter(default_reports['num_shots'], random_mean_scores, color=color)
                    else:
                        figure_label = retrieval_method_to_legend[selected_reports['retrieval_method'].values[0]]
                        plt.plot(selected_reports['num_shots'], selected_reports['micro_f1_score'], label=figure_label, color=color)
                        plt.scatter(selected_reports['num_shots'], selected_reports['micro_f1_score'], color=color)
                        corr = np.corrcoef(selected_reports['num_shots'], selected_reports['micro_f1_score'])[0][1]
                        correlations_per_method[selected_reports['retrieval_method'].values[0]].append(corr)
                        print(f"{model_name} correlation on {dataset_name_to_title[dataset_name]} for {figure_label} is {corr}")
                except Exception as e:
                    print(f"Error: {e}")
                    continue
            plt.legend()
            # Save the figure
            plt.savefig(os.path.join(figure_output_dir, f"{model_name}_{dataset_name}_{prompt_style}.png"))
            plt.show()
            default_best_score = max(default_reports['micro_f1_score'])
            default_best_precision = default_reports[default_reports['micro_f1_score'] == default_best_score]['total_precision'].values[0]
            default_best_recall = default_reports[default_reports['micro_f1_score'] == default_best_score]['total_recall'].values[0]
        
            default_10_report = default_reports[default_reports['num_shots'] == 10]
            default_10_em = default_10_report['total_em_score'].values[0]*100
            default_10_score = default_10_report['micro_f1_score'].values[0]
            default_10_precision = default_10_report['total_precision'].values[0]
            default_10_recall = default_10_report['total_recall'].values[0]
            
            num_shot_default_best_score = default_reports[default_reports['micro_f1_score'] == default_best_score]['num_shots'].values[0]   
            
            random_best_score = max(random_mean_scores)
            random_best_precision = max(ramdom_prec_mean_scores)
            random_best_recall = max(random_recall_mean_scores)
            
            random_10_report = random_reports[random_reports['num_shots'] == 10]
            random_10_em = random_10_report['total_em_score'].mean()*100
            random_10_score = random_10_report['micro_f1_score'].mean()
            random_10_precision = random_10_report['total_precision'].mean()
            random_10_recall = random_10_report['total_recall'].mean()
            
            fastkassim_best_score = max(fastkassim_reports['micro_f1_score'])
            fastkassim_best_precision = fastkassim_reports[fastkassim_reports['micro_f1_score'] == fastkassim_best_score]['total_precision'].values[0]
            fastkassim_best_recall = fastkassim_reports[fastkassim_reports['micro_f1_score'] == fastkassim_best_score]['total_recall'].values[0]
            
            fastkassim_10_report = fastkassim_reports[fastkassim_reports['num_shots'] == 10]
            fastkassim_10_em = fastkassim_10_report['total_em_score'].values[0]*100
            fastkassim_10_score = fastkassim_10_report['micro_f1_score'].values[0]
            fastkassim_10_precision = fastkassim_10_report['total_precision'].values[0]
            fastkassim_10_recall = fastkassim_10_report['total_recall'].values[0]
            
            num_shot_fastkassim_best_score = fastkassim_reports[fastkassim_reports['micro_f1_score'] == fastkassim_best_score]['num_shots'].values[0]
            
            dwi_best_score = max(dwi_reports['micro_f1_score'])
            dwi_best_precision = dwi_reports[dwi_reports['micro_f1_score'] == dwi_best_score]['total_precision'].values[0]
            dwi_best_recall = dwi_reports[dwi_reports['micro_f1_score'] == dwi_best_score]['total_recall'].values[0]
            
            dwi_10_report = dwi_reports[dwi_reports['num_shots']==10]
            dwi_10_em = dwi_10_report['total_em_score'].values[0]*100
            dwi_10_score = dwi_10_report['micro_f1_score'].values[0]
            dwi_10_precision = dwi_10_report['total_precision'].values[0]
            dwi_10_recall = dwi_10_report['total_recall'].values[0]
            
            num_shot_dwi_best_score = dwi_reports[dwi_reports['micro_f1_score'] == dwi_best_score]['num_shots'].values[0]

            
            bm25_best_score = max(bm25_reports['micro_f1_score'])
            bm25_best_precision = bm25_reports[bm25_reports['micro_f1_score'] == bm25_best_score]['total_precision'].values[0]
            bm25_best_recall = bm25_reports[bm25_reports['micro_f1_score'] == bm25_best_score]['total_recall'].values[0]
        
            bm25_10_report = bm25_reports[bm25_reports['num_shots'] == 10]
            bm25_10_em = bm25_10_report['total_em_score'].values[0]*100
            bm25_10_score = bm25_10_report['micro_f1_score'].values[0]
            bm25_10_precision = bm25_10_report['total_precision'].values[0]
            bm25_10_recall = bm25_10_report['total_recall'].values[0]
            
            num_shot_bm25_best_score = bm25_reports[bm25_reports['micro_f1_score'] == bm25_best_score]['num_shots'].values[0]
            
            if dataset_name in ['ACTER', 'GENIA_to_ACL-RD', 'ACL-RD_to_GENIA']:
                best_syntatic_score[model_name] = fastkassim_best_score
            elif dataset_name in ['ACL-RD', 'GENIA']:
                best_semantic_score[model_name] = default_best_score
            
            print(f"#################################### REPORTS ####################################")
            print(f"Model: {model_name}, Dataset: {dataset_name}, Prompt style: {prompt_style}")
            print("Default best score - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}".format(default_best_precision, default_best_recall, default_best_score), f"at {num_shot_default_best_score} shots")
            print('Default with Instruction best score - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}'.format(dwi_best_precision, dwi_best_recall, dwi_best_score), f"at {num_shot_dwi_best_score} shots")
            print('BM25 best score - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}'.format(bm25_best_precision, bm25_best_recall, bm25_best_score), f"at {num_shot_bm25_best_score} shots")
            print('Random best score - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}'.format(random_best_precision, random_best_recall, random_best_score), f"at {num_shot_fastkassim_best_score} shots")
            print('Fastkassim best score - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}'.format(fastkassim_best_precision, fastkassim_best_recall, fastkassim_best_score), f"at {num_shot_fastkassim_best_score} shots")
            print()
            # print('Relative similarity best score - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}'.format(relative_similarity_best_precision*100, relative_similarity_best_recall*100, relative_similarity_best_score*100), f"at {num_shot_relative_similarity_best_score} shots")

            print("Default 10 shots - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}, EM: {:.1f}".format(default_10_precision, default_10_recall, default_10_score, default_10_em))
            print('Default with Instruction 10 shots - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}, EM: {:.1f}'.format(dwi_10_precision, dwi_10_recall, dwi_10_score, dwi_10_em))
            print('BM25 10 shots - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}, EM: {:.1f}'.format(bm25_10_precision, bm25_10_recall, bm25_10_score, bm25_10_em))
            print('Random 10 shots - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}, EM: {:.1f}'.format(random_10_precision, random_10_recall, random_10_score, random_10_em))
            print('Fastkassim 10 shots - precision: {:.1f}, recall: {:.1f}, f1 score: {:.1f}, EM: {:.1f}'.format(fastkassim_10_precision, fastkassim_10_recall, fastkassim_10_score, fastkassim_10_em))

            random_scores = list(random_reports.groupby('num_shots')['micro_f1_score'])
            
            # Calculate the margin of error
            for ramdom_prec_mean_score, random_recall_mean_score, random_mean_score, random_score in zip(ramdom_prec_mean_scores, random_recall_mean_scores, random_mean_scores, random_scores):
                random_score = random_score[1].values
                confidence_level = 0.95
                sem = stats.sem(random_score)
                degrees_of_freedom = len(random_score) - 1

                if random_10_score == random_mean_score:
                    f1_confidence_interval = stats.t.interval(confidence_level, degrees_of_freedom, random_mean_score, sem)
                    prec_confidence_interval = stats.t.interval(confidence_level, degrees_of_freedom, ramdom_prec_mean_score, sem) 
                    recall_confidence_interval = stats.t.interval(confidence_level, degrees_of_freedom, random_recall_mean_score, sem)
                    
                    print('Confidence interval for F1 random score: {:.1f}, inverval: {:.1f} to {:.1f}, margin of error: {:.1f}'.format(random_mean_score, f1_confidence_interval[0], f1_confidence_interval[1], (f1_confidence_interval[1]-random_mean_score)))
                    print('Confidence interval for Precision random score: {:.1f}, inverval: {:.1f} to {:.1f}, margin of error: {:.1f}'.format(ramdom_prec_mean_score, prec_confidence_interval[0], prec_confidence_interval[1], (prec_confidence_interval[1]-ramdom_prec_mean_score)))
                    print('Confidence interval for Recall random score: {:.1f}, inverval: {:.1f} to {:.1f}, margin of error: {:.1f}'.format(random_recall_mean_score, recall_confidence_interval[0], recall_confidence_interval[1], (recall_confidence_interval[1]-random_recall_mean_score)))
            print(f"#####################################################################################\n")

In [None]:
# Baseline performance
best_plm_score_acter = {}
best_plm_score_acl_rd = {}
for model_name in plm_model_names:
    for dataset_name in ['ACTER', 'ACL-RD']:
        criteria = [list(reports['dataset_name']==dataset_name), list(reports['model_name']==model_name)]
        meets_criteria = []
        for i in zip(*criteria):
            meets_criteria.append(all(i))
        plm_reports = reports[meets_criteria]
        best_plm_score = plm_reports['micro_f1_score'].max()
        best_prec_score = plm_reports[plm_reports['micro_f1_score'] == best_plm_score]['total_precision'].values[0]
        best_recall_score = plm_reports[plm_reports['micro_f1_score'] == best_plm_score]['total_recall'].values[0]
        
        print('Datset name: {}, Model name:{}, micro f1 score: {:.1f}, precision: {:.1f}, recall: {:.1f}'.format(dataset_name, model_name, best_plm_score*100, best_prec_score*100, best_recall_score*100))
        if dataset_name == 'ACTER':
            best_plm_score_acter[model_name] = best_plm_score
        elif dataset_name == 'ACL-RD':
            best_plm_score_acl_rd[model_name] = best_plm_score

# Dataset statistics

In [None]:
from datasets import load_from_disk

acter_short_ds = load_from_disk('dataset/ACTER/huggingface')
for dataset_type in acter_short_ds:
    ds = acter_short_ds[dataset_type]
    all_num_words = 0
    all_num_labels = 0
    for row in ds:
        words = row['text'].split(' ')
        labels = row['label']
        all_num_words += len(words)
        all_num_labels += len(labels)
    print(f'Short Dataset name: ACTER, Dataset type: {dataset_type}, Avg words: {all_num_words/len(ds)}, Avg labels: {all_num_labels/len(ds)}')

print()   

acl_rd_short_ds = load_from_disk('dataset/ACL-RD/huggingface')
for dataset_path in acl_rd_short_ds:
    ds = acl_rd_short_ds[dataset_path]
    all_num_words = 0
    all_num_labels = 0
    for row in ds:
        words = row['text'].split(' ')
        labels = row['label']
        all_num_words += len(words)
        all_num_labels += len(labels)
    print(f'Short Dataset name: ACL-RD, Dataset type: {dataset_path}, Avg words: {all_num_words/len(ds)}, Avg labels: {all_num_labels/len(ds)}')
    
print()

bcgm_short_ds = load_from_disk('dataset/BCGM/huggingface')
for dataset_path in bcgm_short_ds:
    ds = bcgm_short_ds[dataset_path]
    all_num_words = 0
    all_num_labels = 0
    for row in ds:
        words = row['text'].split(' ')
        labels = row['label']
        all_num_words += len(words)
        all_num_labels += len(labels)
    print(f'Short Dataset name: BCGM, Dataset type: {dataset_path}, Avg words: {all_num_words/len(ds)}, Avg labels: {all_num_labels/len(ds)}')
print()