Hệ thống phát hiện bệnh trên lá lúa sử dụng mô hình SSD300 với backbone VGG16
Phát hiện 4 loại bệnh chính:
- Bacterial Blight (Bạc lá)
- Blast (Đạo ôn)
- Brown Spot (Đốm nâu)
- Twisted Draft (Xoắn lá)

Mô hình sử dụng định dạng dữ liệu COCO.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision pycocotools -q
!pip install seaborn scikit-learn matplotlib -q

In [None]:
%%writefile config.py
import os
import torch

# Path settings
DATA_ROOT = "/content/drive/MyDrive/Coco_Dataset"
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR = os.path.join(DATA_ROOT, "valid")
TEST_DIR = os.path.join(DATA_ROOT, "test")

TRAIN_ANNO = os.path.join(TRAIN_DIR, "_annotations.coco.json")
VAL_ANNO = os.path.join(VAL_DIR, "_annotations.coco.json")
TEST_ANNO = os.path.join(TEST_DIR, "_annotations.coco.json")

# Output directory for saving models and results
OUTPUT_DIR = "/content/drive/MyDrive/SSD_Output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model parameters
NUM_CLASSES = 5  # 4 disease classes + background class
MODEL_TYPE = "ssd300_vgg16"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training parameters
BATCH_SIZE = 16
NUM_WORKERS = 4
LEARNING_RATE = 0.001
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NUM_EPOCHS = 50

# Scheduler parameters
STEP_SIZE = 20
GAMMA = 0.1


IMAGE_SIZE = 300


CONFIDENCE_THRESHOLD = 0.5
IOU_THRESHOLD = 0.5


CLASS_NAMES = [
    "background",  # Class 0 (background) is always included in SSD
    "bacterial_blight",
    "blast",
    "brown_spot",
    "tungro"
]

# Paths for saving results
MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "ssd_model.pth")
METRICS_SAVE_PATH = os.path.join(OUTPUT_DIR, "metrics")
CONFUSION_MATRIX_PATH = os.path.join(OUTPUT_DIR, "confusion_matrix.png")
PR_CURVE_PATH = os.path.join(OUTPUT_DIR, "pr_curve.png")
F1_CURVE_PATH = os.path.join(OUTPUT_DIR, "f1_curve.png")

# Create directories
os.makedirs(METRICS_SAVE_PATH, exist_ok=True)

def print_config():
    """Print the current configuration."""
    print("\nRice Leaf Disease Detection with SSD - Configuration")
    print("=" * 50)
    print(f"Dataset: {DATA_ROOT}")
    print(f"Model type: {MODEL_TYPE}")
    print(f"Number of classes: {NUM_CLASSES}")
    print(f"Device: {DEVICE}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Learning rate: {LEARNING_RATE}")
    print(f"Number of epochs: {NUM_EPOCHS}")
    print("=" * 50)

In [None]:
%%writefile setup.py
import os
import subprocess
import sys
from google.colab import drive

def mount_drive():
    # need to get that data from somewhere
    drive.mount('/content/drive')
    print("Drive mounted successfully!")

def check_gpu():
    try:
        gpu_info = subprocess.check_output('nvidia-smi', shell=True).decode('utf-8')
        print("GPU information:")
        print(gpu_info)
        return True
    except:
        print("No GPU found or nvidia-smi command failed.")
        return False

def check_dataset():
    data_dir = "/content/drive/MyDrive/Coco_Dataset"
    try:
        train_dir = os.path.join(data_dir, "train")
        val_dir = os.path.join(data_dir, "valid")
        test_dir = os.path.join(data_dir, "test")

        train_anno = os.path.join(train_dir, "_annotations.coco.json")
        val_anno = os.path.join(val_dir, "_annotations.coco.json")
        test_anno = os.path.join(test_dir, "_annotations.coco.json")

        paths_to_check = [train_dir, val_dir, test_dir, train_anno, val_anno, test_anno]

        for path in paths_to_check:
            if not os.path.exists(path):
                print(f"Missing path: {path}")
                return False

        # Check if there are images in the directories
        train_files = [f for f in os.listdir(train_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        val_files = [f for f in os.listdir(val_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        test_files = [f for f in os.listdir(test_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]

        print(f"Train images: {len(train_files)}")
        print(f"Validation images: {len(val_files)}")
        print(f"Test images: {len(test_files)}")

        if len(train_files) == 0 or len(val_files) == 0 or len(test_files) == 0:
            print("Warning: One or more directories have no images.")
            return False

        print("Dataset looks good!")
        return True

    except Exception as e:
        print(f"Error checking dataset: {e}")
        return False

def install_packages():
    packages = [
        "pip install torch torchvision pycocotools",
        "pip install seaborn scikit-learn matplotlib"
    ]

    for cmd in packages:
        try:
            print(f"Running: {cmd}")
            subprocess.run(cmd, shell=True, check=True)
            print("Installation successful")
        except subprocess.CalledProcessError as e:
            print(f"Failed to run: {cmd}")
            print(f"Error: {e}")
            return False

    return True

def check_environment():
    try:
        import torch
        import torchvision
        from pycocotools.coco import COCO

        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA version: {torch.version.cuda}")

        return True

    except ImportError as e:
        print(f"Environment check failed: {e}")
        print("Please run setup.install_packages() first.")
        return False

def setup_all():
    mount_drive()
    gpu_ok = check_gpu()
    if not gpu_ok:
        print("Warning: GPU issues detected. Training may be slow.")

    dataset_ok = check_dataset()
    if not dataset_ok:
        print("Warning: Dataset issues detected.")

    packages_ok = install_packages()
    if not packages_ok:
        print("Error installing packages.")
        return False

    env_ok = check_environment()
    if not env_ok:
        print("Environment check failed.")
        return False

    print("Setup completed successfully!")
    return True

if __name__ == "__main__":
    print("Setting up Rice Leaf Disease Detection with SSD...")
    setup_all()

In [None]:
%%writefile data.py
import torch
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pycocotools.coco import COCO
from PIL import Image
import numpy as np
import os
from config import *

class RiceLeafDataset(Dataset):
    def __init__(self, root, annFile, transform=None, target_transform=None):
        self.root = root
        self.coco = COCO(annFile)
        self.ids = list(self.coco.imgs.keys())
        self.transform = transform
        self.target_transform = target_transform

        # Load category mapping (needed to remap category IDs if necessary)
        with open(annFile, 'r') as f:
            data = json.load(f)

        # Create a mapping from original category IDs to sequential IDs (0-indexed)
        self.cat_mapping = {}
        for i, cat in enumerate(data['categories']):
            self.cat_mapping[cat['id']] = i + 1  # +1 because 0 is background in SSD

        print(f"Loaded {len(self.ids)} images")
        print(f"Category mapping: {self.cat_mapping}")

    def __getitem__(self, index):
        img_id = self.ids[index]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        annotations = self.coco.loadAnns(ann_ids)

        # Load image
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])
        img = Image.open(img_path).convert('RGB')

        # Get image dimensions
        width, height = img.size

        # Process annotations
        boxes = []
        labels = []
        area = []
        iscrowd = []

        for ann in annotations:
            # Get bbox in [x_min, y_min, width, height] format
            x, y, w, h = ann['bbox']

            # Skip invalid boxes
            if w <= 0 or h <= 0:
                continue

            # Convert to [x_min, y_min, x_max, y_max] format for PyTorch
            boxes.append([x, y, x + w, y + h])

            # Get category ID (remap if necessary)
            cat_id = ann['category_id']
            if cat_id in self.cat_mapping:
                # remap category ID
                cat_id = self.cat_mapping[cat_id]
            labels.append(cat_id)

            area.append(ann.get('area', w * h))
            iscrowd.append(ann.get('iscrowd', 0))

        # Convert to tensors
        if boxes:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            area = torch.as_tensor(area, dtype=torch.float32)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        else:
            # Handle images with no annotations
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
            area = torch.zeros(0, dtype=torch.float32)
            iscrowd = torch.zeros(0, dtype=torch.int64)

        # Create target dictionary
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([img_id]),
            'area': area,
            'iscrowd': iscrowd
        }

        # Apply transforms
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.ids)

def get_data_transforms():
    # Training transforms with augmentation
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Validation and test transforms (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    return train_transform, val_transform

def collate_fn(batch):
    images = []
    targets = []
    for img, target in batch:
        images.append(img)
        targets.append(target)
    return images, targets

def get_dataloaders():
    # Get transforms
    train_transform, val_transform = get_data_transforms()

    # Create datasets
    train_dataset = RiceLeafDataset(root=TRAIN_DIR, annFile=TRAIN_ANNO, transform=train_transform)
    val_dataset = RiceLeafDataset(root=VAL_DIR, annFile=VAL_ANNO, transform=val_transform)
    test_dataset = RiceLeafDataset(root=TEST_DIR, annFile=TEST_ANNO, transform=val_transform)

    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        pin_memory=True
    )

    print(f"Train dataset: {len(train_dataset)} images")
    print(f"Validation dataset: {len(val_dataset)} images")
    print(f"Test dataset: {len(test_dataset)} images")

    return train_loader, val_loader, test_loader

def get_class_weights(train_loader):
    class_counts = torch.zeros(NUM_CLASSES)

    for _, targets in train_loader:
        for target in targets:
            labels = target['labels']
            for label in labels:
                class_counts[label] += 1

    # Add a small constant to avoid division by zero
    class_counts = class_counts + 1e-6

    # Inverse frequency weighting
    class_weights = 1.0 / class_counts

    # Normalize weights to sum to NUM_CLASSES
    class_weights = class_weights * (NUM_CLASSES / class_weights.sum())

    print("Class weights:")
    for i, weight in enumerate(class_weights):
        if i < len(CLASS_NAMES):
            print(f"  {CLASS_NAMES[i]}: {weight.item():.4f}")
        else:
            print(f"  Class {i}: {weight.item():.4f}")

    return class_weights.to(DEVICE)

def analyze_dataset():
    """
    Analyze the dataset to get statistics about classes, box sizes, etc.
    """
    datasets = [
        ('Training', TRAIN_ANNO),
        ('Validation', VAL_ANNO),
        ('Testing', TEST_ANNO)
    ]

    for name, annFile in datasets:
        print(f"\n{name} Dataset Analysis:")
        coco = COCO(annFile)

        # Number of images
        img_ids = coco.getImgIds()
        print(f"Number of images: {len(img_ids)}")

        # Number of instances per category
        cat_ids = coco.getCatIds()
        print(f"Categories: {cat_ids}")

        for cat_id in cat_ids:
            cat_name = coco.loadCats(cat_id)[0]['name']
            ann_ids = coco.getAnnIds(catIds=cat_id)
            print(f"  {cat_name} (ID: {cat_id}): {len(ann_ids)} instances")

        # Box size distribution
        all_anns = coco.loadAnns(coco.getAnnIds())
        areas = [ann['area'] for ann in all_anns]

        if areas:
            min_area = min(areas)
            max_area = max(areas)
            avg_area = sum(areas) / len(areas)

            print(f"Bounding box areas:")
            print(f"  Min: {min_area:.2f} pixels²")
            print(f"  Max: {max_area:.2f} pixels²")
            print(f"  Avg: {avg_area:.2f} pixels²")
        else:
            print("No annotations found.")

if __name__ == "__main__":
    # If run directly, analyze the dataset
    analyze_dataset()

    # Test the data loading
    print("\nTesting data loading...")
    train_loader, val_loader, test_loader = get_dataloaders()

    # Display a few samples from the training set
    for images, targets in train_loader:
        print(f"Batch size: {len(images)}")
        print(f"Image shape: {images[0].shape}")
        print(f"Target example: {targets[0]}")
        break

In [None]:
%%writefile model.py
import torch
import torch.nn as nn
from torchvision.models.detection import ssd300_vgg16
from torchvision.models.detection.ssd import SSDHead
import torchvision
import os
from config import *

def get_model(num_classes=NUM_CLASSES, pretrained=True):
    print(f"Creating {MODEL_TYPE} model with {num_classes} classes...")

    # Load the pretrained model
    model = ssd300_vgg16(pretrained=pretrained)

    # Modify the classification head for the new number of classes
    # SSD has one classification head per feature map
    in_channels = model.head.classification_head.classification_headers[0].in_channels
    num_anchors = model.head.classification_head.classification_headers[0].out_channels // 21  # 21 is the default num_classes (20 + background)

    # Create new classification headers for our number of classes
    classification_headers = nn.ModuleList([
        nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
        for in_channels in model.head.classification_head.in_channels
    ])

    # Replace the classification heads in the model
    model.head.classification_head.classification_headers = classification_headers
    model.head.classification_head.num_classes = num_classes

    # Move model to the correct device
    model.to(DEVICE)

    print(f"Model created and moved to {DEVICE}")
    return model

def save_model(model, path=MODEL_SAVE_PATH):
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(path), exist_ok=True)

    # Save the model
    torch.save({
        'model_state_dict': model.state_dict(),
        'num_classes': NUM_CLASSES,
        'model_type': MODEL_TYPE
    }, path)

    print(f"Model saved to {path}")

