## Imports

In [None]:
import warnings
warnings.filterwarnings("ignore")

import os

import zipfile
import random
import json
import re
from matplotlib import pyplot as plt
from pprint import pprint
import numpy as np
from tqdm import tqdm
from PIL import Image
from collections import defaultdict
from pathlib import Path
from IPython.display import Image, display

from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import BlipForQuestionAnswering, BlipProcessor
from peft import get_peft_model, LoraConfig, PeftModel

from sklearn.metrics import accuracy_score, f1_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from bert_score import score as bert_score

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from utils import plot_ngram_analysis, print_all_metrics, plot_type_specific_comparison

## Constants

In [None]:
IMG_SIZE = (480, 480)
GLOBAL_SEED = 42

# Data
DATA_DIR = 'data/'
DATASET_PATH = os.path.join(DATA_DIR, 'dataset/')
IMAGE_PATH = os.path.join(DATA_DIR, 'imgs/')
TUNING_RESULTS_PATH = os.path.join(DATA_DIR, 'tuning/')
FINAL_MODEL_PATH = os.path.join(DATA_DIR, 'final_model_BioMedBLIP/')

# Huggingface Repository Information
repo_id = "BoKelvin/SLAKE"
repo_type = "dataset"
img_file = "imgs.zip"

# Model Definition
BASE_MODEL_NAME = "Salesforce/blip-vqa-base"
MODEL_NAME = "biomedblip/biomedblip"
MODEL_SAVE_NAME = "biomedblip"
CHECKPOINT_LOCATION = "Best BLIP Finetuning/VQA_generation_SLAKE(BLIP-MIMIC&ROCO-10)-006.pth"

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
def set_global_seed():
    random.seed(GLOBAL_SEED)
    np.random.seed(GLOBAL_SEED)
    torch.manual_seed(GLOBAL_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(GLOBAL_SEED)
        torch.cuda.manual_seed_all(GLOBAL_SEED)
        # For deterministic CuDNN operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_global_seed()

## Dataset Download and Saving

In [None]:
# Utility function for downloading and extracting ZIP file
def download_and_store_ZIP(filename, save_dir):
    print(f"Fetching file {filename} from {repo_id} repo")

    try:
        # Caches the file locally and returns the path to the cached file
        cached_zip_path = hf_hub_download(
          repo_id=repo_id,
          filename=filename,
          repo_type=repo_type
        )
        print(f"{filename} download complete. Cached at: {cached_zip_path}")

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Extract the contents
        print(f"Extracting to {save_dir}...")
        with zipfile.ZipFile(cached_zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_dir)

        print("Extraction complete.")
        print(f"{filename} files are located in: {os.path.abspath(save_dir)}")
    except Exception as e:
        print(f"Failed to download or extract {filename}: {e}")

# Scoping to English only
def filter_language(original):
    return original.filter(lambda data: data['q_lang'] == 'en')

# Download and store the dataset
def download_and_store_english_dataset():
    print(f"Downloading dataset from {repo_id} repo")

    # Load from Hugging Face
    original = load_dataset(repo_id)

    # Scope to English Only
    original = filter_language(original)

    # Show the dataset formatting
    pprint(original)

    # Save the original dataset
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    if not os.path.exists(DATASET_PATH):
        os.makedirs(DATASET_PATH)

    original.save_to_disk(DATASET_PATH)
    return original

# Download and store the image files
def download_and_store_image():
    download_and_store_ZIP(img_file, DATA_DIR)

# Download necessary files
def download_and_store_slake():
    dataset = download_and_store_english_dataset()
    download_and_store_image()

    return dataset

## Dataset and Dataloader

In [None]:
class TextPreprocessor:
    def clean_text(self, text):
        # Convert to lowercase
        text = text.lower().strip()
        text = re.sub(r'\s+', ' ', text)
        text = text.replace('\n', ' ')
        
        return text
    
    def preprocess_question(self, question):
        question = self.clean_text(question)
        if not question.endswith('?'):
            question += '?'
        
        return question
    
    def preprocess_answer(self, answer):
        answer = self.clean_text(answer)
        return answer
    
    def handle_slake_specifics(self, item):
        if item.get('q_lang') == 'en':
            question = self.preprocess_question(item['question'])
            answer = self.preprocess_answer(item['answer'])
            
            return {
                'question': question,
                'answer': answer,
                'answer_type': item.get('answer_type'),
                'img_name': item['img_name']
            }
        return None

In [None]:
class SLAKEDatasetBioMedBLIP(Dataset):
    def __init__(self, data, processor, transform=None, cache_images=True):
        self.data = data
        self.processor = processor
        self.transform = transform
        self.cache_images = cache_images

        self.text_preprocessor = TextPreprocessor()
        processed_data = []
        for item in self.data:
            processed = self.text_preprocessor.handle_slake_specifics(item)
            processed_data.append(processed)
        self.data = processed_data

        self.vocab_size = len(self.processor.tokenizer)
        self.pad_token_id = self.processor.tokenizer.pad_token_id
        if self.pad_token_id is None:
            self.pad_token_id = 0

        # Caching
        self.image_cache = {}
        if self.cache_images:
            print(f"Caching images for into RAM...")
            # Get unique image names to avoid duplicate loading
            unique_imgs = set(item['img_name'] for item in self.data)
            
            for img_name in unique_imgs:
                path = os.path.join(IMAGE_PATH, img_name)
                # Load and convert to RGB
                img = Image.open(path).convert('RGB')
                
                # Resize immediately to save RAM and CPU later
                img = img.resize((224, 224)) 
                
                self.image_cache[img_name] = img
            print(f"Cached {len(self.image_cache)} images.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # 1. Image Processing
        image_path = item['img_name']

        if self.cache_images:
            # Get from RAM
            image = self.image_cache[image_path]
        else:
            # Load from Disk and Resize
            img_path = os.path.join(IMAGE_PATH, image_path)
            image = Image.open(img_path).convert('RGB')
            image = image.resize((224, 224))

        # 2. Process using processor
        question = item['question']
        answer = str(item.get('answer', ''))
        

        # Process with BLIP processor
        encoding = self.processor(
            images=image,
            text=question,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True
        )

        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        # Process Target (Answer)
        answer_encoding = self.processor.tokenizer(
            answer,
            return_tensors="pt",
            padding="max_length",
            max_length=32,
            truncation=True
        )
        
        # Clamp token IDs to valid range
        labels = answer_encoding['input_ids'].squeeze(0)
        labels = torch.clamp(labels, max=self.vocab_size - 1)
        
        # Replace padding token IDs with -100 (ignored in loss)
        labels[labels == self.pad_token_id] = -100
        encoding['labels'] = labels

        # Add metadata
        encoding['question_type'] = item.get('answer_type', 'UNKNOWN')
        encoding['answer_text'] = answer
        encoding['question_text'] = question
        encoding['id'] = item.get('qid', idx)
        
        return encoding

In [None]:
def collate_fn(batch):
    # Stack tensors
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    
    # Collect metadata
    question_types = [item['question_type'] for item in batch]
    answers = [item['answer_text'] for item in batch]
    questions = [item['question_text'] for item in batch]
    ids = [item['id'] for item in batch]
    
    return {
        'pixel_values': pixel_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'question_types': question_types,
        'answers': answers,
        'questions': questions,
        'ids': ids,
    }

## Preparation

In [None]:
# dataset = download_and_store_slake()
dataset = load_from_disk(DATASET_PATH)

# Toggle to turn on/off data testing
test_data = True

if test_data:
    # Initialize processor
    processor = BlipProcessor.from_pretrained(
        "Salesforce/blip-vqa-base",
        use_fast=True
    )

    # Get dataset for testing
    test_data = dataset['test']
    test_dataset = SLAKEDatasetBioMedBLIP(
        data=test_data,
        processor=processor,
        cache_images=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=4,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn
    )

    # Test loading a batch
    print("\nTesting data loading...")
    batch = next(iter(test_loader))

    print(f"Batch keys: {batch.keys()}")
    print(f"Pixel values shape: {batch['pixel_values'].shape}")
    print(f"Input IDs shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    print(f"Sample questions: {batch['questions'][:2]}")
    print(f"Sample answers: {batch['answers'][:2]}")
    print(f"Question types: {batch['question_types'][:2]}")
    print(f"IDs: {batch['ids'][:2]}")

### Load BioMedBLIP From Checkpoint Name

In [None]:
def load_model_with_checkpoint():
    local_checkpoint_path = hf_hub_download(repo_id=MODEL_NAME, filename=CHECKPOINT_LOCATION)
    print(f"File {CHECKPOINT_LOCATION} downloaded to: {local_checkpoint_path}")
    
    model = BlipForQuestionAnswering.from_pretrained(BASE_MODEL_NAME)
    checkpoint = torch.load(local_checkpoint_path, map_location='cpu')
    print(f"Checkpoint Keys:\n{checkpoint.keys()}")
    
    model.load_state_dict(checkpoint["model"], strict=False)

    return model, checkpoint

In [None]:
model, chkpoint = load_model_with_checkpoint()

# Check module names for LoRA
with open(os.path.join(DATA_DIR, 'model_layers.txt'), 'w') as f:
    for name, module in model.named_modules():
        if "attn" in name.lower() or "attention" in name.lower():
            f.write(str(name) + str(type(module)) + '\n')

### Calculate Metrics

In [None]:
def calculate_classification_metrics(predictions, targets):
    metrics = {
        'accuracy': accuracy_score(targets, predictions) * 100,
        'macro_f1': f1_score(targets, predictions, average='macro') * 100,
        'weighted_f1': f1_score(targets, predictions, average='weighted') * 100,
    }
    
    return metrics

def calculate_bleu_scores(predictions, references):
    smoothing = SmoothingFunction().method1
    
    bleu_scores = {
        'bleu1': [],
        'bleu2': [],
        'bleu3': [],
        'bleu4': []
    }
    
    for pred, ref in zip(predictions, references):
        # Tokenize
        pred_tokens = pred.lower().split()
        ref_tokens = [ref.lower().split()]
        
        # Calculate BLEU scores up to 4-grams
        try:
            bleu1 = sentence_bleu(ref_tokens, pred_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothing)
            bleu2 = sentence_bleu(ref_tokens, pred_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
            bleu3 = sentence_bleu(ref_tokens, pred_tokens, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing)
            bleu4 = sentence_bleu(ref_tokens, pred_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
            
            bleu_scores['bleu1'].append(bleu1)
            bleu_scores['bleu2'].append(bleu2)
            bleu_scores['bleu3'].append(bleu3)
            bleu_scores['bleu4'].append(bleu4)
        except:
            bleu_scores['bleu1'].append(0.0)
            bleu_scores['bleu2'].append(0.0)
            bleu_scores['bleu3'].append(0.0)
            bleu_scores['bleu4'].append(0.0)
    
    # Average scores
    metrics = {
        'bleu1': np.mean(bleu_scores['bleu1']) * 100,
        'bleu2': np.mean(bleu_scores['bleu2']) * 100,
        'bleu3': np.mean(bleu_scores['bleu3']) * 100,
        'bleu4': np.mean(bleu_scores['bleu4']) * 100,
    }
    
    return metrics

def calculate_rouge_scores(predictions, references):
    rouge_scorer_helper = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []
    
    for pred, ref in zip(predictions, references):
        scores = rouge_scorer_helper.score(ref, pred)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
    
    return {
        'rouge1': np.mean(rouge1_scores) * 100,
        'rouge2': np.mean(rouge2_scores) * 100,
        'rougeL': np.mean(rougeL_scores) * 100,
    }

def calculate_meteor_score(predictions, references):
    meteor_scores = []
    
    for pred, ref in zip(predictions, references):
        pred_tokens = pred.lower().split()
        ref_tokens = ref.lower().split()
        
        try:
            score = meteor_score([ref_tokens], pred_tokens)
            meteor_scores.append(score)
        except:
            meteor_scores.append(0.0)
    
    return {
        'meteor': np.mean(meteor_scores) * 100
    }

def calculate_bertscore(predictions, references):
    P, R, F1 = bert_score(
        predictions, 
        references, 
        lang='en',
        model_type='distilbert-base-uncased',
        verbose=False
    )
    
    return {
        'bertscore_precision': P.mean().item() * 100,
        'bertscore_recall': R.mean().item() * 100,
        'bertscore_f1': F1.mean().item() * 100,
    }

def calculate_exact_match(predictions, references):
    exact_matches = sum(pred.lower().strip() == ref.lower().strip() 
                      for pred, ref in zip(predictions, references))
    
    return {
        'exact_match': (exact_matches / len(predictions)) * 100
    }

def calculate_type_specific_metrics(pred_texts, target_texts, question_types):
    type_metrics = defaultdict(lambda: {
        'pred_texts': [],
        'target_texts': []
    })
    
    for pred_text, target_text, q_type in zip(
        pred_texts, target_texts, question_types
    ):
        type_metrics[q_type]['pred_texts'].append(pred_text)
        type_metrics[q_type]['target_texts'].append(target_text)
    
    results = {}
    for q_type, data in type_metrics.items():
        if len(data['pred_texts']) > 0:
            results[q_type] = {
                'accuracy': accuracy_score(data['target_texts'], data['pred_texts']) * 100,
                'f1': f1_score(data['target_texts'], data['pred_texts'], average='macro', zero_division=0) * 100,
                'exact_match': sum(p.lower() == t.lower() for p, t in zip(data['pred_texts'], data['target_texts'])) / len(data['pred_texts']) * 100,
                'count': len(data['pred_texts'])
            }
    
    return results

def evaluate_all_metrics(pred_texts, target_texts, question_types, model_state="Zero Shot"):
    print("CALCULATING ALL METRICS")
    
    # Calculate all metrics
    metrics = {}
    
    # 1. Classification Metrics (Accuracy, F1, Precision, Recall)
    print("\nCalculating Classification Metrics...")
    metrics['classification'] = calculate_classification_metrics(pred_texts, target_texts)
    
    # 2. BLEU Scores
    print("Calculating BLEU (1-4) Scores...")
    metrics['bleu'] = calculate_bleu_scores(pred_texts, target_texts)

    # 3. Rouge Scores
    print("Calculating ROUGE Scores...")
    metrics['rouge'] = calculate_rouge_scores(pred_texts, target_texts)
    
    # 3. METEOR Score
    print("Calculating METEOR Score...")
    metrics['meteor'] = calculate_meteor_score(pred_texts, target_texts)
    
    # 4. BERTScore
    print("Calculating BERTScore... (Might take some time)")
    metrics['bertscore'] = calculate_bertscore(pred_texts, target_texts)
    
    # 5. Exact Match
    print("Calculating Exact Match...")
    metrics['exact_match'] = calculate_exact_match(pred_texts, target_texts)
    
    # 6. Type-specific metrics
    print("Calculating Type-Specific Metrics...")
    metrics['by_type'] = calculate_type_specific_metrics(
        pred_texts, target_texts, question_types
    )
    
    # Print results
    print_all_metrics(metrics, f'BioMedBLIP {model_state} Evaluation')
    
    # Save results
    save_metrics(metrics, model_state=model_state)
    
    # Create comparison plots
    plot_type_specific_comparison(metrics['by_type'], FINAL_MODEL_PATH, f'BioMedBLIP {model_state}')
    plot_ngram_analysis(metrics['bleu'], FINAL_MODEL_PATH, f'BioMedBLIP {model_state}')
    
    return metrics

def save_metrics(metrics, model_state="Zero Shot"):
    os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
    output_file = os.path.join(FINAL_MODEL_PATH, f'all_metrics_BioMedBLIP_{model_state}.json')
    
    with open(output_file, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    print(f"\nAll Metrics saved to: {output_file}")

In [None]:
def calculate_all_metrics(model, processor, test_loader, model_state="Zero Shot"):
    model.to(device)
    model.eval()

    all_pred_texts = []
    all_target_texts = []
    all_question_types = []

    progress_bar = tqdm(test_loader, desc=f"{model_state} Evaluation")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            # Generate answers
            outputs = model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=32,
                num_beams=5,  # Beam search for better quality
                early_stopping=True
            )

            # Decode predictions
            pred_texts = processor.batch_decode(outputs, skip_special_tokens=True)
            
            all_pred_texts.extend(pred_texts)
            all_target_texts.extend(batch['answers'])
            all_question_types.extend(batch['question_types'])
    return all_pred_texts, all_target_texts, all_question_types

### Zero-Shot Evaluation

In [None]:
all_pred_texts, all_target_texts, all_question_types = calculate_all_metrics(model, processor, test_loader)

In [None]:
zero_shot_metrics = evaluate_all_metrics(all_pred_texts, all_target_texts, all_question_types)

## LoRA

In [None]:
def setup_lora_model(model, lora_config=None):
    if lora_config is None:
        # Default LoRA configuration for BLIP
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=[
                "query",
                "key",
                "value",
                "dense",
                "qkv",
                "projection",
            ],
            lora_dropout=0.1,
            bias="none"
        )
    
    print(f"LoRA Configuration:")
    print(f"  Rank (r):            {lora_config.r}")
    print(f"  Alpha:               {lora_config.lora_alpha}")
    print(f"  Target modules:      {lora_config.target_modules}")
    print(f"  Dropout:             {lora_config.lora_dropout}")
    
    # Apply LoRA
    model = get_peft_model(model, lora_config)
    
    # Print trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    print(f"\nParameter Statistics:")
    print(f"  Trainable params:    {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
    print(f"  Total params:        {total_params:,}")
    
    return model

In [None]:
def train_single_epoch(
        model,
        train_loader, 
        gradient_accumulation_steps,
        optimizer, 
        epoch, 
        num_epochs,
        max_grad_norm=1.0 
    ):
    # model.to(device)
    model.train()
    total_loss = 0
    pad_token_id = model.config.text_config.pad_token_id

    optimizer.zero_grad()

    print("Starting Training Loop")
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for step, batch in enumerate(progress_bar):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        labels_for_shift = labels.clone()
        labels_for_shift[labels == -100] = pad_token_id

        decoder_input_ids = torch.full_like(labels, pad_token_id)
        decoder_input_ids[:, 1:] = labels_for_shift[:, :-1]

        decoder_attention_mask = (decoder_input_ids != pad_token_id).long()

        # Forward pass
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels
        )

        loss = outputs.loss / gradient_accumulation_steps
        loss.backward()

        # Gradient accumulation
        if (step + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * gradient_accumulation_steps
        progress_bar.set_postfix({'loss': f'{loss.item() * gradient_accumulation_steps:.4f}'})

    return total_loss / len(train_loader)

In [None]:
def validate_(model, processor, val_loader):
    # model.to(device)
    model.eval()
    total_loss = 0
    exact_matches = 0
    total = 0

    predictions = []
    references = []
    pad_token_id = model.config.text_config.pad_token_id

    print("Starting Validating Loop")
    progress_bar = tqdm(val_loader, desc="Validating")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            labels_for_shift = labels.clone()
            labels_for_shift[labels == -100] = pad_token_id
    
            decoder_input_ids = torch.full_like(labels, pad_token_id)
            decoder_input_ids[:, 1:] = labels_for_shift[:, :-1]
    
            decoder_attention_mask = (decoder_input_ids != pad_token_id).long()

            # Forward pass
            outputs = model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                labels=labels
            )
            
            total_loss += outputs.loss.item()
            
            # Generate for exact match
            generated = model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=32
            )
            
            pred_texts = processor.batch_decode(generated, skip_special_tokens=True)
            predictions.extend(pred_texts)
            references.extend(batch['answers'])

            for pred, true_ans in zip(pred_texts, batch['answers']):
                if pred.lower().strip() == true_ans.lower().strip():
                    exact_matches += 1
                total += 1

    avg_loss = total_loss / len(val_loader)
    exact_match_acc = 100 * exact_matches / total

    P, R, F1 = bert_score(
        predictions, 
        references, 
        lang='en',
        model_type='distilbert-base-uncased',
        verbose=False
    )

    P = P.mean().item()
    R = R.mean().item()
    F1 = F1.mean().item()
    
    return avg_loss, exact_match_acc, P, R, F1

## BioMedBLIP Trainer

In [None]:
class BioMedBLIPTrainer:
    def __init__(self, model, processor, train_dataset, val_dataset):
        self.model = model
        self.processor = processor
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.results_dir = Path(FINAL_MODEL_PATH)
        self.results_dir.mkdir(exist_ok=True)
        
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'val_exact_match': [],
            'learning_rates': [],
            'bertscore_precision': [],
            'bertscore_recall': [],
            'bertscore_f1': []
        }

    def train(self, num_epochs=20, batch_size=8, learning_rate=5e-5, 
              patience=5, gradient_accumulation_steps=4):
        print(f"Epochs:              {num_epochs}")
        print(f"Batch size:          {batch_size}")
        print(f"Learning rate:       {learning_rate}")
        print(f"Gradient accum:      {gradient_accumulation_steps}")
        print(f"Effective batch:     {batch_size * gradient_accumulation_steps}")
        
        # Create dataloaders
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            collate_fn=collate_fn
        )
        
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_fn
        )
        
        # Optimizer and scheduler
        optimizer = AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=num_epochs,
            eta_min=1e-6
        )
        
        best_exact_match = 0.0
        patience_counter = 0
        
        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            
            # Train
            train_loss = train_single_epoch(
                model=self.model,
                train_loader=train_loader, 
                gradient_accumulation_steps=gradient_accumulation_steps,
                optimizer=optimizer,
                epoch=epoch, 
                num_epochs=num_epochs,
            )
            
            # Validate
            val_loss, val_exact_match, P, R, F1 = validate_(self.model, self.processor, val_loader)
            
            # Learning rate
            current_lr = optimizer.param_groups[0]['lr']
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['val_exact_match'].append(val_exact_match)
            self.history['learning_rates'].append(current_lr)
            self.history['bertscore_precision'].append(P)
            self.history['bertscore_recall'].append(R)
            self.history['bertscore_f1'].append(F1)
            
            print(f"Train Loss:      {train_loss:.4f}")
            print(f"Val Loss:        {val_loss:.4f}")
            print(f"Val Exact Match: {val_exact_match:.2f}%")
            print(f"Val BERTScore Precision: {P:.2f}%")
            print(f"Val BERTScore Recall: {R:.2f}%")
            print(f"Val BERTScore F1: {F1:.2f}%")
            print(f"Learning Rate:   {current_lr:.6f}")
            
            # Save best model
            if val_exact_match > best_exact_match:
                best_exact_match = val_exact_match
                patience_counter = 0
                self.save_checkpoint(epoch, val_exact_match, 'best_lora_model')
                print(f"New best model! Exact Match: {val_exact_match:.2f}%")
            else:
                patience_counter += 1
                print(f"No improvement ({patience_counter}/{patience})")
            
            # Early stopping
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
            
            scheduler.step()
        
        # Plot training curves
        self.plot_training_curves()
        
        print("FINE-TUNING COMPLETE")
        print(f"Best Exact Match: {best_exact_match:.2f}%")
        
        return self.model

    def save_checkpoint(self, epoch, exact_match, prefix):
        # Save LoRA weights
        self.model.save_pretrained(self.results_dir / f'{prefix}_epoch_{epoch}')
        
        # Save training history
        with open(self.results_dir / f'{prefix}_history.json', 'w') as f:
            json.dump(self.history, f, indent=2)

    def load_checkpoint(self, checkpoint_path):
        self.model = PeftModel.from_pretrained(self.model, checkpoint_path)
        print(f"Loaded model from {checkpoint_path}")
    
    def plot_training_curves(self):
        _, axes = plt.subplots(2, 3, figsize=(18, 5))
        
        epochs = range(1, len(self.history['train_loss']) + 1)
        
        # Loss
        axes[0][0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
        axes[0][0].plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
        axes[0][0].set_xlabel('Epoch')
        axes[0][0].set_ylabel('Loss')
        axes[0][0].set_title('Training and Validation Loss')
        axes[0][0].legend()
        axes[0][0].grid(True, alpha=0.3)
        
        # Exact Match
        axes[0][1].plot(epochs, self.history['val_exact_match'], 'g-', linewidth=2)
        axes[0][1].set_xlabel('Epoch')
        axes[0][1].set_ylabel('Exact Match (%)')
        axes[0][1].set_title('Validation Exact Match')
        axes[0][1].grid(True, alpha=0.3)
        
        # Learning Rate
        axes[0][2].plot(epochs, self.history['learning_rates'], 'purple', linewidth=2)
        axes[0][2].set_xlabel('Epoch')
        axes[0][2].set_ylabel('Learning Rate')
        axes[0][2].set_title('Learning Rate Schedule')
        axes[0][2].set_yscale('log')
        axes[0][2].grid(True, alpha=0.3)

        # BERTScores
        axes[1][0].plot(epochs, self.history['bertscore_precision'], 'g-', linewidth=2)
        axes[1][0].set_xlabel('Epoch')
        axes[1][0].set_ylabel('BERTScore Precision (%)')
        axes[1][0].set_title('Validation BERTScore Precision')
        axes[1][0].grid(True, alpha=0.3)

        axes[1][1].plot(epochs, self.history['bertscore_recall'], 'g-', linewidth=2)
        axes[1][1].set_xlabel('Epoch')
        axes[1][1].set_ylabel('BERTScore Recall (%)')
        axes[1][1].set_title('Validation BERTScore Recall')
        axes[1][1].grid(True, alpha=0.3)

        axes[1][2].plot(epochs, self.history['bertscore_f1'], 'g-', linewidth=2)
        axes[1][2].set_xlabel('Epoch')
        axes[1][2].set_ylabel('BERTScore F1 (%)')
        axes[1][2].set_title('Validation BERTScore F1')
        axes[1][2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.results_dir / 'training_curves.png', dpi=300)
        plt.close()

## Hyperparameter Tuning

In [None]:
class QuickLoRATuner:
    def __init__(self, processor, train_dataset, val_dataset):
        self.processor = processor
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.results_dir = Path(TUNING_RESULTS_PATH)
        self.results_dir.mkdir(exist_ok=True)
    
    def test_config(self, config):
        try:
            num_epochs = 10
            history = {
                'train_loss': [],
                'val_loss': [],
                'val_exact_match': [],
                'learning_rates': [],
                'bertscore_precision': [],
                'bertscore_recall': [],
                'bertscore_f1': []
            }
            
            # 2. Load fresh model for this trial
            model, _ = load_model_with_checkpoint()

            # 3. Apply LoRA with sampled config
            lora_config = LoraConfig(
                r=config['r'],
                lora_alpha=config['alpha'] * 2.0,
                target_modules=[
                    "query",
                    "key",
                    "value",
                    "dense",
                    "qkv",
                    "projection",
                ],
                lora_dropout=0.1,
                bias="none"
            )

            model = setup_lora_model(model, lora_config).to(device)

            # 4. Create Data Loaders
            train_loader = DataLoader(
                self.train_dataset,
                batch_size=config['batch_size'],
                shuffle=True,
                num_workers=0,
                collate_fn=collate_fn
            )
            
            val_loader = DataLoader(
                self.val_dataset,
                batch_size=config['batch_size'],
                shuffle=False,
                num_workers=0,
                collate_fn=collate_fn
            )

            # 5. Setup optimizer
            optimizer = AdamW(
                model.parameters(),
                lr=config['lr'],
                weight_decay=0.01
            )

            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=num_epochs,
                eta_min=1e-6
            )

            best_exact_match = 0.0
            patience = 3
            patience_counter = 0
            
            for epoch in range(num_epochs):
                print(f"Epoch [{epoch+1}/{num_epochs}]")
                
                # Train
                train_loss = train_single_epoch(
                    model=model,
                    train_loader=train_loader, 
                    gradient_accumulation_steps=config['grad_accum'],
                    optimizer=optimizer, 
                    epoch=epoch, 
                    num_epochs=num_epochs,
                )
                
                # Validate
                val_loss, val_exact_match, P, R, F1 = validate_(model, self.processor, val_loader)
                
                # Learning rate
                current_lr = optimizer.param_groups[0]['lr']
                
                # Update history
                history['train_loss'].append(train_loss)
                history['val_loss'].append(val_loss)
                history['val_exact_match'].append(val_exact_match)
                history['learning_rates'].append(current_lr)
                history['bertscore_precision'].append(P)
                history['bertscore_recall'].append(R)
                history['bertscore_f1'].append(F1)
                
                print(f"Train Loss:      {train_loss:.4f}")
                print(f"Val Loss:        {val_loss:.4f}")
                print(f"Val Exact Match: {val_exact_match:.2f}%")
                print(f"Val BERTScore Precision: {P:.2f}%")
                print(f"Val BERTScore Recall: {R:.2f}%")
                print(f"Val BERTScore F1: {F1:.2f}%")
                print(f"Learning Rate:   {current_lr:.6f}")
                
                if val_exact_match > best_exact_match:
                    best_exact_match = val_exact_match
                    patience_counter = 0
                else:
                    patience_counter += 1
                    print(f"No improvement ({patience_counter}/{patience})")
                
                # Early stopping
                if patience_counter >= patience:
                    print(f"\nEarly stopping at epoch {epoch+1}")
                    break
                
                scheduler.step()
            
            print(f"CONFIG {config['name'].upper()} TESTING COMPLETE")
            print(f"Best Exact Match: {best_exact_match:.2f}%")
            
            # Clean up
            del model
            torch.cuda.empty_cache()
            
            return best_exact_match
        except Exception as e:
            print(f"Trial failed with error: {e}")
            return 0.0

    def run(self):
        # Predefined configurations based on LoRA practices
        configs = [
            # Conservative
            {'r': 8, 'alpha': 16, 'lr': 5e-5, 'batch_size': 4, 'grad_accum': 8,  'name': 'conservative'},
            
            # Balanced
            {'r': 16, 'alpha': 32, 'lr': 5e-5, 'batch_size': 4, 'grad_accum': 8, 'name': 'balanced'},
            
            # Higher capacity
            {'r': 32, 'alpha': 64, 'lr': 3e-5, 'batch_size': 4, 'grad_accum': 8, 'name': 'high_capacity'},

            # Balanced with Bigger Batch
            {'r': 16, 'alpha': 32, 'lr': 5e-5, 'batch_size': 8, 'grad_accum': 4, 'name': 'balanced_with_bigger_batch'},
            
            # Faster learning
            {'r': 16, 'alpha': 32, 'lr': 1e-4, 'batch_size': 8, 'grad_accum': 4, 'name': 'fast_learning'},
            
            # Higher rank with higher learning
            {'r': 32, 'alpha': 64, 'lr': 1e-4, 'batch_size': 8, 'grad_accum': 4, 'name': 'high_rank_fast_learning'},
        ]
        
        print("QUICK LORA TUNING (5 CONFIGS)")
        
        results = []
        best_score = 0
        best_config = None
        
        for i, config in enumerate(configs):
            print(f"\n[{i+1}/{len(configs)}] Testing: {config['name']}")
            print(json.dumps(config, indent=2))
            
            score = self.test_config(config)
            
            results.append({
                'config': config,
                'exact_match': f"{score:.2f}"
            })
            
            if score > best_score:
                best_score = score
                best_config = config
            
            print(f"  Result: {score:.2f}% exact match")
        
        # Save results
        with open(self.results_dir / f'short_history.json', 'w') as f:
            json.dump({
                'best_config': best_config,
                'best_score': best_score,
                'all_results': results
            }, f, indent=2)
        
        print("QUICK TUNING COMPLETE")
        print(f"Best: {best_config['name']} with {best_score:.2f}%")
        print(f"Config: r={best_config['r']}, alpha={best_config['alpha']}, lr={best_config['lr']}")
        
        return best_config, best_score

