# Grain Variety Classification – Generalization Track

**Course:** M1 Artificial Intelligence – AI Challenge  
**Group:** Group 1 – Grain (Generalization)  
**Institution:** Université Paris-Saclay  

## Objective
This notebook provides a **baseline starting kit** for the Grain 1 generalization task.
It demonstrates dataset loading, a simple baseline model, and an evaluation pipeline.

This baseline is intentionally simple and serves as a reference for further improvements.

The boolean variable COLAB is used to detect whether the notebook is executed on Google Colab, enabling environment-specific setup when needed.

In [2]:
# Detect whether we are running on Google Colab
COLAB = "google.colab" in str(get_ipython())

if COLAB:
    !git clone --depth 1 https://github.com/md-naim-hassan-saykat/grain-1-generalization-ai-challenge.git
    %cd grain-1-generalization-ai-challenge/Starting_Kit

Cloning into 'grain-1-generalization-ai-challenge'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 15 (delta 0), reused 9 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (15/15), 223.07 KiB | 1.66 MiB/s, done.
/content/grain-1-generalization-ai-challenge/Starting_Kit


# 0 - Imports & Settings

In [3]:
import os
import random
import json
import zipfile
import datetime

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score, classification_report

# 1 - Data
This section handles dataset loading and basic preprocessing for the Grain 1
generalization challenge.

We assume the dataset is organized as:
data/
 ├── train/
 │    ├── class_1/
 │    ├── class_2/
 │    └── ...
 ├── val/
 └── test/

In [4]:
class Data:
    """
    Data loader class for the Grain 1 Generalization challenge.
    Expected folder structure:
      data_dir/
        train/<class_name>/*.jpg
        val/<class_name>/*.jpg   (or valid/)
        test/<class_name>/*.jpg
    """

    def __init__(self, data_dir, batch_size=32, img_size=224, num_workers=2):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ])

    def load_split(self, split):
        split_dir = os.path.join(self.data_dir, split)
        if not os.path.isdir(split_dir):
            raise FileNotFoundError(f"Directory not found: {split_dir}")

        dataset = datasets.ImageFolder(root=split_dir, transform=self.transform)

        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=(split == "train"),
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=(self.num_workers > 0),
            drop_last=(split == "train"),
        )

        return dataset, loader

    def load_data(self):
        data = {}

        # allow "val" or "valid"
        split_candidates = {
            "train": ["train"],
            "val": ["val", "valid"],
            "test": ["test"],
        }

        for canonical, candidates in split_candidates.items():
            loaded = False
            for split in candidates:
                try:
                    dataset, loader = self.load_split(split)
                    data[canonical] = {
                        "dataset": dataset,
                        "loader": loader,
                        "num_samples": len(dataset),
                        "num_classes": len(dataset.classes),
                        "classes": dataset.classes,
                        "root": os.path.join(self.data_dir, split),
                    }
                    loaded = True
                    break
                except FileNotFoundError:
                    continue

            if not loaded:
                print(f"Skipping missing split: {canonical} (looked for: {candidates})")

        if not data:
            print(
                "\nNo data found.\n"
                "Expected:\n"
                f"  {self.data_dir}/train/<class_name>/...\n"
                f"  {self.data_dir}/val(or valid)/<class_name>/...\n"
                f"  {self.data_dir}/test/<class_name>/...\n"
            )

        return data

In [5]:
# Example usage
data = Data(data_dir="./data", batch_size=32, img_size=224, num_workers=2)
data_dict = data.load_data()

if not data_dict:
    print("data_dict is empty (dataset not available yet).")
else:
    for split, info in data_dict.items():
        print(
            f"{split}: {info['num_samples']} samples | "
            f"{info['num_classes']} classes"
        )

Skipping missing split: train (looked for: ['train'])
Skipping missing split: val (looked for: ['val', 'valid'])
Skipping missing split: test (looked for: ['test'])

No data found.
Expected:
  ./data/train/<class_name>/...
  ./data/val(or valid)/<class_name>/...
  ./data/test/<class_name>/...

data_dict is empty (dataset not available yet).


# 2 - Visualization
This section provides basic visualization utilities to inspect the dataset.
It helps verify that images and labels are loaded correctly before training.


In [6]:
class Visualize:
    def __init__(self, data_dict):
        """
        Args:
            data_dict (dict): Output of Data.load_data()
        """
        self.data_dict = data_dict

    def plot_samples(self, split="train", n_samples=5):
        """
        Plot a few sample images from a given split.
        """
        if split not in self.data_dict:
            print(f"Split '{split}' not found. Available: {list(self.data_dict.keys())}")
            return

        dataset = self.data_dict[split]["dataset"]
        n_samples = min(n_samples, len(dataset))

        fig, axes = plt.subplots(1, n_samples, figsize=(15, 4))
        if n_samples == 1:
            axes = [axes]

        for ax, idx in zip(axes, range(n_samples)):
            img, label = dataset[idx]
            ax.imshow(img)
            ax.set_title(f"Label: {label}")
            ax.axis("off")

        plt.suptitle(f"{split.capitalize()} samples")
        plt.show()

