# Training a ResNet Model on Custom Waste Data

## Introduction

In this tutorial, we will walk through the process of training a ResNet (Residual Network) model on a custom dataset of waste images. Residual Networks, introduced by He et al. in the paper "Deep Residual Learning for Image Recognition," have proven to be highly effective in various image classification tasks.

Our goal is to build a robust image classification model capable of distinguishing between 11 types of waste.

## Objectives

By the end of this tutorial, you will be able to:
1. Set up a ResNet model using the PyTorch lightning framework .
2. Train the ResNet model on the waste dataset.
3. Evaluate the model's performance and visualize the results.
4. Save the trained model for future inference and deployment.

## Prerequisites

To follow along with this tutorial, you should have:

- An environment set up with Python, Jupyter Notebook, and relevant deep learning libraries (TensorFlow/Keras or PyTorch).
- The script train-resnet.py

## Dataset

For this tutorial, we will use a custom dataset of waste images, categorized into different types such as plastic, paper, metal, glass, and organic waste. Ensure the dataset is organized into separate folders for each class, as this structure will facilitate the training process.

Here is an example structure of the dataset:



waste_data/
├── plastic/
│ ├── plastic_1.jpg
│ ├── plastic_2.jpg
│ └── ...
├── paper/
│ ├── paper_1.jpg
│ ├── paper_2.jpg
│ └── ...
├── metal/
│ ├── metal_1.jpg
│ ├── metal_2.jpg
│ └── ...
├── glass/
│ ├── glass_1.jpg
│ ├── glass_2.jpg
│ └── ...
└── biowaste/
├── biowaste_1.jpg
├── biowaste.jpg
└── ...

In [2]:
!aws s3 sync s3://notebooks-wastetide/data/resnet/ data/

download: s3://data-10-classes/biowaste/bmasks/biowaste-0-1.png to ResNet/data/biowaste/bmasks/biowaste-0-1.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-11-1.png to ResNet/data/biowaste/bmasks/biowaste-11-1.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-10-0.png to ResNet/data/biowaste/bmasks/biowaste-10-0.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-1-1.png to ResNet/data/biowaste/bmasks/biowaste-1-1.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-12-1.png to ResNet/data/biowaste/bmasks/biowaste-12-1.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-1-0.png to ResNet/data/biowaste/bmasks/biowaste-1-0.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-10-1.png to ResNet/data/biowaste/bmasks/biowaste-10-1.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-0-0.png to ResNet/data/biowaste/bmasks/biowaste-0-0.png
download: s3://data-10-classes/biowaste/bmasks/biowaste-13-0.png to ResNet/data/biowaste/bmasks/

## Outline

The tutorial will be structured as follows:
1. **Data Preparation**: Load and preprocess the custom waste dataset.
2. **Model Setup**: Initialize a ResNet model.
3. **Training the Model**: Train the ResNet model on the prepared dataset.
4. **Evaluation**: Assess the model's performance using various metrics and visualization techniques.
5. **Saving the Model**: Save the trained model for future use.

Let's get started !

In [None]:
!pip install -r requirements.txt
%pip install -r requirements-resnet.txt

## 1. Data Preparation

The first step is to load and preprocess our custom waste images and annotations. 
We need to ensure that the images are properly formatted and resized to match the input requirements of the ResNet model.

### Steps:
- Load the dataset from the specified directory.
- Split the data into training, validation, and test sets.

This two steps are handled by make_train_valid_test_data


- Define the image transformations applied in training



In [42]:
from train_resnet import make_train_valid_test_data
DATASET_PATH = 'data3/'

train_data,val_data,test_data = make_train_valid_test_data(DATASET_PATH,size=(620,620))

there is 1371 images for training
there is 381 images for testing


In [43]:
from train_resnet import make_loader



train_loader = make_loader(train_data, batch_size=8, shuffle=True)
val_loader = make_loader(val_data, batch_size=8)
test_loader = make_loader(test_data, batch_size=8)

In [44]:
import torchvision
from torchvision import transforms
train_dataset = torchvision.datasets.ImageFolder(DATASET_PATH)
LABELS=train_dataset.classes

## 2&3. Model Setup & Training

In this section, we will set up the ResNet model and train it thanks to these two functions : 

from train_resnet import create_model,train

We use a pre-trained ResNet model and fine-tune it on our dataset.

### Steps:
- Import the necessary libraries.
- Load a pre-trained ResNet model with create_model. (Customizes the final layers to match num_class)


- Train the model thanks to train(...)



In [7]:
LABELS = train_dataset.classes
num_class = len(LABELS)


from train_resnet import create_model,train

models = ['Resnet','Resnext']

model = create_model(num_class,model_type=models[1])


trainer = train(train_loader,val_loader,model,max_epochs=25, limit_train_batches=100,default_root_dir='./logs/resnext')

## 4. Evaluation

After training the model, it's important to evaluate its performance to ensure it generalizes well to new, unseen data. Here, the train & test data comes from the same internet based source, so it is not necessarily performing well on real world images, we will have to keep that in mind.

### Steps:
- Evaluate the model on the test set to get an overall accuracy score.

- Generate classification reports and confusion matrices to understand the model's performance on each class.
!!TODO


In [None]:
from test_resnet import test,print_confusion_matrix

test(trainer,model,test_loader,ckpt_path='logs/resnet/lightning_logs/version_1/checkpoints/epoch=10-val_loss=0.42358389496803284.ckpt')

In [None]:
#- Generate classification reports and confusion matrices to understand the model's performance on each class.




## 5. Saving the Model

To use the trained model in future applications, we need to save it. This will allow us to load the model later without having to retrain it from scratch.



In [None]:
import torch

# Get the current date
current_date = datetime.datetime.now().strftime("%d_%m")

model_name = 'ResNet'

# Define the path where the model will be saved, including the date
model_save_path = f'path/to/save/{model_name}_{current_date}.pth'

# Save the model
torch.save(model.state_dict(), model_save_path)

print(f"Model successfully saved to {model_save_path}")