In [None]:
train_dataset_blip = SLAKEDatasetBioMedBLIP(dataset['train'], processor)
val_dataset_blip = SLAKEDatasetBioMedBLIP(dataset['validation'], processor)

In [None]:
tuner = QuickLoRATuner(
    processor=processor,
    train_dataset=train_dataset_blip,
    val_dataset=val_dataset_blip
)

In [None]:
best_params, best_score = tuner.run()

best_params

In [None]:
# Load best params from short_history.json if needed
# with open(os.path.join(TUNING_RESULTS_PATH, 'short_history.json'), 'r') as f:
#     history = json.load(f)
# best_params = history['best_config']

# best_params

## Finetune the best model

In [None]:
model, _ = load_model_with_checkpoint()
processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-vqa-base",
    use_fast=True
)

lora_config = LoraConfig(
    r=best_params['r'],
    lora_alpha=best_params['alpha'],
    target_modules=[
        "query",
        "key",
        "value",
        "dense",
        "qkv",
        "projection",
    ],
    lora_dropout=0.1,
    bias="none"
)

model = setup_lora_model(model, lora_config).to(device)

In [None]:
train_dataset_blip = SLAKEDatasetBioMedBLIP(dataset['train'], processor)
val_dataset_blip = SLAKEDatasetBioMedBLIP(dataset['validation'], processor)

