## üìã PART 1: SETUP & ENVIRONMENT

### Step 1: Check GPU and System Info

In [None]:
import torch
import os

print("üñ•Ô∏è KAGGLE ENVIRONMENT INFO")
print("=" * 60)

# Check GPU
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
    print("‚úÖ GPU is ready!")
else:
    print("‚ö†Ô∏è No GPU detected!")
    print("   Go to: Settings ‚Üí Accelerator ‚Üí GPU T4 x2")

print("\nüìÅ Working Directory:", os.getcwd())
print("üìÅ Available Space:")
os.system('df -h /kaggle/working')

print("\n‚úÖ System check complete!")

In [None]:
import torch
import os

print("üñ•Ô∏è KAGGLE ENVIRONMENT INFO")
print("=" * 60)

# Check GPU
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
    print("‚úÖ GPU is ready!")
else:
    print("‚ö†Ô∏è No GPU detected!")
    print("   Go to: Settings ‚Üí Accelerator ‚Üí GPU T4 x2")

print("\nüìÅ Working Directory:", os.getcwd())
print("üìÅ Available Space:")
os.system('df -h /kaggle/working')

print("\n‚úÖ System check complete!")

In [None]:
# Install required packages
print("üì¶ Installing dependencies...")
print("This will take 2-3 minutes...\n")

# Step 1: Aggressively uninstall ALL OpenCV variants
print("üßπ Cleaning up existing OpenCV installations...")
!pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless opencv-contrib-python-headless -qqq 2>/dev/null || true

# Step 2: Install NumPy first (critical for compatibility)
print("üìå Installing NumPy 1.26.4...")
!pip install -q "numpy==1.26.4"

# Step 3: Install compatible OpenCV version
print("üìå Installing OpenCV 4.8.0.76...")
!pip install -q opencv-python-headless==4.8.0.76

# Step 4: Install albumentations with pinned version
print("üìå Installing albumentations 1.3.1...")
!pip install -q albumentations==1.3.1

# Step 5: Install other dependencies
print("üìå Installing other packages...")
!pip install -q \
    transformers==4.35.0 \
    "huggingface-hub>=0.20.0,<1.0" \
    tensorboard \
    scikit-learn \
    tqdm \
    pyyaml \
    pillow \
    matplotlib

# Step 6: Ensure NumPy stays at 1.26.4
!pip install -q --force-reinstall "numpy==1.26.4"

print("\n‚úÖ Installation complete!")
print("\n‚ö†Ô∏è  IMPORTANT: RESTART THE KERNEL NOW!")
print("   Click: Runtime ‚Üí Restart Runtime (or Kernel ‚Üí Restart)")
print("   Then re-run this cell to verify imports.\n")

# Verify imports
print("üß™ Verifying installations...")
try:
    import transformers
    import cv2
    import numpy as np
    import albumentations
    
    print(f"‚úÖ Transformers: {transformers.__version__}")
    print(f"‚úÖ OpenCV: {cv2.__version__}")
    print(f"‚úÖ NumPy: {np.__version__}")
    print(f"‚úÖ Albumentations: {albumentations.__version__}")
    
    # Verify CV_8U attribute exists
    assert hasattr(cv2, 'CV_8U'), "OpenCV missing CV_8U attribute!"
    print(f"‚úÖ OpenCV CV_8U attribute: OK")
    
    # Verify NumPy version
    if np.__version__.startswith('1.'):
        print("‚úÖ NumPy 1.x confirmed (required for compatibility)")
    else:
        print(f"‚ö†Ô∏è WARNING: NumPy {np.__version__} detected, should be 1.x")
    
    print("\n‚úÖ All packages verified and working!")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("\n‚ö†Ô∏è  Please RESTART the kernel and run this cell again!")
    raise
except AssertionError as e:
    print(f"‚ùå Compatibility error: {e}")
    print("\n‚ö†Ô∏è  Please RESTART the kernel and run this cell again!")
    raise

### Step 3: Setup Project Structure

**üìÅ Kaggle File System:**
- `/kaggle/input/` - Read-only input datasets (uploaded by you)
- `/kaggle/working/` - Read-write workspace (lost after session)
- Results saved here will be available in "Output" tab after session ends

In [None]:
# Create project structure in working directory
print("üìÅ Setting up project structure...\n")

# Create directories
directories = [
    'models',
    'data',
    'scripts',
    'utils',
    'checkpoints',
    'logs',
    'results'
]

for dir_name in directories:
    os.makedirs(dir_name, exist_ok=True)
    print(f"‚úÖ Created: {dir_name}/")

print("\nüìÇ Kaggle Directory Structure:")
print("\n/kaggle/input/ (Read-Only):")
print("  ‚îî‚îÄ‚îÄ Your uploaded datasets appear here")
print("\n/kaggle/working/ (Read-Write):")
print("  ‚îú‚îÄ‚îÄ models/")
print("  ‚îú‚îÄ‚îÄ data/")
print("  ‚îú‚îÄ‚îÄ scripts/")
print("  ‚îú‚îÄ‚îÄ utils/")
print("  ‚îú‚îÄ‚îÄ checkpoints/  ‚Üê Models saved here")
print("  ‚îú‚îÄ‚îÄ logs/         ‚Üê Training logs")
print("  ‚îî‚îÄ‚îÄ results/      ‚Üê Test results")

print("\n‚úÖ Project structure created!")

### Step 4: Copy Project Files from Input

**‚ö†Ô∏è IMPORTANT: Upload Your Datasets First!**

Before running this cell:

1. **Upload Project Code to Kaggle:**
   - Go to: https://www.kaggle.com/datasets
   - Click "New Dataset"
   - Upload your `VISIHEALTH CODE` folder (can ZIP it first)
   - Suggested name: `visihealth-code`
   - Click "Create"

2. **Upload SLAKE Dataset to Kaggle:**
   - Go to: https://www.kaggle.com/datasets  
   - Click "New Dataset"
   - Upload your `Slake1.0` folder (can ZIP it first)
   - Suggested name: `slake-medical-vqa` or `my-slake-dataset`
   - Click "Create"