def load_model(path=MODEL_SAVE_PATH):
    if not os.path.exists(path):
        print(f"Model file not found at {path}")
        return None

    # Load the saved model info
    checkpoint = torch.load(path, map_location=DEVICE)

    # Create a new model with the same configuration
    model = get_model(num_classes=checkpoint.get('num_classes', NUM_CLASSES), pretrained=False)

    # Load the state dictionary
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"Model loaded from {path}")
    return model

def get_model_summary(model):
    # Need to use torchinfo for better summary if available
    try:
        from torchinfo import summary
        model_summary = summary(model, input_size=(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE),
                               verbose=0, device=DEVICE)
        return str(model_summary)
    except ImportError:
        return str(model)

def test_model():
    # Create a model
    model = get_model()
    model.eval()

    # Create a dummy input
    dummy_input = torch.randn(2, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)

    # Test in training mode (requires targets)
    model.train()
    dummy_target = [
        {
            'boxes': torch.tensor([[10, 10, 100, 100]], dtype=torch.float32).to(DEVICE),
            'labels': torch.tensor([1], dtype=torch.int64).to(DEVICE)
        },
        {
            'boxes': torch.tensor([[50, 50, 150, 150]], dtype=torch.float32).to(DEVICE),
            'labels': torch.tensor([2], dtype=torch.int64).to(DEVICE)
        }
    ]

    try:
        loss_dict = model(dummy_input, dummy_target)
        print("Model training mode test successful")
        print(f"Loss dictionary: {loss_dict}")
    except Exception as e:
        print(f"Model training mode test failed: {e}")

    # Test in evaluation mode
    model.eval()
    with torch.no_grad():
        try:
            predictions = model(dummy_input)
            print("\nModel evaluation mode test successful")
            print(f"Predictions: {len(predictions)} items")

            # Print the keys in the first prediction
            print(f"Prediction keys: {predictions[0].keys()}")

            # Check shapes
            print(f"Boxes shape: {predictions[0]['boxes'].shape}")
            print(f"Labels shape: {predictions[0]['labels'].shape}")
            print(f"Scores shape: {predictions[0]['scores'].shape}")

        except Exception as e:
            print(f"Model evaluation mode test failed: {e}")

    return model

if __name__ == "__main__":
    # Test the model
    model = test_model()

    # Print model summary
    print("\nModel Architecture:")
    print(get_model_summary(model))

In [None]:
%%writefile train.py
import torch
import time
import datetime
import os
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import numpy as np

from model import get_model, save_model
from data import get_dataloaders, get_class_weights
from config import *

