# üì¶ PART 1: SETUP

## Step 1: Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Verify connection
print("‚úÖ Google Drive mounted successfully!")
print("\nContents of your Drive:")
!ls -lh /content/drive/MyDrive/

## Step 2: Check GPU Availability

In [None]:
import torch

# Check GPU
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print("‚úÖ GPU is ready!")
else:
    print("‚ö†Ô∏è No GPU detected. Go to: Runtime ‚Üí Change runtime type ‚Üí GPU")

## Step 3: Copy Project from Google Drive

In [None]:
import os

# Set paths to your uploaded folders on Google Drive
# ‚ö†Ô∏è CHANGE THESE to match your Google Drive structure
PROJECT_FOLDER = '/content/drive/MyDrive/VISIHEALTH CODE'
SLAKE_FOLDER = '/content/drive/MyDrive/Slake1.0'

# Check if folders exist
if not os.path.exists(PROJECT_FOLDER):
    print(f"‚ùå ERROR: Project folder not found at: {PROJECT_FOLDER}")
    print("\nYour Drive contents:")
    !ls -lh /content/drive/MyDrive/
    print("\nüëÜ Update PROJECT_FOLDER path to match your folder name above")
else:
    print(f"‚úÖ Found project folder: {PROJECT_FOLDER}")

if not os.path.exists(SLAKE_FOLDER):
    print(f"‚ùå ERROR: SLAKE folder not found at: {SLAKE_FOLDER}")
    print("\nYour Drive contents:")
    !ls -lh /content/drive/MyDrive/
    print("\nüëÜ Update SLAKE_FOLDER path to match your folder name above")
else:
    print(f"‚úÖ Found SLAKE folder: {SLAKE_FOLDER}")

# Copy entire project to Colab workspace for faster access
print("\nüì¶ Copying project to Colab workspace...")
!cp -r "{PROJECT_FOLDER}" /content/VisiHealth

# Change to project directory
%cd /content/VisiHealth

print("\n‚úÖ Project copied!")
print("\nProject structure:")
!ls -lh

## Step 4: Copy SLAKE Dataset

In [None]:
# Create data directory if it doesn't exist
!mkdir -p data/SLAKE

# Check if dataset already exists to avoid re-copying
import os

if os.path.exists('data/SLAKE/train.json') and os.path.exists('data/SLAKE/imgs'):
    num_images = len(os.listdir('data/SLAKE/imgs')) if os.path.exists('data/SLAKE/imgs') else 0
    print(f"‚úÖ Dataset already loaded! Found {num_images} images")
    print("Skipping copy to save time...")
