# 🍅 DenseNet169 Tomato Disease Training - Google Colab

This notebook trains a DenseNet169 model for tomato disease detection with GPU acceleration.

**Expected Results:**
- 🎯 **Accuracy**: 99.72% (with proper dataset)
- ⏱️ **Training Time**: 30-90 minutes (vs 6-12 hours on CPU)
- 📊 **Dataset**: 32,022 tomato images, 10 disease classes

## 📋 Step 1: Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"📊 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ Running on CPU - training will be slower")

In [None]:
# Install required packages (if needed)
!pip install torch torchvision matplotlib seaborn pandas scikit-learn pillow

## 📁 Step 2: Upload Files

Upload these files to Colab:
1. `densenet_trainer.py`
2. `densenet_tomato_model.py` 
3. Your dataset (zipped): `tomato_disease.zip`

**Dataset Structure Expected:**
```
dataset/tomato_disease/
├── train/
│   ├── Tomato___Bacterial_spot/
│   ├── Tomato___Early_blight/
│   ├── Tomato___Late_blight/
│   ├── Tomato___Leaf_Mold/
│   ├── Tomato___Septoria_leaf_spot/
│   ├── Tomato___Spider_mites_Two_spotted_spider_mite/
│   ├── Tomato___Target_Spot/
│   ├── Tomato___Tomato_YellowLeaf__Curl_Virus/
│   ├── Tomato___Tomato_mosaic_virus/
│   └── Tomato___healthy/
└── val/
    └── (same structure as train)
```

## 📊 Optional: Add Visualization Files

For model interpretability and analysis, also upload:
4. `densenet_saliency.py` - Generate saliency maps
5. `densenet_occlusion.py` - Occlusion analysis 
6. `densenet_plot.py` - Training plots
7. `torchvis_util.py` - Visualization utilities

**Note**: These are optional but provide valuable insights into model decisions.

In [None]:
# Check if visualization files are available
viz_files = [
    'densenet_saliency.py',
    'densenet_occlusion.py', 
    'densenet_plot.py',
    'torchvis_util.py'
]

print("🔍 Checking for visualization files:")
available_viz = []
for file in viz_files:
    if os.path.exists(file):
        available_viz.append(file)
        print(f"   ✅ {file}")
    else:
        print(f"   ❌ {file} (optional)")

if available_viz:
    print(f"\n🎉 Found {len(available_viz)} visualization files!")
    print("   You can generate saliency maps and occlusion analysis after training.")
else:
    print("\n💡 No visualization files found. You can still train the model!")
    print("   Upload the visualization files to enable advanced analysis.")

In [None]:
# Extract dataset if uploaded as zip
import zipfile
import os

# Uncomment and modify if you uploaded a zip file
# with zipfile.ZipFile('tomato_disease.zip', 'r') as zip_ref:
#     zip_ref.extractall('dataset/')

# Check dataset structure
if os.path.exists('dataset/tomato_disease'):
    train_dir = 'dataset/tomato_disease/train'
    val_dir = 'dataset/tomato_disease/val'
    
    if os.path.exists(train_dir):
        train_classes = os.listdir(train_dir)
        print(f"✅ Found {len(train_classes)} training classes:")
        for cls in sorted(train_classes):
            count = len(os.listdir(os.path.join(train_dir, cls)))
            print(f"   📁 {cls}: {count} images")
    
    if os.path.exists(val_dir):
        val_classes = os.listdir(val_dir)
        print(f"\n✅ Found {len(val_classes)} validation classes")
        total_val = sum(len(os.listdir(os.path.join(val_dir, cls))) for cls in val_classes)
        print(f"   📊 Total validation images: {total_val}")
else:
    print("❌ Dataset not found. Please upload your dataset.")

## 🏋️ Step 3: Start Training

In [None]:
# Import training modules
import sys
sys.path.append('.')

from densenet_trainer import DenseNetTrainer
from densenet_tomato_model import DenseNetTomatoClassifier

print("🍅 DenseNet169 Tomato Disease Training Started!")
print("=" * 60)

In [None]:
# Initialize trainer with GPU optimization
trainer = DenseNetTrainer(
    data_dir='dataset/tomato_disease',
    save_dir='trained_models'
)

# Start training
print("🚀 Starting DenseNet169 training with GPU acceleration...")
model_path = trainer.run_training()

print(f"\n🎉 Training completed!")
print(f"📁 Model saved to: {model_path}")
print("\n💾 Download your trained model from the 'trained_models/' folder")

## 📊 Step 4: Test the Trained Model