3. **Add Both Datasets to This Notebook:**
   - In this notebook, click "+ Add Data" (right sidebar ‚Üí Input tab)
   - Click "Your Datasets" tab
   - Add both datasets you just uploaded

4. **Update the paths below** to match YOUR dataset names exactly

In [None]:
import os
import shutil

# ‚ö†Ô∏è IMPORTANT: UPDATE THESE PATHS to match YOUR uploaded dataset names!
# After uploading, check "Available input datasets" below to see exact names
PROJECT_INPUT = '/kaggle/input/visihealth-code/VISIHEALTH CODE'      # Your project code dataset
SLAKE_INPUT = '/kaggle/input/slake-medical-vqa/Slake1.0'      # Your SLAKE dataset

# Examples of what paths might look like:
# PROJECT_INPUT = '/kaggle/input/my-visihealth-code'
# SLAKE_INPUT = '/kaggle/input/my-slake-dataset'
# SLAKE_INPUT = '/kaggle/input/slake1-0'

print("üîç Checking input datasets...\n")

# Check what's available in /kaggle/input/
print("Available input datasets:")
if os.path.exists('/kaggle/input'):
    datasets = os.listdir('/kaggle/input')
    for dataset in datasets:
        print(f"  üìÅ {dataset}")
        # Show contents
        dataset_path = os.path.join('/kaggle/input', dataset)
        if os.path.isdir(dataset_path):
            contents = os.listdir(dataset_path)[:5]  # First 5 items
            for item in contents:
                print(f"      - {item}")
            if len(os.listdir(dataset_path)) > 5:
                print(f"      ... and {len(os.listdir(dataset_path)) - 5} more")
else:
    print("  ‚ö†Ô∏è No datasets found!")
    print("\nüëÜ You need to add datasets using '+ Add Data' button above")

print("\n" + "="*60)

# Check if project files exist
if os.path.exists(PROJECT_INPUT):
    print(f"‚úÖ Found project at: {PROJECT_INPUT}")
    
    # Copy project files
    print("\nüì¶ Copying project files...")
    
    # Copy models
    if os.path.exists(f"{PROJECT_INPUT}/models"):
        shutil.copytree(f"{PROJECT_INPUT}/models", "models", dirs_exist_ok=True)
        print("  ‚úÖ Copied models/")
    
    # Copy scripts
    if os.path.exists(f"{PROJECT_INPUT}/scripts"):
        shutil.copytree(f"{PROJECT_INPUT}/scripts", "scripts", dirs_exist_ok=True)
        print("  ‚úÖ Copied scripts/")
    
    # Copy utils
    if os.path.exists(f"{PROJECT_INPUT}/utils"):
        shutil.copytree(f"{PROJECT_INPUT}/utils", "utils", dirs_exist_ok=True)
        print("  ‚úÖ Copied utils/")
    
    # Copy data module
    if os.path.exists(f"{PROJECT_INPUT}/data"):
        shutil.copytree(f"{PROJECT_INPUT}/data", "data", dirs_exist_ok=True)
        print("  ‚úÖ Copied data/")
    
    # Copy config files
    if os.path.exists(f"{PROJECT_INPUT}/config.yaml"):
        shutil.copy(f"{PROJECT_INPUT}/config.yaml", "config.yaml")
        print("  ‚úÖ Copied config.yaml")
    
    if os.path.exists(f"{PROJECT_INPUT}/requirements.txt"):
        shutil.copy(f"{PROJECT_INPUT}/requirements.txt", "requirements.txt")
        print("  ‚úÖ Copied requirements.txt")
    
    print("\n‚úÖ Project files copied!")
else:
    print(f"‚ùå Project not found at: {PROJECT_INPUT}")
    print("\nüìã TO FIX:")
    print("1. Click '+ Add Data' button above")
    print("2. Go to 'Your Datasets'")
    print("3. Upload your VISIHEALTH CODE folder")
    print("4. Update PROJECT_INPUT path in this cell")
    print("5. Re-run this cell")

print("\n" + "="*60)

# Check SLAKE dataset
if os.path.exists(SLAKE_INPUT):
    print(f"‚úÖ Found SLAKE dataset at: {SLAKE_INPUT}")
    
    # Create symlink or copy (symlink is faster)
    print("\nüìä Linking SLAKE dataset...")
    
    slake_dest = "data/SLAKE"
    if os.path.exists(slake_dest):
        # Remove existing link or directory
        if os.path.islink(slake_dest):
            os.unlink(slake_dest)  # Remove symlink
        else:
            shutil.rmtree(slake_dest)  # Remove directory
    
    # Create symbolic link (faster than copying)
    os.symlink(SLAKE_INPUT, slake_dest)
    
    # Verify
    if os.path.exists(f"{slake_dest}/train.json"):
        print("  ‚úÖ train.json found")
    if os.path.exists(f"{slake_dest}/test.json"):
        print("  ‚úÖ test.json found")
    if os.path.exists(f"{slake_dest}/imgs"):
        num_imgs = len(os.listdir(f"{slake_dest}/imgs"))
        print(f"  ‚úÖ imgs/ found ({num_imgs} images)")
    
    print("\n‚úÖ SLAKE dataset linked!")
else:
    print(f"‚ùå SLAKE dataset not found at: {SLAKE_INPUT}")
    print("\nüìã TO FIX:")
    print("1. Click '+ Add Data' button above")
    print("2. Go to 'Your Datasets'")
    print("3. Upload your Slake1.0 folder")
    print("4. Update SLAKE_INPUT path in this cell")
    print("5. Re-run this cell")

print("\n" + "="*60)
print("‚úÖ Setup check complete!")

### Step 5: Verify Project Setup

In [None]:
# Verify all files are in place
print("üß™ Verifying project setup...\n")
print("=" * 60)

core_files = [
    'models/cnn_model.py',
    'models/bert_model.py',
    'models/fusion_model.py',
    'models/__init__.py',
    'scripts/train.py',
    'scripts/demo.py',
    'data/dataset.py',
    'data/__init__.py',
    'utils/knowledge_graph.py',
    'utils/__init__.py',
    'config.yaml',
    'requirements.txt'
]

