In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the complete path to your dataset
DATASET_PATH = '/content/drive/MyDrive/Computer Vision/50.035 CV Team 9'

# Change directory to the dataset location
%cd "/content/drive/MyDrive/Computer Vision/50.035 CV Team 9"

# Verify the path exists (optional check)
import os
assert os.path.exists(DATASET_PATH), "[!] Dataset path does not exist. Please check the path."

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1Mdz9CpJD5zYhDk1e3Ch93fV4o95Ud7HJ/50.035 CV Team 9


In [None]:
%%capture
def is_running_in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_running_in_colab():
  # Normal packages
  %pip install lightning polars
  # Dev packages
  %pip install icecream rich tqdm

In [None]:
import torch
import polars as pl
import json
import numpy as np
import lightning as L
from lightning.pytorch.callbacks import RichProgressBar

from icecream import i

!pip install torchmetrics
!pip install torch torchvision lightning torchmetrics polars icecream



In [None]:
import torch
import lightning as L
import torchmetrics
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.io import decode_image
from torchvision.transforms import v2
import numpy as np
import polars as pl
from pathlib import Path
from lightning.pytorch.callbacks import ModelCheckpoint

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe: pl.DataFrame, training=False):
        super().__init__()
        paths = dataframe.select('image_path').to_numpy().squeeze()
        self.image_path = np.array([
            str(Path('plantvillage_dataset/color') / '/'.join(Path(p).parts[-2:]))
            for p in paths
        ])
        self.disease_type = dataframe.select('disease_type').to_numpy().squeeze()
        self.disease_to_idx = {disease: i for i, disease in enumerate(np.unique(self.disease_type))}
        self.training = training
        self.train_transforms = v2.Compose([
            v2.RandomHorizontalFlip(),
            v2.RandomVerticalFlip(),
            v2.RandomRotation(30),
            v2.RandomResizedCrop(224, scale=(0.8, 1.0)),  # EfficientNet expects 224x224
            v2.RandomErasing(),
        ])
        self.transforms = v2.Compose([
            v2.Resize(224),  # EfficientNet expects 224x224
            v2.ToDtype(torch.float32, scale=True),
        ])

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, idx):
        image = decode_image(self.image_path[idx])
        if self.training:
            image = self.train_transforms(image)
        image = self.transforms(image)
        disease = self.disease_to_idx[self.disease_type[idx]]
        return image, disease

class PlantVillageData(L.LightningDataModule):
    def __init__(self, ws_root: Path = Path("."), batch_size=32, num_workers=0):
        super().__init__()
        metadata_path = ws_root / 'plantvillage_dataset' / 'metadata'
        self.train_ds = ImageDataset(pl.read_csv(metadata_path / 'resampled_training_set.csv').filter(pl.col('image_path').str.contains('augment').eq(False)), training=True)
        self.val_ds = ImageDataset(pl.read_csv(metadata_path / 'validation_set.csv'))
        self.test_ds = ImageDataset(pl.read_csv(metadata_path / 'test_set.csv'))

        self.n_classes = len(self.train_ds.disease_to_idx)
        self.idx_to_disease = {v: k for k, v in self.train_ds.disease_to_idx.items()}
        self.batch_size = batch_size

        self.dataloader_extras = dict(
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=num_workers > 0
        )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, **self.dataloader_extras)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size*2, **self.dataloader_extras)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_ds, batch_size=self.batch_size*2, **self.dataloader_extras)

class EfficientNetModel(torch.nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        self.model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True),
            torch.nn.Linear(1280, n_classes),
        )
        self.model.num_classes = n_classes

    def forward(self, x):
        return self.model(x)