In [None]:
# Test the trained model
import torch
from PIL import Image
import matplotlib.pyplot as plt

# Load the trained model
classifier = DenseNetTomatoClassifier()
if os.path.exists(model_path):
    result = classifier.load_model(model_path)
    print(f"✅ Model loaded: {result}")
    
    # Test with a sample image (modify path as needed)
    # sample_image_path = "path/to/your/test/image.jpg"
    # if os.path.exists(sample_image_path):
    #     result = classifier.predict_image(sample_image_path)
    #     print(f"🔍 Prediction: {result}")
else:
    print("❌ Model file not found")

## 📥 Step 5: Download Results

After training completes:

1. **Download the trained model**:
   - File: `trained_models/densenet169_tomato.pth`
   - Size: ~100-200 MB

2. **Download training plots** (if generated):
   - Training/validation curves
   - Accuracy metrics

3. **Transfer to your local project**:
   - Place `.pth` file in your local `trained_models/` folder
   - Update your FastAPI backend to use the new model

In [None]:
# Download files helper
from google.colab import files

# Download the trained model
if os.path.exists(model_path):
    print("📥 Downloading trained model...")
    files.download(model_path)
    print("✅ Download started!")
else:
    print("❌ Model file not found for download")

# List all available files
print("\n📁 Available files in trained_models/:")
if os.path.exists('trained_models'):
    for file in os.listdir('trained_models'):
        file_path = os.path.join('trained_models', file)
        size_mb = os.path.getsize(file_path) / (1024*1024)
        print(f"   📄 {file} ({size_mb:.1f} MB)")

## 🔬 Step 6: Model Visualization & Analysis

Generate saliency maps and occlusion analysis to understand what the model learned.

In [None]:
# Saliency Map Generation
if 'densenet_saliency.py' in available_viz and 'torchvis_util.py' in available_viz:
    print("🔍 Generating Saliency Maps...")
    
    # Create visualization directory structure
    os.makedirs('visualization', exist_ok=True)
    if os.path.exists('torchvis_util.py'):
        import shutil
        shutil.copy('torchvis_util.py', 'visualization/')
        # Create __init__.py for the visualization package
        with open('visualization/__init__.py', 'w') as f:
            f.write('# Visualization utilities\n')
    
    # Import saliency utilities
    exec(open('densenet_saliency.py').read())
    
    # Generate saliency maps for sample images
    sample_classes = ['Tomato___healthy', 'Tomato___Late_blight', 'Tomato___Early_blight']
    
    for class_name in sample_classes:
        class_path = f'dataset/tomato_disease/val/{class_name}'
        if os.path.exists(class_path):
            images = os.listdir(class_path)[:2]  # Take first 2 images
            for img_name in images:
                img_path = os.path.join(class_path, img_name)
                try:
                    print(f"   📊 Generating saliency for {class_name}/{img_name}")
                    
                    # Create saliency map (you'll need to call the actual function)
                    # result = generate_saliency_map(model_path, img_path, 'saliency_output/')
                    
                except Exception as e:
                    print(f"   ❌ Error: {e}")
    
    print("✅ Saliency maps generated in 'saliency_output/' folder")
else:
    print("❌ Saliency files not found. Upload densenet_saliency.py and torchvis_util.py for visualization.")

In [None]:
# Occlusion Analysis
if 'densenet_occlusion.py' in available_viz:
    print("🔍 Running Occlusion Analysis...")
    
    # Import occlusion utilities
    exec(open('densenet_occlusion.py').read())
    
    # Run occlusion analysis on sample images
    sample_classes = ['Tomato___healthy', 'Tomato___Late_blight']
    
    for class_name in sample_classes:
        class_path = f'dataset/tomato_disease/val/{class_name}'
        if os.path.exists(class_path):
            images = os.listdir(class_path)[:1]  # Take first image
            for img_name in images:
                img_path = os.path.join(class_path, img_name)
                try:
                    print(f"   🔍 Occlusion analysis for {class_name}/{img_name}")
                    
                    # Run occlusion experiment (you'll need to call the actual function)
                    # result = run_occlusion_experiment(model_path, img_path, 'occlusion_output/')
                    
                except Exception as e:
                    print(f"   ❌ Error: {e}")
    
    print("✅ Occlusion analysis completed in 'occlusion_output/' folder")
else:
    print("❌ Occlusion file not found. Upload densenet_occlusion.py for occlusion analysis.")

