<a href="https://colab.research.google.com/github/gapac/ML_AI_examples/blob/main/Ga%C5%A1per_Jezernik_bird_or_forest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bird or Forest Classification Using Resnet-18 Model

#### Main task: Classify images into two categories: **birds** and **forests**.

1. Download Images - a set of images belonging to two categories: birds and forests.
2. Use Torch Transforms and Torchvision Dataloader - efficient data loading and augmentation
3. Load ResNet-18 Model - pre-trained on the ImageNet dataset.
4. Alter ResNet-18 Model Architecture - to fit our specific task
5. Fine-Tune ResNet-18 Model - training some model layers on our dataset

### ResNet (Residual Network) Models

- Ever since the first CNN-based model, image classification has increasingly useed more layers in deep neural networks

- Usually, adding more layers is beneficial, but beyond a certain point, it triggers a common deep learning problem known as the vanishing gradients.

- As the gradient values are passed through many layers, they can get smaller and smaller, essentially "vanishing" to zero. This makes it very difficult to update the weights in the earlier layers of the network, leading to very slow or stagnant training progress, particularly in networks with many layers.

- Residual Connection: In order to solve the problem of the vanishing gradient, this architecture introduced the concept called **skip connections**. The skip connection connects activations of a  layer to further layers by skipping some layers in between. These connections make it easier for gradients to flow back through the network during training, ensuring that even the early layers get updated effectively. This results in faster convergence and often better overall performance.

<img src="https://github.com/Pubec/ml-workshop/blob/main/resnet/assets/skip_connection.png?raw=true" alt="Plain vs ResNet" style="max-height: 300px;">

<img src="https://github.com/Pubec/ml-workshop/blob/main/resnet/assets/normal_resnet.png?raw=true" alt="Plain vs ResNet" style="max-height: 300px;">


- This residual connection enables ResNet networks to be significantly deeper (e.g., 18, 32, ..., 152 ) without performance degradation

## Step 0: Download dataset