def train_one_epoch(model, dataloader, optimizer, scaler, device, epoch):
    model.train()
    total_loss = 0
    loss_classifier = 0
    loss_box_reg = 0
    loss_objectness = 0
    loss_rpn_box_reg = 0

    # Progress tracking
    start_time = time.time()
    num_batches = len(dataloader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: Training...")

    for i, (images, targets) in enumerate(dataloader):
        # Move data to device
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with mixed precision
        with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
            loss_dict = model(images, targets)
            # Calculate total loss
            losses = sum(loss for loss in loss_dict.values())

        # Backward pass with gradient scaling
        scaler.scale(losses).backward()
        scaler.step(optimizer)
        scaler.update()

        # Add to running loss
        total_loss += losses.item()

        # Track individual losses
        for loss_name, loss_value in loss_dict.items():
            if 'classifier' in loss_name:
                loss_classifier += loss_value.item()
            elif 'box_reg' in loss_name and 'rpn' not in loss_name:
                loss_box_reg += loss_value.item()
            elif 'objectness' in loss_name:
                loss_objectness += loss_value.item()
            elif 'rpn_box_reg' in loss_name:
                loss_rpn_box_reg += loss_value.item()

        # Print progress
        if (i + 1) % 10 == 0 or (i + 1) == num_batches:
            elapsed = time.time() - start_time
            elapsed_str = str(datetime.timedelta(seconds=int(elapsed)))
            eta = elapsed * (num_batches - i - 1) / (i + 1)
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            print(f"  Batch {i+1}/{num_batches}, Loss: {losses.item():.4f}, Time: {elapsed_str}, ETA: {eta_str}", end='\r')

    # Calculate average losses
    avg_loss = total_loss / num_batches
    avg_loss_classifier = loss_classifier / num_batches
    avg_loss_box_reg = loss_box_reg / num_batches
    avg_loss_objectness = loss_objectness / num_batches
    avg_loss_rpn_box_reg = loss_rpn_box_reg / num_batches

    print(f"\nEpoch {epoch+1}: Avg Loss: {avg_loss:.4f}, Classifier: {avg_loss_classifier:.4f}, "
          f"Box Reg: {avg_loss_box_reg:.4f}, Objectness: {avg_loss_objectness:.4f}, "
          f"RPN Box Reg: {avg_loss_rpn_box_reg:.4f}")

    return {
        'total': avg_loss,
        'classifier': avg_loss_classifier,
        'box_reg': avg_loss_box_reg,
        'objectness': avg_loss_objectness,
        'rpn_box_reg': avg_loss_rpn_box_reg
    }

def validate(model, dataloader, device, epoch):
    model.eval()
    total_loss = 0
    loss_classifier = 0
    loss_box_reg = 0
    loss_objectness = 0
    loss_rpn_box_reg = 0

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: Validating...")

    with torch.no_grad():
        for i, (images, targets) in enumerate(dataloader):
            # Move data to device
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Forward pass (don't need mixed precision for validation)
            loss_dict = model(images, targets)

            # Calculate total loss
            losses = sum(loss for loss in loss_dict.values())

            # Add to running loss
            total_loss += losses.item()

            # Track individual losses
            for loss_name, loss_value in loss_dict.items():
                if 'classifier' in loss_name:
                    loss_classifier += loss_value.item()
                elif 'box_reg' in loss_name and 'rpn' not in loss_name:
                    loss_box_reg += loss_value.item()
                elif 'objectness' in loss_name:
                    loss_objectness += loss_value.item()
                elif 'rpn_box_reg' in loss_name:
                    loss_rpn_box_reg += loss_value.item()

    # Calculate average losses
    num_batches = len(dataloader)
    avg_loss = total_loss / num_batches
    avg_loss_classifier = loss_classifier / num_batches
    avg_loss_box_reg = loss_box_reg / num_batches
    avg_loss_objectness = loss_objectness / num_batches
    avg_loss_rpn_box_reg = loss_rpn_box_reg / num_batches

    print(f"Validation Loss: {avg_loss:.4f}, Classifier: {avg_loss_classifier:.4f}, "
          f"Box Reg: {avg_loss_box_reg:.4f}, Objectness: {avg_loss_objectness:.4f}, "
          f"RPN Box Reg: {avg_loss_rpn_box_reg:.4f}")

    return {
        'total': avg_loss,
        'classifier': avg_loss_classifier,
        'box_reg': avg_loss_box_reg,
        'objectness': avg_loss_objectness,
        'rpn_box_reg': avg_loss_rpn_box_reg
    }

def plot_losses(train_losses, val_losses, save_path=None):
    plt.figure(figsize=(12, 8))
    epochs = range(1, len(train_losses) + 1)

    # Plot total loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, [loss['total'] for loss in train_losses], 'b-', label='Training Loss')
    plt.plot(epochs, [loss['total'] for loss in val_losses], 'r-', label='Validation Loss')
    plt.title('Total Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot classifier loss
    plt.subplot(2, 2, 2)
    plt.plot(epochs, [loss['classifier'] for loss in train_losses], 'b-', label='Training')
    plt.plot(epochs, [loss['classifier'] for loss in val_losses], 'r-', label='Validation')
    plt.title('Classifier Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot box regression loss
    plt.subplot(2, 2, 3)
    plt.plot(epochs, [loss['box_reg'] for loss in train_losses], 'b-', label='Training')
    plt.plot(epochs, [loss['box_reg'] for loss in val_losses], 'r-', label='Validation')
    plt.title('Box Regression Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot objectness loss
    plt.subplot(2, 2, 4)
    plt.plot(epochs, [loss['objectness'] for loss in train_losses], 'b-', label='Training')
    plt.plot(epochs, [loss['objectness'] for loss in val_losses], 'r-', label='Validation')
    plt.title('Objectness Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
        print(f"Loss plot saved to {save_path}")
    else:
        plt.show()

    plt.close()

def train_model(resume_from=None):
    """
    Train the SSD model.

    Args:
        resume_from: Path to saved model to resume training from (if None, start fresh)

    Returns:
        Trained model and loss history
    """
    # Get data loaders
    train_loader, val_loader, _ = get_dataloaders()

    # Get model
    if resume_from and os.path.exists(resume_from):
        # Load the model to resume training
        from model import load_model
        model = load_model(resume_from)
        print(f"Resuming training from {resume_from}")
    else:
        # Create a new model
        model = get_model(num_classes=NUM_CLASSES, pretrained=True)
        print("Starting fresh training")

    # Define optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY
    )

    # Learning rate scheduler
    scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

    # Create gradient scaler for mixed precision training
    scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))

    # Lists to store losses
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    # Start timer
    start_time = time.time()

    # Train for the specified number of epochs
    for epoch in range(NUM_EPOCHS):
        # Train for one epoch
        train_loss = train_one_epoch(model, train_loader, optimizer, scaler, DEVICE, epoch)
        train_losses.append(train_loss)

        # Validate
        val_loss = validate(model, val_loader, DEVICE, epoch)
        val_losses.append(val_loss)

        # Update learning rate
        scheduler.step()

        # Save model if validation loss improved
        if val_loss['total'] < best_val_loss:
            best_val_loss = val_loss['total']
            save_model(model, os.path.join(OUTPUT_DIR, "best_model.pth"))
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_model(model, os.path.join(OUTPUT_DIR, f"checkpoint_epoch{epoch+1}.pth"))

            # Plot and save losses so far
            plot_losses(train_losses, val_losses,
                       save_path=os.path.join(METRICS_SAVE_PATH, f"losses_epoch{epoch+1}.png"))

    # Calculate training time
    total_time = time.time() - start_time
    print(f"Training completed in {datetime.timedelta(seconds=int(total_time))}")

    # Save final model
    save_model(model)
    print(f"Saved final model to {MODEL_SAVE_PATH}")

    # Plot losses
    plot_losses(train_losses, val_losses,
               save_path=os.path.join(METRICS_SAVE_PATH, "losses_final.png"))

    return model, (train_losses, val_losses)

if __name__ == "__main__":
    # Train the model
    model, losses = train_model()

In [None]:
%%writefile utils.py
import os
import torch
import numpy as np
import json
import matplotlib.pyplot as plt
from torchvision.ops import box_iou
from config import *

def calculate_map(pred_boxes, pred_labels, pred_scores, gt_boxes, gt_labels, iou_threshold=0.5, num_classes=NUM_CLASSES):
    # If no predictions or ground truth, return 0
    if len(pred_boxes) == 0 or len(gt_boxes) == 0:
        return 0.0, {cls: 0.0 for cls in range(1, num_classes)}

    # Convert to tensors if not already
    if not isinstance(pred_boxes, torch.Tensor):
        pred_boxes = torch.tensor(pred_boxes, dtype=torch.float32)
        pred_labels = torch.tensor(pred_labels, dtype=torch.int64)
        pred_scores = torch.tensor(pred_scores, dtype=torch.float32)
        gt_boxes = torch.tensor(gt_boxes, dtype=torch.float32)
        gt_labels = torch.tensor(gt_labels, dtype=torch.int64)

    # Calculate IoU matrix between all predictions and ground truth boxes
    iou_matrix = box_iou(pred_boxes, gt_boxes)

    # Calculate AP for each class
    aps = {}
    for cls in range(1, num_classes):  # Skip background class
        # Find predictions and ground truth for this class
        cls_pred_indices = (pred_labels == cls).nonzero(as_tuple=True)[0]
        cls_gt_indices = (gt_labels == cls).nonzero(as_tuple=True)[0]

        # If no predictions or ground truth for this class, AP is 0
        if len(cls_pred_indices) == 0 or len(cls_gt_indices) == 0:
            aps[cls] = 0.0
            continue

        # Get scores for this class
        cls_scores = pred_scores[cls_pred_indices]

        # Sort predictions by confidence score (descending)
        sorted_indices = torch.argsort(cls_scores, descending=True)
        cls_pred_indices = cls_pred_indices[sorted_indices]

        # Get IoU matrix for this class
        cls_iou_matrix = iou_matrix[cls_pred_indices][:, cls_gt_indices]

        # For each prediction, find the best matching ground truth
        tp = torch.zeros(len(cls_pred_indices))
        fp = torch.zeros(len(cls_pred_indices))

        # Keep track of which ground truths have been matched
        gt_matched = torch.zeros(len(cls_gt_indices), dtype=torch.bool)

        # For each prediction (in order of confidence)
        for i in range(len(cls_pred_indices)):
            # Find the ground truth with highest IoU
            max_iou, max_idx = torch.max(cls_iou_matrix[i], dim=0)

            # If IoU > threshold and ground truth not already matched, it's a true positive
            if max_iou >= iou_threshold and not gt_matched[max_idx]:
                tp[i] = 1
                gt_matched[max_idx] = True
            else:
                fp[i] = 1

        # Calculate precision and recall at each prediction
        tp_cumsum = torch.cumsum(tp, dim=0)
        fp_cumsum = torch.cumsum(fp, dim=0)

        precision = tp_cumsum / (tp_cumsum + fp_cumsum)
        recall = tp_cumsum / len(cls_gt_indices)

        # Append sentinel values for easier calculation
        precision = torch.cat([torch.tensor([1.0]), precision])
        recall = torch.cat([torch.tensor([0.0]), recall])

        # Calculate AP using all points interpolation
        # For each recall level, take the maximum precision
        for i in range(len(precision) - 2, -1, -1):
            precision[i] = max(precision[i], precision[i + 1])

        # Calculate area under the curve (AP)
        ap = 0.0
        for i in range(1, len(recall)):
            ap += (recall[i] - recall[i - 1]) * precision[i]

        aps[cls] = float(ap)

    # Calculate mAP
    map_value = sum(aps.values()) / (num_classes - 1)  # Exclude background

    return map_value, aps

def plot_image_with_boxes(image, boxes, labels, scores=None, class_names=None, figsize=(10, 10), title=None):
    # Convert tensor to numpy array if needed
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).cpu().numpy()
        # Denormalize if needed
        image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        image = np.clip(image, 0, 1)

    # Create figure and axes
    fig, ax = plt.subplots(1, figsize=figsize)
    ax.imshow(image)

    # Colors for different classes
    colors = plt.cm.hsv(np.linspace(0, 1, NUM_CLASSES))

    # Draw each box
    for i, (box, label) in enumerate(zip(boxes, labels)):
        # Get box coordinates
        x1, y1, x2, y2 = box

        # Create rectangle
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
                           edgecolor=colors[label % len(colors)], facecolor='none')
        ax.add_patch(rect)

        # Add label text
        class_name = class_names[label] if class_names and label < len(class_names) else f"Class {label}"
        text = class_name

        # Add score if provided
        if scores is not None:
            text += f": {scores[i]:.2f}"

        # Draw label
        ax.text(x1, y1, text, backgroundcolor=colors[label % len(colors)], color='white', fontsize=8)

    # Add title if provided
    if title:
        ax.set_title(title)

    # Remove axes ticks
    ax.set_xticks([])
    ax.set_yticks([])

    # Show the plot
    plt.tight_layout()
    plt.show()

def save_model_info(model, losses, metrics, output_dir=None):
    if output_dir is None:
        output_dir = METRICS_SAVE_PATH

    # Create directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Save model architecture summary
    try:
        from model import get_model_summary
        model_summary = get_model_summary(model)

        with open(os.path.join(output_dir, "model_summary.txt"), "w") as f:
            f.write(model_summary)
    except Exception as e:
        print(f"Failed to save model summary: {e}")

    # Save losses
    if losses:
        try:
            train_losses, val_losses = losses

            # Convert to serializable format
            train_losses_json = []
            for loss in train_losses:
                train_losses_json.append({k: float(v) for k, v in loss.items()})

            val_losses_json = []
            for loss in val_losses:
                val_losses_json.append({k: float(v) for k, v in loss.items()})

            with open(os.path.join(output_dir, "training_losses.json"), "w") as f:
                json.dump({"train": train_losses_json, "val": val_losses_json}, f, indent=4)
        except Exception as e:
            print(f"Failed to save losses: {e}")

    # Save metrics
    if metrics:
        try:
            with open(os.path.join(output_dir, "evaluation_metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)
        except Exception as e:
            print(f"Failed to save metrics: {e}")

def check_dataset_balance(coco_json_path):
    # Load the JSON file
    with open(coco_json_path, "r") as f:
        data = json.load(f)

    # Get categories and their IDs
    categories = {}
    for cat in data["categories"]:
        categories[cat["id"]] = cat["name"]

    # Count annotations by category
    class_counts = {cat_id: 0 for cat_id in categories}

    for ann in data["annotations"]:
        cat_id = ann["category_id"]
        if cat_id in class_counts:
            class_counts[cat_id] += 1

    # Print class counts
    print(f"Class distribution in {os.path.basename(coco_json_path)}:")
    for cat_id, count in class_counts.items():
        cat_name = categories[cat_id]
        print(f"  - {cat_name} (ID: {cat_id}): {count}")

    # Calculate class imbalance
    total_annotations = sum(class_counts.values())
    max_count = max(class_counts.values())
    min_count = min(class_counts.values())

    print(f"Total annotations: {total_annotations}")
    print(f"Max/min ratio: {max_count/min_count:.2f}")

    # Return the counts
    return class_counts

def model_complexity(model):
    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Print parameter counts
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {total_params - trainable_params:,}")

    # Calculate FLOPs if torchprofile is available
    flops = None
    try:
        from torchprofile import profile_macs
        dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(next(model.parameters()).device)
        flops = profile_macs(model, dummy_input)
        print(f"Estimated FLOPs: {flops:,}")
    except ImportError:
        print("torchprofile not installed. FLOPs calculation skipped.")

    return {
        "total_params": total_params,
        "trainable_params": trainable_params,
        "non_trainable_params": total_params - trainable_params,
        "flops": flops
    }

if __name__ == "__main__":
    # Print header
    print("Rice Leaf Disease Detection with SSD - Utilities")

    # Check dataset balance
    for json_path in [TRAIN_ANNO, VAL_ANNO, TEST_ANNO]:
        if os.path.exists(json_path):
            check_dataset_balance(json_path)
            print()

In [None]:
%%writefile visualize.py
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw, ImageFont
import cv2
import os
import random
from torchvision import transforms
from torch.cuda.amp import autocast

from model import load_model
from data import RiceLeafDataset
from config import *

def visualize_dataset_samples(dataset, num_samples=5, output_dir=None):
    # Create output directory if it doesn't exist
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    # Select random indices
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

    for i, idx in enumerate(indices):
        # Get image and target
        img, target = dataset[idx]

        # If it's a tensor, convert to numpy array
        if isinstance(img, torch.Tensor):
            # Denormalize and convert to numpy array
            img = img.permute(1, 2, 0).numpy()
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)

        # Create figure
        fig, ax = plt.subplots(1, figsize=(12, 9))
        ax.imshow(img)

        # Extract boxes and labels
        boxes = target['boxes'].numpy() if isinstance(target['boxes'], torch.Tensor) else target['boxes']
        labels = target['labels'].numpy() if isinstance(target['labels'], torch.Tensor) else target['labels']

        # Draw boxes and labels
        for box, label in zip(boxes, labels):
            x1, y1, x2, y2 = box
            width = x2 - x1
            height = y2 - y1

            # Create rectangle
            rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')
            ax.add_patch(rect)

            # Add label
            class_name = CLASS_NAMES[label] if label < len(CLASS_NAMES) else f"Class {label}"
            ax.text(x1, y1, class_name, backgroundcolor='red', color='white', fontsize=8)

        # Set title
        ax.set_title(f"Sample {i+1}: {len(boxes)} annotations")

        # Remove axes ticks
        ax.set_xticks([])
        ax.set_yticks([])

        # Save or show
        if output_dir:
            plt.savefig(os.path.join(output_dir, f"sample_{i+1}.png"))
            plt.close()
        else:
            plt.show()

def visualize_prediction(model, image_path, confidence_threshold=CONFIDENCE_THRESHOLD, output_path=None):
    # Set model to evaluation mode
    model.eval()
    device = next(model.parameters()).device

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    orig_image = image.copy()

    # Transform image
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    input_image = transform(image).unsqueeze(0).to(device)

    # Get predictions
    with torch.no_grad():
        with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
            predictions = model(input_image)[0]

    # Get boxes, scores, and labels
    boxes = predictions['boxes'].cpu().numpy()
    scores = predictions['scores'].cpu().numpy()
    labels = predictions['labels'].cpu().numpy()

    # Filter by confidence threshold
    keep = scores >= confidence_threshold
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    # Get original image size
    orig_width, orig_height = orig_image.size

    # Scale boxes to original image size
    scale_x = orig_width / IMAGE_SIZE
    scale_y = orig_height / IMAGE_SIZE

    scaled_boxes = []
    for box in boxes:
        x1, y1, x2, y2 = box
        scaled_boxes.append([
            x1 * scale_x, y1 * scale_y,
            x2 * scale_x, y2 * scale_y
        ])

    boxes = np.array(scaled_boxes)

    # Create a copy of the original image for drawing
    draw_image = orig_image.copy()
    draw = ImageDraw.Draw(draw_image)

    # Try to load a font, use default if not available
    try:
        font = ImageFont.truetype("arial.ttf", 15)
    except IOError:
        font = ImageFont.load_default()

    # Colors for each class
    colors = [
        (255, 0, 0),    # Red
        (0, 255, 0),    # Green
        (0, 0, 255),    # Blue
        (255, 255, 0),  # Yellow
        (255, 0, 255),  # Magenta
    ]

    # Draw boxes and labels
    for i, (box, score, label) in enumerate(zip(boxes, scores, labels)):
        x1, y1, x2, y2 = box
        color = colors[label % len(colors)]

        # Draw box
        draw.rectangle([x1, y1, x2, y2], outline=color, width=3)

        # Draw label
        class_name = CLASS_NAMES[label] if label < len(CLASS_NAMES) else f"Class {label}"
        text = f"{class_name}: {score:.2f}"
        text_size = draw.textbbox((0, 0), text, font=font)[2:4]

        # Draw text background
        draw.rectangle([x1, y1, x1 + text_size[0], y1 + text_size[1]], fill=color)

        # Draw text
        draw.text((x1, y1), text, fill=(255, 255, 255), font=font)

    # Save or show
    if output_path:
        draw_image.save(output_path)
        print(f"Saved visualization to {output_path}")
    else:
        plt.figure(figsize=(12, 9))
        plt.imshow(np.array(draw_image))
        plt.axis('off')
        plt.title(f"Detection Results: {len(boxes)} objects")
        plt.show()

    return draw_image

def visualize_batch_predictions(model, dataloader, num_samples=5, confidence_threshold=CONFIDENCE_THRESHOLD, output_dir=None):
    # Set model to evaluation mode
    model.eval()
    device = next(model.parameters()).device

    # Create output directory if it doesn't exist
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    # Colors for each class
    colors = [
        (255, 0, 0),    # Red
        (0, 255, 0),    # Green
        (0, 0, 255),    # Blue
        (255, 255, 0),  # Yellow
        (255, 0, 255),  # Magenta
    ]

    # Sample counter
    count = 0

    # Process batches
    for images, targets in dataloader:
        # Move images to device
        images = [img.to(device) for img in images]

        # Get predictions
        with torch.no_grad():
            with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
                predictions = model(images)

        # Visualize each image in the batch
        for img, target, pred in zip(images, targets, predictions):
            # Skip if we've visualized enough samples
            if count >= num_samples:
                break

            # Create figure
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

            # Denormalize image
            img_np = img.cpu().permute(1, 2, 0).numpy()
            img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img_np = np.clip(img_np, 0, 1)

            # Show original image with ground truth on the left
            ax1.imshow(img_np)
            ax1.set_title("Ground Truth")

            # Draw ground truth boxes
            gt_boxes = target['boxes'].cpu().numpy()
            gt_labels = target['labels'].cpu().numpy()

            for box, label in zip(gt_boxes, gt_labels):
                x1, y1, x2, y2 = box
                width = x2 - x1
                height = y2 - y1

                # Get color for this class
                color = colors[label % len(colors)]
                color_rgb = [c/255 for c in color]  # Convert to RGB [0,1]

                # Create rectangle
                rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor=color_rgb, facecolor='none')
                ax1.add_patch(rect)

                # Add label
                class_name = CLASS_NAMES[label] if label < len(CLASS_NAMES) else f"Class {label}"
                ax1.text(x1, y1, class_name, backgroundcolor=color_rgb, color='white', fontsize=8)

            # Show image with predictions on the right
            ax2.imshow(img_np)
            ax2.set_title("Predictions")

            # Draw prediction boxes
            pred_boxes = pred['boxes'].cpu().numpy()
            pred_scores = pred['scores'].cpu().numpy()
            pred_labels = pred['labels'].cpu().numpy()

            # Filter by confidence threshold
            keep = pred_scores >= confidence_threshold
            pred_boxes = pred_boxes[keep]
            pred_scores = pred_scores[keep]
            pred_labels = pred_labels[keep]

            for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
                x1, y1, x2, y2 = box
                width = x2 - x1
                height = y2 - y1

                # Get color for this class
                color = colors[label % len(colors)]
                color_rgb = [c/255 for c in color]  # Convert to RGB [0,1]

                # Create rectangle
                rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor=color_rgb, facecolor='none')
                ax2.add_patch(rect)

                # Add label
                class_name = CLASS_NAMES[label] if label < len(CLASS_NAMES) else f"Class {label}"
                text = f"{class_name}: {score:.2f}"
                ax2.text(x1, y1, text, backgroundcolor=color_rgb, color='white', fontsize=8)

            # Remove axes ticks
            ax1.set_xticks([])
            ax1.set_yticks([])
            ax2.set_xticks([])
            ax2.set_yticks([])

            # Set super title
            fig.suptitle(f"Sample {count+1}: {len(gt_boxes)} ground truth, {len(pred_boxes)} predictions")
            plt.tight_layout()

            # Save or show
            if output_dir:
                plt.savefig(os.path.join(output_dir, f"comparison_{count+1}.png"))
                plt.close()
            else:
                plt.show()

            count += 1

            # Break if we've visualized enough samples
            if count >= num_samples:
                break

        # Break if we've visualized enough samples
        if count >= num_samples:
            break