print("Checking project files:")
all_good = True
for file in core_files:
    exists = os.path.exists(file)
    status = "‚úÖ" if exists else "‚ùå"
    print(f"  {status} {file}")
    if not exists:
        all_good = False

print("\n" + "=" * 60)
print("Checking dataset:")

dataset_files = [
    'data/SLAKE/train.json',
    'data/SLAKE/test.json',
    'data/SLAKE/imgs'
]

for file in dataset_files:
    exists = os.path.exists(file)
    status = "‚úÖ" if exists else "‚ùå"
    print(f"  {status} {file}")
    if not exists:
        all_good = False

print("\n" + "=" * 60)

if all_good:
    print("‚úÖ SETUP VERIFICATION COMPLETE!")
    print("‚úÖ READY TO TRAIN!")
else:
    print("‚ö†Ô∏è SOME FILES ARE MISSING")
    print("Go back to Step 4 and check the instructions")

print("=" * 60)

### Step 5.5: Fix Module Imports (Run this if you get ImportError)

In [None]:
# Fix data module imports
print("üîß Fixing module imports...\n")

# Ensure data/__init__.py has proper imports
data_init_content = """\"\"\"Data module for VisiHealth.\"\"\"
from .dataset import SLAKEDataset, get_dataloader

__all__ = ['SLAKEDataset', 'get_dataloader']
"""

with open('data/__init__.py', 'w') as f:
    f.write(data_init_content)
print("‚úÖ Fixed data/__init__.py")

# Ensure models/__init__.py has proper imports
models_init_content = """\"\"\"Models module for VisiHealth.\"\"\"
from .cnn_model import get_cnn_model
from .bert_model import get_bert_model
from .fusion_model import build_visihealth_model

__all__ = ['get_cnn_model', 'get_bert_model', 'build_visihealth_model']
"""

with open('models/__init__.py', 'w') as f:
    f.write(models_init_content)
print("‚úÖ Fixed models/__init__.py")

# Ensure utils/__init__.py has proper imports
utils_init_content = """\"\"\"Utilities module for VisiHealth.\"\"\"
from .knowledge_graph import KnowledgeGraph, RationaleGenerator, load_knowledge_graph

__all__ = ['KnowledgeGraph', 'RationaleGenerator', 'load_knowledge_graph']
"""

with open('utils/__init__.py', 'w') as f:
    f.write(utils_init_content)
print("‚úÖ Fixed utils/__init__.py")

print("\n‚úÖ All module imports fixed! You can now proceed with training.")
print("‚ö†Ô∏è If you still get errors, restart the kernel and re-run from Step 1.")

In [None]:
# Fix initialization order bug in train.py
print("üîß Fixing train.py initialization bug...\n")

import re

# Read train.py
with open('scripts/train.py', 'r') as f:
    content = f.read()

# Find the __init__ method and fix the initialization order
# The bug: self.class_weights = self._compute_class_weights() is called before self.num_classes is set

# Pattern to find where num_classes is set
if 'self.num_classes' in content and 'self.class_weights = self._compute_class_weights()' in content:
    # We need to move the class_weights line after num_classes is set
    # First, remove the problematic line
    content_fixed = content.replace(
        '        self.class_weights = self._compute_class_weights()',
        '        # self.class_weights moved after num_classes is set'
    )
    
    # Find where num_classes is set and add class_weights right after
    # Look for the pattern where train_dataset is created and num_classes is set
    if 'self.num_classes = self.train_dataset.num_classes' in content_fixed:
        content_fixed = content_fixed.replace(
            'self.num_classes = self.train_dataset.num_classes',
            '''self.num_classes = self.train_dataset.num_classes
        
        # Compute class weights for balanced training
        self.class_weights = self._compute_class_weights()'''
        )
        
        # Write back
        with open('scripts/train.py', 'w') as f:
            f.write(content_fixed)
        
        print("‚úÖ Fixed train.py initialization order")
        print("   Moved class_weights computation after num_classes is set")
    else:
        print("‚ö†Ô∏è Could not find num_classes assignment pattern")
        print("   You may need to manually fix train.py")
else:
    print("‚úÖ train.py appears to be already fixed or has different structure")

print("\n‚úÖ Training script patched!")

### Step 5.6: Fix Training Script Bug (Critical!)

---

## üèãÔ∏è PART 2: TRAINING

### Step 6: Load and Visualize Dataset

In [None]:
import yaml
import matplotlib.pyplot as plt
import numpy as np
import sys

# Add project to path
sys.path.insert(0, '/kaggle/working')

from data import get_dataloader

# Load config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("üìã Loading SLAKE dataset...\n")

# Load train dataset
train_loader, train_dataset = get_dataloader(
    data_dir='data/SLAKE',
    split='train',
    batch_size=config['training']['batch_size'],
    num_workers=2,
    tokenizer_name=config['model']['bert']['model_name']
)

print(f"‚úÖ Train dataset: {len(train_dataset)} samples")
print(f"‚úÖ Number of classes: {train_dataset.num_classes}")
print(f"‚úÖ Batch size: {config['training']['batch_size']}")
print(f"‚úÖ Total batches: {len(train_loader)}")

# Visualize samples
print("\nüì∏ Visualizing sample data...")
sample_batch = next(iter(train_loader))

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i, ax in enumerate(axes.flat):
    if i >= len(sample_batch['image']):
        break
    
    img = sample_batch['image'][i].cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    
    ax.imshow(img)
    question = sample_batch['question_text'][i][:50] + "..."
    answer = sample_batch['answer_text'][i]
    ax.set_title(f"Q: {question}\nA: {answer}", fontsize=8)
    ax.axis('off')

plt.tight_layout()
plt.show()

print("\n‚úÖ Dataset loaded and verified!")

### Step 7: Start Training

**‚è±Ô∏è Training Time Estimates:**
- Kaggle GPU T4: ~2-4 hours for 50 epochs
- With early stopping: May finish earlier

