In [1]:
import pandas as pd
import re
import json
from collections import defaultdict, Counter
# import matplotlib.pyplot as plt

In [2]:
def load_medical_dict(file_path):
    with open(file_path, 'r') as file:
        medical_dict = json.load(file)
    medical_terms = set(term for terms in medical_dict.values() for term in terms)
    return medical_terms

In [3]:
def load_medical_dict(file_path):
    with open(file_path, 'r') as file:
        medical_dict = json.load(file)
    medical_terms = set(term for terms in medical_dict.values() for term in terms)
    return medical_terms

def preprocess(text):
    return re.sub(r'[^\w\s]', '', text.lower())

def find_incorrect_terms(trans, refs, medical_terms):
    trans_terms = set(preprocess(trans).split())
    refs_terms = set(preprocess(refs).split())
    return {term for term in refs_terms.intersection(medical_terms) if term not in trans_terms}

def check_corrections(incorrect_terms, model_corrected):
    corrected_terms = set(preprocess(model_corrected).split())
    return incorrect_terms.intersection(corrected_terms)

In [33]:
def get_med_stats(data_path, all_medical_terms):

    data = load_dataset(data_path)
    df = data['test'].to_pandas()

    df['med_count_refs'] = df['refs'].apply(lambda x: count_medical_terms(x, all_medical_terms))
    df['med_count_trans'] = df['trans'].apply(lambda x: count_medical_terms(x, all_medical_terms))

    total_medical_terms_count_refs = df['med_count_refs'].sum()
    total_medical_terms_count_trans = df['med_count_trans'].sum()

    print(f"Total count of all medical terms in refs: {total_medical_terms_count_refs}")
    print(f"Total count of all medical terms in trans: {total_medical_terms_count_trans}")

In [42]:
def count_medical_terms(text, medical_terms):
    text = preprocess_text(text)  
    return sum(text.count(term) for term in medical_terms)

In [45]:
def preprocess_text(text):
    return text.lower()


In [46]:
def eval_med_terms(med_dict, model_outs, output_path):

  medical_terms = load_medical_dict(med_dict)

  # all_medical_terms = [term for terms in medical_terms.values() for term in terms]

  df = pd.read_csv(model_outs)

  df = df.dropna()

  total_incorrect = 0
  total_corrected = 0

  incorrect_lists = []
  corrected_lists = []


  for _, row in df.iterrows():
      incorrect_terms = find_incorrect_terms(row['trans'], row['refs'], medical_terms)
      corrected_terms = check_corrections(incorrect_terms, row['model_corrected'])

      total_incorrect += len(incorrect_terms)
      total_corrected += len(corrected_terms)

      incorrect_lists.append(incorrect_terms)
      corrected_lists.append(corrected_terms)

      if total_incorrect > 0:
          improvement_percentage = (total_corrected / total_incorrect) * 100
      else:
          improvement_percentage = 0

  df['med_count_refs'] = df['refs'].apply(lambda x: count_medical_terms(x, medical_terms))
  df['med_count_trans'] = df['trans'].apply(lambda x: count_medical_terms(x, medical_terms))

  total_medical_terms_count_refs = df['med_count_refs'].sum()
  total_medical_terms_count_trans = df['med_count_trans'].sum()

    # print(f"Total count of all medical terms in refs: {total_medical_terms_count_refs}")
    # print(f"Total count of all medical terms in trans: {total_medical_terms_count_trans}")
    
  with open(output_path, 'w') as text_file:
        text_file.write(output_path)
        text_file.write('Total count of all medical terms in refs : {}\n' .format(total_medical_terms_count_refs))
        text_file.write('Total count of all medical terms in trans: {}\n'.format(total_medical_terms_count_trans))
        text_file.write('Total incorrect/missing medical terms: {}\n'.format(total_incorrect))
        text_file.write('Total corrected terms in model_corrected : {}\n' .format(total_corrected))
        text_file.write('Improvement Percentage: {}\n'.format(improvement_percentage))

  # print(f"Total incorrect/missing medical terms: {total_incorrect}")
  # # print(incorrect_lists)
  # print(f"Total corrected terms in 'model_corrected': {total_corrected}")
  # # print(corrected_lists)
  # print(f"Improvement Percentage: {improvement_percentage:.2f}%")

In [47]:
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/gcd/zero_gcd.csv', 'n-shot/gpt3-5/gcd/zero_gcd_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/gcd/one_gcd.csv', 'n-shot/gpt3-5/gcd/one_gcd_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/gcd/two_gcd.csv', 'n-shot/gpt3-5/gcd/two_gcd_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/gcd/few_gcd.csv', 'n-shot/gpt3-5/gcd/few_gcd_med.txt')

In [48]:
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/babylon/zero_babylon.csv', 'n-shot/gpt3-5/babylon/zero_babylon_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/babylon/one_babylon.csv', 'n-shot/gpt3-5/babylon/one_babylon_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/babylon/two_babylon.csv', 'n-shot/gpt3-5/babylon/two_babylon_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/babylon/few_babylon.csv', 'n-shot/gpt3-5/babylon/few_babylon_med.txt')

In [49]:
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/kaggle/zero_kaggle.csv', 'n-shot/gpt3-5/kaggle/zero_kaggle_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/kaggle/one_kaggle.csv', 'n-shot/gpt3-5/kaggle/one_kaggle_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/kaggle/two_kaggle.csv', 'n-shot/gpt3-5/kaggle/two_kaggle_med.txt')
eval_med_terms('medical_terms.json', 'n-shot/gpt3-5/kaggle/few_kaggle.csv', 'n-shot/gpt3-5/kaggle/few_kaggle_med.txt')