def create_video_visualization(model, input_video_path, output_video_path, confidence_threshold=CONFIDENCE_THRESHOLD):
    # Set model to evaluation mode
    model.eval()
    device = next(model.parameters()).device

    # Open video
    cap = cv2.VideoCapture(input_video_path)

    # Check if video opened successfully
    if not cap.isOpened():
        print(f"Error: Could not open video {input_video_path}")
        return

    # Get video properties
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Create video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    # Transformation for preprocessing
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Colors for each class
    colors = [
        (255, 0, 0),    # Red
        (0, 255, 0),    # Green
        (0, 0, 255),    # Blue
        (255, 255, 0),  # Yellow
        (255, 0, 255),  # Magenta
    ]

    # Process each frame
    frame_count = 0
    while cap.isOpened():
        # Read frame
        ret, frame = cap.read()

        if not ret:
            break

        # Preprocess frame
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        input_tensor = transform(frame_rgb).unsqueeze(0).to(device)

        # Get predictions
        with torch.no_grad():
            with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
                predictions = model(input_tensor)[0]

        # Get boxes, scores, and labels
        boxes = predictions['boxes'].cpu().numpy()
        scores = predictions['scores'].cpu().numpy()
        labels = predictions['labels'].cpu().numpy()

        # Filter by confidence threshold
        keep = scores >= confidence_threshold
        boxes = boxes[keep]
        scores = scores[keep]
        labels = labels[keep]

        # Scale boxes to original frame size
        scale_x = frame_width / IMAGE_SIZE
        scale_y = frame_height / IMAGE_SIZE

        scaled_boxes = []
        for box in boxes:
            x1, y1, x2, y2 = box
            scaled_boxes.append([
                int(x1 * scale_x), int(y1 * scale_y),
                int(x2 * scale_x), int(y2 * scale_y)
            ])

        boxes = np.array(scaled_boxes, dtype=np.int32)

        # Draw boxes and labels on the frame
        for box, score, label in zip(boxes, scores, labels):
            x1, y1, x2, y2 = box

            # Get color for this class
            color = colors[label % len(colors)]

            # Draw box
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

            # Draw label
            class_name = CLASS_NAMES[label] if label < len(CLASS_NAMES) else f"Class {label}"
            text = f"{class_name}: {score:.2f}"

            # Get text size
            text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
            text_w, text_h = text_size

            # Draw text background
            cv2.rectangle(frame, (x1, y1), (x1 + text_w, y1 - text_h - 5), color, -1)

            # Draw text
            cv2.putText(frame, text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

        # Draw frame number
        cv2.putText(frame, f"Frame: {frame_count}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

        # Write frame
        out.write(frame)

        # Increment frame counter
        frame_count += 1

        # Print progress
        if frame_count % 100 == 0:
            print(f"Processed {frame_count} frames")

    # Release resources
    cap.release()
    out.release()

    print(f"Video processing completed. Output saved to {output_video_path}")

if __name__ == "__main__":
    # Load model
    model = load_model()

    if model is None:
        print("Failed to load model.")
        exit(1)

    # Create a dataset for visualization
    from data import get_data_transforms
    _, val_transform = get_data_transforms()

    # Create a test dataset
    test_dataset = RiceLeafDataset(root=TEST_DIR, annFile=TEST_ANNO, transform=val_transform)

    # Visualize dataset samples
    output_dir = os.path.join(OUTPUT_DIR, "visualizations", "samples")
    os.makedirs(output_dir, exist_ok=True)
    visualize_dataset_samples(test_dataset, num_samples=5, output_dir=output_dir)

    # Get test dataloader for batch predictions
    from data import get_dataloaders
    _, _, test_loader = get_dataloaders()

    # Visualize batch predictions
    output_dir = os.path.join(OUTPUT_DIR, "visualizations", "predictions")
    os.makedirs(output_dir, exist_ok=True)
    visualize_batch_predictions(model, test_loader, num_samples=5, output_dir=output_dir)

In [None]:
%%writefile evaluate.py
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, f1_score
from torch.cuda.amp import autocast
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from model import load_model
from data import get_dataloaders
from config import *

def convert_to_coco_format(predictions, image_ids, output_file=None):
    coco_predictions = []

    for pred, img_id in zip(predictions, image_ids):
        boxes = pred['boxes'].cpu().numpy()
        scores = pred['scores'].cpu().numpy()
        labels = pred['labels'].cpu().numpy()

        for box, score, label in zip(boxes, scores, labels):
            # Skip low-confidence predictions
            if score < CONFIDENCE_THRESHOLD:
                continue

            # Convert box from [x1, y1, x2, y2] to [x, y, width, height]
            x1, y1, x2, y2 = box
            x, y, w, h = x1, y1, x2 - x1, y2 - y1

            # Create prediction entry
            pred_entry = {
                'image_id': int(img_id),
                'category_id': int(label),
                'bbox': [float(x), float(y), float(w), float(h)],
                'score': float(score)
            }

            coco_predictions.append(pred_entry)

    # Save to file if specified
    if output_file:
        with open(output_file, 'w') as f:
            json.dump(coco_predictions, f)
        print(f"Saved COCO format predictions to {output_file}")

    return coco_predictions

def evaluate_coco(model, dataloader, anno_file, output_dir=None):
    model.eval()
    device = next(model.parameters()).device

    # Lists to store predictions
    all_predictions = []
    all_image_ids = []

    print("Running inference on dataset...")

    with torch.no_grad():
        for i, (images, targets) in enumerate(dataloader):
            # Move images to device
            images = [img.to(device) for img in images]

            # Get image IDs
            image_ids = [target['image_id'].item() for target in targets]

            # Run inference with mixed precision
            with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
                outputs = model(images)

            all_predictions.extend(outputs)
            all_image_ids.extend(image_ids)

            # Print progress
            if (i + 1) % 10 == 0:
                print(f"Processed {i+1}/{len(dataloader)} batches", end='\r')

    print(f"\nProcessed {len(all_predictions)} images.")

    # Convert predictions to COCO format
    output_file = os.path.join(output_dir, "coco_predictions.json") if output_dir else None
    coco_predictions = convert_to_coco_format(all_predictions, all_image_ids, output_file)

    # Load COCO API for ground truth
    coco_gt = COCO(anno_file)

    # Create COCO API for predictions
    coco_dt = coco_gt.loadRes(coco_predictions)

    # Run COCO evaluation
    print("Running COCO evaluation...")
    coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

    # Extract results
    results = {
        'AP': coco_eval.stats[0],
        'AP50': coco_eval.stats[1],
        'AP75': coco_eval.stats[2],
        'APs': coco_eval.stats[3],
        'APm': coco_eval.stats[4],
        'APl': coco_eval.stats[5],
        'ARmax1': coco_eval.stats[6],
        'ARmax10': coco_eval.stats[7],
        'ARmax100': coco_eval.stats[8],
        'ARs': coco_eval.stats[9],
        'ARm': coco_eval.stats[10],
        'ARl': coco_eval.stats[11]
    }

    # Save results to file if specified
    if output_dir:
        results_file = os.path.join(output_dir, "coco_results.json")
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=4)
        print(f"Saved evaluation results to {results_file}")

    return results, coco_eval

def compute_confusion_matrix(model, dataloader, num_classes=NUM_CLASSES):
    model.eval()
    device = next(model.parameters()).device

    # Initialize lists to store true and predicted labels
    y_true = []
    y_pred = []
    y_scores = []

    print("Computing confusion matrix...")

    with torch.no_grad():
        for images, targets in dataloader:
            # Move images to device
            images = [img.to(device) for img in images]

            # Forward pass
            outputs = model(images)

            # For each image in the batch
            for i, (output, target) in enumerate(zip(outputs, targets)):
                pred_boxes = output['boxes'].cpu().numpy()
                pred_labels = output['labels'].cpu().numpy()
                pred_scores = output['scores'].cpu().numpy()

                gt_boxes = target['boxes'].cpu().numpy()
                gt_labels = target['labels'].cpu().numpy()

                # Skip if no ground truth or predictions
                if len(gt_boxes) == 0 or len(pred_boxes) == 0:
                    continue

                # Filter predictions by confidence
                conf_mask = pred_scores >= CONFIDENCE_THRESHOLD
                pred_boxes = pred_boxes[conf_mask]
                pred_labels = pred_labels[conf_mask]
                pred_scores = pred_scores[conf_mask]

                # Skip if no predictions after filtering
                if len(pred_boxes) == 0:
                    continue

                # Calculate IoU between each prediction and ground truth box
                ious = np.zeros((len(pred_boxes), len(gt_boxes)))
                for p_idx, pred_box in enumerate(pred_boxes):
                    for g_idx, gt_box in enumerate(gt_boxes):
                        # Calculate intersection coordinates
                        x1 = max(pred_box[0], gt_box[0])
                        y1 = max(pred_box[1], gt_box[1])
                        x2 = min(pred_box[2], gt_box[2])
                        y2 = min(pred_box[3], gt_box[3])

                        # Calculate intersection area
                        w = max(0, x2 - x1)
                        h = max(0, y2 - y1)
                        intersection = w * h

                        # Calculate union area
                        pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
                        gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
                        union = pred_area + gt_area - intersection

                        # Calculate IoU
                        iou = intersection / union if union > 0 else 0
                        ious[p_idx, g_idx] = iou

                # Match predictions to ground truth based on IoU
                matched_gt_indices = np.argmax(ious, axis=1)

                # Only count matches with IoU above threshold
                valid_matches = np.max(ious, axis=1) >= IOU_THRESHOLD

                for p_idx, (pred_label, gt_idx, valid) in enumerate(zip(pred_labels, matched_gt_indices, valid_matches)):
                    if valid:
                        y_true.append(gt_labels[gt_idx])
                        y_pred.append(pred_label)
                        y_scores.append(pred_scores[p_idx])

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=range(1, num_classes))  # Skip background class (0)

    return cm, y_true, y_pred, y_scores