In [7]:
# Example usage
data = Data(data_dir="./data", batch_size=32)
data_dict = data.load_data()

if "train" in data_dict:
    visualize = Visualize(data_dict)
    visualize.plot_samples(split="train", n_samples=5)
else:
    print("Skipping visualization: 'train' split not available.")

Skipping missing split: train (looked for: ['train'])
Skipping missing split: val (looked for: ['val', 'valid'])
Skipping missing split: test (looked for: ['test'])

No data found.
Expected:
  ./data/train/<class_name>/...
  ./data/val(or valid)/<class_name>/...
  ./data/test/<class_name>/...

Skipping visualization: 'train' split not available.


# 3 - Training
This section implements a simple **baseline training pipeline**.
It supports two modes:

- **Dummy mode** (used when no dataset is available yet)
- **Real data mode** (automatically activated once data is provided)

The goal is to validate the end-to-end training and evaluation workflow.

In [8]:
class SimpleCNN(nn.Module):
    """
    Minimal CNN baseline for image classification.
    """

    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 56 * 56, 128),  # assumes 224x224 input
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [9]:
class Train:
    def __init__(self, data_dict, device=None):
        self.data_dict = data_dict
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None

    def train_dummy(self, epochs=2):
        print("Training in dummy mode (no real data found).")

        X = torch.randn(128, 3, 224, 224)
        y = torch.randint(0, 5, (128,))

        ds = TensorDataset(X, y)
        dl = DataLoader(ds, batch_size=16, shuffle=True)

        model = SimpleCNN(num_classes=5).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-4)

        for epoch in range(epochs):
            model.train()
            losses = []

            for xb, yb in dl:
                xb, yb = xb.to(self.device), yb.to(self.device)
                optimizer.zero_grad()
                outputs = model(xb)
                loss = criterion(outputs, yb)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                losses.append(loss.item())

            print(
                f"Epoch {epoch+1}/{epochs} | "
                f"Loss (mean): {np.mean(losses):.4f} | "
                f"Loss (max): {np.max(losses):.4f}"
            )

        self.model = model

    def train_real(self, epochs=5):
        # keep your real training code here (or leave placeholder)
        print("Real data mode not implemented yet.")
        self.model = None

    def train(self):
        if "train" not in self.data_dict:
            self.train_dummy()
        else:
            self.train_real()

In [10]:
# Example usage
trainer = Train(data_dict)
trainer.train()

Training in dummy mode (no real data found).
Epoch 1/2 | Loss (mean): 1.6985 | Loss (max): 1.8643
Epoch 2/2 | Loss (mean): 1.5838 | Loss (max): 1.6636


# 4 - Scoring
This section evaluates a trained model using standard classification metrics.
If validation or test data is not available yet, scoring is skipped gracefully.

Once real data is provided, this section can be used to compare performance
across domains and assess generalization.

In [11]:
class Score:
    def __init__(self, model, data_dict, device=None):
        """
        Args:
            model (nn.Module): trained model
            data_dict (dict): output of Data.load_data()
            device (str): 'cpu' or 'cuda'
        """
        self.model = model
        self.data_dict = data_dict
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.results = {}

    def evaluate(self, split="val"):
        """
        Evaluate model on a given split.
        """
        if split not in self.data_dict:
            print(f"Cannot evaluate: split '{split}' not available.")
            return None

        loader = self.data_dict[split]["loader"]
        self.model.eval()

        y_true, y_pred = [], []

        with torch.no_grad():
            for images, labels in loader:
                images = images.to(self.device)
                outputs = self.model(images)
                preds = torch.argmax(outputs, dim=1).cpu().numpy()

                y_true.extend(labels.numpy())
                y_pred.extend(preds)

        acc = accuracy_score(y_true, y_pred)

        self.results[split] = {
            "accuracy": acc,
            "report": classification_report(y_true, y_pred, output_dict=True)
        }

        print(f"{split.capitalize()} Accuracy: {acc:.4f}")
        return self.results[split]

    def summary(self):
        """
        Print a summary of all computed scores.
        """
        if not self.results:
            print("No scores computed yet.")
            return

        print("\nScoring Summary")
        for split, res in self.results.items():
            print(f"- {split}: accuracy = {res['accuracy']:.4f}")

In [12]:
# Example usage
if trainer.model is None:
    print("No trained model available for scoring yet.")