else:
    # Copy SLAKE dataset from Google Drive (this may take 5-10 minutes)
    print("üìä Copying SLAKE dataset...")
    print("This will take a few minutes for 642 images...")
    !cp -r {SLAKE_FOLDER}/* data/SLAKE/
    print("‚úÖ Dataset copied!")

print("\nDataset structure:")
!ls -lh data/SLAKE/
print("\nNumber of image folders:")
!ls data/SLAKE/imgs/ | wc -l

## Step 5: Install Dependencies

In [None]:
# Step 5: Install Dependencies (only runs if packages not already installed)

import importlib.util

# Check if packages are already installed
try:
    import cv2
    import transformers
    packages_installed = True
    print("‚úÖ Packages already installed. Skipping installation...")
except ImportError:
    packages_installed = False

if not packages_installed:
    print("üì¶ Installing dependencies...")
    
    # Uninstall conflicting packages first
    !pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless -q
    
    # Install packages in specific order
    !pip install -q "numpy<2.0"
    !pip install -q opencv-python-headless==4.10.0.84
    !pip install -q transformers==4.35.0
    !pip install -q "huggingface-hub>=0.20.0,<1.0"
    !pip install -q albumentations==1.4.0
    !pip install -q tensorboard scikit-learn tqdm pyyaml pillow matplotlib
    
    print("\n‚úÖ Installation complete!")
    print("\n‚ö†Ô∏è Please restart runtime manually and run all cells again.")
else:
    print("‚úÖ Ready to continue!")

## Step 6: Verify Setup

In [None]:
# Step 6: Verify Setup

print("üß™ Verifying setup...\n")

# First, verify Python packages
print("=" * 60)
print("Checking installed packages...")
print("=" * 60)

try:
    import transformers
    import cv2
    import numpy as np
    import torch
    
    print(f"‚úÖ PyTorch: {torch.__version__}")
    print(f"‚úÖ Transformers: {transformers.__version__}")
    print(f"‚úÖ OpenCV: {cv2.__version__}")
    print(f"‚úÖ NumPy: {np.__version__}")
    
    # Verify NumPy is version 1.x
    if np.__version__.startswith('2'):
        print("\n‚ö†Ô∏è WARNING: NumPy 2.x detected. May cause issues.")
        print("   Consider restarting runtime and running Step 5 again.")
    
    print("\n‚úÖ All packages verified!")
    
except ImportError as e:
    print(f"\n‚ùå Package import failed: {e}")
    print("\nüî¥ CRITICAL ERROR:")
    print("   Packages not installed correctly or runtime not restarted.")
    print("\nüìã To fix:")
    print("   1. Go back to Step 5")
    print("   2. Run Step 5 again")
    print("   3. Restart runtime: Runtime ‚Üí Restart runtime")
    print("   4. Re-run from Step 1")
    raise

# Now check project files
print("\n" + "=" * 60)
print("Checking project files...")
print("=" * 60)

import os
core_files = [
    'models/cnn_model.py',
    'models/bert_model.py',
    'models/fusion_model.py',
    'scripts/train.py',
    'scripts/demo.py',
    'data/dataset.py',
    'utils/knowledge_graph.py',
    'config.yaml',
    'requirements.txt'
]

all_good = True
for file in core_files:
    exists = "‚úÖ" if os.path.exists(file) else "‚ùå"
    print(f"  {exists} {file}")
    if not os.path.exists(file):
        all_good = False

print("\n" + "=" * 60)
print("Checking dataset...")
print("=" * 60)

if os.path.exists('data/SLAKE/train.json'):
    print("  ‚úÖ data/SLAKE/train.json")
    print("  ‚úÖ data/SLAKE/test.json")
    print("  ‚úÖ data/SLAKE/imgs/")
else:
    print("  ‚ùå Dataset not found!")
    all_good = False

if all_good:
    print("\n" + "=" * 60)
    print("‚úÖ SETUP VERIFICATION COMPLETE! READY TO TRAIN.")
    print("=" * 60)
else:
    print("\n" + "=" * 60)
    print("‚ö†Ô∏è SOME FILES ARE MISSING")
    print("=" * 60)
    print("Check the paths in Step 3 and Step 4, then try again.")

## Step 7: Create Checkpoint Directories in Google Drive

In [None]:
# Create directories in Google Drive to save checkpoints permanently
!mkdir -p /content/drive/MyDrive/VisiHealth_Checkpoints
!mkdir -p /content/drive/MyDrive/VisiHealth_Logs

# Create symlinks to save directly to Drive
!ln -sf /content/drive/MyDrive/VisiHealth_Checkpoints /content/VisiHealth/checkpoints
!ln -sf /content/drive/MyDrive/VisiHealth_Logs /content/VisiHealth/logs

print("‚úÖ Checkpoints will be saved to Google Drive")
print("   Location: MyDrive/VisiHealth_Checkpoints/")
print("‚úÖ Logs will be saved to Google Drive")
print("   Location: MyDrive/VisiHealth_Logs/")

---

# üèãÔ∏è PART 2: TRAINING

## Step 8: Load and Visualize Dataset

In [None]:
import yaml
import matplotlib.pyplot as plt
import numpy as np
from data import get_dataloader

# Load config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("üìã Loading SLAKE dataset...\n")

# Load train dataset
train_loader, train_dataset = get_dataloader(
    data_dir='data/SLAKE',
    split='train',
    batch_size=config['training']['batch_size'],
    num_workers=2,
    tokenizer_name=config['model']['bert']['model_name']
)

print(f"‚úÖ Train dataset: {len(train_dataset)} samples")
print(f"‚úÖ Number of classes: {train_dataset.num_classes}")
print(f"‚úÖ Batch size: {config['training']['batch_size']}")
print(f"‚úÖ Total batches: {len(train_loader)}")

# Visualize a few samples
print("\nüì∏ Visualizing sample data...")
sample_batch = next(iter(train_loader))

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i, ax in enumerate(axes.flat):
    if i >= len(sample_batch['image']):
        break
    
    img = sample_batch['image'][i].cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    
    ax.imshow(img)
    question = sample_batch['question_text'][i][:50] + "..."
    answer = sample_batch['answer_text'][i]
    ax.set_title(f"Q: {question}\nA: {answer}", fontsize=8)
    ax.axis('off')

plt.tight_layout()
plt.show()

print("\n‚úÖ Dataset loaded and verified!")

## Step 9: Start Training

**‚ö†Ô∏è IMPORTANT:**
- Training will take 2-4 hours on T4 GPU
- Checkpoints are automatically saved to Google Drive
- You can monitor progress in real-time below
- If session disconnects, your checkpoint is safe in Drive

In [None]:
# Start training (with auto-resume from checkpoint if available)
print("üèãÔ∏è Starting training...\n")
print("This will:")
print("  - Auto-resume from latest checkpoint if available")
print("  - Train for up to 50 epochs (with early stopping)")
print("  - Save checkpoints to Google Drive every 5 epochs")
print("  - Log metrics for TensorBoard")
print("  - Show progress bars for each epoch")
print("\n" + "="*60 + "\n")

# Always use --resume flag to automatically resume if checkpoints exist
!python scripts/train.py --resume

print("\n" + "="*60)
print("‚úÖ Training complete!")

## Step 10: Monitor Training with TensorBoard (Optional)

In [None]:
# Load TensorBoard to visualize training metrics
%load_ext tensorboard
%tensorboard --logdir logs

## Step 11: Check Training Results

In [None]:
import glob
import torch

# List checkpoints
checkpoints = glob.glob('/content/drive/MyDrive/VisiHealth_Checkpoints/*.pth')
print(f"üìä Found {len(checkpoints)} checkpoint(s):")
for ckpt in sorted(checkpoints):
    size_mb = os.path.getsize(ckpt) / (1024*1024)
    print(f"  - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

# Check best checkpoint
best_ckpt = '/content/drive/MyDrive/VisiHealth_Checkpoints/best_checkpoint.pth'
if os.path.exists(best_ckpt):
    print("\n‚úÖ Best checkpoint found!")
    
    # Load checkpoint info
    checkpoint = torch.load(best_ckpt, map_location='cpu')
    print(f"\nBest model metrics:")
    print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"  Best Val Accuracy: {checkpoint.get('best_val_acc', 0):.2f}%")
    print(f"  Best Val Loss: {checkpoint.get('best_val_loss', 0):.4f}")
else:
    print("\n‚ö†Ô∏è No best checkpoint found yet.")
    print("Training may still be in progress or hasn't completed successfully.")

---

# üîÆ PART 3: INFERENCE

## Step 12: Load Trained Model

In [None]:
import yaml
import torch
from models import get_cnn_model, get_bert_model, build_visihealth_model
from data import get_dataloader

# Load config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Load test dataset
print("üìä Loading test dataset...")
test_loader, test_dataset = get_dataloader(
    data_dir='data/SLAKE',
    split='test',
    batch_size=1,
    num_workers=2,
    tokenizer_name=config['model']['bert']['model_name']
)
print(f"‚úÖ Test dataset: {len(test_dataset)} samples")

# Build model
print("\nüß† Building model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cnn = get_cnn_model(config)
bert = get_bert_model(config)
model = build_visihealth_model(config, cnn, bert)

# Load checkpoint
CHECKPOINT_PATH = '/content/drive/MyDrive/VisiHealth_Checkpoints/best_checkpoint.pth'
print(f"üì¶ Loading checkpoint from: {CHECKPOINT_PATH}")

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print(f"‚úÖ Model loaded successfully!")
print(f"   Trained for {checkpoint.get('epoch', 'N/A')} epochs")
print(f"   Best validation accuracy: {checkpoint.get('best_val_acc', 0):.2f}%")

## Step 13: Run Inference on Random Samples

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Get random samples
num_samples = 6
indices = np.random.choice(len(test_dataset), num_samples, replace=False)

correct_predictions = 0

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flat

for i, idx in enumerate(indices):
    sample = test_dataset[idx]
    
    # Prepare input
    image = sample['image'].unsqueeze(0).to(device)
    input_ids = sample['input_ids'].unsqueeze(0).to(device)
    attention_mask = sample['attention_mask'].unsqueeze(0).to(device)
    
    # Run inference
    with torch.no_grad():
        outputs = model(image, input_ids, attention_mask, return_attention=True)
        answer_logits = outputs['answer_logits']
        roi_scores = outputs['roi_scores']
    
    # Get predictions
    answer_probs = F.softmax(answer_logits, dim=1)
    pred_idx = answer_logits.argmax(dim=1).item()
    confidence = answer_probs[0, pred_idx].item()
    
    pred_answer = test_dataset.get_answer_text(pred_idx)
    true_answer = sample['answer_text']
    
    is_correct = pred_answer.lower() == true_answer.lower()
    if is_correct:
        correct_predictions += 1
    
    # Get top ROI
    top_roi_idx = roi_scores.argmax(dim=1).item()
    roi_confidence = F.softmax(roi_scores, dim=1)[0, top_roi_idx].item()
    
    # Visualize
    img_display = sample['image'].cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_display = std * img_display + mean
    img_display = np.clip(img_display, 0, 1)
    
    axes[i].imshow(img_display)
    
    status = "‚úÖ CORRECT" if is_correct else "‚ùå WRONG"
    title = (
        f"Q: {sample['question_text'][:40]}...\n"
        f"Pred: {pred_answer} ({confidence:.1%}) | True: {true_answer}\n"
        f"{status} | ROI: {top_roi_idx} ({roi_confidence:.1%})"
    )
    axes[i].set_title(title, fontsize=8)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print(f"\nüìä Accuracy on these samples: {correct_predictions}/{num_samples} ({100*correct_predictions/num_samples:.1f}%)")

## Step 14: Generate Rationales with Knowledge Graph

In [None]:
from utils.knowledge_graph import load_knowledge_graph, RationaleGenerator

# Load KG
kg_file = 'data/SLAKE/kg.txt'
if os.path.exists(kg_file):
    kg = load_knowledge_graph(kg_file)
    rationale_gen = RationaleGenerator(kg)
    print(f"‚úÖ Knowledge graph loaded: {len(kg.triplets)} triplets")
else:
    print("‚ö†Ô∏è KG file not found. Creating sample...")
    os.makedirs(os.path.dirname(kg_file), exist_ok=True)
    with open(kg_file, 'w') as f:
        f.write("liver\tis_located_in\tabdomen\n")
        f.write("lung\tis_located_in\tchest\n")
        f.write("heart\tis_located_in\tchest\n")
    kg = load_knowledge_graph(kg_file)
    rationale_gen = RationaleGenerator(kg)

# Generate rationales for first 3 samples
print("\n" + "="*80)
print("GENERATING RATIONALES")
print("="*80)

for idx in indices[:3]:
    sample = test_dataset[idx]
    
    # Prepare input
    image = sample['image'].unsqueeze(0).to(device)
    input_ids = sample['input_ids'].unsqueeze(0).to(device)
    attention_mask = sample['attention_mask'].unsqueeze(0).to(device)
    
    # Run inference
    with torch.no_grad():
        outputs = model(image, input_ids, attention_mask)
        answer_logits = outputs['answer_logits']
        roi_scores = outputs['roi_scores']
    
    # Get predictions
    answer_probs = F.softmax(answer_logits, dim=1)
    pred_idx = answer_logits.argmax(dim=1).item()
    confidence = answer_probs[0, pred_idx].item()
    pred_answer = test_dataset.get_answer_text(pred_idx)
    
    # Get top ROIs
    top_k_rois = torch.topk(roi_scores[0], k=3)
    
    # Generate rationale
    rationale = rationale_gen.generate_rationale(
        predicted_answer=pred_answer,
        confidence=confidence,
        top_roi_indices=top_k_rois.indices.tolist(),
        roi_scores=top_k_rois.values.tolist(),
        question=sample['question_text']
    )
    
    print(f"\n{'='*80}")
    print(f"Image: {sample['img_name']}")
    print(f"Question: {sample['question_text']}")
    print(f"True Answer: {sample['answer_text']}")
    print(f"Predicted Answer: {pred_answer} (confidence: {confidence:.2%})")
    print(f"\nüìù Rationale:\n{rationale}")
    print(f"{'='*80}")

## Step 15: Calculate Full Test Set Accuracy

In [None]:
from tqdm import tqdm

# Evaluate on entire test set
print("üìä Evaluating on full test set...\n")

model.eval()
correct = 0
total = 0

# Reload with batch size for faster evaluation
test_loader_batch, _ = get_dataloader(
    data_dir='data/SLAKE',
    split='test',
    batch_size=16,
    num_workers=2,
    tokenizer_name=config['model']['bert']['model_name']
)

with torch.no_grad():
    for batch in tqdm(test_loader_batch, desc="Evaluating"):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        answers = batch['answer'].to(device)
        
        outputs = model(images, input_ids, attention_mask)
        predictions = outputs['answer_logits'].argmax(dim=1)
        
        correct += (predictions == answers).sum().item()
        total += answers.size(0)

accuracy = 100 * correct / total

print(f"\n{'='*60}")
print(f"üìä TEST SET RESULTS")
print(f"{'='*60}")
print(f"Total Samples: {total}")
print(f"Correct Predictions: {correct}")
print(f"Accuracy: {accuracy:.2f}%")
print(f"{'='*60}")

# Save results
results = {
    'test_accuracy': accuracy,
    'correct': correct,
    'total': total,
    'checkpoint': CHECKPOINT_PATH
}

import json
results_file = '/content/drive/MyDrive/VisiHealth_Results.json'
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n‚úÖ Results saved to: {results_file}")

## Step 16: Export Model Info for Frontend

In [None]:
# Export model info for frontend integration
export_info = {
    'model_type': 'VisiHealth AI',
    'checkpoint_path': CHECKPOINT_PATH,
    'test_accuracy': accuracy,
    'num_classes': test_dataset.num_classes,
    'answer_vocab': {v: k for k, v in test_dataset.answer_vocab.items()},
    'image_size': 224,
    'bert_model': config['model']['bert']['model_name'],
    'usage': {
        'input': 'Medical image (224x224) + question text',
        'output': 'Answer + confidence + ROI scores + rationale'
    },
    'training_info': {
        'dataset': 'SLAKE 1.0',
        'total_samples': len(train_dataset),
        'epochs': checkpoint.get('epoch', 'N/A'),
        'best_val_acc': checkpoint.get('best_val_acc', 0)
    }
}

info_file = '/content/drive/MyDrive/VisiHealth_Model_Info.json'
with open(info_file, 'w') as f:
    json.dump(export_info, f, indent=2)

print(f"‚úÖ Model info exported to: {info_file}")
print("\nThis file contains:")
print("  - Model checkpoint path")
print("  - Test accuracy")
print("  - Answer vocabulary (for mapping predictions)")
print("  - Input/output specifications")
print("  - Training information")
print("\nüåê Ready for frontend integration!")

---

# ‚úÖ ALL DONE!

## üéâ Complete Workflow Finished!

### What Was Accomplished:
1. ‚úÖ **Setup Complete** - Project and dataset copied, dependencies installed
2. ‚úÖ **Training Complete** - Model trained on SLAKE dataset with GPU
3. ‚úÖ **Inference Complete** - Model tested, accuracy calculated, results saved

### Files Saved to Google Drive:
- üìÅ **VisiHealth_Checkpoints/** - Model checkpoints
  - `best_checkpoint.pth` - Best model (~500MB)
  - `checkpoint_epoch_XX.pth` - Regular checkpoints
- üìÅ **VisiHealth_Logs/** - TensorBoard training logs
- üìÑ **VisiHealth_Results.json** - Test accuracy and metrics
- üìÑ **VisiHealth_Model_Info.json** - Model specifications for frontend

### Next Steps:
1. **Download checkpoint** from Google Drive to your laptop:
   - Go to: https://drive.google.com
   - Navigate to: `VisiHealth_Checkpoints/`
   - Download: `best_checkpoint.pth`

2. **Use locally** on your laptop:
   ```bash
   python scripts/demo.py --checkpoint checkpoints/best_checkpoint.pth
   ```

3. **Build frontend** (Flask/FastAPI backend + React/Vue frontend)

### Performance Summary:
- **Test Accuracy:** Shown above
- **Training Time:** ~2-4 hours
- **Model Size:** ~500 MB
- **Inference Speed:** ~200-300ms per image (GPU)

---

## üöÄ Congratulations! Your Medical VQA System is Ready!

**Questions? Check the documentation in your project folder.**