def plot_confusion_matrix(cm, class_names=None, output_file=None):
    if class_names is None:
        class_names = [CLASS_NAMES[i] for i in range(1, len(CLASS_NAMES))]  # Skip background class

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    if output_file:
        plt.savefig(output_file)
        print(f"Saved confusion matrix to {output_file}")
    else:
        plt.show()

    plt.close()

def plot_precision_recall_curve(y_true, y_pred, y_scores, num_classes=NUM_CLASSES, output_file=None):
    # Convert to numpy arrays
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_scores = np.array(y_scores)

    plt.figure(figsize=(10, 8))

    # For each class (skip background)
    for cls in range(1, num_classes):
        # Prepare binary classification problem
        binary_true = (y_true == cls).astype(int)

        # Get scores for this class
        cls_scores = np.zeros_like(y_scores)
        cls_scores[y_pred == cls] = y_scores[y_pred == cls]

        # Calculate precision-recall curve
        precision, recall, _ = precision_recall_curve(binary_true, cls_scores)

        # Calculate average precision
        ap = average_precision_score(binary_true, cls_scores, average='macro')

        # Plot
        plt.plot(recall, precision, lw=2, label=f'{CLASS_NAMES[cls]} (AP={ap:.2f})')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='best')
    plt.grid(True)

    if output_file:
        plt.savefig(output_file)
        print(f"Saved precision-recall curve to {output_file}")
    else:
        plt.show()

    plt.close()