**üíæ Checkpoint Saving:**
- Checkpoints saved to `/kaggle/working/checkpoints/`
- After session ends, download from "Output" tab
- Best model automatically saved as `best_checkpoint.pth`

### Step 6.5: Load Previous Checkpoint (Resume Training)

**‚ö†Ô∏è IMPORTANT: Only run this if you want to resume from a previous checkpoint!**

If you uploaded your checkpoint as a Kaggle dataset, this will copy it to the working directory so training continues from where it left off.

In [None]:
import shutil

# ‚ö†Ô∏è UPDATE THIS PATH to match YOUR uploaded checkpoint dataset name!
CHECKPOINT_DATASET = '/kaggle/input/visihealth-checkpoint-epoch45'  # Change this to your dataset name

print("üì¶ Checking for previous checkpoint to resume training...\n")

# Check if checkpoint dataset exists
if os.path.exists(CHECKPOINT_DATASET):
    print(f"‚úÖ Found checkpoint dataset: {CHECKPOINT_DATASET}")
    
    # Look for checkpoint file
    checkpoint_file = None
    for file in os.listdir(CHECKPOINT_DATASET):
        if file.endswith('.pth'):
            checkpoint_file = file
            break
    
    if checkpoint_file:
        source_path = os.path.join(CHECKPOINT_DATASET, checkpoint_file)
        dest_path = f'/kaggle/working/checkpoints/{checkpoint_file}'
        
        # Create checkpoints directory
        os.makedirs('/kaggle/working/checkpoints', exist_ok=True)
        
        # Copy checkpoint
        print(f"üìã Copying checkpoint: {checkpoint_file}")
        shutil.copy(source_path, dest_path)
        
        # Verify
        size_mb = os.path.getsize(dest_path) / (1024*1024)
        print(f"‚úÖ Checkpoint copied successfully!")
        print(f"   Location: {dest_path}")
        print(f"   Size: {size_mb:.1f} MB")
        
        # Load and show checkpoint info
        import torch
        ckpt = torch.load(dest_path, map_location='cpu')
        print(f"\nüìä Checkpoint Info:")
        print(f"   Previous Epoch: {ckpt.get('epoch', 'N/A')}")
        print(f"   Best Val Accuracy: {ckpt.get('best_val_acc', 0):.2f}%")
        print(f"   Best Val Loss: {ckpt.get('best_val_loss', 0):.4f}")
        print(f"\nüîÑ Training will resume from epoch {ckpt.get('epoch', 0) + 1}")
    else:
        print("‚ùå No .pth file found in checkpoint dataset!")
else:
    print(f"‚ö†Ô∏è No checkpoint dataset found at: {CHECKPOINT_DATASET}")
    print("   Training will start from scratch (epoch 0)")
    print("\nüìã To resume training:")
    print("   1. Upload your checkpoint as a Kaggle dataset")
    print("   2. Add it to this notebook (+ Add Data)")
    print("   3. Update CHECKPOINT_DATASET path in this cell")

print("\n" + "="*60)

In [None]:
# Start training
print("üèãÔ∏è Starting training...\n")
print("Training configuration:")
print("  - Auto-resume from latest checkpoint if available")
print("  - Train for up to 100 epochs (with early stopping)")
print("  - Save checkpoints every 5 epochs")
print("  - Log metrics for TensorBoard")
print("  - Show progress bars for each epoch")
print("\n" + "="*60 + "\n")

# Change to working directory
os.chdir('/kaggle/working')

# Check if checkpoint exists and pass explicit path
checkpoint_path = '/kaggle/working/checkpoints/best_checkpoint.pth'
if os.path.exists(checkpoint_path):
    print(f"‚úÖ Found checkpoint: {checkpoint_path}")
    print("   Training will resume from this checkpoint\n")
    !python scripts/train.py --resume --checkpoint {checkpoint_path}
else:
    print("‚ö†Ô∏è No checkpoint found - starting from scratch\n")
    !python scripts/train.py

print("\n" + "="*60)
print("‚úÖ Training complete!")


### Step 8: Monitor Training with TensorBoard (Optional)

In [None]:
# Load TensorBoard to visualize training metrics
%load_ext tensorboard
%tensorboard --logdir logs

### Step 9: Check Training Results

In [None]:
import glob
import torch

# List checkpoints
checkpoints = glob.glob('checkpoints/*.pth')
print(f"üìä Found {len(checkpoints)} checkpoint(s):")
for ckpt in sorted(checkpoints):
    size_mb = os.path.getsize(ckpt) / (1024*1024)
    print(f"  - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

# Check best checkpoint
best_ckpt = 'checkpoints/best_checkpoint.pth'
if os.path.exists(best_ckpt):
    print("\n‚úÖ Best checkpoint found!")
    
    # Load checkpoint info
    checkpoint = torch.load(best_ckpt, map_location='cpu')
    print(f"\nBest model metrics:")
    print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"  Best Val Accuracy: {checkpoint.get('best_val_acc', 0):.2f}%")
    print(f"  Best Val Loss: {checkpoint.get('best_val_loss', 0):.4f}")
    
    print(f"\nüíæ Checkpoint size: {os.path.getsize(best_ckpt) / (1024*1024):.1f} MB")
    print(f"üìÅ Location: {os.path.abspath(best_ckpt)}")
else:
    print("\n‚ö†Ô∏è No best checkpoint found yet.")
    print("Training may still be in progress or hasn't completed successfully.")

---

## üîÆ PART 3: INFERENCE & EVALUATION

### Step 10: Load Trained Model

In [None]:
import yaml
import torch
from models import get_cnn_model, get_bert_model, build_visihealth_model
from data.dataset import SLAKEDataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

# Load config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load checkpoint FIRST to get the training vocabulary
CHECKPOINT_PATH = 'checkpoints/best_checkpoint.pth'
print(f"üì¶ Loading checkpoint from: {CHECKPOINT_PATH}")

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