class LitEfficientNet(L.LightningModule):
    def __init__(self, n_classes, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = EfficientNetModel(n_classes)
        self.n_classes = n_classes
        self.learning_rate = learning_rate

        self.val_metrics = torchmetrics.MetricCollection(
            {
                "accuracy": torchmetrics.classification.Accuracy(task="multiclass", num_classes=n_classes),
                "f1": torchmetrics.classification.F1Score(task="multiclass", num_classes=n_classes),
                "auroc": torchmetrics.classification.AUROC(task="multiclass", num_classes=n_classes)
            },
            prefix="val_",
        )
        self.test_metrics = self.val_metrics.clone(prefix="test_")

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        loss = torch.nn.functional.cross_entropy(y_pred, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        loss = torch.nn.functional.cross_entropy(y_pred, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log_dict(self.val_metrics(y_pred, y), prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        self.log_dict(self.test_metrics(y_pred, y), prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=0.01
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=5,
            verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_accuracy",
                "frequency": 1
            },
        }

# Initialize data module
plantvillage_data = PlantVillageData(num_workers=15, batch_size=32)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/efficientnet',
    filename='plantvillage-{epoch:02d}-{val_accuracy:.2f}',
    save_top_k=3,
    monitor='val_accuracy',
    mode='max'
)

# Initialize trainer
trainer = L.Trainer(
    max_epochs=10,
    accelerator='gpu',
    callbacks=[checkpoint_callback, RichProgressBar()],
    logger=True,  # Using default TensorBoard logger
)

# Load checkpoint or create new model
checkpoints = list(Path('checkpoints/efficientnet').glob('*.ckpt'))
if checkpoints:
    latest_checkpoint = max(checkpoints, key=lambda x: x.stat().st_mtime)
    print(f"Resuming from checkpoint: {latest_checkpoint}")
    lit_model = LitEfficientNet.load_from_checkpoint(
        str(latest_checkpoint),
        n_classes=plantvillage_data.n_classes
    )
else:
    print("Starting fresh training...")
    lit_model = LitEfficientNet(
        n_classes=plantvillage_data.n_classes,
        learning_rate=1e-3
    )

# Train the model
trainer.fit(model=lit_model, datamodule=plantvillage_data)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Resuming from checkpoint: checkpoints/efficientnet/plantvillage-epoch=07-val_accuracy=0.99.ckpt


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 86.8MB/s]
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /content/drive/.shortcut-targets-by-id/1Mdz9CpJD5zYhDk1e3Ch93fV4o95Ud7HJ/50.035 CV Team 9/checkpoints/efficientnet exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


In [None]:
# Test model
trainer.test(model=lit_model, datamodule=plantvillage_data)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'test_accuracy': 0.9901208877563477,
  'test_auroc': 0.824338436126709,
  'test_f1': 0.9901208877563477}]

In [None]:
from pathlib import Path
import torch

# Define experiment name
exp_name = "efficientnet_plantvillage"

# Create save directory
model_save_path = Path("models") / "classification"
model_save_path = model_save_path / exp_name
model_save_path.mkdir(exist_ok=True, parents=True)

# Get the underlying model from the Lightning module
model = lit_model.model
model = model.eval().cpu()

# Save model weights
torch.save(model.state_dict(), model_save_path / f"weights_{exp_name}.pt")

# Save complete model
torch.save(model, model_save_path / f"model_{exp_name}.pt")

# Export using torch.export with modified constraints
# Define constraints allowing a small range around 224
_height = torch.export.Dim('_height', min=224, max=256)  # Allow slightly larger heights
_width = torch.export.Dim('_width', min=224, max=256)    # Allow slightly larger widths

dynamic_shapes = {
    "x": {
        0: torch.export.Dim("batch", min=1, max=256),  # Reasonable batch size limits
        1: 3,  # RGB channels (fixed)
        2: _height,
        3: _width,
    }
}

# Create example input
example_input = torch.randn(2, 3, 224, 224)

# Export with modified constraints
ep = torch.export.export(
    model,
    (example_input,),
    dynamic_shapes=dynamic_shapes,
)
torch.export.save(ep, model_save_path / f"export_{exp_name}.pt2")

print(f"Model saved to {model_save_path}")

Model saved to models/classification/efficientnet_plantvillage
