# Example of the `aitlas` toolbox in the context of multi class image classification

This notebook shows a sample implementation of a multi class image classification using the `aitlas` toolbox and the CLRS dataset.

In [None]:
from aitlas.datasets import CLRSDataset
from aitlas.models import ResNet50
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor
from aitlas.utils import image_loader

## Load the dataset

In [None]:
dataset_config = {
    "data_dir": "./data/CLRS",
    "csv_file": "./data/CLRS/train.csv"
}
dataset = CLRSDataset(dataset_config)

## Show images from the dataset

In [None]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(80)
fig3 = dataset.show_batch(15)

## Inspect the data

In [None]:
dataset.show_samples()

In [None]:
dataset.data_distribution_table()

In [None]:
fig = dataset.data_distribution_barchart()

## Load train and test splits

In [None]:
train_dataset_config = {
    "batch_size": 16,
    "shuffle": True,
    "num_workers": 4,
    "data_dir": "./data/CLRS",
    "csv_file": "./data/CLRS/train.csv"
}

train_dataset = CLRSDataset(train_dataset_config)
train_dataset.transform = ResizeCenterCropFlipHVToTensor() 

test_dataset_config = {
    "batch_size": 4,
    "shuffle": False,
    "num_workers": 4,
    "data_dir": "./data/CLRS",
    "csv_file": "./data/CLRS/test.csv",
    "transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}

test_dataset = CLRSDataset(test_dataset_config)
len(train_dataset), len(test_dataset)

## Setup and create the model for training

In [None]:
epochs = 10
model_directory = "./experiments/CLRS"
model_config = {
    "num_classes": 25, 
    "learning_rate": 0.0001,
    "pretrained": True,
    "metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = ResNet50(model_config)
model.prepare()

## Training and evaluation

In [None]:
model.train_and_evaluate_model(
    train_dataset=train_dataset,
    epochs=epochs,
    model_directory=model_directory,
    val_dataset=test_dataset,
    run_id='1',
)

## Predictions

In [None]:
model_path = "./experiments/CLRS/checkpoint.pth.tar"
#labels = CLRSDataset.labels
labels = ["airport", "bare-land", "beach", "bridge", "commercial", "desert", "farmland", "forest", "golf-course",
          "highway", "industrial", "meadow", "mountain", "overpass", "park", "parking", "playground", "port", "railway",
          "railway-station", "residential", "river", "runway", "stadium", "storage-tank"]
transform = ResizeCenterCropToTensor()
model.load_model(model_path)

image = image_loader('./data/predict/image1.tif')
fig = model.predict_image(image, labels, transform)

image = image_loader('./data/predict/image2.tif')
fig = model.predict_image(image, labels, transform)

image = image_loader('./data/predict/image3.tif')
fig = model.predict_image(image, labels, transform)

image = image_loader('./data/predict/image4.tif')
fig = model.predict_image(image, labels, transform)