# Get answer vocabulary from checkpoint or training data
if 'answer_vocab' in checkpoint:
    answer_vocab = checkpoint['answer_vocab']
    num_classes = len(answer_vocab)
    print(f"‚úÖ Loaded answer vocabulary from checkpoint: {num_classes} classes")
else:
    print("‚ö†Ô∏è No answer vocab in checkpoint, loading from train dataset...")
    # Import get_dataloader to load train dataset
    from data import get_dataloader
    _, train_dataset = get_dataloader(
        data_dir='data/SLAKE',
        split='train',
        batch_size=1,
        num_workers=2,
        tokenizer_name=config['model']['bert']['model_name']
    )
    answer_vocab = train_dataset.answer_vocab
    num_classes = train_dataset.num_classes
    print(f"‚úÖ Loaded answer vocabulary from train dataset: {num_classes} classes")

# Load test dataset with the SAME vocabulary as training
print("\nüìä Loading test dataset with training vocabulary...")

test_dataset = SLAKEDataset(
    data_dir='data/SLAKE',
    split='test',
    tokenizer_name=config['model']['bert']['model_name'],
    answer_vocab=answer_vocab,  # Use training vocabulary!
    max_length=config['model']['bert']['max_length']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2
)

print(f"‚úÖ Test dataset: {len(test_dataset)} samples")
print(f"‚úÖ Using {num_classes} answer classes (from training)")

# Build model with correct number of classes
print("\nüß† Building model...")

# Update config with correct num_classes in the EXACT location the model reads from
config['num_classes'] = num_classes
config['model']['num_classes'] = num_classes
config['model']['cnn']['num_classes'] = num_classes  # This is the key one!
if 'fusion' not in config:
    config['fusion'] = {}
config['fusion']['num_classes'] = num_classes

print(f"‚öôÔ∏è Building model with {num_classes} answer classes...")
print(f"üîç Config model.cnn.num_classes: {config['model']['cnn'].get('num_classes', 'NOT SET')}")

cnn = get_cnn_model(config)
bert = get_bert_model(config)

# Build model - it reads num_classes from config['model']['cnn']['num_classes']
model = build_visihealth_model(config, cnn, bert)

# Load checkpoint weights
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print(f"‚úÖ Model loaded successfully!")
print(f"   Trained for {checkpoint.get('epoch', 'N/A')} epochs")
print(f"   Best validation accuracy: {checkpoint.get('best_val_acc', 0):.2f}%")

### Step 11: Run Inference on Random Samples

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Get random samples
num_samples = 6
indices = np.random.choice(len(test_dataset), num_samples, replace=False)

correct_predictions = 0

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flat

for i, idx in enumerate(indices):
    sample = test_dataset[idx]
    
    # Prepare input
    image = sample['image'].unsqueeze(0).to(device)
    input_ids = sample['input_ids'].unsqueeze(0).to(device)
    attention_mask = sample['attention_mask'].unsqueeze(0).to(device)
    
    # Run inference
    with torch.no_grad():
        outputs = model(image, input_ids, attention_mask, return_attention=True)
        answer_logits = outputs['answer_logits']
        roi_scores = outputs['roi_scores']
    
    # Get predictions
    answer_probs = F.softmax(answer_logits, dim=1)
    pred_idx = answer_logits.argmax(dim=1).item()
    confidence = answer_probs[0, pred_idx].item()
    
    pred_answer = test_dataset.get_answer_text(pred_idx)
    true_answer = sample['answer_text']
    
    is_correct = pred_answer.lower() == true_answer.lower()
    if is_correct:
        correct_predictions += 1
    
    # Get top ROI
    top_roi_idx = roi_scores.argmax(dim=1).item()
    roi_confidence = F.softmax(roi_scores, dim=1)[0, top_roi_idx].item()
    
    # Visualize
    img_display = sample['image'].cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_display = std * img_display + mean
    img_display = np.clip(img_display, 0, 1)
    
    axes[i].imshow(img_display)
    
    status = "‚úÖ CORRECT" if is_correct else "‚ùå WRONG"
    title = (
        f"Q: {sample['question_text'][:40]}...\n"
        f"Pred: {pred_answer} ({confidence:.1%}) | True: {true_answer}\n"
        f"{status} | ROI: {top_roi_idx} ({roi_confidence:.1%})"
    )
    axes[i].set_title(title, fontsize=8)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print(f"\nüìä Accuracy on these samples: {correct_predictions}/{num_samples} ({100*correct_predictions/num_samples:.1f}%)")

### Step 12: Generate Rationales with Knowledge Graph

In [None]:
import pandas as pd

# Check if KG CSV files exist in the SLAKE dataset
kg_dir = 'data/SLAKE/KG'
kg_output_file = 'kg.txt'  # Write to writable /kaggle/working/ directory

if os.path.exists(kg_dir):
    print(f"üîç Found KG directory: {kg_dir}")
    print("üìä Converting CSV files to kg.txt...\n")
    
    # List of CSV files to process
    csv_files = [
        'en_disease.csv',
        'en_organ.csv', 
        'en_organ_rel.csv'
    ]
    
    all_triplets = []
    
    for csv_file in csv_files:
        csv_path = os.path.join(kg_dir, csv_file)
        if os.path.exists(csv_path):
            print(f"  üìÑ Reading {csv_file}...")
            df = pd.read_csv(csv_path, sep='#', header=0, names=['entity', 'relation', 'value'])
            
            # Convert to triplets (entity, relation, value)
            for _, row in df.iterrows():
                entity = str(row['entity']).strip().lower()
                relation = str(row['relation']).strip().replace(' ', '_').lower()
                value = str(row['value']).strip().lower()
                
                # Skip header rows or invalid data
                if entity == 'organ' or entity == 'nan':
                    continue
                
                # For multi-value entries (comma-separated), create separate triplets
                if ',' in value:
                    values = [v.strip() for v in value.split(',')]
                    for v in values:
                        if v and v != 'nan':
                            all_triplets.append(f"{entity}\t{relation}\t{v}")
                else:
                    if value and value != 'nan':
                        all_triplets.append(f"{entity}\t{relation}\t{value}")
            
            print(f"    ‚úÖ Extracted {len(all_triplets)} triplets so far")
    
    # Remove duplicates and write to kg.txt in writable location
    all_triplets = list(set(all_triplets))
    all_triplets.sort()
    
    with open(kg_output_file, 'w', encoding='utf-8') as f:
        for triplet in all_triplets:
            f.write(triplet + '\n')
    
    print(f"\n‚úÖ Created {kg_output_file} with {len(all_triplets)} unique triplets!")
    print(f"üìÅ Location: {os.path.abspath(kg_output_file)}")
    print(f"   (Saved to writable /kaggle/working/ directory)")
    
    # Show sample triplets
    print("\nüìã Sample triplets:")
    for triplet in all_triplets[:10]:
        parts = triplet.split('\t')
        print(f"  ‚Ä¢ {parts[0]} ‚Üí {parts[1]} ‚Üí {parts[2]}")
    
