# **Image Classification with ConvNeXt**

Written By: Tharnarch Thoranisttakul

This project is for the internship program in Japan at Kanazawa University.

The topic is as stated above. Firstly, we will take a look at classifying image with ConvNeXt. Then, we will take a look at classifying image with Transformers (if we have time then we will take a look at Transformers).

Right now, we will only focus on **ConvNeXt**.

## **Data Exploration**

The task from this lab was to use **CIFAR-100** as the dataset for image classificatoin

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List
from tqdm.notebook import tqdm
import glob
import os

%matplotlib inline

# Import scikit-learn libraries and packages
from sklearn.preprocessing import LabelEncoder

# Import PyTorch libraries and packages
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset, Dataset
import torchvision
from torchvision import transforms
from torchvision.models import resnet50, ConvNeXt, ConvNeXt_Base_Weights
from torchvision.ops import StochasticDepth

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Extract dataset
# Uncomment the code below to extract the dataset
# !tar xvzf data/cifar-100-python.tar.gz -C data

In [None]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [None]:
# Load the data
train = unpickle('data/cifar-100-python/train')
test = unpickle('data/cifar-100-python/test')

In [None]:
print(f'Train Keys: {train.keys()}\nTest Keys: {test.keys()}')
print('-'*20)
print(f'Train Data Shape: {train[b"data"].shape}\nTest Data Shape: {test[b"data"].shape}')
print('-'*20)
print(f'Train Data Type: {type(train[b"data"])}\nTest Data Type: {type(test[b"data"])}')

### **Dataset Description**

