In [5]:
import os
import zipfile
import shutil
from pathlib import Path

def extract_and_flatten(zip_path, final_extract_path):
    temp_extract_path = Path("/content/temp_extract")

    # Step 0: Cleanup temp if it exists
    if temp_extract_path.exists():
        print("🧹 Cleaning up old temp directory...")
        shutil.rmtree(temp_extract_path)

    print("📦 Extracting ZIP to temporary directory...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(temp_extract_path)

    current = temp_extract_path
    while True:
        subdirs = [d for d in current.iterdir() if d.is_dir()]
        image_dirs = [
            d for d in subdirs if any(f.suffix.lower() in [".jpg", ".jpeg", ".png"] for f in d.glob("*"))
        ]

        if image_dirs:
            break  

        if len(subdirs) != 1:
            print("❌ Could not flatten — ambiguous or unexpected folder structure.")
            return

        current = subdirs[0]

    if final_extract_path.exists():
        print("🧹 Removing old extract folder.")
        shutil.rmtree(final_extract_path)

    print(f"📁 Moving extracted image folder to: {final_extract_path}")
    shutil.move(str(current), final_extract_path)

    if temp_extract_path.exists():
        shutil.rmtree(temp_extract_path)

    print(f"✅ Extraction and flattening complete: {final_extract_path}")


# === USAGE ===
zip_path = "/content/drive/MyDrive/Tomato_.zip"
final_extract_path = Path("/content/tomato_ds")

extract_and_flatten(zip_path, final_extract_path)




📦 Extracting ZIP to temporary directory...
🧹 Removing old extract folder.
📁 Moving extracted image folder to: /content/tomato_ds
✅ Extraction and flattening complete: /content/tomato_ds


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [6]:
import os
import shutil
import random
from pathlib import Path

def prepare_dataset(base_path, train_ratio=0.8):
    base_path = Path(base_path)
    train_path = base_path.parent / 'train'
    val_path = base_path.parent / 'val'

    # if train_path.exists() and val_path.exists():
    #     print("✔️ Dataset already split.")
    #     return

    train_path.mkdir(parents=True, exist_ok=True)
    val_path.mkdir(parents=True, exist_ok=True)

    # Each subfolder in PlantVillage is a class
    class_folders = [d for d in base_path.iterdir() if d.is_dir()]

    for class_dir in class_folders:
        images = list(class_dir.glob("*.JPG")) + list(class_dir.glob("*.png")) + list(class_dir.glob("*.jpeg"))

        if len(images) < 2:
            print(f"⚠️ Skipping {class_dir.name} — not enough images.")
            continue

        random.shuffle(images)
        split_idx = int(len(images) * train_ratio)
        train_imgs = images[:split_idx]
        val_imgs = images[split_idx:]

        train_class_dir = train_path / class_dir.name
        val_class_dir = val_path / class_dir.name
        train_class_dir.mkdir(parents=True, exist_ok=True)
        val_class_dir.mkdir(parents=True, exist_ok=True)

        for img in train_imgs:
            shutil.copy2(img, train_class_dir / img.name)

        for img in val_imgs:
            shutil.copy2(img, val_class_dir / img.name)

        print(f"✅ {class_dir.name}: {len(train_imgs)} train, {len(val_imgs)} val")

    print("🎉 Dataset split into train/val complete.")

# === USAGE ===
prepare_dataset("/content/tomato_ds")


✅ Tomato_Early_blight: 800 train, 200 val
✅ Tomato_Late_blight: 1405 train, 352 val
✅ Tomato_healthy: 1272 train, 318 val
🎉 Dataset split into train/val complete.


In [10]:
# 🚀 Colab Setup: Install Dependencies
# !pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
# !pip install -q albumentations==1.3.1
# !pip install -q scikit-learn
# !pip install -q matplotlib seaborn
# !pip install -q tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models
from torchvision.models import EfficientNet_B0_Weights
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import f1_score, confusion_matrix
import numpy as np
import cv2
import copy
import os
from pathlib import Path
import logging
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import urllib.request
import hashlib

# --------------------- Setup Logging ---------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(), logging.FileHandler("training_log_v5.txt")]
)
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

# --------------------- Setup ---------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"🚀 Device: {device}")

DATA_ROOT = Path("/content")
TRAIN_DIR = DATA_ROOT / "train"
VAL_DIR = DATA_ROOT / "val"

