- Python >=3.10
- Poetry
If simply using the package on cpu:
poetry install --with cpuIf developing you can add any of the following options:
poetry install --with dev,test,gpuThen use poetry shell to enter the virtualenv.
You need to split the images into a validation and a train folders.
For each class, place all the images in a folder with the class's name.
You then need to create a classes.names file next to the train and validation folders, with the names of the classes (one per line).
Structure example
cifar-10/ ├── Train/ │ ├── airplaine │ ├── automobile │ ├── bird │ ├── cat │ └── ... ├── Validation/ │ ├── airplaine │ ├── automobile │ ├── bird │ ├── cat │ └── ... └── classes.namesCIFAR-10 instructions
The commands below will download, extract and format the cifar 10 dataset into the ./data/cifar_10_images folder.
wget https://www.cs.toronto.edu/\~kriz/cifar-10-python.tar.gz -P data
tar -xvf data/cifar-10-python.tar.gz -C data
python utils/cifar_10.py data/cifar-10-batches-py
rm data/cifar-10-python.tar.gz
rm -r data/cifar-10-batches-py/Note:
You'll need to modify a few values in config/model_config.py in the next step since cifar10's images are small.
CROP_IMAGE_SIZES: tuple[int, int] = (32, 32) # Center crop
RESIZE_IMAGE_SIZES: tuple[int, int] = (32, 32) # All images will be resized to this size
...
CHANNELS: list[int] = field(default_factory=lambda: [3, 16, 32, 16])
SIZES: list[int | tuple[int, int]] = field(default_factory=lambda: [3, 3, 3]) # Kernel sizes
STRIDES: list[int | tuple[int, int]] = field(default_factory=lambda: [2, 2, 2])
PADDINGS: list[int | tuple[int, int]] = field(default_factory=lambda: [1, 1, 1])
BLOCKS: list[int] = field(default_factory=lambda: [1, 2, 1])Imagenette instructions
The commands below will download, extract and format the cifar 10 dataset into the ./data/cifar_10_images folder.
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz -P data
tar -xvf data/imagenette2.tgz -C data
python utils/preprocess_imagenette.py data/imagenette2
rm data/imagenette2.tgzIn the config folder of this repo you will find two config template files. You need to copy them and remove the "_template" part like this:
cp config/data_config_template.py config/data_config.py
cp config/model_config_template.py config/model_config.py
Contains config for recording TensorBoard and checkpoints. You probably just want to modify _training_name.
Contains the parameters that influence training. Most default values should work okayish, but you'll need to modify a few:
MAX_EPOCHS: usually around 400 epochs is enough, you will need to train at least once to get an idea for your particular dataset.IMG_MEANandIMG_STD: The defaults are the imagenet ones. You can keep them as long as they are not too different from the actual ones (especially if using a pretrained model).
Imagenette example
The default, gitted config should give decent-ish (~85% val acc) result.Cifar-10 example
If training on Cifar-10, you'll need to modify the model in the config `src/classfication/configs/train_config.py` since cifar10's images are small. You'll also need to remove/modify the resize hardcoded in `src/classfication/train.py`. ```python MODEL: ModelHelper = ModelHelper.SmallDarknet CHANNELS: list[int] = field(default_factory=lambda: [3, 16, 32, 16]) SIZES: list[int | tuple[int, int]] = field(default_factory=lambda: [3, 3, 3]) # Kernel sizes STRIDES: list[int | tuple[int, int]] = field(default_factory=lambda: [2, 2, 2]) PADDINGS: list[int | tuple[int, int]] = field(default_factory=lambda: [1, 1, 1]) BLOCKS: list[int] = field(default_factory=lambda: [1, 2, 1]) ```Once you have the environment all set up and your two config files ready, training an AI is straightforward.
classification-train \
--train_data_path <path to train dataset> \
--val_data_path <path to val dataset> \
--classes_names_path <path to classes.names file>Imagenette example
classification-train \
--train_data_path data/imagenette2/train/ \
--val_data_path data/imagenette2/val/ \
--classes_names_path data/imagenette2/classes.namesThe resulting checkpoints can be found in CHECKPOINTS_DIR (see the RecordConfig).
The resulting checkpoints can be found in TB_DIR (see the RecordConfig).
classification-test \
checkpoints/imagenette_resnet32/train_50.pt \
data/imagenette2/val \
--classes_names_path data/imagenette2/classes.names \
--limit 100classification-gradcam \
checkpoints/imagenette_resnet32/train_50.pt \
data/imagenette2/val \
--classes_names_path data/imagenette2/classes.names \
--limit 10