def plot_f1_curve(y_true, y_pred, y_scores, num_classes=NUM_CLASSES, output_file=None):
    # Convert to numpy arrays
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_scores = np.array(y_scores)

    plt.figure(figsize=(10, 8))

    # For each class (skip background)
    for cls in range(1, num_classes):
        # Prepare binary classification problem
        binary_true = (y_true == cls).astype(int)

        # Get scores for this class
        cls_scores = np.zeros_like(y_scores)
        cls_scores[y_pred == cls] = y_scores[y_pred == cls]

        # Calculate precision-recall curve
        precision, recall, thresholds = precision_recall_curve(binary_true, cls_scores)

        # Calculate F1 scores at each threshold
        f1_scores = np.zeros_like(precision)
        for i in range(len(precision)):
            if precision[i] + recall[i] > 0:  # Avoid division by zero
                f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i])

        # Plot
        plt.plot(recall, f1_scores, lw=2, label=f'{CLASS_NAMES[cls]}')

    plt.xlabel('Recall')
    plt.ylabel('F1 Score')
    plt.title('F1 Score vs Recall Curve')
    plt.legend(loc='best')
    plt.grid(True)

    if output_file:
        plt.savefig(output_file)
        print(f"Saved F1 curve to {output_file}")
    else:
        plt.show()

    plt.close()