In [None]:
trainer = BioMedBLIPTrainer(model, processor, train_dataset_blip, val_dataset_blip)

In [None]:
model = trainer.train(
    num_epochs=20, 
    batch_size=best_params['batch_size'], 
    learning_rate=best_params['lr'], 
    gradient_accumulation_steps=best_params['grad_accum']
)

In [None]:
# Load best model for evaluation
model, _ = load_model_with_checkpoint()
processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-vqa-base",
    use_fast=True
)

# Replace with the path to the best LoRA model. In my case, it's epoch 9.
model = PeftModel.from_pretrained(model, os.path.join(FINAL_MODEL_PATH, 'best_lora_model_epoch_9')).to(device)

In [None]:
test_dataset_blip = SLAKEDatasetBioMedBLIP(dataset['test'], processor)
test_loader = DataLoader(
    test_dataset_blip,
    batch_size=best_params['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=collate_fn
)

In [None]:
all_pred_texts, all_target_texts, all_question_types = calculate_all_metrics(model, processor, test_loader, model_state="Final Finetuning")

In [None]:
final_finetune_metrics = evaluate_all_metrics(all_pred_texts, all_target_texts, all_question_types, model_state="Final Finetuning")

## Error Analysis

In [None]:
questions = [item['question'] for item in test_dataset_blip.data if item is not None]
images = [item['img_name'] for item in test_dataset_blip.data if item is not None]

# Print a few to verify
print("Sample Questions:")
for i in range(min(5, len(questions))):
    print(f"  Question: {questions[i]}")

In [None]:
def show_vqa_case(case):
    path_img = os.path.join(IMAGE_PATH, case['image'])
    display(Image(filename=path_img, width=400))
    
    # Print text on subsequent lines
    print(f"Question: {case['question']}")
    print(f"Predicted Answer: {case['prediction']}")
    print(f"Actual Answer: {case['ground_truth']}")

In [None]:
# Get all wrong predictions
wrong_predictions = []
for i in range(len(all_pred_texts)):
    if all_pred_texts[i].lower().strip() != all_target_texts[i].lower().strip():
        wrong_predictions.append({
            'question': questions[i],
            'ground_truth': all_target_texts[i],
            'prediction': all_pred_texts[i],
            'image': images[i],
            'question_type': all_question_types[i]
        })

### Directional Error Analysis

In [None]:
class DirectionalErrorAnalyzer:
    def __init__(self):
        self.directional_terms = {
            'lateral': ['left', 'right'],
            'vertical': ['upper', 'lower', 'top', 'bottom'],
            'medial': ['central', 'center', 'middle']
        }
    
    def extract_directional_info(self, text):
        text_lower = text.lower().strip()
        
        directions = {
            'lateral': [],
            'vertical': [],
            'medial': []
        }
        
        for direction_type, terms in self.directional_terms.items():
            for term in terms:
                if re.search(r'\b' + term + r'\b', text_lower):
                    directions[direction_type].append(term)
        
        return directions
    
    def compare_directions(self, gt_directions, pred_directions):
        for direction_type in ['lateral', 'vertical', 'medial']:
            gt_terms_array = gt_directions[direction_type]
            pred_terms_array = pred_directions[direction_type]
            gt_terms = set(gt_directions[direction_type])
            pred_terms = set(pred_directions[direction_type])

            if direction_type == 'medial':
                if not gt_terms and pred_terms:
                    return 'extra_medial_direction'
                elif gt_terms and not pred_terms:
                    return 'missing_medial_direction'
                
            if direction_type == 'lateral':
                if len(gt_terms_array) < len(pred_terms_array):
                    return 'extra_lateral_direction'
                elif len(gt_terms_array) > len(pred_terms_array):
                    return 'missing_lateral_direction'
                else:
                    for i in range(len(gt_terms_array)):
                        if gt_terms_array[i] != pred_terms_array[i]:
                            return 'left_right_confusion'
                        
            if direction_type == 'vertical':
                if len(gt_terms_array) < len(pred_terms_array):
                    return 'extra_vertical_direction'
                elif len(gt_terms_array) > len(pred_terms_array):
                    return 'missing_vertical_direction'
                else:
                    for i in range(len(gt_terms_array)):
                        if gt_terms_array[i] != pred_terms_array[i]:
                            return 'upper_lower_confusion'

        return None
    
    def has_directional_info(self, text):
        directions = self.extract_directional_info(text)
        return any(len(terms) > 0 for terms in directions.values())
    
    def analyze_predictions(self, predictions, ground_truths, questions, images):
        results = {
            'errors': {
                'left_right_confusion': [],
                'upper_lower_confusion': [],
                'partial_match': [],
                'extra_medial_direction': [],
                'missing_medial_direction': [],
                'missing_lateral_direction': [],
                'extra_lateral_direction': [],
                'missing_vertical_direction': [],
                'extra_vertical_direction': [],
                'missing_direction': [],
                'extra_direction': []
            },
            'correct': [],
            'non_directional': []
        }
        
        for pred, gt, question, img in zip(predictions, ground_truths, questions, images):
            pred_lower = pred.lower().strip()
            gt_lower = gt.lower().strip()
            
            has_gt_direction = self.has_directional_info(gt)
            has_pred_direction = self.has_directional_info(pred)
            
            if not has_gt_direction and not has_pred_direction:
                results['non_directional'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred,
                    'correct': pred_lower == gt_lower
                })
                continue

            if not has_gt_direction and has_pred_direction:
                results['errors']['extra_direction'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred
                })
                continue
            
            if pred_lower == gt_lower:
                results['correct'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred
                })
                continue
            
            gt_directions = self.extract_directional_info(gt)
            pred_directions = self.extract_directional_info(pred)
            
            if not has_pred_direction:
                results['errors']['missing_direction'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred,
                    'gt_directions': gt_directions
                })
                continue
            
            error_type = self.compare_directions(gt_directions, pred_directions)
            
            if error_type:
                results['errors'][error_type].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred,
                    'gt_directions': gt_directions,
                    'pred_directions': pred_directions
                })
            else:
                if gt_directions != pred_directions:
                    results['errors']['partial_match'].append({
                        'image': img,
                        'question': question,
                        'ground_truth': gt,
                        'prediction': pred,
                        'gt_directions': gt_directions,
                        'pred_directions': pred_directions
                    })
                else:
                    results['correct'].append({
                        'image': img,
                        'question': question,
                        'ground_truth': gt,
                        'prediction': pred,
                        'note': 'Directionally correct but overall wrong'
                    })
        
        return results
    
    def calculate_statistics(self, results):
        total_directional = (
            len(results['correct']) + 
            sum(len(errors) for errors in results['errors'].values())
        )
        
        if total_directional == 0:
            return None
        
        stats = {
            'total_directional_questions': total_directional,
            'correct': len(results['correct']),
            'directional_accuracy': 100 * len(results['correct']) / total_directional,
            'error_breakdown': {}
        }
        
        for error_type, errors in results['errors'].items():
            count = len(errors)
            stats['error_breakdown'][error_type] = {
                'count': count,
                'percentage': 100 * count / total_directional
            }
        
        total_errors = sum(len(errors) for errors in results['errors'].values())
        stats['total_errors'] = total_errors
        stats['error_rate'] = 100 * total_errors / total_directional
        
        return stats
    
    def print_analysis(self, results, stats):
        print("DIRECTIONAL ERROR ANALYSIS")
        
        if stats is None:
            print("No directional questions found in dataset.")
            return
        
        print(f"\nTotal questions with directional info: {stats['total_directional_questions']}")
        print(f"Correct: {stats['correct']} ({stats['directional_accuracy']:.2f}%)")
        print(f"Errors: {stats['total_errors']} ({stats['error_rate']:.2f}%)")
        
        print("Error Breakdown:")
        
        sorted_errors = sorted(
            stats['error_breakdown'].items(),
            key=lambda x: x[1]['count'],
            reverse=True
        )
        
        for error_type, error_stats in sorted_errors:
            if error_stats['count'] > 0:
                error_name = error_type.replace('_', ' ').title()
                print(f"  {error_name:30s}: {error_stats['count']:3d} ({error_stats['percentage']:5.2f}%)")
        
        for error_type, errors in results['errors'].items():
            if len(errors) > 0:
                print(f"\n {error_type.replace('_', ' ').title()} Examples:")
                for i, example in enumerate(errors[:3]):
                    print(f"\n  Example {i+1}:")
                    print(f"    Question:     {example['question']}")
                    print(f"    Ground Truth: {example['ground_truth']}")
                    print(f"    Prediction:   {example['prediction']}")
                    
                    if 'gt_directions' in example and 'pred_directions' in example:
                        print(f"    GT Directions: {example['gt_directions']}")
                        print(f"    Pred Directions: {example['pred_directions']}")
    
    def save_detailed_report(self, results, stats, model_name, output_dir=FINAL_MODEL_PATH):
        from pathlib import Path
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True)
        
        report = {
            'model': model_name,
            'statistics': stats,
            'examples': {}
        }
        
        # Add examples for each error type
        for error_type, errors in results['errors'].items():
            report['examples'][error_type] = errors[:10]  # Save first 10 of each type
        
        # Add correct examples
        report['examples']['correct'] = results['correct'][:10]
        
        output_file = output_dir / f'{model_name.lower()}_directional_analysis.json'
        with open(output_file, 'w') as f:
            json.dump(report, f, indent=2)
        
        print(f"Detailed report saved to {output_file}")