else:
    print(f"‚ö†Ô∏è KG directory not found at: {kg_dir}")
    print("   This step is optional. A sample KG will be created in Step 12 if needed.")

### Step 11.5: Convert KG CSV Files to kg.txt (Optional Enhancement)

In [None]:
from utils.knowledge_graph import load_knowledge_graph, RationaleGenerator

# Load KG - check read-only input first, then create in writable location if needed
kg_file_readonly = 'data/SLAKE/kg.txt'  # Read-only location
kg_file_writable = 'kg.txt'  # Writable location in /kaggle/working

if os.path.exists(kg_file_readonly):
    kg_file = kg_file_readonly
    kg = load_knowledge_graph(kg_file)
    rationale_gen = RationaleGenerator(kg)
    print(f"‚úÖ Knowledge graph loaded from dataset: {len(kg.triplets)} triplets")
elif os.path.exists(kg_file_writable):
    kg_file = kg_file_writable
    kg = load_knowledge_graph(kg_file)
    rationale_gen = RationaleGenerator(kg)
    print(f"‚úÖ Knowledge graph loaded: {len(kg.triplets)} triplets")
else:
    print("‚ö†Ô∏è KG file not found. Creating sample knowledge graph...")
    kg_file = kg_file_writable
    with open(kg_file, 'w') as f:
        f.write("liver\tis_located_in\tabdomen\n")
        f.write("lung\tis_located_in\tchest\n")
        f.write("heart\tis_located_in\tchest\n")
        f.write("brain\tis_located_in\thead\n")
        f.write("kidney\tis_located_in\tabdomen\n")
    kg = load_knowledge_graph(kg_file)
    rationale_gen = RationaleGenerator(kg)
    print(f"‚úÖ Created sample knowledge graph: {len(kg.triplets)} triplets")

# Generate rationales for first 3 samples
print("\n" + "="*80)
print("GENERATING RATIONALES")
print("="*80)

for idx in indices[:3]:
    sample = test_dataset[idx]
    
    # Prepare input
    image = sample['image'].unsqueeze(0).to(device)
    input_ids = sample['input_ids'].unsqueeze(0).to(device)
    attention_mask = sample['attention_mask'].unsqueeze(0).to(device)
    
    # Run inference
    with torch.no_grad():
        outputs = model(image, input_ids, attention_mask)
        answer_logits = outputs['answer_logits']
        roi_scores = outputs['roi_scores']
    
    # Get predictions
    answer_probs = F.softmax(answer_logits, dim=1)
    pred_idx = answer_logits.argmax(dim=1).item()
    confidence = answer_probs[0, pred_idx].item()
    pred_answer = test_dataset.get_answer_text(pred_idx)
    
    # Get top ROIs
    top_k_rois = torch.topk(roi_scores[0], k=3)
    
    # Generate rationale
    rationale = rationale_gen.generate_rationale(
        predicted_answer=pred_answer,
        confidence=confidence,
        top_roi_indices=top_k_rois.indices.tolist(),
        roi_scores=top_k_rois.values.tolist(),
        question=sample['question_text']
    )
    
    print(f"\n{'='*80}")
    print(f"Image: {sample['img_name']}")
    print(f"Question: {sample['question_text']}")
    print(f"True Answer: {sample['answer_text']}")
    print(f"Predicted Answer: {pred_answer} (confidence: {confidence:.2%})")
    print(f"\nüìù Rationale:\n{rationale}")
    print(f"{'='*80}")

### Step 13: Calculate Full Test Set Accuracy

In [None]:
from tqdm import tqdm

# Evaluate on entire test set
print("üìä Evaluating on full test set...\n")

model.eval()
correct = 0
total = 0

# Create test dataset with SAME vocabulary as training (must pass answer_vocab!)
test_dataset_batch = SLAKEDataset(
    data_dir='data/SLAKE',
    split='test',
    tokenizer_name=config['model']['bert']['model_name'],
    answer_vocab=answer_vocab,  # CRITICAL: Use training vocabulary!
    max_length=config['model']['bert']['max_length']
)

test_loader_batch = DataLoader(
    test_dataset_batch,
    batch_size=16,
    shuffle=False,
    num_workers=2
)

print(f"‚úÖ Test dataset: {len(test_dataset_batch)} samples with {len(answer_vocab)} classes (training vocab)")

with torch.no_grad():
    for batch in tqdm(test_loader_batch, desc="Evaluating"):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        answers = batch['answer'].to(device)
        
        outputs = model(images, input_ids, attention_mask)
        predictions = outputs['answer_logits'].argmax(dim=1)
        
        correct += (predictions == answers).sum().item()
        total += answers.size(0)

accuracy = 100 * correct / total

print(f"\n{'='*60}")
print(f"üìä TEST SET RESULTS")
print(f"{'='*60}")
print(f"Total Samples: {total}")
print(f"Correct Predictions: {correct}")
print(f"Accuracy: {accuracy:.2f}%")
print(f"{'='*60}")

# Save results
results = {
    'test_accuracy': accuracy,
    'correct': correct,
    'total': total,
    'checkpoint': CHECKPOINT_PATH,
    'platform': 'Kaggle'
}