CONFIG = {
    'batch_size': 16,
    'epochs': 5,
    'patience': 3,
    'lr': 1e-3,
    'model_name': 'efficientnet_b0',
    'num_workers': 0,
}

# ------------------ Transforms -------------------
train_tfms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_tfms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

class AlbumentationsDataset(datasets.ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=None, is_valid_file=self.is_valid_file)
        self.transform = transform
        self.samples = self._validate_samples()
        logger.info(f"Valid samples in {root}: {len(self.samples)}")

    @staticmethod
    def is_valid_file(filename):
        valid_extensions = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG')
        return filename.lower().endswith(valid_extensions)

    def _validate_samples(self):
        if not self.samples:
            logger.error(f"No valid images found in {self.root}. Check directory structure and file extensions.")
            return []
        valid_samples = []
        for path, target in self.samples:
            try:
                img = cv2.imread(str(path))
                if img is None or img.size == 0:
                    logger.warning(f"Skipping invalid image: {path}")
                    continue
                valid_samples.append((path, target))
            except Exception as e:
                logger.warning(f"Error validating {path}: {e}")
        if not valid_samples:
            logger.error(f"No valid images after validation in {self.root}.")
        return valid_samples

    def __getitem__(self, index):
        path, target = self.samples[index]
        try:
            image = cv2.imread(str(path))
            if image is None:
                raise ValueError(f"Failed to load image: {path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            if self.transform:
                image = self.transform(image=image)['image']
            return image, target
        except Exception as e:
            logger.error(f"Error loading {path}: {e}")
            dummy = np.zeros((224, 224, 3), dtype=np.uint8)
            if self.transform:
                dummy = self.transform(image=dummy)['image']
            return dummy, target

# ------------------ Load Data --------------------
try:
    if not TRAIN_DIR.exists() or not VAL_DIR.exists():
        raise FileNotFoundError(f"Dataset directories not found: {TRAIN_DIR}, {VAL_DIR}")
    train_ds = AlbumentationsDataset(TRAIN_DIR, transform=train_tfms)
    val_ds = AlbumentationsDataset(VAL_DIR, transform=val_tfms)
except Exception as e:
    logger.error(f"Failed to load dataset: {e}")
    raise

if len(train_ds) == 0 or len(val_ds) == 0:
    logger.error("One or both datasets are empty. Check directory contents and file extensions.")
    raise ValueError("Empty dataset")

class_names = train_ds.classes
num_classes = len(class_names)
logger.info(f"📊 Classes: {num_classes} | Train: {len(train_ds)} | Val: {len(val_ds)}")

train_loader = DataLoader(
    train_ds, batch_size=CONFIG['batch_size'], shuffle=True,
    num_workers=CONFIG['num_workers']
)
val_loader = DataLoader(
    val_ds, batch_size=CONFIG['batch_size'], shuffle=False,
    num_workers=CONFIG['num_workers']
)
dataloaders = {'train': train_loader, 'val': val_loader}

# ------------------- Model Builder -------------------
def build_model(model_name=CONFIG['model_name'], num_classes=15):
    try:
        weights_url = "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth"
        cache_dir = Path("/root/.cache/torch/hub/checkpoints")
        cache_dir.mkdir(parents=True, exist_ok=True)
        weights_path = cache_dir / "efficientnet_b0_rwightman-7f5810bc.pth"

        if not weights_path.exists():
            logger.info(f"Downloading weights from {weights_url}")
            urllib.request.urlretrieve(weights_url, weights_path)

        model = models.efficientnet_b0(weights=None)
        state_dict = torch.load(weights_path, map_location=device)
        model.load_state_dict(state_dict, strict=False)
        in_features = model.classifier[1].in_features
        model.classifier = nn.Linear(in_features, num_classes)
        logger.info(f"Model {model_name} initialized with manual weights")
        return model.to(device)
    except Exception as e:
        logger.warning(f"Failed to load pretrained weights: {e}. Falling back to random initialization.")
        model = models.efficientnet_b0(weights=None)
        in_features = model.classifier[1].in_features
        model.classifier = nn.Linear(in_features, num_classes)
        logger.info(f"Model {model_name} initialized without pretrained weights")
        return model.to(device)

model = build_model(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['lr'])
scheduler = CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'], eta_min=1e-6)

# ------------------- Training ---------------------
def train_model(model, dataloaders, epochs, patience):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_losses, val_losses = [], []
    trigger = 0
    checkpoint_dir = Path("checkpoints_v5")
    checkpoint_dir.mkdir(exist_ok=True)

    with open("metrics_v5.csv", "w") as f:
        f.write("epoch,phase,loss,accuracy\n")

    for epoch in range(epochs):
        logger.info(f"\n📅 Epoch {epoch+1}/{epochs}")
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss, correct, total = 0.0, 0, 0
            preds_all, labels_all = [], []

            try:
                for batch_idx, (inputs, labels) in enumerate(tqdm(dataloaders[phase], desc=phase.upper())):
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    preds = outputs.argmax(dim=1)
                    preds_all.extend(preds.cpu().numpy())
                    labels_all.extend(labels.cpu().numpy())
                    running_loss += loss.item() * inputs.size(0)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

                    batch_acc = (preds == labels).float().mean().item()
                    logger.info(f"Batch {batch_idx+1}: Loss: {loss.item():.4f}, Acc: {batch_acc:.4f}")

            except Exception as e:
                logger.error(f"Error in {phase} phase: {e}")
                raise

            epoch_loss = running_loss / total if total > 0 else float('inf')
            epoch_acc = correct / total if total > 0 else 0.0
            epoch_f1 = f1_score(labels_all, preds_all, average='macro') if labels_all else 0.0

            logger.info(f"🔸 {phase.title()} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f} | F1: {epoch_f1:.4f}")

            with open("metrics_v5.csv", "a") as f:
                f.write(f"{epoch+1},{phase},{epoch_loss:.4f},{epoch_acc:.4f}\n")

            if phase == 'val':
                scheduler.step()
                val_losses.append(epoch_loss)

                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': epoch_loss,
                    'accuracy': epoch_acc
                }, checkpoint_dir / f"model_epoch_{epoch}.pth")
                logger.info(f"💾 Checkpoint saved: model_epoch_{epoch}.pth")

                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(best_model_wts, checkpoint_dir / "best_model_v5.pth")
                    logger.info("💾 Best model saved: best_model_v5.pth")
                else:
                    trigger += 1
            else:
                train_losses.append(epoch_loss)

            if trigger >= patience:
                logger.info("⛔ Early stopping activated.")
                break

    model.load_state_dict(best_model_wts)
    logger.info(f"\n🏆 Best Val Accuracy: {best_acc:.4f}")

    torch.save(model.state_dict(), "nucleusV5_weights.pth")
    torch.save(model, "nucleusV5_model.pth")
    logger.info("💾 Final model saved as nucleusV5_model.pth and nucleusV5_weights.pth")

    return model, train_losses, val_losses, preds_all, labels_all