else:
    scorer = Score(trainer.model, data_dict)

    # Prefer val, else test
    if "val" in data_dict:
        scorer.evaluate("val")
    elif "test" in data_dict:
        scorer.evaluate("test")
    else:
        print(
            "Scoring skipped: no 'val' or 'test' split found.\n"
            f"Available splits: {list(data_dict.keys())}\n"
            "Add data to ./data/val or ./data/test to enable scoring."
        )

    scorer.summary()

Scoring skipped: no 'val' or 'test' split found.
Available splits: []
Add data to ./data/val or ./data/test to enable scoring.
No scores computed yet.


# 5 - (Optional) Prepare submission for Codabench
This section prepares a submission ZIP compatible with Codabench.

Depending on the competition setup, the submission can contain:
- the trained **model checkpoint** (code submission), or
- the **predictions file** (result submission).

This is provided as a template and can be adapted once the final
Codabench submission format is defined.

***

In this section you should prepare a zip of the trained model (if your competition is a code submission competition) or zip of the predictions (if your competition is a result submission competition).

***

In [19]:
class Submission:
    """
    Creates a clean Codabench submission zip.

    Best practice:
    - Save artifacts into ./submission/
    - Create the .zip OUTSIDE ./submission/ to avoid nesting old zips
    """

    def __init__(self, submission_dir="./submission", zip_file_name=None, clean_dir=True):
        self.submission_dir = submission_dir
        os.makedirs(self.submission_dir, exist_ok=True)

        # Optional: clean the folder to avoid zipping older artifacts
        if clean_dir:
            for fn in os.listdir(self.submission_dir):
                fp = os.path.join(self.submission_dir, fn)
                if os.path.isfile(fp):
                    os.remove(fp)

        if zip_file_name is None:
            zip_file_name = f"Submission_{datetime.datetime.now().strftime('%y-%m-%d-%H-%M')}.zip"
        self.zip_file_name = zip_file_name

        # IMPORTANT: zip path is outside submission_dir
        self.zip_path = os.path.join(".", self.zip_file_name)

    def write_readme(self):
        readme_path = os.path.join(self.submission_dir, "README.txt")
        with open(readme_path, "w") as f:
            f.write(
                "Grain 1 – Generalization AI Challenge\n"
                "Submission generated by starting kit.\n"
                "Contains model checkpoint OR predictions.\n"
            )
        return readme_path

    def save_code(self, model):
        """Save trained model checkpoint (code submission)."""
        if model is None:
            print("No trained model available to save.")
            return None

        model_path = os.path.join(self.submission_dir, "model.pth")

        # If it's a torch.nn.Module, save state_dict; otherwise try saving directly
        if hasattr(model, "state_dict"):
            torch.save(model.state_dict(), model_path)
        else:
            torch.save(model, model_path)

        print(f"Model saved at: {model_path}")
        return model_path

    def save_result(self, predictions=None):
        """Save predictions file (result submission)."""
        if predictions is None:
            print("No predictions provided. Saving dummy predictions.")
            predictions = np.zeros(10, dtype=int)

        pred_path = os.path.join(self.submission_dir, "predictions.json")
        with open(pred_path, "w") as f:
            json.dump({"predictions": list(map(int, predictions))}, f)

        print(f"Predictions saved at: {pred_path}")
        return pred_path

    def zip_submission(self):
        """Create zip from submission_dir contents (no nested zip issue)."""
        # Remove existing zip if present
        if os.path.exists(self.zip_path):
            os.remove(self.zip_path)

        with zipfile.ZipFile(self.zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
            for filename in os.listdir(self.submission_dir):
                file_path = os.path.join(self.submission_dir, filename)
                if os.path.isfile(file_path):
                    zf.write(file_path, arcname=filename)

        size_mb = os.path.getsize(self.zip_path) / (1024 * 1024)
        print(f"Submission ZIP saved at: {self.zip_path} ({size_mb:.2f} MB)")

        # Quick integrity test
        with zipfile.ZipFile(self.zip_path, "r") as zf:
            ok = (zf.testzip() is None)
            print("ZIP test:", "OK" if ok else "Corrupt")
            print("Files in ZIP:", zf.namelist())

        return self.zip_path

In [20]:
# Example usage
submission = Submission(submission_dir="./submission", clean_dir=True)
submission.write_readme()
submission.save_code(trainer.model)
zip_path = submission.zip_submission()

# Download in Colab
try:
    from google.colab import files
    files.download(zip_path)
except Exception as e:
    print("Not running in Colab or download failed:", e)
    print("ZIP is available at:", zip_path)

Model saved at: ./submission/model.pth
Submission ZIP saved at: ./Submission_26-01-10-02-04.zip (45.48 MB)
ZIP test: OK
Files in ZIP: ['model.pth', 'README.txt']


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>