import json
results_file = 'results/VisiHealth_Results.json'
os.makedirs('results', exist_ok=True)
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n‚úÖ Results saved to: {results_file}")

### Step 13.5: Verify Files Are Saved

**‚ö†Ô∏è CRITICAL: How Kaggle Output Works**

Files in `/kaggle/working` are ONLY saved to Output if:
1. The notebook runs ALL cells to completion
2. OR you click **"Save & Run All (Commit)"** button

DO NOT manually click "Save Version" while running - files will be lost!

In [None]:
import os
import glob

print("üîç Checking saved model files in /kaggle/working...\n")
print("=" * 70)

# Check checkpoints folder
checkpoint_dir = '/kaggle/working/checkpoints'
if os.path.exists(checkpoint_dir):
    checkpoints = glob.glob(f'{checkpoint_dir}/*.pth')
    print(f"\n‚úÖ CHECKPOINTS FOLDER EXISTS:")
    print(f"   Location: {checkpoint_dir}")
    print(f"   Files found: {len(checkpoints)}")
    for ckpt in checkpoints:
        size_mb = os.path.getsize(ckpt) / (1024*1024)
        print(f"   ‚Ä¢ {os.path.basename(ckpt)} ({size_mb:.1f} MB)")
else:
    print(f"\n‚ùå CHECKPOINTS FOLDER NOT FOUND!")
    print(f"   Expected at: {checkpoint_dir}")

# Check results folder
results_dir = '/kaggle/working/results'
if os.path.exists(results_dir):
    results = glob.glob(f'{results_dir}/*.json')
    print(f"\n‚úÖ RESULTS FOLDER EXISTS:")
    print(f"   Location: {results_dir}")
    print(f"   Files found: {len(results)}")
    for res in results:
        size_kb = os.path.getsize(res) / 1024
        print(f"   ‚Ä¢ {os.path.basename(res)} ({size_kb:.1f} KB)")
else:
    print(f"\n‚ùå RESULTS FOLDER NOT FOUND!")
    print(f"   Expected at: {results_dir}")

print("\n" + "=" * 70)
print("üìã IMPORTANT: HOW TO SAVE FILES TO KAGGLE OUTPUT")
print("=" * 70)
print("\nüö® METHOD 1 (RECOMMENDED):")
print("   1. Click: 'Save & Run All (Commit)' button (top right)")
print("   2. Wait for ALL cells to complete")
print("   3. Go to 'Versions' tab ‚Üí Find completed version")
print("   4. Click on version ‚Üí Go to 'Output' tab")
print("   5. Download 'checkpoints' folder (contains models)")

print("\nüö® METHOD 2 (If session is active):")
print("   1. Let notebook run to the VERY LAST cell")
print("   2. After last cell finishes, files save automatically")
print("   3. Session ends ‚Üí Go to 'Output' tab")
print("   4. Download files")

print("\n‚ö†Ô∏è  DO NOT:")
print("   ‚ùå Click 'Save Version' while notebook is running")
print("   ‚ùå Stop session before notebook completes")
print("   ‚ùå These will cause files to be LOST!")

print("\n" + "=" * 70)
print("‚úÖ Files are ready in /kaggle/working/")
print("   They will appear in Output tab AFTER notebook completes!")
print("=" * 70)

In [None]:
# ==============================
# üì¶ FINAL KAGGLE DOWNLOAD CELL
# ==============================

import os
import zipfile
from IPython.display import FileLink, display

ZIP_NAME = "VisiHealth_Training_Output.zip"
ZIP_PATH = f"/kaggle/working/{ZIP_NAME}"

folders_to_zip = {
    "checkpoints": "/kaggle/working/checkpoints",
    "results": "/kaggle/working/results"
}

with zipfile.ZipFile(ZIP_PATH, "w", zipfile.ZIP_DEFLATED) as zipf:
    for folder_name, folder_path in folders_to_zip.items():
        if not os.path.exists(folder_path):
            print(f"‚ö†Ô∏è Skipping missing folder: {folder_path}")
            continue

        for root, _, files in os.walk(folder_path):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.join(folder_name, file)
                zipf.write(file_path, arcname)

print("‚úÖ ZIP file created successfully!")
display(FileLink(ZIP_PATH))
print("üì• You can now download the ZIP or get it from the Output tab.")


### Step 14: Export Model Info

In [None]:
# Export model info for frontend integration
export_info = {
    'model_type': 'VisiHealth AI',
    'checkpoint_path': CHECKPOINT_PATH,
    'test_accuracy': accuracy,
    'num_classes': test_dataset.num_classes,
    'answer_vocab': {v: k for k, v in test_dataset.answer_vocab.items()},
    'image_size': 224,
    'bert_model': config['model']['bert']['model_name'],
    'usage': {
        'input': 'Medical image (224x224) + question text',
        'output': 'Answer + confidence + ROI scores + rationale'
    },
    'training_info': {
        'dataset': 'SLAKE 1.0',
        'total_samples': len(train_dataset),
        'epochs': checkpoint.get('epoch', 'N/A'),
        'best_val_acc': checkpoint.get('best_val_acc', 0),
        'platform': 'Kaggle'
    }
}

info_file = 'results/VisiHealth_Model_Info.json'
with open(info_file, 'w') as f:
    json.dump(export_info, f, indent=2)

print(f"‚úÖ Model info exported to: {info_file}")
print("\nThis file contains:")
print("  - Model checkpoint path")
print("  - Test accuracy")
print("  - Answer vocabulary (for mapping predictions)")
print("  - Input/output specifications")
print("  - Training information")

### Step 15: Prepare Files for Download

In [None]:
# Create a summary of all important files to download
print("üì¶ Preparing files for download...\n")
print("=" * 60)

download_files = {
    'Checkpoints': glob.glob('checkpoints/*.pth'),
    'Results': glob.glob('results/*.json'),
    'Logs': ['logs/'] if os.path.exists('logs/') else []
}

total_size = 0