def calculate_metrics(y_true, y_pred, num_classes=NUM_CLASSES):
    metrics = {}

    # Convert to numpy arrays
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Calculate overall metrics
    metrics['accuracy'] = np.mean(y_true == y_pred)

    # Calculate per-class metrics
    for cls in range(1, num_classes):  # Skip background class
        # Binary classification problem
        binary_true = (y_true == cls).astype(int)
        binary_pred = (y_pred == cls).astype(int)

        # True positives, false positives, false negatives
        tp = np.sum((binary_true == 1) & (binary_pred == 1))
        fp = np.sum((binary_true == 0) & (binary_pred == 1))
        fn = np.sum((binary_true == 1) & (binary_pred == 0))

        # Precision, recall, F1
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        metrics[f'class_{cls}_precision'] = precision
        metrics[f'class_{cls}_recall'] = recall
        metrics[f'class_{cls}_f1'] = f1

    return metrics

def evaluate_model(model_path=None, output_dir=None):
    # Use default paths if not specified
    if model_path is None:
        model_path = MODEL_SAVE_PATH

    if output_dir is None:
        output_dir = METRICS_SAVE_PATH

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Load model
    model = load_model(model_path)
    if model is None:
        print(f"Failed to load model from {model_path}")
        return None

    # Set model to evaluation mode
    model.eval()

    # Get test dataloader
    _, _, test_loader = get_dataloaders()

    # Run COCO evaluation
    print("\nRunning COCO evaluation...")
    coco_results, coco_eval = evaluate_coco(model, test_loader, TEST_ANNO, output_dir)

    # Compute confusion matrix
    print("\nComputing confusion matrix...")
    cm, y_true, y_pred, y_scores = compute_confusion_matrix(model, test_loader)

    # Plot confusion matrix
    if len(y_true) > 0:
        plot_confusion_matrix(cm, class_names=[CLASS_NAMES[i] for i in range(1, NUM_CLASSES)],
                            output_file=CONFUSION_MATRIX_PATH)

        # Plot precision-recall curve
        plot_precision_recall_curve(y_true, y_pred, y_scores,
                                   output_file=PR_CURVE_PATH)

        # Plot F1 curve
        plot_f1_curve(y_true, y_pred, y_scores,
                     output_file=F1_CURVE_PATH)

        # Calculate additional metrics
        metrics = calculate_metrics(y_true, y_pred)

        # Combine all results
        results = {
            'coco_metrics': coco_results,
            'confusion_matrix': cm.tolist(),
            'classification_metrics': metrics,
            'num_matched_predictions': len(y_true)
        }

        # Save all results to file
        results_file = os.path.join(output_dir, "evaluation_results.json")
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=4)

        print(f"\nSaved complete evaluation results to {results_file}")

        # Print summary
        print("\nEvaluation Summary:")
        print(f"- mAP (IoU=0.50:0.95): {coco_results['AP']:.4f}")
        print(f"- mAP (IoU=0.50): {coco_results['AP50']:.4f}")
        print(f"- mAP (IoU=0.75): {coco_results['AP75']:.4f}")
        print(f"- Overall accuracy: {metrics['accuracy']:.4f}")

        # Print per-class metrics
        print("\nPer-class metrics:")
        for cls in range(1, NUM_CLASSES):
            print(f"- {CLASS_NAMES[cls]}:")
            print(f"  - Precision: {metrics[f'class_{cls}_precision']:.4f}")
            print(f"  - Recall: {metrics[f'class_{cls}_recall']:.4f}")
            print(f"  - F1 Score: {metrics[f'class_{cls}_f1']:.4f}")

        return results
    else:
        print("No matched predictions found. Cannot compute confusion matrix and metrics.")

        # Save just the COCO results
        results = {
            'coco_metrics': coco_results,
            'num_matched_predictions': 0
        }

        results_file = os.path.join(output_dir, "evaluation_results.json")
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=4)

        print(f"\nSaved partial evaluation results to {results_file}")

        return results

if __name__ == "__main__":
    # Evaluate the model
    evaluate_model()

In [None]:
%%writefile main.py
import os
import time
import argparse
import torch
from config import *