In [None]:
def analyze_biomedblip_directional_errors(predictions, ground_truths, questions, images):
    analyzer = DirectionalErrorAnalyzer()
    
    # Run analysis
    results = analyzer.analyze_predictions(predictions, ground_truths, questions, images)
    stats = analyzer.calculate_statistics(results)
    
    # Print results
    analyzer.print_analysis(results, stats)
    
    # Save detailed report
    analyzer.save_detailed_report(results, stats, 'BioMedBLIP')
    
    return results, stats

In [None]:
results, stats = analyze_biomedblip_directional_errors(
    all_pred_texts,
    all_target_texts,
    questions,
    images
)

In [None]:
# Case study visualization
case = results['errors']['left_right_confusion'][3]
show_vqa_case(case)

In [None]:
case = results['errors']['left_right_confusion'][7]
show_vqa_case(case)

In [None]:
case = results['errors']['upper_lower_confusion'][3]
show_vqa_case(case)

In [None]:
case = results['errors']['upper_lower_confusion'][4]
show_vqa_case(case)

### Similar but not exact answers

In [None]:
total_modality_confusion = 0
# Check answers where "ct scan", "x-ray", or "mri" are in ground truth but predicted wrong
for i in range(len(wrong_predictions)):
    gt_lower = wrong_predictions[i]['ground_truth'].lower()
    if any(modality == gt_lower for modality in ['ct', 'x-ray', 'mri', 't2']):
        # Remove spaces from prediction for clarity
        clean_prediction = wrong_predictions[i]['prediction'].replace(' ', '')
        if clean_prediction == gt_lower:
            print(f"\nModality Confusion Case: {i}")
            print(f"  Question: {wrong_predictions[i]['question']}")
            print(f"  Predicted Answer: {wrong_predictions[i]['prediction']}")
            print(f"  Actual Answer: {wrong_predictions[i]['ground_truth']}")

            total_modality_confusion += 1

