# ü´Å TB Detection with Proven AST - FIXED VERSION

**Uses the EXACT same AST code that achieved 88.98% energy savings on malaria!**

This notebook:
- ‚úÖ Properly handles TBX11K dataset structure
- ‚úÖ Uses proven `train_ast.py` from Malaria project
- ‚úÖ Expected: 85-90% energy savings + 90%+ accuracy

---

**‚öôÔ∏è Setup**: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 recommended)

**‚è±Ô∏è Time**: ~2-3 hours with GPU

## Step 1: Clone Malaria Project

In [None]:
!git clone https://github.com/oluwafemidiakhoa/Malaria.git
%cd Malaria
!git pull origin main

print("‚úÖ Malaria project cloned!")

## Step 2: Setup Kaggle API

In [None]:
from google.colab import files
import os

print("üìÅ Upload your kaggle.json:")
uploaded = files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("‚úÖ Kaggle API configured!")

## Step 3: Install Dependencies

In [None]:
!pip install -q torch torchvision timm adaptive-sparse-training>=1.0.1 \
    scikit-learn matplotlib seaborn pyyaml tqdm kaggle pillow numpy

print("‚úÖ All dependencies installed!")

import torch
if torch.cuda.is_available():
    print(f"\nüñ•Ô∏è GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("\n‚ö†Ô∏è No GPU detected!")

## Step 4: Download TBX11K Dataset

In [None]:
!kaggle datasets download -d usmanshams/tbx-11
!unzip -q tbx-11.zip -d tb_data

print("‚úÖ TB dataset downloaded!")
print("\nüìÅ Exploring dataset structure...")
!find tb_data -type d | head -20

## Step 5: Explore TBX11K Structure (IMPORTANT!)

In [None]:
from pathlib import Path

# Explore the actual structure
tb_root = Path('tb_data')

print("üìä TBX11K Dataset Structure:")
print("="*60)

# Find all directories with images
for item in sorted(tb_root.rglob('*')):
    if item.is_dir():
        png_count = len(list(item.glob('*.png')))
        jpg_count = len(list(item.glob('*.jpg')))
        total = png_count + jpg_count
        if total > 0:
            rel_path = item.relative_to(tb_root)
            print(f"  {rel_path}: {total} images")

# Count all images
all_images = list(tb_root.rglob('*.png')) + list(tb_root.rglob('*.jpg'))
print(f"\nüìà Total images found: {len(all_images)}")

# Sample some paths to understand labeling
print(f"\nüìù Sample image paths:")
for img in all_images[:10]:
    print(f"  {img.relative_to(tb_root)}")

## Step 6: Smart Data Organization (Handles TBX11K structure)

In [None]:
from pathlib import Path
import shutil
from sklearn.model_selection import train_test_split
from collections import Counter
import random

random.seed(42)

tb_root = Path('tb_data')
all_images = list(tb_root.rglob('*.png')) + list(tb_root.rglob('*.jpg'))

print(f"üîç Analyzing {len(all_images)} images...\n")

# Smart classification based on TBX11K structure
data = []
for img_path in all_images:
    path_lower = str(img_path).lower()
    parts = [p.lower() for p in img_path.parts]
    
    # TBX11K structure detection
    # Common patterns: 'Tuberculosis', 'Normal', 'Sick', 'Healthy'
    is_tb = False
    is_normal = False
    
    # Check directory names
    for part in parts:
        if 'tuberculosis' in part or 'tb' == part or 'sick' in part or 'abnormal' in part:
            is_tb = True
        if 'normal' in part or 'healthy' in part:
            is_normal = True
    
    # Assign label
    if is_tb and not is_normal:
        label = 'TB'
    elif is_normal and not is_tb:
        label = 'Normal'
    else:
        # Ambiguous - try filename
        fname_lower = img_path.name.lower()
        if 'normal' in fname_lower or 'healthy' in fname_lower:
            label = 'Normal'
        elif 'tb' in fname_lower or 'sick' in fname_lower:
            label = 'TB'
        else:
            # Skip if we can't determine
            continue
    
    data.append((img_path, label))

# Check distribution
label_counts = Counter([d[1] for d in data])
print("üìä Label distribution:")
for label, count in label_counts.items():
    print(f"  {label}: {count:,} ({count/len(data)*100:.1f}%)")

# Verify we have both classes
if len(label_counts) < 2:
    print("\n‚ùå ERROR: Only one class found!")
    print("\nTBX11K appears to be single-class. Let's check the structure:")
    print("\nüìÇ Directory tree:")
    !ls -R tb_data/TBX11K/ | head -100
    
    print("\n‚ö†Ô∏è SOLUTION: We need to use a different dataset with both Normal and TB classes.")
    print("Try: 'tawsifurrahman/tuberculosis-tb-chest-xray-dataset' instead.")
else:
    # We have both classes - proceed with split
    print(f"\n‚úÖ Found {len(data)} usable images with both classes!")
    
    # Split into train/val (80/20)
    train_data, val_data = train_test_split(
        data, test_size=0.2, random_state=42, 
        stratify=[d[1] for d in data]
    )
    
    # Create directory structure
    print("\nüìÅ Creating data directories...")
    for split, split_data in [('train', train_data), ('val', val_data)]:
        for label in ['Normal', 'TB']:
            dest = Path(f'data/{split}/{label}')
            dest.mkdir(parents=True, exist_ok=True)
        
        for img_path, label in split_data:
            dest_path = Path(f'data/{split}/{label}/{img_path.name}')
            if not dest_path.exists():  # Avoid duplicates
                shutil.copy(img_path, dest_path)
    
    print("\n‚úÖ Data organized:")
    print(f"\n   Train: {len(train_data):,} images")
    for label in ['Normal', 'TB']:
        count = len(list(Path(f'data/train/{label}').glob('*')))
        print(f"      {label}: {count:,}")
    
    print(f"\n   Val: {len(val_data):,} images")
    for label in ['Normal', 'TB']:
        count = len(list(Path(f'data/val/{label}').glob('*')))
        print(f"      {label}: {count:,}")

## Step 6B: Alternative - Use Better TB Dataset

**If TBX11K doesn't have both classes, use this instead:**

In [None]:
# ALTERNATIVE DATASET with confirmed Normal + TB classes
# Only run this if Step 6 failed

# Clean up previous download
!rm -rf tb_data
!rm -f *.zip

# Download alternative TB dataset
!kaggle datasets download -d tawsifurrahman/tuberculosis-tb-chest-xray-dataset
!unzip -q tuberculosis-tb-chest-xray-dataset.zip -d tb_data

print("‚úÖ Alternative TB dataset downloaded!")
print("\nüìÅ Dataset structure:")
!ls -la tb_data/

# This dataset has clear Normal/ and Tuberculosis/ folders
from pathlib import Path
import shutil
from sklearn.model_selection import train_test_split
import random

random.seed(42)

# Find images in Normal and TB folders
tb_root = Path('tb_data')
data = []

# Look for Normal images
for normal_dir in tb_root.rglob('Normal'):
    if normal_dir.is_dir():
        for img in normal_dir.glob('*.png'):
            data.append((img, 'Normal'))
        for img in normal_dir.glob('*.jpg'):
            data.append((img, 'Normal'))

# Look for TB images
for tb_dir in tb_root.rglob('Tuberculosis'):
    if tb_dir.is_dir():
        for img in tb_dir.glob('*.png'):
            data.append((img, 'TB'))
        for img in tb_dir.glob('*.jpg'):
            data.append((img, 'TB'))

from collections import Counter
label_counts = Counter([d[1] for d in data])
print(f"\nüìä Label distribution:")
for label, count in label_counts.items():
    print(f"  {label}: {count:,}")

if len(label_counts) == 2:
    # Split and organize
    train_data, val_data = train_test_split(
        data, test_size=0.2, random_state=42, 
        stratify=[d[1] for d in data]
    )
    
    for split, split_data in [('train', train_data), ('val', val_data)]:
        for label in ['Normal', 'TB']:
            dest = Path(f'data/{split}/{label}')
            dest.mkdir(parents=True, exist_ok=True)
        
        for img_path, label in split_data:
            dest_path = Path(f'data/{split}/{label}/{img_path.name}')
            shutil.copy(img_path, dest_path)
    
    print("\n‚úÖ Data organized successfully!")
    print(f"   Train: {len(train_data):,} | Val: {len(val_data):,}")

## Step 7: Create TB Config

In [None]:
import yaml
from pathlib import Path

config = {
    "model_name": "efficientnet_b0",
    "num_classes": 2,
    "image_size": 224,
    "epochs": 50,
    "batch_size": 32,
    "learning_rate": 0.0003,
    "weight_decay": 0.0001,
    "num_workers": 2,
    "amp": True,
    "train_dir": "data/train",
    "val_dir": "data/val",
    "save_dir": "checkpoints_tb_ast",
    "resume": True,
    "patience": 15,
    # AST settings - EXACT same as malaria
    "ast_target_activation_rate": 0.40,
    "ast_initial_threshold": 3.0,
    "ast_adapt_kp": 0.005,
    "ast_adapt_ki": 0.0001,
    "ast_ema_alpha": 0.1,
    "ast_warmup_epochs": 2,
}

config_path = Path("configs/config_tb_ast.yaml")
config_path.parent.mkdir(exist_ok=True)

with open(config_path, "w") as f:
    yaml.dump(config, f, default_flow_style=False)

print(f"‚úÖ Config created!")
print(f"\n‚öôÔ∏è AST Settings:")
print(f"  Activation rate: {config['ast_target_activation_rate']*100:.0f}%")
print(f"  Expected savings: ~{(1-config['ast_target_activation_rate'])*100:.0f}%")

## Step 8: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!mkdir -p '/content/drive/MyDrive/TB_AST_Results'
print("‚úÖ Drive mounted!")

## Step 9: Train TB Model with AST

**This will now work with proper data organization!**

In [None]:
# Verify data exists before training
!ls -la data/train/
!ls -la data/val/

# Train with proven AST code
!python train_ast.py --config configs/config_tb_ast.yaml

## Step 10: View Results

In [None]:
import json
import pandas as pd

metrics = []
with open('checkpoints_tb_ast/metrics_ast.jsonl', 'r') as f:
    for line in f:
        metrics.append(json.loads(line))

df = pd.DataFrame(metrics)

print("="*80)
print("üéâ TB DETECTION TRAINING COMPLETE")
print("="*80)

best_acc = df['val_acc'].max() * 100
best_epoch = df.loc[df['val_acc'].idxmax(), 'epoch']
print(f"\nüéØ Best Accuracy: {best_acc:.2f}% (Epoch {best_epoch})")

non_warmup = df[df['epoch'] > 2]
if len(non_warmup) > 0:
    avg_savings = non_warmup['energy_savings'].mean()
    avg_activation = non_warmup['activation_rate'].mean()
    print(f"\n‚ö° Energy Efficiency:")
    print(f"   Average Energy Savings: {avg_savings:.1f}%")
    print(f"   Average Activation Rate: {avg_activation*100:.1f}%")

print("\nüìä Last 10 Epochs:")
print(df[['epoch', 'val_acc', 'activation_rate', 'energy_savings']].tail(10))

print("\n" + "="*80)
print(f"üé§ Results: '{best_acc:.1f}% TB detection with {avg_savings:.0f}% energy savings'")
print("="*80)

## Step 11: Save to Drive

In [None]:
!cp -r checkpoints_tb_ast /content/drive/MyDrive/TB_AST_Results/
!cp configs/config_tb_ast.yaml /content/drive/MyDrive/TB_AST_Results/

print("‚úÖ Results saved to Google Drive!")