In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

In [None]:
METRIC_FILE_BASELINE = 'data/final_model/all_metrics_baseline.json'
METRIC_FILE_BIOMEDBLIP = 'data/tuning/all_metrics_BioMedBLIP_Final Finetuning.json'
model_names = ['CNN-BiLSTM', 'BioMedBLIP']
output_image = 'data/model_comparison.png'
type_specific_output_image  = 'data/model_comparison_type_specific.png'

groups = ['Exact Match', 'F1 Scores', 'BLEU', 'ROUGE', 'METEOR', 'BERTScore']
question_types = ['OPEN', 'CLOSED']

In [None]:
def load_metrics(filepath):
    with open(filepath, 'r') as f:
        return json.load(f)

In [None]:
try:
    baseline_metrics = load_metrics(METRIC_FILE_BASELINE)
    biomedblip_metrics = load_metrics(METRIC_FILE_BIOMEDBLIP)
except FileNotFoundError as e:
    print(f"Error: {e}.")

In [None]:
def get_metric_group(data, group_name):
    if group_name == 'Exact Match':
        return {'Exact Match': data['exact_match']['exact_match']}
    
    elif group_name == 'F1 Scores':
        return {
            'Macro F1': data['classification']['macro_f1'],
            'Weighted F1': data['classification']['weighted_f1']
        }
    
    elif group_name == 'BLEU':
        return {
            'BLEU-1': data['bleu']['bleu1'],
            'BLEU-2': data['bleu']['bleu2'],
            'BLEU-3': data['bleu']['bleu3'],
            'BLEU-4': data['bleu']['bleu4']
        }
    elif group_name == 'ROUGE':
        return {
            'ROUGE-1': data['rouge']['rouge1'],
            'ROUGE-2': data['rouge']['rouge2'],
            'ROUGE-L': data['rouge']['rougeL']
        }
    elif group_name == 'METEOR':
        return {'METEOR': data['meteor']['meteor']}
    elif group_name == 'BERTScore':
        return {
            'Precision': data['bertscore']['bertscore_precision'],
            'Recall': data['bertscore']['bertscore_recall'],
            'F1': data['bertscore']['bertscore_f1']
        }
    return {}

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()
bar_width = 0.35

for i, group in enumerate(groups):
    ax = axes[i]
    
    # Extract data for this group
    m1 = get_metric_group(baseline_metrics, group)
    m2 = get_metric_group(biomedblip_metrics, group)

    keys = list(m1.keys())
    vals1 = [m1[k] for k in keys]
    vals2 = [m2[k] for k in keys]
    
    x = np.arange(len(keys))
    
    # Plot bars
    rects1 = ax.bar(x - bar_width/2, vals1, bar_width, label=model_names[0], color='skyblue', edgecolor='grey')
    rects2 = ax.bar(x + bar_width/2, vals2, bar_width, label=model_names[1], color='salmon', edgecolor='grey')
    
    # Styling
    ax.set_title(group, fontsize=12, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(keys)
    ax.set_ylim(0, 115)  # Scale y-axis to fit labels
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    
    # Add legend 
    ax.legend()
        
    # Add value labels on top of bars
    def add_labels(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.1f}',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)

    add_labels(rects1)
    add_labels(rects2)

plt.tight_layout()
plt.savefig(output_image)
print(f"Plot saved to {output_image}")
plt.show()

In [None]:
metrics_keys = ['f1', 'exact_match']
metrics_labels = ['F1 Score', 'Exact Match']

fig, axes = plt.subplots(1, 2, figsize=(14, 6))
bar_width = 0.35

for i, q_type in enumerate(question_types):
    ax = axes[i]
    
    # Extract values for the current question type
    # Check if key exists to avoid errors if a type is missing
    if q_type in baseline_metrics['by_type'] and q_type in biomedblip_metrics['by_type']:
        vals1 = [baseline_metrics['by_type'][q_type][k] for k in metrics_keys]
        vals2 = [biomedblip_metrics['by_type'][q_type][k] for k in metrics_keys]
    else:
        print(f"Warning: Type '{q_type}' not found in one of the files.")
        vals1, vals2 = [0, 0], [0, 0]

    x = np.arange(len(metrics_labels))
    
    # Create Bars
    rects1 = ax.bar(x - bar_width/2, vals1, bar_width, label=model_names[0], color='skyblue', edgecolor='grey')
    rects2 = ax.bar(x + bar_width/2, vals2, bar_width, label=model_names[1], color='salmon', edgecolor='grey')
    
    # Styling
    ax.set_title(f'Question Type: {q_type}', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics_labels, fontsize=12)
    ax.set_ylim(0, 115) # Extend Y-axis for labels
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    
    # Add legend
    ax.legend()
        
    # Helper to add labels on bars
    def add_labels(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.1f}',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=10)

    add_labels(rects1)
    add_labels(rects2)

plt.tight_layout()
plt.savefig(type_specific_output_image)
print(f"Plot saved to {type_specific_output_image}")
plt.show()