We will utilize libraries:
- **fast.ai** - image download and verification [https://github.com/fastai/fastai](https://github.com/fastai/fastai)
- **dudduckgo_search** for web scraping for imgaes [https://github.com/deedy5/duckduckgo_search](https://github.com/deedy5/duckduckgo_search)

In [None]:
!pip install -q duckduckgo_search fastai

In [None]:
from time import sleep
from pathlib import Path

from fastai.vision.all import download_images, resize_images, verify_images, get_image_files
from duckduckgo_search import DDGS

DATASET_PATH = 'bird_or_not'
DO_DOWNLOAD = True

In [None]:
def search_images(term, max_images=5):
    """
    Search DuckDuckGo Engine to find images by `term` and return their urls (number of limited by `max_images`).
    """
    print(f"Searching for '{term}'")
    with DDGS() as ddgs:
        ddgs_images_gen = ddgs.images(term)
        count = 0
        ddgs_images_list = []
        while count < max_images:
            image = ddgs_images_gen[count]
            # get image url
            url = image.get('image')
            # if url containts arguments, remove them
            i = url.find("?")
            if i > 0:
                url = url[:i]
            ddgs_images_list.append(url)
            count = count+1
        return ddgs_images_list

In [None]:
urls = search_images('bird photos', max_images=5)
for url in urls:
    print(url)

In [None]:
urls = search_images('forest photos', max_images=5)
for url in urls:
    print(url)

### In the next step download images to train-test-validation split

1. search for different images of forest and birds
2. save images to different folders
3. resize images to max_size

*NOTE: Pause between searches to prevent crashing*

*NOTE: Due to time restricted GPU session, do this part in CPU, then disable downloading to be able to rerun entire session*

In [None]:
if DO_DOWNLOAD:
  searches = ['forest', 'bird']
  # Pause between searches to avoid over-loading server
  TIME_SLEEP = 0.5

  # Training
  save_dir = Path(DATASET_PATH) / "train"
  nr_images = 25
  for search_term in searches:
      dest = save_dir / search_term
      dest.mkdir(exist_ok=True, parents=True)

      # download 1
      term = f'{search_term} photo'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # download 2
      term = f'{search_term} sun photo'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # donwload 3
      term = f'{search_term} shade photo'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # resize
      resize_images(dest, max_size=400, dest=dest)


  # Validation
  save_dir = Path(DATASET_PATH) / "val"
  nr_images = 5
  for search_term in searches:
      dest = save_dir / search_term
      dest.mkdir(exist_ok=True, parents=True)

      # download 1
      term = f'{search_term} photography'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # download 2
      term = f'{search_term} sun photography'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # donwload 3
      term = f'{search_term} shade photography'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # resize
      resize_images(dest, max_size=400, dest=dest)


  # Test
  save_dir = Path(DATASET_PATH) / "test"
  nr_images = 10
  for search_term in searches:
      dest = save_dir / search_term
      dest.mkdir(exist_ok=True, parents=True)

      # download 1
      term = f'{search_term} image'
      download_images(dest, urls=search_images(term, max_images=nr_images))
      print("Downloaded", term)
      sleep(TIME_SLEEP)

      # resize
      resize_images(dest, max_size=400, dest=dest)


In [None]:
# Verify images (check whether can be loaded)

path = Path(DATASET_PATH)
failed = verify_images(get_image_files(path))
for fail in failed:
    print("could not open:", fail)
failed.map(Path.unlink)
print("Failed images:", len(failed))

## Step 1: Load and Augment Dataset

We will utilize libraries:
- **Torchvision** - to transform or augment data [https://pytorch.org/vision/stable/transforms.html](https://pytorch.org/vision/stable/transforms.html)

### Image Folder

```bash
- train/bird/xxx.png
- train/bird/yyy.png
- train/forest/xxx.png
- train/forest/yyy.png
```

In [None]:
from pathlib import Path
from torchvision.datasets import ImageFolder

train_path = Path(DATASET_PATH) / "train"

train_folder = ImageFolder(train_path)
print(train_folder)

In [None]:
print("Classes:", train_folder.classes)
print("Classes:", train_folder.class_to_idx)

In [None]:
for img in train_folder.imgs[69:80]:
    print(img)

In [None]:
from PIL import Image
from matplotlib import pyplot as plt

img_path, img_class = train_folder.imgs[12]

image = Image.open(img_path)

plt.suptitle(train_folder.classes[img_class])
plt.imshow(image)
plt.axis('off')
plt.show()

### Transforms

- loading
- augmentation
- chained together using Compose `transforms.Compose`
- accepts PIL Image, Tensor Image or batch of Tensor Images as input

Note:
- Tensor Image Shape: `C x H x W`
- Tensor Batch Shape: `B x C x H x W`

In [None]:
from torchvision.transforms import transforms

# TODO: Try Flip, Rotate, Padding, Crop, ResizeCrop, RandomPerspective, ColorJitter


transform = transforms.Compose([
    # transforms.Resize(256),
    transforms.RandomResizedCrop(112),
    # transforms.CenterCrop(224),
    # transforms.Pad(padding=50),
    # transforms.RandomRotation(120),
    # transforms.RandomVerticalFlip(0.5),
    # transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.3, hue=0.4),
    # transforms.ColorJitter(contrast=(0, 10)),
    # transforms.RandomSolarize(threshold=5.0)
])

transformed_image = transform(image)


fig, ax = plt.subplots(1, 2)
ax[0].imshow(image)
ax[1].imshow(transformed_image)
plt.show()

In [None]:
# Define the transformations to apply to the images during training

# Note: Normalize is done by calculating the mean and standard deviation of your dataset images and making your data unit normed.
# But, to simplify, just use imagenet dataset's mean and standard deviation to normalize the dataset approximately.
# These numbers are imagenet mean and standard deviation!

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

1. Use Transforms in ImageFolder
2. Use ImageFolder in Dataloder

In [None]:
from torch.utils.data import DataLoader

# TODO:
# train_dataset = ...
# val_dataset = ...


In [None]:
import numpy as np

# TODO:
# check train_dataset loader outputs


## Step 2: ResNet-18

In [None]:
import torchvision.models as models

model = models.resnet18(pretrained=True)
print(model)

In [None]:
print("Number of output layers:", model.fc.out_features)

Modify ResNet model output layer

In [None]:
import torch

# TODO:
# alter model.fc output layer


In [None]:
print("Number of output layers:", model.fc.out_features)

## Step 3: Test ResNnet on test 'bird/forest' dataset

We will test the model performance on test dataset before fine-tunning

In [None]:
# Helper Scripts
import torchvision.transforms.functional as F

def unnormalize(img):
    img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    return img


def display_images(images, true_labels, predicted_labels, class_names):
    fig = plt.figure(figsize=(10, 10))
    for i in range(len(images)):
        ax = fig.add_subplot(4, 4, i + 1, xticks=[], yticks=[])
        img = unnormalize(images[i])
        img = F.to_pil_image(img)
        ax.imshow(img)
        ax.set_title(f'True: {class_names[true_labels[i]]}\nPred: {class_names[predicted_labels[i]]}', color=("green" if true_labels[i] == predicted_labels[i] else "red"))
    plt.show()

In [None]:
# use validation transforms
test_path = Path(DATASET_PATH) / "test"
test_folder = ImageFolder(test_path, transform=val_transform)
test_dataset = DataLoader(test_folder, batch_size=4, shuffle=True)

model.eval()
with torch.no_grad():
    count = 0
    for inputs, classes in test_dataset:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        display_images(inputs, classes, preds, test_folder.classes)
        if count >= 2:
            break
        count += 1

## Step 4: Train ResNet

1. Train Last Layer (FC)
    - larger learning rate
2. Fine-Tune ResNet
    - smaller learning rate

In [None]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    """
    Train step that takes model, training and validation dataloaders, loss functions, optimizer, number of epochs and device
    """
    # Train the model for the specified number of epochs
    for epoch in range(num_epochs):
        # Set the model to train mode
        model.train()

        # Initialize the running loss and accuracy
        running_loss = 0.0
        running_corrects = 0

        # Iterate over the batches of the train loader
        for inputs, labels in train_loader:
            # Move the inputs and labels to the device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the optimizer gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()

            # Update the running loss and accuracy
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # Calculate the train loss and accuracy
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = running_corrects.double() / len(train_loader.dataset)

        # Set the model to evaluation mode
        model.eval()

        # Initialize the running loss and accuracy
        running_loss = 0.0
        running_corrects = 0

        # Iterate over the batches of the validation loader
        with torch.no_grad():
            for inputs, labels in val_loader:
                # Move the inputs and labels to the device
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # Update the running loss and accuracy
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        # Calculate the validation loss and accuracy
        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_corrects.double() / len(val_loader.dataset)

        # Print the epoch results
        print(f'Epoch [{epoch+1}/{num_epochs}], train loss: {train_loss:.4f}, train acc: {train_acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')



### 4.1: Fine-tune the last layer for a few epochs

Freeze all layers except FC

In [None]:
# TODO:
# freeze all model.parameters()
# unfreeze model.fc.parameters()
# check results


In [None]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model.to(device)
train(model, train_dataset, val_dataset, loss_function, optimizer, num_epochs=20, device=device)

### 4.2: Fine-tune the entire model for a few epochs

Unfreeze all the layers

In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
train(model, train_dataset, val_dataset, loss_function, optimizer, num_epochs=10, device=device)

## Step 5: Test ResNnet on test 'bird/forest' dataset

We will test the model performance on test dataset **after** fine-tunning

In [None]:
# use validation transforms
model.to('cpu')
model.eval()
with torch.no_grad():
    count = 0
    for inputs, classes in test_dataset:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        display_images(inputs, classes, preds, test_folder.classes)
        if count >= 20:
            break
        count += 1