for category, files in download_files.items():
    print(f"\n{category}:")
    if not files:
        print("  (none)")
        continue
    
    for file in files:
        if os.path.isfile(file):
            size_mb = os.path.getsize(file) / (1024*1024)
            total_size += size_mb
            print(f"  ‚úÖ {file} ({size_mb:.1f} MB)")
        elif os.path.isdir(file):
            dir_size = sum(
                os.path.getsize(os.path.join(dirpath, filename))
                for dirpath, dirnames, filenames in os.walk(file)
                for filename in filenames
            ) / (1024*1024)
            total_size += dir_size
            print(f"  ‚úÖ {file} ({dir_size:.1f} MB)")

print("\n" + "=" * 60)
print(f"üìä Total size: {total_size:.1f} MB")
print("\nüíæ After session ends, download from:")
print("   1. Click 'Output' tab (top right)")
print("   2. Download all files")
print("   3. Extract and use locally")
print("=" * 60)

---

## ‚úÖ TRAINING COMPLETE!

### üéâ Congratulations! Your Model is Trained!

### üìÅ What You Have Now:

**In Kaggle Output (Download after session ends):**
- ‚úÖ `checkpoints/best_checkpoint.pth` - Your trained model (~500MB)
- ‚úÖ `checkpoints/checkpoint_epoch_XX.pth` - Training checkpoints
- ‚úÖ `results/VisiHealth_Results.json` - Test accuracy and metrics
- ‚úÖ `results/VisiHealth_Model_Info.json` - Model specifications
- ‚úÖ `logs/` - TensorBoard training logs

### üì• How to Download Your Files:

1. **Wait for session to end** or click "Stop Session"
2. **Click "Output" tab** (top right of notebook)
3. **Download all files** - especially the checkpoints folder
4. **Extract on your local machine**

### üöÄ Next Steps:

#### Option 1: Use Locally
```bash
# On your laptop
python scripts/demo.py --checkpoint checkpoints/best_checkpoint.pth
```

#### Option 2: Deploy to Web
- Use Flask/FastAPI for backend
- Load checkpoint in API endpoint
- Build React/Vue frontend
- Deploy on Heroku/AWS/Azure

#### Option 3: Continue Training
- Add this notebook to a new Kaggle session
- Add your checkpoint as a dataset
- Resume training with `--resume` flag

### üìä Performance Metrics:
- **Training Platform:** Kaggle (GPU T4)
- **Dataset:** SLAKE 1.0
- **Test Accuracy:** See Step 13 results above
- **Model Size:** ~500 MB
- **Inference Speed:** ~200-300ms per image (GPU)

### üí° Tips:
- Keep your checkpoint file safe - it contains all your training!
- The answer vocabulary is in Model_Info.json - you need this for predictions
- Test locally before deploying to ensure everything works

---

## üéØ You're All Set!

Your medical VQA system is trained and ready to use. Download your files and start building amazing applications!

**Questions?** Check the project documentation or experiment with the demo script.

### ‚ö†Ô∏è FINAL SAFETY CHECK - READ THIS BEFORE STOPPING!

**üö® CRITICAL: If you're seeing this, DON'T STOP THE SESSION YET! üö®**

Run the cell below to verify your files will be saved to Output.

In [None]:
import os
import glob

print("="*80)
print("üîç FINAL FILE VERIFICATION - DO NOT SKIP THIS!")
print("="*80)

# Check if training completed
checkpoint_dir = '/kaggle/working/checkpoints'
results_dir = '/kaggle/working/results'

checkpoints = glob.glob(f'{checkpoint_dir}/*.pth') if os.path.exists(checkpoint_dir) else []
results = glob.glob(f'{results_dir}/*.json') if os.path.exists(results_dir) else []

print("\nüìä FILES CURRENTLY IN /kaggle/working/:\n")

if checkpoints:
    print("‚úÖ CHECKPOINTS FOUND:")
    for ckpt in checkpoints:
        size_mb = os.path.getsize(ckpt) / (1024*1024)
        print(f"   ‚Ä¢ {os.path.basename(ckpt)} ({size_mb:.1f} MB)")
else:
    print("‚ùå NO CHECKPOINTS FOUND!")
    print("   Training may have failed!")

if results:
    print("\n‚úÖ RESULTS FOUND:")
    for res in results:
        size_kb = os.path.getsize(res) / 1024
        print(f"   ‚Ä¢ {os.path.basename(res)} ({size_kb:.1f} KB)")
else:
    print("\n‚ùå NO RESULTS FOUND!")

print("\n" + "="*80)
print("üö® CRITICAL INSTRUCTIONS - READ CAREFULLY:")
print("="*80)

if checkpoints and results:
    print("\n‚úÖ YOUR FILES EXIST IN /kaggle/working/")
    print("\nüìã HOW TO SAVE THEM TO OUTPUT TAB:")
    print("\n   IF YOU USED 'Save & Run All (Commit)':")
    print("   ‚úÖ This cell is the LAST cell")
    print("   ‚úÖ When this finishes, Kaggle will AUTO-SAVE your files")
    print("   ‚úÖ Go to: Versions tab ‚Üí Find this version ‚Üí Output tab")
    print("   ‚úÖ Download your files from there")
    print("\n   IF YOU MANUALLY RAN CELLS:")
    print("   ‚ö†Ô∏è  Your files will be LOST when session ends!")
    print("   ‚ö†Ô∏è  You MUST use 'Save & Run All (Commit)' to save them!")
    print("   ‚ö†Ô∏è  DO NOT manually stop this session!")
    
    print("\n" + "="*80)
    print("‚úÖ IF THIS IS THE LAST CELL TO RUN:")
    print("   Wait 30 seconds after this completes")
    print("   Files will auto-save to Output")
    print("   Then you can safely close/stop")
    print("="*80)
else:
    print("\n‚ùå FILES ARE MISSING!")
    print("\n   Possible reasons:")
    print("   1. Training failed (check earlier cells)")
    print("   2. Training didn't run (did you skip Step 7?)")
    print("   3. File paths are wrong")
    print("\n   ‚ö†Ô∏è  DO NOT STOP - go back and check training cell!")
    print("="*80)