def print_header(text):
    """Print a section header."""
    print("\n" + "="*50)
    print(f"{text}")
    print("="*50)

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Rice Leaf Disease Detection with SSD")
    parser.add_argument("--skip-setup", action="store_true", help="Skip setup steps")
    parser.add_argument("--skip-train", action="store_true", help="Skip training")
    parser.add_argument("--skip-evaluate", action="store_true", help="Skip evaluation")
    parser.add_argument("--skip-visualize", action="store_true", help="Skip visualization")
    parser.add_argument("--resume", action="store_true", help="Resume training from checkpoint")
    parser.add_argument("--model-path", type=str, help="Path to model weights")
    parser.add_argument("--test-image", type=str, help="Path to single test image for visualization")
    parser.add_argument("--test-video", type=str, help="Path to test video for visualization")
    args = parser.parse_args()

    # Start timer
    start_time = time.time()

    # Step 1: Setup
    if not args.skip_setup:
        print_header("SETUP")
        from setup import setup_all

        if not setup_all():
            print("Setup failed. Exiting.")
            return

    # Import config after setup
    print_config()

    # Step 2: Training
    model = None
    if not args.skip_train:
        print_header("TRAINING")
        from train import train_model

        # Use specified model path or default
        model_path = args.model_path if args.resume and args.model_path else None

        # Train the model
        model, loss_history = train_model(resume_from=model_path)

    # Step 3: Evaluation
    if not args.skip_evaluate:
        print_header("EVALUATION")
        from evaluate import evaluate_model

        # Use specified model path or trained model
        model_path = args.model_path if args.model_path else MODEL_SAVE_PATH

        # Evaluate the model
        evaluation_results = evaluate_model(model_path=model_path)

        if evaluation_results:
            print("Evaluation complete.")
        else:
            print("Evaluation failed.")

    # Step 4: Visualization
    if not args.skip_visualize:
        print_header("VISUALIZATION")
        from visualize import visualize_dataset_samples, visualize_prediction, visualize_batch_predictions
        from visualize import create_video_visualization
        from data import RiceLeafDataset, get_dataloaders, get_data_transforms

        # Load model if not already loaded
        if model is None:
            from model import load_model
            model_path = args.model_path if args.model_path else MODEL_SAVE_PATH
            model = load_model(model_path)

            if model is None:
                print("Failed to load model for visualization. Skipping.")
                return

        # Create visualization directory
        vis_dir = os.path.join(OUTPUT_DIR, "visualizations")
        os.makedirs(vis_dir, exist_ok=True)

        # Visualize dataset samples
        print("Visualizing dataset samples...")
        _, val_transform = get_data_transforms()
        test_dataset = RiceLeafDataset(root=TEST_DIR, annFile=TEST_ANNO, transform=val_transform)

        sample_dir = os.path.join(vis_dir, "samples")
        os.makedirs(sample_dir, exist_ok=True)
        visualize_dataset_samples(test_dataset, num_samples=5, output_dir=sample_dir)

        # Visualize model predictions
        print("Visualizing model predictions...")
        _, _, test_loader = get_dataloaders()

        pred_dir = os.path.join(vis_dir, "predictions")
        os.makedirs(pred_dir, exist_ok=True)
        visualize_batch_predictions(model, test_loader, num_samples=5, output_dir=pred_dir)

        # Visualize single test image if specified
        if args.test_image and os.path.exists(args.test_image):
            print(f"Visualizing predictions on {args.test_image}...")
            output_path = os.path.join(vis_dir, "test_image_prediction.jpg")
            visualize_prediction(model, args.test_image, output_path=output_path)

        # Process test video if specified
        if args.test_video and os.path.exists(args.test_video):
            print(f"Processing video {args.test_video}...")
            output_path = os.path.join(vis_dir, "test_video_prediction.mp4")
            create_video_visualization(model, args.test_video, output_path)

    # Print total execution time
    total_time = time.time() - start_time
    print_header("COMPLETED")

    hours, remainder = divmod(total_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Total execution time: {int(hours)}h {int(minutes)}m {int(seconds)}s")
    print("Rice Leaf Disease Detection with SSD completed successfully!")

if __name__ == "__main__":
    # Print welcome message
    print("Rice Leaf Disease Detection with SSD")
    print("A PyTorch implementation for detecting rice leaf diseases using SSD")

    # Check CUDA availability
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")

    # Run the main function
    main()


import os
import json

# Đường dẫn thư mục dữ liệu
data_root = "/content/drive/MyDrive/Coco_Dataset"

# Kiểm tra thư mục chính
if os.path.exists(data_root):
    print(f"✅ Thư mục chính tồn tại: {data_root}")

    # Kiểm tra các thư mục con
    for folder in ["train", "valid", "test"]:
        folder_path = os.path.join(data_root, folder)
        if os.path.exists(folder_path):
            print(f"✅ Thư mục {folder} tồn tại")

            # Kiểm tra file annotations
            anno_file = os.path.join(folder_path, "_annotations.coco.json")
            if os.path.exists(anno_file):
                # Đọc số lượng ảnh và annotations
                with open(anno_file, 'r') as f:
                    data = json.load(f)
                print(f"   - Số lượng ảnh: {len(data.get('images', []))}")
                print(f"   - Số lượng annotations: {len(data.get('annotations', []))}")
                print(f"   - Số lượng categories: {len(data.get('categories', []))}")

                # Hiển thị thông tin về categories
                print("   - Danh sách categories:")
                for cat in data.get('categories', []):
                    print(f"      - ID: {cat['id']}, Name: {cat['name']}")
            else:
                print(f"❌ THIẾU file annotations: {anno_file}")

            # Kiểm tra số lượng ảnh
            img_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
            print(f"   - Số lượng file ảnh trong thư mục: {len(img_files)}")
        else:
            print(f"❌ THIẾU thư mục {folder}")
else:
    print(f"❌ THIẾU thư mục chính: {data_root}")
    print("Vui lòng tạo thư mục Coco_Dataset trong Google Drive của bạn")

In [None]:
from setup import setup_all

print("==== THIẾT LẬP DỰ ÁN ====")
setup_all()

In [None]:
from data import analyze_dataset, get_dataloaders, get_class_weights

print("==== PHÂN TÍCH DỮ LIỆU ====")
# Phân tích dữ liệu
analyze_dataset()

# Tạo dataloaders và kiểm tra số lượng ảnh
print("\nTạo dataloaders:")
train_loader, val_loader, test_loader = get_dataloaders()

# Tính toán trọng số lớp từ dữ liệu huấn luyện
print("\nTính toán trọng số lớp:")
class_weights = get_class_weights(train_loader)


In [None]:
from model import test_model, get_model_summary

print("==== KIỂM TRA MÔ HÌNH ====")
model = test_model()

# Hiển thị tóm tắt về mô hình
print("\nCấu trúc mô hình:")
model_summary = get_model_summary(model)
print(model_summary)

In [None]:
from train import train_model

print("==== BẮT ĐẦU HUẤN LUYỆN ====")
# Huấn luyện mô hình từ đầu
# model, losses = train_model()

model, losses = train_model(resume_from="/content/drive/MyDrive/Coco_Dataset/checkpoint.pth")

In [None]:
from evaluate import evaluate_model

print("==== ĐÁNH GIÁ MÔ HÌNH ====")
evaluation_results = evaluate_model()

In [None]:
from visualize import visualize_dataset_samples, visualize_batch_predictions, visualize_prediction
from data import RiceLeafDataset, get_data_transforms, get_dataloaders
from model import load_model
import os

print("==== TRỰC QUAN HÓA KẾT QUẢ ====")

# Tạo thư mục visualizations
vis_dir = os.path.join(OUTPUT_DIR, "visualizations")
os.makedirs(vis_dir, exist_ok=True)

# Hiển thị một số mẫu từ tập dữ liệu
print("Hiển thị mẫu dữ liệu:")
_, val_transform = get_data_transforms()
test_dataset = RiceLeafDataset(root=TEST_DIR, annFile=TEST_ANNO, transform=val_transform)

sample_dir = os.path.join(vis_dir, "samples")
os.makedirs(sample_dir, exist_ok=True)
visualize_dataset_samples(test_dataset, num_samples=3)

# Hiển thị kết quả dự đoán
print("\nHiển thị kết quả dự đoán:")
# Tải mô hình đã huấn luyện
model = load_model()
if model is not None:
    # Tạo dataloader để thực hiện dự đoán
    _, _, test_loader = get_dataloaders()

    # Hiển thị kết quả dự đoán
    pred_dir = os.path.join(vis_dir, "predictions")
    os.makedirs(pred_dir, exist_ok=True)
    visualize_batch_predictions(model, test_loader, num_samples=3)
else:
    print("Không thể tải mô hình. Vui lòng huấn luyện mô hình trước.")

In [None]:
from visualize import visualize_prediction
from model import load_model
import os
from google.colab import files

print("==== DỰ ĐOÁN TRÊN ẢNH TÙY CHỈNH ====")

# Tải lên ảnh
print("Tải lên ảnh lá lúa để dự đoán:")
try:
    uploaded = files.upload()
except:
    # Sử dụng ảnh mẫu nếu không thể tải lên
    uploaded = {'sample.jpg': 'Sử dụng ảnh mẫu'}
    print("Không thể tải ảnh lên. Sử dụng ảnh mẫu từ tập test.")

if len(uploaded) > 0:
    # Tải mô hình
    model = load_model()
    if model is not None:
        # Thực hiện dự đoán trên từng ảnh tải lên
        for filename in uploaded.keys():
            print(f"\nDự đoán trên ảnh: {filename}")

            # Nếu đang sử dụng ảnh mẫu từ tập test
            if filename == 'sample.jpg' and 'Sử dụng ảnh mẫu' in uploaded[filename]:
                # Lấy ảnh mẫu từ tập test
                import glob
                test_images = glob.glob(os.path.join(TEST_DIR, "*.jpg"))
                if test_images:
                    img_path = test_images[0]
                else:
                    print("Không tìm thấy ảnh mẫu trong tập test.")
                    continue
            else:
                img_path = filename

            # Tạo thư mục output nếu chưa tồn tại
            output_dir = os.path.join(OUTPUT_DIR, "custom_predictions")
            os.makedirs(output_dir, exist_ok=True)

            # Đường dẫn lưu kết quả
            output_path = os.path.join(output_dir, f"pred_{filename}")

            # Thực hiện dự đoán và hiển thị kết quả
            visualize_prediction(model, img_path, output_path=output_path)

            # Hiển thị ảnh kết quả
            from IPython.display import Image, display
            print(f"Kết quả dự đoán được lưu tại: {output_path}")
            display(Image(output_path))
    else:
        print("Không thể tải mô hình. Vui lòng huấn luyện mô hình trước.")
else:
    print("Không có ảnh nào được tải lên.")