From the [official website](https://www.cs.toronto.edu/~kriz/cifar.html), the CIFAR-100 dataset contains 100 classes. Each class contains 600 images each, 500 of which are training images and the rest are testing images.

Similarly to CIFAR-10, the data shape is a 60000x3072 numpy array. This corresponds to the code and description above, which is 500 training images and 100 testing images in each class, resulting in a grand total of 50000x3072 and 10000x3072 numpy array.

In [None]:
print(f'Train Coarse Labels: {train[b"coarse_labels"]}\nTest Coarse Labels: {test[b"coarse_labels"]}')
print('-'*20)
print(f'Train Fine Labels: {train[b"fine_labels"]}\nTest Fine Labels: {test[b"fine_labels"]}')
print('-'*20)
print(f'Train Batch Label: {train[b"batch_label"]}\nTest Batch Label: {test[b"batch_label"]}')

The fine labels are classes, while the coarse labels are superclasses stated below.

|**Superclass No.**|**Superclass**|**Classes**|
|:-:|:-:|:-:|
|1|aquatic mammals|beaver, dolphin, otter, seal, whale|
|2|fish|aquarium fish, flatfish, ray, shark, trout|
|3|flowers|orchids, poppies, roses, sunflowers, tulips|
|4|food containers|bottles, bowls, cans, cups, plates|
|5|fruit and vegetables|apples, mushrooms, oranges, pears, sweet peppers|
|6|household electrical devices|clock, computer keyboard, lamp, telephone, television|
|7|household furniture|bed, chair, couch, table, wardrobe|
|8|insects|bee, beetle, butterfly, caterpillar, cockroach|
|9|large carnivores|bear, leopard, lion, tiger, wolf|
|10|large man-made outdoor things|bridge, castle, house, road, skyscraper|
|11|large natural outdoor scenes|cloud, forest, mountain, plain, sea|
|12|large omnivores and herbivores|camel, cattle, chimpanzee, elephant, kangaroo|
|13|medium-sized mammals|fox, porcupine, possum, raccoon, skunk|
|14|non-insect invertebrates|crab, lobster, snail, spider, worm|
|15|people|baby, boy, girl, man, woman|
|16|reptiles|crocodile, dinosaur, lizard, snake, turtle|
|17|small mammals|hamster, mouse, rabbit, shrew, squirrel|
|18|trees|maple, oak, palm, pine, willow|
|19|vehicles 1|bicycle, bus, motorcycle, pickup truck, train|
|20|vehicles 2|lawn-mower, rocket, streetcar, tank, tractor|

The labels are kept in the meta file. Therefore, we will extract the label names (fine and coarse) from the meta file directly.

In [None]:
# Load the labels
label = unpickle('data/cifar-100-python/meta')
print(f'Label Keys: {label.keys()}')
print('-'*20)
print(f'Label Fine Label Names: {label[b"fine_label_names"]}')
print('-'*20)
print(f'Label Coarse Label Names: {label[b"coarse_label_names"]}')

label_fine = label[b"fine_label_names"]
label_coarse = label[b"coarse_label_names"]

# Decode labels
label_fine = [x.decode('utf-8') for x in label_fine]
label_coarse = [x.decode('utf-8') for x in label_coarse]

label2id = {label_fine[i]: i for i in range(len(label_fine))}
id2label = {i: label_fine[i] for i in range(len(label_fine))}

In [None]:
train_data = train[b"data"]
train_coarse_labels = train[b"coarse_labels"]
train_fine_labels = train[b"fine_labels"]
train_batch_label = train[b"batch_label"]

test_data = test[b"data"]
test_coarse_labels = test[b"coarse_labels"]
test_fine_labels = test[b"fine_labels"]
test_batch_label = test[b"batch_label"]

In [None]:
train_data[0], train_coarse_labels[0], train_fine_labels[0], train_batch_label, test_data[0], test_coarse_labels[0], test_fine_labels[0], test_batch_label

In [None]:
print(f'Train data min: {train_data.min()}\nTrain data max: {train_data.max()}')
print(f'Test data min: {test_data.min()}\nTest data max: {test_data.max()}')
print('-'*20)
print(f'Train data mean: {train_data.mean()}\nTrain data std: {train_data.std()}')
print(f'Test data mean: {test_data.mean()}\nTest data std: {test_data.std()}')

## **Data Preprocessing**

Since the image data's range is (0, 255), we will normalize it to (0, 1).

In [None]:
# Reshape the data
train_data = train_data.reshape(50000, 32, 32, 3)
test_data = test_data.reshape(10000, 32, 32, 3)

print(f'Train Data Shape: {train_data.shape}\nTest Data Shape: {test_data.shape}')

In [None]:
# Normalize the data
train_data, test_data = train_data/255.0, test_data/255.0

In [None]:
def plot_images(data, datalabel, fulllabel):
    plt.figure(figsize=(10, 10))
    for i in range(25):
        plt.subplot(5, 5, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(data[i])
        plt.xlabel(fulllabel[datalabel[i]])

In [None]:
plot_images(train_data, train_fine_labels, label_fine)

Seems like we reshaped the data wrongly. We will try to reshape it again by separating each color profile (R, G, B) and restack them.

In [None]:
train_data = train[b'data']
test_data = test[b'data']

train_data_R = train_data[:, :1024].reshape(50000, 32, 32)
train_data_G = train_data[:, 1024:2048].reshape(50000, 32, 32)
train_data_B = train_data[:, 2048:].reshape(50000, 32, 32)

test_data_R = test_data[:, :1024].reshape(10000, 32, 32)
test_data_G = test_data[:, 1024:2048].reshape(10000, 32, 32)
test_data_B = test_data[:, 2048:].reshape(10000, 32, 32)

# Stack the data
train_data = np.stack((train_data_R, train_data_G, train_data_B), axis=3)
test_data = np.stack((test_data_R, test_data_G, test_data_B), axis=3)

In [None]:
plot_images(train_data, train_fine_labels, label_fine)

In [None]:
print(f'Train Data Shape: {train_data.shape}\nTest Data Shape: {test_data.shape}')

The reason why we have to separate them like this instead of directly reshaping them is because the data were flatten before saving as text file.

Credit to: https://www.kaggle.com/code/yipengzhou3/cifar100-pytorch for doing the same 'wrong' approach to reshaping the data.

In [None]:
X_train, y_train = train_data, train_fine_labels
X_test, y_test = test_data, test_fine_labels

num_classes = 100
label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train)
y_test_encoded = label_encoder.transform(y_test)

In [None]:
#Transform data to Tensor
transform_data = transforms.Compose([transforms.ToTensor()])

In [None]:
# Apply the augmentation
augmented_train_data = []
augmented_train_label = []
for image, label in zip(X_train, y_train_encoded):
    augmented_train_data.append(transform_data(image))
    augmented_train_label.append(label)

augmented_test_data = []
augmented_test_label = []
for image, label in zip(X_test, y_test_encoded):
    augmented_test_data.append(transform_data(image))
    augmented_test_label.append(label)

augmented_train_data = torch.stack(augmented_train_data)
augmented_test_data = torch.stack(augmented_test_data)

print(f'Augmented Train Data Shape: {augmented_train_data.shape}\nAugmented Test Data Shape: {augmented_test_data.shape}')

In [None]:
# Create the dataset
train_dataset = TensorDataset(augmented_train_data, torch.Tensor(augmented_train_label))
test_dataset = TensorDataset(augmented_test_data, torch.Tensor(augmented_test_label))

# Create the dataloader
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
# Training
def train_model(models, modelnames, criterion, optimizer, scheduler, num_epochs):

    # Store the loss and accuracy
    traintestData = {}

    ### Train the model ###
    for model, modelname in zip(models, modelnames):
        train_loss = []
        test_acc = []
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            for _, batch in enumerate(train_loader):
                images = batch[0].to(device)
                labels = batch[1].to(device).long()

                optimizer.zero_grad()
                
                outputs = model(images)

                # Convert outputs to tensor
                if isinstance(outputs, ImageClassifierOutputWithNoAttention):
                    outputs = outputs.logits

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * images.size(0)

            scheduler.step()

            ### Test the model ###

            model.eval()
            num_correct = 0
            total = 0
            with torch.no_grad():
                for _, batch in enumerate(test_loader):
                    images = batch[0].to(device)
                    labels = batch[1].to(device).long()

                    outputs = model(images)

                    if isinstance(outputs, ImageClassifierOutputWithNoAttention):
                        outputs = outputs.logits
                    
                    _, predicted = torch.max(outputs, dim=1)
                    total += labels.size(0)
                    num_correct += (predicted == labels).sum().item()

            epoch_loss = running_loss/len(train_loader.dataset)

            # Save the loss and accuracy for train and test set
            train_loss.append(epoch_loss)
            test_acc.append(num_correct/total)

            print(f'ConvNeXt {modelname} -> Epoch: {epoch+1}/{num_epochs} | Loss: {epoch_loss} | Test Accuracy: {num_correct/total}')

        traintestData[f'Training Loss_{modelname}'] = train_loss
        traintestData[f'Test Accuracy_{modelname}'] = test_acc

    return traintestData

In [None]:
def lossAccPlot(lossAccDict, modelnames, cutoff=0):
    names = ['Training Loss', 'Test Accuracy']
    label = ['Loss', 'Accuracy']

    _, axes = plt.subplots(2, 2, figsize=(20, 10))
    for ax, modelname in zip(range(len(axes)), modelnames):
        for i in range(len(names)):
            val = lossAccDict[f'{names[i]}_{modelname}']
            axes[ax][i].plot(val, label=names[i])
            if cutoff != 0:
                # Plot cutoff line
                axes[ax][i].axvline(x=cutoff, color='r', linestyle='--')
            axes[ax][i].set_xlabel('Epoch')
            axes[ax][i].set_ylabel(label[i])
            axes[ax][i].legend()

## **ConvNeXt**

In [None]:
from transformers import AutoModelForImageClassification

convnext_tiny = AutoModelForImageClassification.from_pretrained("facebook/convnext-tiny-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_tiny_wd = AutoModelForImageClassification.from_pretrained("facebook/convnext-tiny-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_base = AutoModelForImageClassification.from_pretrained("facebook/convnext-base-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_base_wd = AutoModelForImageClassification.from_pretrained("facebook/convnext-base-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_large = AutoModelForImageClassification.from_pretrained("facebook/convnext-large-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_large_wd = AutoModelForImageClassification.from_pretrained("facebook/convnext-large-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_tiny.to(device)
convnext_tiny_wd.to(device)
convnext_base.to(device)
convnext_base_wd.to(device)
convnext_large.to(device)
convnext_large_wd.to(device)

models = [convnext_tiny, convnext_base, convnext_large]
modelnames = ['Tiny', 'Base', 'Large']
models_wd = [convnext_tiny_wd, convnext_base_wd, convnext_large_wd]
modelnames_wd = ['TinyWD', 'BaseWD', 'LargeWD']

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.AdamW([{'params': model.parameters()} for model in models], lr=5e-5)
optimizer_wd = torch.optim.AdamW([{'params': model.parameters()} for model in models_wd], lr=5e-5, weight_decay=1e-8)

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

In [None]:
noWDDict = train_model(models, modelnames, criterion, optimizer, lr_scheduler, num_epochs=30)

In [None]:
lossAccPlot(noWDDict, modelnames)

In [None]:
WDDict = train_model(models_wd, modelnames_wd, criterion, optimizer_wd, lr_scheduler, num_epochs=30)

In [None]:
lossAccPlot(WDDict, modelnames_wd)

In [None]:
# Get last and max test accuracy
for modelname in modelnames:
    print(f'ConvNeXt {modelname} -> Last Test Accuracy: {noWDDict[f"Test Accuracy_{modelname}"][-1]} | Max Test Accuracy: {max(noWDDict[f"Test Accuracy_{modelname}"])} (Epoch:{noWDDict[f"Test Accuracy_{modelname}"].index(max(noWDDict[f"Test Accuracy_{modelname}"]))})')

print('-'*50)

for modelname in modelnames_wd:
    print(f'ConvNeXt {modelname} -> Last Test Accuracy: {WDDict[f"Test Accuracy_{modelname}"][-1]} | Max Test Accuracy: {max(WDDict[f"Test Accuracy_{modelname}"])} (Epoch:{WDDict[f"Test Accuracy_{modelname}"].index(max(WDDict[f"Test Accuracy_{modelname}"]))})')

From the result above, applying weight decay to the optimizer improved the performance by 1%. We will come back to visit the preferrable weight decay value and learning rate after we try augmenting the data.

We will stick with using learning rate at 5e-5 and weight decay at 1e-8, which is from the fine-tuning setting section of ConvNeXt's paper.

In [None]:
def save_model(models):
    for model, modelname in zip(models, modelnames):
        torch.save(model.state_dict(), f'models/ConvNeXt_{modelname}.pt')

In [None]:
save_model(models_wd)

del models, models_wd
torch.cuda.empty_cache()

## **Data Augmentation**

In [None]:
X_train, y_train = train_data, train_fine_labels
X_test, y_test = test_data, test_fine_labels

num_classes = 100
label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train)
y_test_encoded = label_encoder.transform(y_test)

In [None]:
#Transform data to Tensor
transform_data = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
    ])

In [None]:
# Apply the augmentation
augmented_train_data = []
augmented_train_label = []
for image, label in zip(X_train, y_train_encoded):
    augmented_train_data.append(transform_data(image))
    augmented_train_label.append(label)

augmented_test_data = []
augmented_test_label = []
for image, label in zip(X_test, y_test_encoded):
    augmented_test_data.append(transform_data(image))
    augmented_test_label.append(label)

augmented_train_data = torch.stack(augmented_train_data)
augmented_test_data = torch.stack(augmented_test_data)

print(f'Augmented Train Data Shape: {augmented_train_data.shape}\nAugmented Test Data Shape: {augmented_test_data.shape}')

In [None]:
# Create the dataset
train_dataset = TensorDataset(augmented_train_data, torch.Tensor(augmented_train_label))
test_dataset = TensorDataset(augmented_test_data, torch.Tensor(augmented_test_label))

# Create the dataloader
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
from transformers import AutoModelForImageClassification

convnext_tiny = AutoModelForImageClassification.from_pretrained("facebook/convnext-tiny-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_base = AutoModelForImageClassification.from_pretrained("facebook/convnext-base-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)

convnext_large = AutoModelForImageClassification.from_pretrained("facebook/convnext-large-224",
                                                        num_labels=len(label_fine),
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes=True)


convnext_tiny.to(device)
convnext_base.to(device)
convnext_large.to(device)

models = [convnext_tiny, convnext_base, convnext_large]
modelnames = ['Tiny', 'Base', 'Large']

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.AdamW([{'params': model.parameters()} for model in models], lr=5e-5, weight_decay=1e-8)

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

In [None]:
# Train the model
augmentDataModel = train_model(models, modelnames, criterion, optimizer, lr_scheduler, num_epochs=30)

In [None]:
lossAccPlot(augmentDataModel, modelnames)