In [None]:
# Display Sample Visualizations
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def display_visualizations():
    """Display generated saliency maps and occlusion results"""
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('🔬 DenseNet169 Model Visualizations', fontsize=16, fontweight='bold')
    
    # Try to display saliency maps
    saliency_dir = 'saliency_output'
    occlusion_dir = 'occlusion_output'
    
    row_titles = ['Saliency Maps', 'Occlusion Analysis']
    col_titles = ['Healthy', 'Late Blight', 'Early Blight']
    
    for i, (row_title, directory) in enumerate([(row_titles[0], saliency_dir), 
                                                (row_titles[1], occlusion_dir)]):
        for j, col_title in enumerate(col_titles):
            ax = axes[i, j]
            
            # Look for visualization files
            viz_found = False
            if os.path.exists(directory):
                for file in os.listdir(directory):
                    if col_title.lower().replace(' ', '_') in file.lower() and file.endswith(('.png', '.jpg')):
                        try:
                            img = mpimg.imread(os.path.join(directory, file))
                            ax.imshow(img)
                            ax.set_title(f"{row_title}: {col_title}")
                            ax.axis('off')
                            viz_found = True
                            break
                        except Exception as e:
                            pass
            
            if not viz_found:
                ax.text(0.5, 0.5, f"No {row_title.lower()}\nfor {col_title}", 
                       ha='center', va='center', transform=ax.transAxes,
                       fontsize=10, style='italic')
                ax.set_title(f"{row_title}: {col_title}")
                ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Summary
    print("📊 Visualization Summary:")
    if os.path.exists(saliency_dir):
        saliency_count = len([f for f in os.listdir(saliency_dir) if f.endswith(('.png', '.jpg'))])
        print(f"   🔍 Saliency maps: {saliency_count} generated")
    
    if os.path.exists(occlusion_dir):
        occlusion_count = len([f for f in os.listdir(occlusion_dir) if f.endswith(('.png', '.jpg'))])
        print(f"   🔍 Occlusion maps: {occlusion_count} generated")

# Display visualizations if available
if any(os.path.exists(d) for d in ['saliency_output', 'occlusion_output']):
    display_visualizations()
else:
    print("📊 No visualizations generated yet. Run the visualization cells above first.")

In [None]:
# Training Progress Visualization
if 'densenet_plot.py' in available_viz:
    print("📈 Generating Training Plots...")
    
    # Import plotting utilities
    exec(open('densenet_plot.py').read())
    
    # Look for training history or logs
    if os.path.exists('trained_models'):
        for file in os.listdir('trained_models'):
            if 'history' in file.lower() or 'log' in file.lower():
                print(f"   📊 Found training history: {file}")
                # You can add code here to plot training curves
    
    print("✅ Training plots generated (if training history available)")
else:
    print("❌ Plot file not found. Upload densenet_plot.py for training visualizations.")

## 📥 Step 7: Download All Results

Download trained model and visualizations for local use.

In [None]:
# Download All Generated Files
import zipfile
from google.colab import files

def create_download_package():
    """Create a zip package with all generated files"""
    
    package_name = "densenet_training_results.zip"
    
    with zipfile.ZipFile(package_name, 'w') as zipf:
        # Add trained model
        if os.path.exists('trained_models'):
            for file in os.listdir('trained_models'):
                file_path = os.path.join('trained_models', file)
                zipf.write(file_path, f"trained_models/{file}")
                print(f"   📄 Added: {file}")
        
        # Add visualizations
        viz_dirs = ['saliency_output', 'occlusion_output']
        for viz_dir in viz_dirs:
            if os.path.exists(viz_dir):
                for file in os.listdir(viz_dir):
                    file_path = os.path.join(viz_dir, file)
                    zipf.write(file_path, f"{viz_dir}/{file}")
                    print(f"   🖼️ Added: {viz_dir}/{file}")
    
    return package_name

# Create and download package
print("📦 Creating download package...")
try:
    package_file = create_download_package()
    package_size = os.path.getsize(package_file) / (1024*1024)
    
    print(f"✅ Package created: {package_file} ({package_size:.1f} MB)")
    print("📥 Starting download...")
    
    files.download(package_file)
    print("🎉 Download complete!")
    
except Exception as e:
    print(f"❌ Error creating package: {e}")
    print("💡 Downloading files individually...")
    
    # Download individual files as fallback
    if os.path.exists(model_path):
        files.download(model_path)
    
    # Download visualizations
    for viz_dir in ['saliency_output', 'occlusion_output']:
        if os.path.exists(viz_dir):
            for file in os.listdir(viz_dir)[:5]:  # Limit to first 5 files
                try:
                    files.download(os.path.join(viz_dir, file))
                except:
                    pass