print(f"\nTotal Modality Confusion Cases: {total_modality_confusion}")

In [None]:
case = wrong_predictions[16]
show_vqa_case(case)

In [None]:
case = wrong_predictions[127]
show_vqa_case(case)

In [None]:
# For testing, replace all "x - ray" with "x-ray" in predictions and re-evaluate
corrected_predictions = []
for pred in all_pred_texts:
    corrected_pred = re.sub(r'x\s*-\s*ray', 'x-ray', pred, flags=re.IGNORECASE)
    corrected_predictions.append(corrected_pred)

final_finetune_metrics_corrected = evaluate_all_metrics(corrected_predictions, all_target_texts, all_question_types, model_state="Final Finetuning - Corrected")

In [None]:
# Check the wrong predictions where question is open-ended
open_ended_wrongs = []
for i in range(len(wrong_predictions)):
    if wrong_predictions[i]['question_type'] == 'OPEN':
        open_ended_wrongs.append(wrong_predictions[i])

In [None]:
# Check wrong open-ended cases with more than 3 words in ground truth answer
long_open_ended_wrongs = []
for i in range(len(open_ended_wrongs)):
    gt_word_count = len(open_ended_wrongs[i]['ground_truth'].split())
    if gt_word_count > 3:
        long_open_ended_wrongs.append(open_ended_wrongs[i])

In [None]:
case = long_open_ended_wrongs[4]
show_vqa_case(case)

In [None]:
case = long_open_ended_wrongs[-4]
show_vqa_case(case)

In [None]:
case = long_open_ended_wrongs[-3]
show_vqa_case(case)

In [None]:
case = long_open_ended_wrongs[-6]
show_vqa_case(case)

In [None]:
case = long_open_ended_wrongs[12]
show_vqa_case(case)

In [None]:
case = long_open_ended_wrongs[14]
show_vqa_case(case)