# ------------------ Train & Save -------------------
try:
    model, train_losses, val_losses, preds_all, labels_all = train_model(model, dataloaders, CONFIG['epochs'], CONFIG['patience'])
except Exception as e:
    logger.error(f"Training failed: {e}")
    raise

# ------------------- Visualizations -----------------






TRAIN: 100%|██████████| 218/218 [00:23<00:00,  9.32it/s]
VAL: 100%|██████████| 55/55 [00:02<00:00, 24.42it/s]
TRAIN: 100%|██████████| 218/218 [00:24<00:00,  8.94it/s]
VAL: 100%|██████████| 55/55 [00:02<00:00, 23.32it/s]
TRAIN: 100%|██████████| 218/218 [00:23<00:00,  9.10it/s]
VAL: 100%|██████████| 55/55 [00:02<00:00, 23.95it/s]
TRAIN: 100%|██████████| 218/218 [00:24<00:00,  9.04it/s]
VAL: 100%|██████████| 55/55 [00:02<00:00, 23.76it/s]
TRAIN: 100%|██████████| 218/218 [00:23<00:00,  9.23it/s]
VAL: 100%|██████████| 55/55 [00:02<00:00, 23.57it/s]


In [12]:
# ------------------ Train & Save -------------------
# model, train_losses, val_losses, preds_all, labels_all = train_model(model, dataloaders, CONFIG['epochs'], CONFIG['patience'])

# Save both state dict and full model
# torch.save(model.state_dict(), "nucleusV4_weights.pth")
# torch.save(model, "nucleusV4_model.pth")
# logger.info("💾 Model saved as nucleusV4_model.pth and nucleusV4_weights.pth")