In [None]:
# !fusermount -u /content/drive  # Unmount if already mounted
# !rm -rf /content/drive         # Remove any leftover files

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


###Dependancies

In [None]:
!pip install torch torchvision transformers
!pip install datasets accelerate wandb
!pip install opencv-python pillow
!pip install scikit-learn nltk rouge-score
!pip install bert-score
!pip install rouge_score
!pip install medcat spacy scispacy




Collecting git+https://github.com/tylin/coco-caption
  Cloning https://github.com/tylin/coco-caption to /tmp/pip-req-build-_h1_r7g8
  Running command git clone --filter=blob:none --quiet https://github.com/tylin/coco-caption /tmp/pip-req-build-_h1_r7g8
  Resolved https://github.com/tylin/coco-caption to commit 3a9afb2682141a03e1cdc02b0df6770d2c884f6f
[31mERROR: git+https://github.com/tylin/coco-caption does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0m

# Medical Image Captioning with Concept-Aware BLIP Fine-Tuning

This notebook fine-tunes the `Salesforce/blip-image-captioning-base` model to generate medical captions using radiology images and concept prompts.

##  Dataset Structure
- **Images**: JPG files named as `{ID}.jpg`
- **Captions**: `train_captions.csv` with fields: `ID`, `Caption`
- **Concepts**: `train_concepts.csv` with `ID`, `CUIs`
- **Concept Mapping**: `cui_names.csv` maps `CUI` → readable name

##  Data Preprocessing
- Cleans abbreviations (`CT`, `MRI`, etc.)
- Removes noisy/empty/too short/long captions
- Builds anatomical & modality prompt vocab
- Generates 2 types of prompts:
  - **Smart prompt**: based on anatomy + modality
  - **Concept-aware prompt**: directly includes top UMLS concepts

##  Model: BLIP (Base)
- `Salesforce/blip-image-captioning-base` (ViT-GPT2)
- Inputs: image + optional prompt
- Outputs: caption
- Tokenizer: BLIP processor (auto-handled)

##  Fine-Tuning Techniques
- **Prompt-based fine-tuning**: the prompt is prepended to the caption
- **Label masking**: model learns only to predict the caption part (not the prompt)
- **Custom scoring**: promotes captions that mention concepts & avoid repetition
- **Dynamic prompts**: tries 3 generation types: no prompt, smart prompt, and concept-aware

##  Training Details
- Optimizer: `AdamW` with `lr=2e-5`, `weight_decay=0.01`
- Scheduler: `get_linear_schedule_with_warmup` (10% warmup steps)
- Epochs: 5 (adjustable)
- Batch Size: 2 (lowered for memory constraints)
- Loss: CrossEntropy via `labels=inputs['input_ids']` (with masked prompt tokens)
- Gradient clipping: max norm = 1.0
- Train/Val/Test split: 80/10/10% from 10,000 samples
- Early stopping: model is saved if validation loss improves

##  Generation Settings
- `num_beams=4`, `max_new_tokens=40`, `repetition_penalty=2.0`
- `no_repeat_ngram_size=2`, `length_penalty=0.9`
- Optionally adds missing concepts post-generation
- Removes repeated words, excessive punctuation

##  Output
- Trained model saved at: `/content/.../best_medical_blip`
- Processor saved alongside the model
- Example usage:

##Results
- BLEU-4 Average Score: 0.0170
-ROUGE-L Average Score: 0.1995
```python
model = BlipForConditionalGeneration.from_pretrained("/content/.../best_medical_blip")
processor = BlipProcessor.from_pretrained("/content/.../best_medical_blip")


In [None]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.4-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.4


In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BlipProcessor, BlipForConditionalGeneration,
    get_linear_schedule_with_warmup, set_seed
)
from sklearn.model_selection import train_test_split
import json
import re
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
set_seed(42)

class MedicalCaptionDataset(Dataset):
    """Optimized dataset for medical image captioning"""

    def __init__(self, image_dir: str, caption_csv: str, concept_csv: str,
                 mapping_csv: str, processor, split: str = 'train',
                 max_caption_length: int = 100):

        self.image_dir = image_dir
        self.processor = processor
        self.max_caption_length = max_caption_length
        self.split = split

        # Load and merge data
        self._load_and_process_data(caption_csv, concept_csv, mapping_csv)

        # Medical vocabulary for better prompting
        self.modality_terms = {
            'radiograph': ['radiograph', 'x-ray', 'xray'],
            'ct': ['ct', 'computed tomography', 'cat scan'],
            'mri': ['mri', 'magnetic resonance'],
            'ultrasound': ['ultrasound', 'sonogram', 'echo'],
            'mammography': ['mammography', 'mammogram']
        }

        self.anatomy_terms = {
            'chest': ['chest', 'thorax', 'lung', 'heart', 'rib'],
            'abdomen': ['abdomen', 'stomach', 'liver', 'kidney'],
            'spine': ['spine', 'vertebra', 'spinal', 'lumbar', 'cervical'],
            'brain': ['brain', 'cerebral', 'cranial', 'head'],
            'pelvis': ['pelvis', 'hip', 'pelvic'],
            'extremity': ['arm', 'leg', 'hand', 'foot', 'bone']
        }

    def _load_and_process_data(self, caption_csv: str, concept_csv: str, mapping_csv: str):
        """Load and preprocess all data files"""
        print("Loading data files...")

        # Load files
        captions_df = pd.read_csv(caption_csv)
        concepts_df = pd.read_csv(concept_csv)
        mapping_df = pd.read_csv(mapping_csv)

        # Standardize column names
        captions_df.columns = [col.strip() for col in captions_df.columns]
        concepts_df.columns = [col.strip() for col in concepts_df.columns]
        mapping_df.columns = [col.strip() for col in mapping_df.columns]

        # Rename to standard names
        if 'ID' in captions_df.columns and 'Caption' in captions_df.columns:
            captions_df = captions_df.rename(columns={"ID": "id", "Caption": "caption"})
        if 'ID' in concepts_df.columns and 'CUIs' in concepts_df.columns:
            concepts_df = concepts_df.rename(columns={"ID": "id", "CUIs": "concepts"})
        if 'CUI' in mapping_df.columns and 'Name' in mapping_df.columns:
            mapping_df = mapping_df.rename(columns={"CUI": "cui", "Name": "name"})

        # Build concept mapping
        self.concept_map = dict(zip(mapping_df["cui"], mapping_df["name"]))

        # Merge data
        merged_df = pd.merge(captions_df, concepts_df, on="id", how="inner")

        merged_df["image_path"] = merged_df["id"].apply(
            lambda x: os.path.join(self.image_dir, f"{x}.jpg")
        )


        existing_mask = merged_df["image_path"].apply(os.path.exists)
        merged_df = merged_df[existing_mask].reset_index(drop=True)

        merged_df["caption"] = merged_df["caption"].apply(self._clean_caption)


        merged_df = merged_df[merged_df["caption"].str.strip() != ""].reset_index(drop=True)

        word_counts = merged_df["caption"].apply(lambda x: len(x.split()))
        merged_df = merged_df[
            (word_counts >= 5) & (word_counts <= self.max_caption_length)
        ].reset_index(drop=True)

        self.df = merged_df
        print(f"Loaded {len(self.df)} valid samples")

    def _clean_caption(self, caption: str) -> str:
        """Clean and normalize captions"""
        if pd.isna(caption) or caption is None:
            return ""


        caption = re.sub(r'\s+', ' ', str(caption).strip())

        abbreviations = {
            r'\bCT\b': 'CT',
            r'\bMRI\b': 'MRI',
            r'\bXR\b': 'X-ray',
            r'\bmm\b': 'millimeters',
            r'\bcm\b': 'centimeters'
        }

        for abbrev, full in abbreviations.items():
            caption = re.sub(abbrev, full, caption, flags=re.IGNORECASE)

        return caption

    def _get_concept_names(self, concept_string: str) -> List[str]:
        """Convert CUI codes to readable concept names"""
        if pd.isna(concept_string) or concept_string == '':
            return []

        cui_codes = str(concept_string).split(';')
        concept_names = []

        for cui in cui_codes:
            cui = cui.strip()
            if cui in self.concept_map:
                name = self.concept_map[cui].strip()
                if name and name not in concept_names:
                    concept_names.append(name)

        return concept_names

    def _create_smart_prompt(self, concept_names: List[str]) -> str:
        """Create intelligent prompts based on medical concepts"""
        if not concept_names:
            return "medical image:"

        modality = "medical image"
        concept_text = ' '.join(concept_names).lower()

        for mod, terms in self.modality_terms.items():
            if any(term in concept_text for term in terms):
                modality = mod
                break

        anatomy = ""
        for anat, terms in self.anatomy_terms.items():
            if any(term in concept_text for term in terms):
                anatomy = anat
                break

        # Create simple, effective prompts
        if anatomy and modality != "medical image":
            return f"{modality} of {anatomy}:"
        elif modality != "medical image":
            return f"{modality}:"
        elif anatomy:
            return f"{anatomy} image:"
        else:
            return "medical image:"

    def _create_concept_aware_prompt(self, concept_names: List[str]) -> str:
        """Create prompts that incorporate medical concepts directly"""
        if not concept_names:
            return "medical image:"

        clean_concepts = []
        for concept in concept_names[:3]:
            if concept and len(concept.strip()) > 2:
                clean_concepts.append(concept.strip().lower())

        if not clean_concepts:
            return "medical image:"

        concept_text = ', '.join(clean_concepts)

        modality_found = False
        anatomy_found = False

        for mod, terms in self.modality_terms.items():
            if any(term in concept_text for term in terms):
                modality_found = mod
                break

        for anat, terms in self.anatomy_terms.items():
            if any(term in concept_text for term in terms):
                anatomy_found = anat
                break

        # Strategy 1: Include concepts directly in prompt
        if modality_found and anatomy_found:
            return f"{modality_found} of {anatomy_found} showing {concept_text}:"
        elif modality_found:
            return f"{modality_found} showing {concept_text}:"
        elif anatomy_found:
            return f"{anatomy_found} image showing {concept_text}:"
        else:
            return f"medical image showing {concept_text}:"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load image
        try:
            image = Image.open(row["image_path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading image {row['image_path']}: {e}")
            # Return a blank image as fallback
            image = Image.new('RGB', (224, 224), color='white')


        concept_names = self._get_concept_names(row["concepts"])

        basic_prompt = self._create_smart_prompt(concept_names)
        concept_prompt = self._create_concept_aware_prompt(concept_names)

        # Get caption
        caption = str(row["caption"])

        return {
            'image': image,
            'prompt': basic_prompt,
            'concept_prompt': concept_prompt,
            'caption': caption,
            'concepts': concept_names,
            'image_id': row["id"]
        }

def collate_fn(batch):
    """Custom collate function for DataLoader"""
    return {
        'images': [item['image'] for item in batch],
        'prompts': [item['prompt'] for item in batch],
        'concept_prompts': [item['concept_prompt'] for item in batch],
        'captions': [item['caption'] for item in batch],
        'concepts': [item['concepts'] for item in batch],
        'image_ids': [item['image_id'] for item in batch]
    }

class MedicalCaptionTrainer:
    """Main trainer class for medical image captioning - COMPLETELY FIXED VERSION"""

    def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load model and processor
        print(f"Loading model: {model_name}")
        self.processor = BlipProcessor.from_pretrained(model_name)
        self.model = BlipForConditionalGeneration.from_pretrained(model_name)
        self.model.to(self.device)

        print(f"Model loaded on device: {self.device}")

    def generate_caption(self, image, prompt: str = None, concepts: List[str] = None, **generation_kwargs):
        """Main caption generation method - ROBUST FIX for duplicate parameter error"""
        self.model.eval()

        if concepts:

            dataset_helper = type('Helper', (), {})()
            dataset_helper.modality_terms = {
                'radiograph': ['radiograph', 'x-ray', 'xray'],
                'ct': ['ct', 'computed tomography', 'cat scan'],
                'mri': ['mri', 'magnetic resonance'],
                'ultrasound': ['ultrasound', 'sonogram', 'echo'],
                'mammography': ['mammography', 'mammogram']
            }
            dataset_helper.anatomy_terms = {
                'chest': ['chest', 'thorax', 'lung', 'heart', 'rib'],
                'abdomen': ['abdomen', 'stomach', 'liver', 'kidney'],
                'spine': ['spine', 'vertebra', 'spinal', 'lumbar', 'cervical'],
                'brain': ['brain', 'cerebral', 'cranial', 'head'],
                'pelvis': ['pelvis', 'hip', 'pelvic'],
                'extremity': ['arm', 'leg', 'hand', 'foot', 'bone']
            }

            def create_smart_prompt(concept_names):
                if not concept_names:
                    return "medical image:"

                modality = "medical image"
                concept_text = ' '.join(concept_names).lower()

                for mod, terms in dataset_helper.modality_terms.items():
                    if any(term in concept_text for term in terms):
                        modality = mod
                        break

                anatomy = ""
                for anat, terms in dataset_helper.anatomy_terms.items():
                    if any(term in concept_text for term in terms):
                        anatomy = anat
                        break

                if anatomy and modality != "medical image":
                    return f"{modality} of {anatomy}:"
                elif modality != "medical image":
                    return f"{modality}:"
                elif anatomy:
                    return f"{anatomy} image:"
                else:
                    return "medical image:"

            concept_prompt = create_smart_prompt(concepts)
        else:
            concept_prompt = prompt or "medical image:"

        approaches = [
            {"prompt": "", "name": "no_prompt"},
            {"prompt": concept_prompt, "name": "concept_prompt"}
        ]

        if prompt and prompt != concept_prompt:
            approaches.insert(1, {"prompt": prompt, "name": "custom_prompt"})

        best_caption = ""
        best_score = -1

        for approach in approaches:
            try:
                current_prompt = approach["prompt"]

                if current_prompt == "":
                    # Image-only processing
                    processor_outputs = self.processor(images=image, return_tensors="pt")
                else:
                    # Image + text processing
                    processor_outputs = self.processor(images=image, text=current_prompt, return_tensors="pt")

                model_inputs = {}


                valid_model_input_keys = {'pixel_values', 'input_ids', 'attention_mask'}

                for key, value in processor_outputs.items():
                    if key in valid_model_input_keys:
                        model_inputs[key] = value.to(self.device)

                generation_params = {
                    'max_new_tokens': 45,
                    'min_new_tokens': 8,
                    'num_beams': 4,
                    'early_stopping': True,
                    'no_repeat_ngram_size': 2,
                    'repetition_penalty': 2.0,
                    'length_penalty': 0.9,
                    'do_sample': False,
                }


                tokenizer = self.processor.tokenizer

                self.model.generation_config.pad_token_id = tokenizer.pad_token_id


                for key, value in generation_kwargs.items():
                    if key not in ['pad_token_id', 'eos_token_id', 'bos_token_id']:
                        generation_params[key] = value

                with torch.no_grad():
                    outputs = self.model.generate(**model_inputs, **generation_params)

                # Decode and clean
                caption = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
                caption = self._clean_generated_caption(caption, current_prompt)

                # Score the caption
                if concepts:
                    score = self._score_caption_with_concepts(caption, concepts)
                else:
                    score = self._score_caption(caption)

                if score > best_score:
                    best_score = score
                    best_caption = caption

            except Exception as e:
                print(f"Generation error with {approach['name']}: {e}")
                continue

        # Post-process to enhance concept coverage
        if concepts and best_caption:
            best_caption = self._enhance_caption_with_concepts(best_caption, concepts)

        return best_caption if best_caption else "Medical image showing pathological findings."

    def _clean_generated_caption(self, caption: str, prompt: str) -> str:
        """Clean and post-process generated captions"""
        if not caption:
            return ""

        # Remove the prompt from the beginning
        if prompt and caption.lower().startswith(prompt.lower()):
            caption = caption[len(prompt):].strip()

        # Remove common repetitive patterns
        caption = re.sub(r'\b(\w+)\s+\1\b', r'\1', caption)
        caption = re.sub(r'\b(\w+)\s+\1\s+\1\b', r'\1', caption)

        # Remove excessive punctuation
        caption = re.sub(r'[,]{2,}', ',', caption)
        caption = re.sub(r'[:]{2,}', ':', caption)

        # Clean up spaces and punctuation
        caption = re.sub(r'\s+', ' ', caption)
        caption = caption.strip(' ,:.-')

        # Ensure proper capitalization
        if caption:
            caption = caption[0].upper() + caption[1:] if len(caption) > 1 else caption.upper()

        return caption

    def _score_caption(self, caption: str) -> float:
        """Score caption quality (higher is better)"""
        if not caption or len(caption.strip()) < 5:
            return -1

        words = caption.split()
        if len(words) < 3:
            return -1

        # Penalize repetitive captions
        unique_words = set(words)
        repetition_ratio = len(unique_words) / len(words)

        # Penalize very short or very long captions
        length_score = min(len(words) / 20, 1.0)

        # Bonus for medical terms
        medical_terms = ['radiograph', 'ct', 'mri', 'scan', 'image', 'showing', 'demonstrates', 'findings']
        medical_bonus = sum(1 for term in medical_terms if term in caption.lower()) * 0.1

        # Overall score
        score = repetition_ratio * length_score + medical_bonus
        return score

    def _score_caption_with_concepts(self, caption: str, concepts: List[str]) -> float:
        """Score caption quality with concept awareness"""
        if not caption or len(caption.strip()) < 5:
            return -1

        words = caption.split()
        if len(words) < 3:
            return -1

        # Basic quality score
        unique_words = set(words)
        repetition_ratio = len(unique_words) / len(words)
        length_score = min(len(words) / 15, 1.0)

        # Concept coverage bonus
        concept_score = 0
        if concepts:
            mentioned_concepts = 0
            for concept in concepts:
                if concept.lower() in caption.lower():
                    mentioned_concepts += 1
            concept_score = mentioned_concepts / len(concepts)

        # Medical terminology bonus
        medical_terms = ['radiograph', 'ct', 'mri', 'scan', 'image', 'showing', 'demonstrates',
                        'findings', 'medical', 'anatomy', 'pathology', 'lesion', 'abnormal']
        medical_bonus = sum(1 for term in medical_terms if term in caption.lower()) * 0.05

        # Weighted final score
        score = (repetition_ratio * 0.4 + length_score * 0.3 + concept_score * 0.2 + medical_bonus * 0.1)
        return score

    def _enhance_caption_with_concepts(self, caption: str, concepts: List[str]) -> str:
        """Post-process caption to better incorporate concepts"""
        if not concepts or not caption:
            return caption

        # Check which concepts are missing
        missing_concepts = []
        for concept in concepts[:2]:  # Only consider top 2 concepts
            if concept.lower() not in caption.lower():
                # Simplify concept names for better integration
                simplified = self._simplify_concept_name(concept)
                if simplified and simplified not in caption.lower():
                    missing_concepts.append(simplified)

        # If important concepts are missing, try to add them naturally
        if missing_concepts and len(caption.split()) < 20:  # Only if caption isn't too long
            # Add missing concepts at the end
            additional_info = ', '.join(missing_concepts[:1])  # Add max 1 missing concept
            if not caption.endswith('.'):
                caption += f" with {additional_info}"
            else:
                caption = caption[:-1] + f" with {additional_info}."

        return caption

    def _simplify_concept_name(self, concept: str) -> str:
        """Simplify medical concept names for natural integration"""
        if not concept:
            return ""

        # Common simplifications
        simplifications = {
            'structure of': '',
            ', unspecified': '',
            'radiograph': 'radiographic findings',
            'ct': 'CT findings',
            'mri': 'MRI findings',
            'magnetic resonance imaging': 'MRI',
            'computed tomography': 'CT'
        }

        simplified = concept.lower()
        for old, new in simplifications.items():
            simplified = simplified.replace(old, new)

        simplified = simplified.strip(' ,-.')

        # Only return if it's a meaningful addition
        if len(simplified) > 3 and simplified not in ['action', 'finding', 'image']:
            return simplified

        return ""

    def _calculate_concept_coverage(self, caption: str, concepts: List[str]) -> float:
        """Calculate what percentage of concepts are mentioned in caption"""
        if not concepts or not caption:
            return 0.0

        caption_lower = caption.lower()
        mentioned = 0

        for concept in concepts:
            # Check if concept or parts of it are mentioned
            concept_words = concept.lower().split()
            if any(word in caption_lower for word in concept_words if len(word) > 3):
                mentioned += 1

        return mentioned / len(concepts)

    def _assess_generation_quality(self, generated: str, ground_truth: str) -> str:
        """Quick assessment of generation quality"""
        if not generated or len(generated.strip()) < 5:
            return " Too short/empty"

        words = generated.split()
        if len(set(words)) < len(words) * 0.7:  # More than 30% repetition
            return "Too repetitive"

        # Check for medical relevance
        medical_terms = ['radiograph', 'ct', 'mri', 'scan', 'image', 'showing', 'demonstrates',
                        'findings', 'medical', 'clinical', 'anatomy', 'pathology']

        has_medical_terms = any(term in generated.lower() for term in medical_terms)

        if len(words) >= 5 and has_medical_terms:
            return " Good quality"
        elif len(words) >= 3:
            return " Okay but could be better"
        else:
            return " Poor quality"

    def train_simple(self, train_dataset, val_dataset,
                    epochs: int = 5, batch_size: int = 2,
                    learning_rate: float = 2e-5,
                    save_dir: str = "./fine_tuned_medical_blip"):
        """Simplified training approach - use this if main train() has issues"""

        # Create data loaders with smaller batch size
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size,
            shuffle=True, collate_fn=collate_fn, num_workers=0
        )
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size,
            shuffle=False, collate_fn=collate_fn, num_workers=0
        )

        # Setup optimizer
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )

        self.model.train()
        best_val_loss = float('inf')

        for epoch in range(epochs):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch + 1}/{epochs}")
            print(f"{'='*60}")

            # Training phase
            epoch_loss = 0
            num_batches = 0

            for batch_idx, batch in enumerate(train_loader):
                optimizer.zero_grad()
                total_loss = 0

                # Process each sample individually to avoid batch issues
                for i in range(len(batch['images'])):
                    try:
                        # Use concept-rich prompts when available for better training
                        use_concept_prompt = hasattr(batch, 'concept_prompts') and batch.get('concept_prompts')

                        if use_concept_prompt and i < len(batch['concept_prompts']):
                            training_prompt = batch['concept_prompts'][i]
                        else:
                            training_prompt = batch['prompts'][i] if i < len(batch['prompts']) else "medical image:"

                        # Combine prompt and caption for training
                        full_text = f"{training_prompt} {batch['captions'][i]}"

                        inputs = self.processor(
                            images=batch['images'][i],
                            text=full_text,
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=120
                        )
                        inputs = {k: v.to(self.device) for k, v in inputs.items()}

                        # Use the model's built-in loss calculation
                        outputs = self.model(**inputs, labels=inputs['input_ids'])
                        loss = outputs.loss
                        total_loss += loss

                    except Exception as e:
                        print(f"Skipping sample {i} due to error: {e}")
                        continue

                if total_loss > 0:
                    # Average loss over valid samples
                    avg_loss = total_loss / len(batch['images'])
                    avg_loss.backward()

                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                    optimizer.step()

                    epoch_loss += avg_loss.item()
                    num_batches += 1

                # Progress reporting
                if batch_idx % 10 == 0:
                    current_loss = epoch_loss / max(num_batches, 1)
                    print(f"  Batch {batch_idx}/{len(train_loader)} - Loss: {current_loss:.4f}")

            # Validation
            val_loss = self._validate_simple(val_loader)

            train_loss = epoch_loss / max(num_batches, 1)
            print(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                print(f"New best validation loss: {val_loss:.4f}")
                self._save_model(save_dir)

        print(f"\nTraining completed! Best validation loss: {best_val_loss:.4f}")
        return best_val_loss

    def _validate_simple(self, val_loader):
        self.model.eval()
        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for batch in val_loader:
                batch_loss = 0
                valid_samples = 0

                for i in range(len(batch['images'])):
                    try:

                        use_concept_prompt = hasattr(batch, 'concept_prompts') and batch.get('concept_prompts')

                        if use_concept_prompt and i < len(batch['concept_prompts']):
                            training_prompt = batch['concept_prompts'][i]
                        else:
                            training_prompt = batch['prompts'][i] if i < len(batch['prompts']) else "medical image:"

                        full_text = f"{training_prompt} {batch['captions'][i]}"

                        inputs = self.processor(
                            images=batch['images'][i],
                            text=full_text,
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=120
                        )
                        inputs = {k: v.to(self.device) for k, v in inputs.items()}

                        outputs = self.model(**inputs, labels=inputs['input_ids'])
                        loss = outputs.loss

                        batch_loss += loss.item()
                        valid_samples += 1

                    except Exception as e:
                        continue

                if valid_samples > 0:
                    avg_loss = batch_loss / valid_samples
                    total_loss += avg_loss
                    num_batches += 1

        return total_loss / max(num_batches, 1)

    def train(self, train_dataset, val_dataset,
              epochs: int = 5, batch_size: int = 4,
              learning_rate: float = 2e-5, warmup_ratio: float = 0.1,
              save_dir: str = "./fine_tuned_medical_blip"):
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size,
            shuffle=True, collate_fn=collate_fn, num_workers=2
        )
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size,
            shuffle=False, collate_fn=collate_fn, num_workers=2
        )

        # Setup optimizer and scheduler
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.01,
            eps=1e-8
        )

        total_steps = len(train_loader) * epochs
        warmup_steps = int(total_steps * warmup_ratio)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        # Training loop
        self.model.train()
        best_val_loss = float('inf')

        for epoch in range(epochs):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch + 1}/{epochs}")
            print(f"{'='*60}")

            # Training phase
            train_loss = self._train_epoch(train_loader, optimizer, scheduler)

            # Validation phase
            val_loss = self._validate_epoch(val_loader)

            print(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                print(f"New best validation loss: {val_loss:.4f}")
                self._save_model(save_dir)

    def _train_epoch(self, train_loader, optimizer, scheduler):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()

            batch_loss = 0
            valid_samples = 0

            for i in range(len(batch['images'])):
                try:
                    # FIXED: Process everything together for BLIP
                    full_text = f"{batch['prompts'][i]} {batch['captions'][i]}"

                    # Process image and full text together
                    inputs = self.processor(
                        images=batch['images'][i],
                        text=full_text,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=150
                    )
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}

                    # Create labels: shift input_ids to create target sequence
                    input_ids = inputs['input_ids']

                    # For BLIP, we need to create proper labels
                    prompt_tokens = self.processor(
                        text=batch['prompts'][i],
                        return_tensors="pt",
                        add_special_tokens=False
                    )['input_ids']

                    # Create labels by masking the prompt part (-100 ignores these tokens)
                    labels = input_ids.clone()
                    prompt_length = prompt_tokens.shape[1]
                    labels[:, :prompt_length] = -100  # Ignore prompt tokens in loss

                    # Forward pass with proper labels
                    outputs = self.model(
                        input_ids=input_ids,
                        pixel_values=inputs['pixel_values'],
                        labels=labels
                    )
                    loss = outputs.loss

                    batch_loss += loss
                    valid_samples += 1

                except Exception as e:
                    print(f"Error processing sample {i}: {e}")
                    continue

            if valid_samples > 0:
                # Average loss and backward pass
                batch_loss = batch_loss / valid_samples
                batch_loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                optimizer.step()
                scheduler.step()

                total_loss += batch_loss.item()
                num_batches += 1

            # Progress reporting
            if batch_idx % 20 == 0 and batch_idx > 0:
                avg_loss = total_loss / num_batches
                lr = scheduler.get_last_lr()[0]
                print(f"  Batch {batch_idx}/{len(train_loader)} - Loss: {avg_loss:.4f}, LR: {lr:.2e}")

        return total_loss / max(num_batches, 1)

    def _validate_epoch(self, val_loader):
        """Validate for one epoch"""
        self.model.eval()
        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for batch in val_loader:
                batch_loss = 0
                valid_samples = 0

                for i in range(len(batch['images'])):
                    try:
                        inputs = self.processor(
                            images=batch['images'][i],
                            text=batch['prompts'][i],
                            return_tensors="pt",
                            padding=True,
                            truncation=True
                        )
                        inputs = {k: v.to(self.device) for k, v in inputs.items()}

                        labels = self.processor(
                            text=batch['captions'][i],
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=128
                        )
                        labels = labels['input_ids'].to(self.device)

                        outputs = self.model(**inputs, labels=labels)
                        loss = outputs.loss

                        batch_loss += loss
                        valid_samples += 1

                    except Exception as e:
                        continue

                if valid_samples > 0:
                    batch_loss = batch_loss / valid_samples
                    total_loss += batch_loss.item()
                    num_batches += 1

        return total_loss / max(num_batches, 1)

    def _save_model(self, save_dir: str):
        """Save model and processor"""
        os.makedirs(save_dir, exist_ok=True)
        self.model.save_pretrained(save_dir)
        self.processor.save_pretrained(save_dir)
        print(f"Model saved to {save_dir}")

    def quick_test(self, test_dataset, num_samples: int = 3):
        """Quick test to verify generation is working properly with concepts"""
        print(f"\n{'='*60}")
        print(f"QUICK TEST - CONCEPT-AWARE GENERATION")
        print(f"{'='*60}")

        for i in range(min(num_samples, len(test_dataset))):
            sample = test_dataset[i]
            image = sample['image']
            concepts = sample['concepts']
            ground_truth = sample['caption']

            print(f"\nExample {i+1}")
            print(f"Concepts: {concepts[:3]}{'...' if len(concepts) > 3 else ''}")
            print(f"Ground Truth: {ground_truth[:80]}{'...' if len(ground_truth) > 80 else ''}")
            print()

            # Test 1: No prompt + concepts
            try:
                caption1 = self.generate_caption(
                    image, "", concepts=concepts,
                    max_new_tokens=35,
                    num_beams=3,
                    repetition_penalty=2.0,
                    no_repeat_ngram_size=2
                )
                print(f" Concept-guided (no prompt): '{caption1}'")
                coverage1 = self._calculate_concept_coverage(caption1, concepts)
                print(f"   Concept coverage: {coverage1:.1%}")
            except Exception as e:
                print(f" Concept-guided failed: {e}")

            # Test 2: Smart prompt + concepts
            try:
                basic_prompt = sample.get('prompt', 'medical image:')
                caption2 = self.generate_caption(
                    image, basic_prompt, concepts=concepts,
                    max_new_tokens=35,
                    num_beams=3,
                    repetition_penalty=2.0,
                    no_repeat_ngram_size=2
                )
                print(f" Smart prompt + concepts: '{caption2}'")
                coverage2 = self._calculate_concept_coverage(caption2, concepts)
                print(f"   Concept coverage: {coverage2:.1%}")
            except Exception as e:
                print(f" Smart prompt failed: {e}")

            # Test 3: Concept-rich prompt
            try:
                concept_prompt = sample.get('concept_prompt', 'medical image:')
                caption3 = self.generate_caption(
                    image, concept_prompt,
                    max_new_tokens=35,
                    num_beams=3,
                    repetition_penalty=2.0,
                    no_repeat_ngram_size=2
                )
                print(f"✨ Concept-rich prompt: '{caption3}'")
                coverage3 = self._calculate_concept_coverage(caption3, concepts)
                print(f"   Concept coverage: {coverage3:.1%}")
            except Exception as e:
                print(f"Concept-rich prompt failed: {e}")

            print("-" * 60)

    def evaluate_model(self, test_dataset, num_samples: int = None):
        """Comprehensive evaluation with concept-aware generation"""
        if num_samples is None:
            num_samples = min(10, len(test_dataset))

        print(f"\n{'='*60}")
        print(f"EVALUATING MODEL ON {num_samples} SAMPLES")
        print(f"{'='*60}")

        # Test different concept-aware strategies
        strategies = [
            {
                'name': 'Concept-Guided Generation (Best)',
                'use_concepts': True,
                'use_prompt': False,
                'params': {
                    'max_new_tokens': 40,
                    'num_beams': 4,
                    'repetition_penalty': 2.0,
                    'no_repeat_ngram_size': 2,
                    'length_penalty': 0.9
                }
            },
            {
                'name': 'Smart Prompt + Concepts',
                'use_concepts': True,
                'use_prompt': True,
                'params': {
                    'max_new_tokens': 40,
                    'num_beams': 4,
                    'repetition_penalty': 1.8,
                    'no_repeat_ngram_size': 2,
                    'early_stopping': True
                }
            },
            {
                'name': 'Pure Image Captioning (Baseline)',
                'use_concepts': False,
                'use_prompt': False,
                'params': {
                    'max_new_tokens': 40,
                    'num_beams': 4,
                    'repetition_penalty': 2.0,
                    'no_repeat_ngram_size': 2,
                    'length_penalty': 0.8
                }
            }
        ]

        for strategy in strategies:
            print(f"\n--- {strategy['name']} ---")

            total_coverage = 0
            valid_samples = 0

            for i in range(min(3, num_samples)):  # Show 3 examples per strategy
                sample = test_dataset[i]
                image = sample['image']
                concepts = sample['concepts']
                ground_truth = sample['caption']

                # Generate caption based on strategy
                if strategy['use_concepts'] and strategy['use_prompt']:
                    prompt = sample.get('prompt', 'medical image:')
                    generated = self.generate_caption(image, prompt, concepts=concepts, **strategy['params'])
                elif strategy['use_concepts']:
                    generated = self.generate_caption(image, "", concepts=concepts, **strategy['params'])
                else:
                    generated = self.generate_caption(image, "", **strategy['params'])

                print(f"\nExample {i+1}:")
                print(f"Concepts: {concepts[:3]}{'...' if len(concepts) > 3 else ''}")
                print(f"Ground Truth: {ground_truth[:80]}{'...' if len(ground_truth) > 80 else ''}")
                print(f"Generated: {generated}")

                # Calculate concept coverage
                coverage = self._calculate_concept_coverage(generated, concepts)
                total_coverage += coverage
                valid_samples += 1

                # Quality assessment
                quality = self._assess_generation_quality(generated, ground_truth)
                print(f"Quality: {quality}")
                print(f"Concept Coverage: {coverage:.1%}")
                print("-" * 40)

            # Strategy summary
            avg_coverage = total_coverage / max(valid_samples, 1)



def main():
    """Main execution function"""

    # Configuration
    config = {
        'image_dir': "/content/drive/MyDrive/medical/development-3/development/train/train",  # UPDATE THIS PATH
        'caption_csv': "/content/drive/MyDrive/medical/development-3/development/train/train_captions.csv",  # UPDATE THIS PATH
        'concept_csv': "/content/drive/MyDrive/medical/development-3/development/train/train_concepts.csv",  # UPDATE THIS PATH
        'mapping_csv': "/content/drive/MyDrive/medical/cui_names.csv",  # UPDATE THIS PATH
        'model_name': "Salesforce/blip-image-captioning-base",
        'epochs': 5,
        'batch_size': 4,
        'learning_rate': 2e-5,
        'save_dir': "/content/drive/MyDrive/medical/best_medical_blip"
    }


    # Load dataset
    print("Loading dataset...")
    full_dataset = MedicalCaptionDataset(
        image_dir=config['image_dir'],
        caption_csv=config['caption_csv'],
        concept_csv=config['concept_csv'],
        mapping_csv=config['mapping_csv'],
        processor=None  # Will be set by trainer
    )

    # Initialize trainer
    trainer = MedicalCaptionTrainer(model_name=config['model_name'])

    # Set processor in dataset
    full_dataset.processor = trainer.processor

        # Take random 10,000 subset
    subset_size = min(10_000, len(full_dataset))
    indices = torch.randperm(len(full_dataset))[:subset_size].tolist()
    full_dataset = torch.utils.data.Subset(full_dataset, indices)

    # Proceed with split
    train_size = int(0.8 * len(full_dataset))
    val_size = int(0.1 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size

    train_dataset, temp_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size + test_size])
    val_dataset, test_dataset = torch.utils.data.random_split(temp_dataset, [val_size, test_size])
    print(f"Dataset split: {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test")

    # Quick test before training
    print("\n BEFORE TRAINING - QUICK TEST:")
    trainer.quick_test(test_dataset, num_samples=3)

    # Train model (using simplified approach to avoid batch size issues)
    print("\n STARTING TRAINING (Simplified Method):")
    trainer.train_simple(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        epochs=config['epochs'],
        batch_size=2,  # Smaller batch size to avoid memory issues
        learning_rate=config['learning_rate'],
        save_dir=config['save_dir']
    )

    # Evaluate after training
    print("\nAFTER TRAINING:")
    trainer.evaluate_model(test_dataset, num_samples=5)

    print(f"\n Training complete! Model saved to {config['save_dir']}")
    print("To use the trained model later:")
    print(f"  model = BlipForConditionalGeneration.from_pretrained('{config['save_dir']}')")
    print(f"  processor = BlipProcessor.from_pretrained('{config['save_dir']}')")

if __name__ == "__main__":
    main()

In [None]:
# !pip install -U transformers==4.36.2 peft==0.7.1
# !pip install sacrebleu rouge-score bert-score datasets evaluate

Collecting transformers==4.36.2
  Downloading transformers-4.36.2-py3-none-any.whl.metadata (126 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/126.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m122.9/126.8 kB[0m [31m4.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft==0.7.1
  Downloading peft-0.7.1-py3-none-any.whl.metadata (25 kB)
Collecting tokenizers<0.19,>=0.14 (from transformers==4.36.2)
  Downloading tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.36.2-py3-none-any.whl (8.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m80.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.7.1-py3-none-any.whl (168 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

# Enhanced Med-BLIP-2: Concept-Aware Medical Captioning

This script fine-tunes `Salesforce/blip2-opt-2.7b` using medical images and concept-guided prompts via LoRA adapters.

## Dataset Inputs
- Images: JPEGs in `/train/{ID}.jpg`
- Captions: `train_captions.csv` (`ID`, `Caption`)
- Concepts: `train_concepts.csv` (`ID`, `CUIs`)
- Mapping: `cui_names.csv` (`CUI`, `Name` → natural text)

## Model
- Backbone: `Salesforce/blip2-opt-2.7b`
- Adapter: `NouRed/Med-BLIP-2-QLoRA` (PEFT LoRA)
- Tokenizer: `AutoProcessor` with `use_fast=False`
- Trainable layers: `qformer`, `language_projection`

## Key Features
- Prompts include mapped medical concepts (UMLS → natural terms)
- Dynamic prompting templates for medical context
- Concept filtering and length-based data cleaning
- Mixed-precision training with gradient accumulation
- Repetition control in generation (`no_repeat_ngram_size=3`, `repetition_penalty=1.3`)

## Training
- Loss: CrossEntropy with label smoothing (`ignore_index=-100`)
- Optimizer: AdamW (`lr=2e-5`, `weight_decay=0.01`)
- Scheduler: Cosine decay with warmup
- Epochs: 5 (configurable)
- Gradient clipping: max norm = 1.0
- Data split: 90% train / 10% val (random)

## Generation
- Beam search with 5 beams
- Concept-driven prompt per image
- Cleaned output (prompt stripped, repetition reduced)

##Results
- AVG_PRED_LENGTH     : 24.1200
- AVG_REF_LENGTH      : 22.3200
- BLEU_SCORE          : 1.5318
- ROUGE_L             : 0.0949
- BERTSCORE_F1        : 0.6185
## Output
- Fine-tuned model saved at: `cfg["save_dir"]`
- Can be reloaded with `from_pretrained()`

> Tip: update paths in `cfg = {...}` before training.


In [None]:
# Enhanced Med-BLIP-2 Training Code
# Includes full dataset and trainer classes with concept integration, decoding, and tuning improvements

import os, gc, json, logging, warnings, re, random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pandas as pd
from tqdm import tqdm
from peft import PeftModel
from transformers import (
    AutoProcessor, Blip2ForConditionalGeneration,
    get_cosine_schedule_with_warmup, set_seed
)
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_

warnings.filterwarnings('ignore')
set_seed(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EnhancedMedicalConceptDataset(Dataset):
    def __init__(self, image_dir, caption_csv, concept_csv, mapping_csv, max_samples=5000, img_size=224):
        caps = pd.read_csv(caption_csv).rename(columns={'ID':'id','Caption':'caption'})
        conc = pd.read_csv(concept_csv).rename(columns={'ID':'id','CUIs':'concepts'})
        mapping = pd.read_csv(mapping_csv).rename(columns={'CUI':'cui','Name':'name'})
        self.concept_map = dict(zip(mapping.cui, mapping.name))

        df = caps.merge(conc, on='id').dropna()
        df['concept_text'] = df['concepts'].apply(self._map_concepts)
        df['image_path'] = df['id'].apply(lambda x: os.path.join(image_dir, f"{x}.jpg"))
        df = df[df['image_path'].apply(os.path.exists)].copy()
        df = self._filter_quality_samples(df)

        if max_samples and len(df) > max_samples:
            df = df.sample(n=max_samples, random_state=42).copy()

        self.df = df.reset_index(drop=True)
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])

    def _filter_quality_samples(self, df):
        df = df[df['caption'].str.len().between(20, 500)]
        df = df[df['concept_text'].str.len() > 5]
        df = df[~df['caption'].str.contains(r'^[A-Za-z0-9_\-\.]+\.(jpg|png|jpeg)$', regex=True)]
        return df

    def _map_concepts(self, cui_str):
        cuis = re.split(r'[;|,\s]+', str(cui_str))
        names = [self.concept_map.get(c.strip(), c.strip()) for c in cuis if c.strip()]
        return ', '.join(sorted(set(filter(lambda x: len(x) > 2, names))))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = Image.open(row.image_path).convert("RGB")
        return {
            "image": img,
            "caption": row.caption,
            "concept_text": row.concept_text,
            "image_id": row.id
        }

    @staticmethod
    def collate_fn(batch):
        imgs = [transforms.Resize((224,224))(b["image"]) for b in batch]
        prompts = [f"Radiological findings: {b['concept_text']}. What does this image show?" for b in batch]
        return {
            "images": imgs,
            "prompts": prompts,
            "captions": [b["caption"] for b in batch],
            "concepts": [b["concept_text"] for b in batch],
            "image_ids": [b["image_id"] for b in batch]
        }

class ImprovedMedBLIPTrainer:
    def __init__(self, base_model="Salesforce/blip2-opt-2.7b", adapter="NouRed/Med-BLIP-2-QLoRA"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.grad_accum_steps = 16
        self._clear_mem()
        self._load_model(base_model, adapter)
        self.scaler = GradScaler()

        self.prompt_templates = [
            "Radiological findings: {concept_text}. What does this image show?",
            "Based on {concept_text}, describe this medical image.",
            "Imaging context: {concept_text}. Explain the findings.",
            "{concept_text}. Describe the anatomical or pathological features."
        ]

    def _clear_mem(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
        gc.collect()

    def _load_model(self, base_model, adapter):
        self.processor = AutoProcessor.from_pretrained(base_model, use_fast=False)
        base = Blip2ForConditionalGeneration.from_pretrained(base_model, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)

        try:
            self.model = PeftModel.from_pretrained(base, adapter, torch_dtype=torch.float16, trust_remote_code=True)
        except:
            self.model = base

        self.model.train()
        for name, param in self.model.named_parameters():
            if 'qformer' in name or 'language_projection' in name:
                param.requires_grad = True
                param.data = param.data.float()
            else:
                param.requires_grad = False

    def augment_prompt(self, concept_text):
        if not concept_text.strip():
            return "Describe what is shown in this medical image:"
        return random.choice(self.prompt_templates).format(concept_text=concept_text)

    def compute_enhanced_loss(self, outputs, labels):
        loss_fct = torch.nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=-100)
        shift_logits = outputs.logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        return loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    def train(self, dataset, epochs=5, batch_size=1, lr=2e-5, save_dir="./medblip_final"):
        train_len = int(0.9 * len(dataset))
        val_len = len(dataset) - train_len
        train_ds, val_ds = random_split(dataset, [train_len, val_len])

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn)
        val_loader = DataLoader(val_ds, batch_size=1, collate_fn=dataset.collate_fn)

        opt = torch.optim.AdamW([p for p in self.model.parameters() if p.requires_grad], lr=lr, weight_decay=0.01)
        total_steps = len(train_loader) * epochs // self.grad_accum_steps
        sched = get_cosine_schedule_with_warmup(opt, total_steps // 10, total_steps)

        for ep in range(epochs):
            logger.info(f"Epoch {ep+1}/{epochs}")
            self._run_epoch(train_loader, opt, sched, training=True)
            self._run_epoch(val_loader, opt, sched, training=False)
            self._save(save_dir)

        return train_ds

    def _run_epoch(self, loader, opt, sched, training=True):
        self.model.train() if training else self.model.eval()
        bar = tqdm(loader, desc="Train" if training else "Val")
        for i, batch in enumerate(bar):
            if training:
                prompts = [self.augment_prompt(c) for c in batch["concepts"]]
                batch["prompts"] = prompts
            inputs = self.processor(images=batch["images"], text=batch["prompts"], return_tensors="pt", padding=True, truncation=True).to(self.device)
            labels = inputs.input_ids.clone()
            with autocast():
                outputs = self.model(**inputs, labels=labels)
                loss = self.compute_enhanced_loss(outputs, labels)
                loss = loss / self.grad_accum_steps
            if training:
                self.scaler.scale(loss).backward()
                if (i+1) % self.grad_accum_steps == 0:
                    self.scaler.unscale_(opt)
                    clip_grad_norm_(self.model.parameters(), 1.0)
                    self.scaler.step(opt)
                    self.scaler.update()
                    opt.zero_grad()
                    sched.step()

    def generate_improved(self, image, concept_text="", max_tokens=60):
        prompt = self.augment_prompt(concept_text)
        inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
        with torch.no_grad(), autocast():
            outputs = self.model.generate(
                **inputs, max_new_tokens=max_tokens,
                num_beams=5, pad_token_id=self.processor.tokenizer.eos_token_id,
                no_repeat_ngram_size=3, repetition_penalty=1.3, length_penalty=1.1,
                early_stopping=True
            )
        return self.processor.batch_decode(outputs, skip_special_tokens=True)[0].replace(prompt, "").strip()

    def _save(self, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        self.model.save_pretrained(save_dir)
        self.processor.save_pretrained(save_dir)

# Training config
cfg = {
    "image_dir": "/content/drive/MyDrive/medical/development-3/development/train/train",
    "caption_csv": "/content/drive/MyDrive/medical/development-3/development/train/train_captions.csv",
    "concept_csv": "/content/drive/MyDrive/medical/development-3/development/train/train_concepts.csv",
    "mapping_csv": "/content/drive/MyDrive/medical/cui_names.csv",
    "epochs": 5,
    "batch_size": 1,
    "lr": 2e-5,
    "max_samples": 2000,
    "save_dir": "/content/drive/MyDrive/medical/medblip_improved_final"
}

def main():
    trainer = ImprovedMedBLIPTrainer()
    dataset = EnhancedMedicalConceptDataset(
        cfg["image_dir"], cfg["caption_csv"], cfg["concept_csv"],
        cfg["mapping_csv"], max_samples=cfg["max_samples"]
    )
    trainer.train(dataset, cfg["epochs"], cfg["batch_size"], cfg["lr"], cfg["save_dir"])

if __name__ == "__main__":
    main()


###Evaluation Script

In [16]:
#!/usr/bin/env python3
"""
Standalone evaluation script for trained Medical BLIP-2 model
"""

import os, json, logging, warnings, re
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pandas as pd
from tqdm import tqdm
from peft import PeftModel
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from torch.cuda.amp import autocast

# Try to import metrics, handle gracefully if missing
try:
    from datasets import load_metric
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    print("Datasets library not available for metrics")

try:
    import evaluate
    EVALUATE_AVAILABLE = True
except ImportError:
    EVALUATE_AVAILABLE = False
    print(" Evaluate library not available")

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MedBLIPEvaluator:
    def __init__(self, model_path, base_model="Salesforce/blip2-opt-2.7b"):
        """
        Load trained model for evaluation
        Args:
            model_path: Path to saved model directory
            base_model: Base model name (if needed for processor)
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_path = model_path
        self._load_trained_model(model_path, base_model)

    def _load_trained_model(self, model_path, base_model):
        """Load the trained model from saved directory"""
        logger.info(f"🔍 Loading trained model from: {model_path}")

        try:
            # Try loading processor from saved model first
            self.processor = AutoProcessor.from_pretrained(
                model_path,
                trust_remote_code=True,
                use_fast=False
            )
            logger.info(" Loaded processor from saved model")
        except Exception as e:
            logger.warning(f"Could not load processor from saved model: {e}")
            logger.info(f" Loading processor from base model: {base_model}")
            self.processor = AutoProcessor.from_pretrained(
                base_model,
                trust_remote_code=True,
                use_fast=False
            )

        try:
            # Load the trained model
            self.model = Blip2ForConditionalGeneration.from_pretrained(
                model_path,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
            logger.info(" Loaded trained model successfully")
        except Exception as e:
            logger.error(f" Failed to load trained model: {e}")
            logger.info(" Make sure the model was saved correctly and the path is correct")
            raise

        self.model.eval()

        # Clear memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def generate(self, image, concept_text="", max_tokens=50):
        """Generate caption for an image with optional concept guidance"""
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")

        prompt = f"Findings: {concept_text}. What does this image show?" if concept_text else "What does this image show?"

        inputs = self.processor(
            images=image,
            text=prompt,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad(), autocast():
            # Handle pad_token_id
            pad_token_id = getattr(self.processor.tokenizer, 'pad_token_id', None)
            if pad_token_id is None:
                pad_token_id = self.processor.tokenizer.eos_token_id

            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                num_beams=3,
                pad_token_id=pad_token_id,
                do_sample=False,
                early_stopping=True,
                repetition_penalty=1.1
            )

        generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
        # Remove the prompt from generated text
        return generated_text.replace(prompt, "").strip()

    def evaluate_dataset(self, dataset, num_samples=None, save_results=True):
        """Evaluate model on a dataset"""
        logger.info(f" Starting evaluation on {len(dataset)} samples")

        if num_samples is not None:
            num_samples = min(num_samples, len(dataset))
            logger.info(f" Limiting evaluation to {num_samples} samples")
        else:
            num_samples = len(dataset)

        references = []
        predictions = []
        detailed_results = []

        for i in tqdm(range(num_samples), desc="Evaluating"):
            try:
                sample = dataset[i]
                image = sample["image"]
                concept_text = sample.get("concept_text", "")
                ground_truth = sample["caption"]

                # Generate prediction
                prediction = self.generate(image, concept_text)

                references.append(ground_truth)
                predictions.append(prediction)

                detailed_results.append({
                    "sample_id": i,
                    "ground_truth": ground_truth,
                    "prediction": prediction,
                    "concept_text": concept_text
                })

                # Print a few examples
                if i < 5:
                    print(f"\n--- Sample {i+1} ---")
                    print(f"Concepts: {concept_text}")
                    print(f"Ground Truth: {ground_truth}")
                    print(f"Prediction: {prediction}")

            except Exception as e:
                logger.warning(f"Error processing sample {i}: {e}")
                continue

        # Calculate metrics
        metrics = self._calculate_metrics(references, predictions)

        # Save results if requested
        if save_results:
            self._save_evaluation_results(detailed_results, metrics)

        return metrics, detailed_results

    def _calculate_metrics(self, references, predictions):
        """Calculate evaluation metrics"""
        metrics = {}

        # Basic metrics
        metrics['num_samples'] = len(predictions)
        metrics['avg_pred_length'] = sum(len(p.split()) for p in predictions) / len(predictions)
        metrics['avg_ref_length'] = sum(len(r.split()) for r in references) / len(references)

        # Try to calculate BLEU, ROUGE, and BERTScore
        if METRICS_AVAILABLE:
            try:
                # BLEU Score
                bleu = load_metric("sacrebleu")
                bleu_result = bleu.compute(
                    predictions=predictions,
                    references=[[ref] for ref in references]
                )
                metrics['bleu_score'] = bleu_result['score']
                logger.info(f" BLEU Score: {bleu_result['score']:.2f}")
            except Exception as e:
                logger.warning(f"Could not calculate BLEU: {e}")

            try:
                # ROUGE Score
                rouge = load_metric("rouge")
                rouge_result = rouge.compute(
                    predictions=predictions,
                    references=references
                )
                metrics['rouge_l'] = rouge_result['rougeL'].mid.fmeasure
                logger.info(f" ROUGE-L: {rouge_result['rougeL'].mid.fmeasure:.4f}")
            except Exception as e:
                logger.warning(f"Could not calculate ROUGE: {e}")

            try:
                # BERTScore
                bertscore = load_metric("bertscore")
                bert_result = bertscore.compute(
                    predictions=predictions,
                    references=references,
                    lang="en"
                )
                metrics['bertscore_f1'] = sum(bert_result['f1']) / len(bert_result['f1'])
                logger.info(f"BERTScore F1: {metrics['bertscore_f1']:.4f}")
            except Exception as e:
                logger.warning(f"Could not calculate BERTScore: {e}")

        elif EVALUATE_AVAILABLE:
            try:
                # Use evaluate library as fallback
                bleu = evaluate.load("sacrebleu")
                bleu_result = bleu.compute(
                    predictions=predictions,
                    references=[[ref] for ref in references]
                )
                metrics['bleu_score'] = bleu_result['score']
                logger.info(f"✅ BLEU Score: {bleu_result['score']:.2f}")
            except Exception as e:
                logger.warning(f"Could not calculate BLEU with evaluate: {e}")

        else:
            logger.warning("⚠️ No metrics libraries available. Install with:")
            logger.warning("pip install datasets evaluate sacrebleu rouge-score bert-score")

        return metrics

    def _save_evaluation_results(self, detailed_results, metrics):
        """Save evaluation results to files"""
        # Create results directory
        results_dir = os.path.join(self.model_path, "evaluation_results")
        os.makedirs(results_dir, exist_ok=True)

        # Save detailed results
        detailed_path = os.path.join(results_dir, "detailed_results.json")
        with open(detailed_path, 'w') as f:
            json.dump(detailed_results, f, indent=2)

        # Save metrics summary
        metrics_path = os.path.join(results_dir, "metrics_summary.json")
        with open(metrics_path, 'w') as f:
            json.dump(metrics, f, indent=2)

        logger.info(f"💾 Saved results to {results_dir}")

        # Print summary
        print(f"\n{'='*50}")
        print("📊 EVALUATION SUMMARY")
        print(f"{'='*50}")
        for key, value in metrics.items():
            if isinstance(value, float):
                print(f"{key.upper():20}: {value:.4f}")
            else:
                print(f"{key.upper():20}: {value}")
        print(f"{'='*50}")

class MedicalConceptDataset(Dataset):
    """Same dataset class as in training script"""
    def __init__(self, image_dir, caption_csv, concept_csv, mapping_csv, max_samples=None):
        caps = pd.read_csv(caption_csv).rename(columns={'ID':'id','Caption':'caption'})
        conc = pd.read_csv(concept_csv).rename(columns={'ID':'id','CUIs':'concepts'})
        mapping = pd.read_csv(mapping_csv).rename(columns={'CUI':'cui','Name':'name'})
        self.concept_map = dict(zip(mapping.cui, mapping.name))

        df = caps.merge(conc, on='id').dropna()
        df['concept_text'] = df['concepts'].apply(self._map_concepts)
        df['image_path'] = df['id'].apply(lambda x: os.path.join(image_dir, f"{x}.jpg"))
        df = df[df['image_path'].apply(os.path.exists)].copy()

        if max_samples and len(df) > max_samples:
            df = df.sample(n=max_samples, random_state=42).copy()

        self.df = df.reset_index(drop=True)
        print(f"📊 Dataset size: {len(self.df)} samples")

    def _map_concepts(self, cui_str):
        cuis = re.split(r'[;|,\s]+', str(cui_str))
        names = [self.concept_map.get(cui.strip(), cui.strip()) for cui in cuis if cui.strip()]
        return ', '.join(sorted(set(names)))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = Image.open(row.image_path).convert("RGB")
        return {
            "image": img,
            "caption": row.caption,
            "concept_text": row.concept_text
        }

def main():
    """Main evaluation function"""

    # Configuration - UPDATE THESE PATHS
    config = {
        "model_path": "/content/drive/MyDrive/medical/best_medical_blip_fixed",  # Path to your saved model
        "image_dir": "/content/drive/MyDrive/medical/development-3/development/train/train",
        "caption_csv": "/content/drive/MyDrive/medical/development-3/development/train/train_captions.csv",
        "concept_csv": "/content/drive/MyDrive/medical/development-3/development/train/train_concepts.csv",
        "mapping_csv": "/content/drive/MyDrive/medical/cui_names.csv",
        "num_eval_samples": 100,  # Set to None to evaluate all samples
        "base_model": "Salesforce/blip2-opt-2.7b"
    }

    try:
        # Initialize evaluator
        evaluator = MedBLIPEvaluator(
            model_path=config["model_path"],
            base_model=config["base_model"]
        )

        # Load dataset (you can also create a separate test set)
        dataset = MedicalConceptDataset(
            config["image_dir"],
            config["caption_csv"],
            config["concept_csv"],
            config["mapping_csv"],
            max_samples=config["num_eval_samples"]
        )

        # Run evaluation
        metrics, detailed_results = evaluator.evaluate_dataset(
            dataset,
            num_samples=config["num_eval_samples"],
            save_results=True
        )

        print(" Evaluation completed!")

        # Optional: Test single image
        if len(dataset) > 0:
            print(f"\n Testing single image generation...")
            sample = dataset[0]
            result = evaluator.generate(sample["image"], sample["concept_text"])
            print(f"Concept: {sample['concept_text']}")
            print(f"Generated: {result}")

    except Exception as e:
        logger.error(f" Evaluation failed: {e}")
        raise

if __name__ == "__main__":
    main()

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

📊 Dataset size: 100 samples


Evaluating:   1%|          | 1/100 [00:02<04:17,  2.60s/it]


--- Sample 1 ---
Concepts: CT, Internal Nare
Ground Truth: CT scan of the patient shown obstruction on both side of choanae.
Prediction: This image shows the right side of the patient's head. The patient is lying on his back with his head tilted to the left. The patient's head is tilted to the left. The patient's head is tilted to the left. The patient's


Evaluating:   2%|▏         | 2/100 [00:05<04:14,  2.60s/it]


--- Sample 2 ---
Concepts: CT, Maxilla bone structure
Ground Truth: Coronal CT images showing the left anterior superior alveolar nerve (ASAN) branching inferiorly from the infraorbital nerve. The ASAN pathway is disrupted distally within the anterior maxillary wall due to bony sclerosis.
Prediction: This image shows the maxilla bone structure. The maxilla is a bone in the front of the skull that connects the jaw to the rest of the skull. The maxilla is a bone in the front of the skull that connects the jaw to the


Evaluating:   3%|▎         | 3/100 [00:05<02:49,  1.75s/it]


--- Sample 3 ---
Concepts: CT, aortic arch, chest and upper back, enlarged
Ground Truth: CT scan of the thorax showing the enlarged mediastinal mass in front of the aortic arch.
Prediction: 


Evaluating:   4%|▍         | 4/100 [00:06<02:04,  1.30s/it]


--- Sample 4 ---
Concepts: CT, Posterior part of pelvis, Structure of pelvic region, unspecified
Ground Truth: Mean values shown for pelvic incidence (PI) and sacral table angle (STA) in patients with L5 spondylolysis with measures demonstrated on computed tomography from a patient with spondylolysis (PI solid line; STA dashed line).
Prediction: What does it tell us about the patient?


Evaluating:   5%|▌         | 5/100 [00:09<02:45,  1.74s/it]


--- Sample 5 ---
Concepts: CT
Ground Truth: Image showing the progression of the apparently benign left adrenal incidentaloma (arrow in Figure 1) to adrenocortical carcinoma (arrow). CT scan images nine years later showing the change in the characteristics of the lesion, with a rapid interval growth measuring 57.3 mm (maximum diameter).CT: computed tomography
Prediction: The image shows a patient with a large mass in the right lower quadrant of the chest. The mass is located in the right lower quadrant of the chest. The mass is located in the right lower quadrant of the chest.


Evaluating:  84%|████████▍ | 84/100 [02:50<00:32,  2.02s/it]


KeyboardInterrupt: 