# Source: https://github.com/apple/corenet/blob/main/tutorials/train_a_new_model_on_a_new_dataset_from_scratch.ipynb

# Install dependencies

## Repository: https://github.com/apple/corenet/tree/main

# Create structure

In [1]:
! mkdir -p corenet/projects/playground_cifar10/classification

# Create config

In [2]:
%%file corenet/projects/playground_cifar10/classification/cifar10.yaml

common:
    log_freq: 2000                 # Log the training metrics every 2000 iterations.

dataset:
    category: classification
    name: "cifar10"                # We'll register the "cifar10" name at DATASET_REGISTRY later in this tutorial.

    # The `corenet-train` entrypoint uses train_batch_size0 and val_batch_size0 values to construct 
    # training/validation batches during training. The `corenet-eval` entrypoint uses eval_batch_size0 to 
    # construct batches during evaluation (ie test).
    #
    # The effective batch size is: num_nodes x num_gpus x train_batch_size0
    train_batch_size0: 4
    val_batch_size0: 4
    eval_batch_size0: 1

    workers: 2
    persistent_workers: true
    pin_memory: true

model:
    classification:
        name: "two_layer"          # We'll register the "two_layer" name at MODEL_REGISTRY later in this tutorial.
        n_classes: 10

    layer:
        # Weight initialization parameters:
        conv_init: "kaiming_normal"
        linear_init: "trunc_normal"
        linear_init_std_dev: 0.02


sampler:
    name: batch_sampler

    # The following dimensions will be passed to the dataset.__get__ method, and the dataset produces samples 
    # cropped and resized to the requested dimensions. 
    bs:
        crop_size_width: 32
        crop_size_height: 32

loss:
    category: classification
    classification:
        name: cross_entropy       # The implemention is available in "corenet/loss_fn/" folder.

optim:
    name: sgd
    sgd:
        momentum: 0.9

scheduler:
    name: fixed                    # The implementation is available in "corenet/optims/scheduler/" folder.
    max_epochs: 2
    fixed:
        lr: 0.001                  # Fixed Learning Rate

stats:
  val: ["loss", "top1"]            # Metrics to log
  train: ["loss", "top1"]
  checkpoint_metric: top1          # Assigns a checkpoint to results/checkpoint_best.pt
  checkpoint_metric_max: true

Writing corenet/projects/playground_cifar10/classification/cifar10.yaml


# Register Dataset

In [3]:
%%file corenet/corenet/data/datasets/classification/playground_dataset.py

from argparse import Namespace
from typing import Any, Dict, Tuple

import torchvision
import torchvision.transforms as transforms

from corenet.data.datasets import DATASET_REGISTRY
from corenet.data.datasets.dataset_base import BaseDataset


@DATASET_REGISTRY.register(name="cifar10", type="classification")
class Cifar10(BaseDataset):
    CLASS_NAMES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    def __init__(self, opts: Namespace, **kwargs) -> None:
        super().__init__(opts, **kwargs)
        self._torchvision_dataset = torchvision.datasets.CIFAR10(
            "/tmp/cifar10_cache",
            train=self.is_training,
            download=True,
        )

    def __len__(self) -> int:
        return len(self._torchvision_dataset)

    def __getitem__(self, sample_size_and_index: Tuple[int]) -> Dict[str, Any]:
        # In CoreNet, not only does the sampler determine the index of the samples, but
        # also the sampler determines the crop size dynamically for each batch. This
        # allows samplers to train multi-scale models more efficiently.
        # See: corenet/data/sampler/variable_batch_sampler.py
        (crop_size_h, crop_size_w, sample_index) = sample_size_and_index

        img, target = self._torchvision_dataset[sample_index]

        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Resize(size=(crop_size_h, crop_size_w)),
            ]
        )
        img = transform(img)
        return {
            "samples": img,
            "targets": target,
        }

Writing corenet/corenet/data/datasets/classification/playground_dataset.py


# Register Model

In [4]:
%%file corenet/corenet/modeling/models/classification/playground_model.py

import argparse

import torch
import torch.nn.functional as F
from torch import nn

from corenet.modeling.models import MODEL_REGISTRY
from corenet.modeling.models.base_model import BaseAnyNNModel


@MODEL_REGISTRY.register("two_layer", type="classification")
class Net(BaseAnyNNModel):
    """A simple 2-layer CNN, inspired by https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html"""

    def __init__(self, opts: argparse.Namespace) -> None:
        super().__init__(opts)
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.reset_parameters(opts)  # Initialize the weights

    def forward(self, x: torch.Tensor):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Writing corenet/corenet/modeling/models/classification/playground_model.py


# Start training (in shell in corenet directory)

# Evaulate after training (in shell in corenet directory)

# Visualize result

In [6]:
from corenet.options.opts import get_training_arguments
from corenet.modeling import get_model
from PIL import Image
import torch
from torchvision.transforms import Compose, Resize, PILToTensor, CenterCrop
from torchvision.transforms import ToPILImage
from corenet.data.datasets.classification.playground_dataset import Cifar10

config_file = "corenet/projects/playground_cifar10/classification/cifar10.yaml"
pretrained_weights = "corenet/results/run_1/checkpoint_best.pt"

opts = get_training_arguments(
    args=[
        "--common.config-file",
        config_file,
        "--model.classification.pretrained",
        pretrained_weights,
    ]
)

# Load the model
model = get_model(opts)
model.eval()

for image_path in ["corenet/assets/cat.jpeg", "corenet/assets/dog.jpeg"]:
    image = Image.open(image_path).convert("RGB")
    img_transforms = Compose([CenterCrop(600), Resize(size=(32, 32)), PILToTensor()])

    # Transform the image, normalize between 0 and 1
    input_tensor = img_transforms(image)

    # Show the transformed image
    ToPILImage()(input_tensor).show()

    input_tensor = input_tensor.to(torch.float).div(255.0)

    # add dummy batch dimension
    input_tensor = input_tensor[None, ...]

    with torch.no_grad():
        logits = model(input_tensor)[0]
        probs = torch.softmax(logits, dim=-1)
        predictions = sorted(zip(probs.tolist(), Cifar10.CLASS_NAMES), reverse=True)
        print(
            "Top 3 Predictions:",
            [f"{cls}: {prob:.1%}" for prob, cls in predictions[:3]],
        )

ModuleNotFoundError: No module named 'corenet.data'