# Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils import data
import torchvision.models as models

import io
import tqdm
import lance
import wandb

from PIL import Image
from tqdm import tqdm

import time
import warnings
warnings.simplefilter('ignore')

### Defining the Image Classes, Transformation function and other utilities

We are defining the different image classes that comes with the `cinic-10` and the transformation functions that needs to be applied to the images.

In [5]:
# Define the image classes
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# transformation function 
transform_train = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available")
else:
    device = torch.device("cpu")
    print("CUDA is not available, using CPU instead")

CUDA is available


# Custom Image Dataset Class

We are going to use a custom Dataset class to load the images from the `cinic-10` Image Lance dataset. To know more about how we created a Lance image dataset, refer to `convert-any-image-dataset-to-lance.py` script in `converters` folder. 


Along with it, we are passing the adequate number of different classes and transformation function that needs to be applied to the images.

To make sure the cnn architecture remains constant for all kind of images, we are going to apply the `RGB transformation` to the various images to maintain the same color space with a default setting of 3 channels.

In [7]:
# Define the custom dataset class
class CustomImageDataset(data.Dataset):
    def __init__(self, table, classes, transform=None):
        self.table = table
        self.classes = classes
        self.transform = transform

    def __len__(self):
        return len(self.table)

    def __getitem__(self, idx):
        img_data = self.table["image"][idx].as_py()
        label = self.table["label"][idx].as_py()

        img = Image.open(io.BytesIO(img_data))

        # Convert grayscale images to RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        label = self.classes.index(label)
        return img, label

# Model hyperparameters and Architecture

In [8]:
lr = 1e-3
momentum = 0.9
number_of_epochs = 50
train_dataset_path = "cinic/cinic_train.lance"
test_dataset_path = "cinic/cinic_test.lance"
validation_dataset_path = "cinic/cinic_val.lance"
model_batch_size = 64
batches_to_train = 256

### Using a pre-trained `ResNet-34` architecture

We are going to use a pre-trained `ResNet-34` architecture to train the model.

In [None]:
class Net(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()
        self.resnet = models.resnet34(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Training Function

`train_model` is the standard training function that we are going to use to train our CNN model. We will pass the relevant dataloaders, model, loss function, optimizer, device , batches to train and number of epochs to train the model.

In [11]:
def train_model(train_loader, val_loader, model, criterion, optimizer, device, num_epochs, batch_to_train):
    model.train()
    total_start = time.time()

    for epoch in range(num_epochs):
        running_loss = 0.0
        total_batch_start = time.time()

        with tqdm(enumerate(train_loader), total=batch_to_train, desc=f"Epoch {epoch+1}") as pbar_epoch:
            for i, data in pbar_epoch:
                if i >= batch_to_train:
                    break

                optimizer.zero_grad()
                inputs, labels = data[0].to(device), data[1].to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                if i % 10 == 0:
                    pbar_epoch.set_postfix({'Loss': loss.item()})
                    pbar_epoch.update(10)

        per_epoch_time = time.time() - total_batch_start
        avg_loss = running_loss / batch_to_train
        print(f'Epoch {epoch+1}/{num_epochs} | Avg Loss: {avg_loss:.4f} | Time: {per_epoch_time:.4f} sec')
        wandb.log({"Loss": loss.item()})
        wandb.log({"Epoch Duration": per_epoch_time})

    total_training_time = (time.time() - total_start) / 60
    print(f"Total Training Time: {total_training_time:.4f} mins")


    # Validation
    model.eval()
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for data in val_loader:
            images_val, labels_val = data[0].to(device), data[1].to(device)
            outputs_val = model(images_val)
            _, predicted_val = torch.max(outputs_val.data, 1)
            total_val += labels_val.size(0)
            correct_val += (predicted_val == labels_val).sum().item()

    val_accuracy = 100 * correct_val / total_val
    print(f'Validation Accuracy: {val_accuracy:.2f}%')
    wandb.log({"Validation Accuracy": val_accuracy})
    print('Finished Training')
    return model

In [12]:
train_ds = lance.dataset(train_dataset_path)
test_ds = lance.dataset(test_dataset_path)
val_ds = lance.dataset(validation_dataset_path)

train_ds_table = train_ds.to_table()
test_ds_table = test_ds.to_table()
val_ds_table = val_ds.to_table()

train_dataset = CustomImageDataset(train_ds_table, classes, transform=transform_train)
test_dataset = CustomImageDataset(test_ds_table, classes, transform=transform_test)
val_dataset = CustomImageDataset(val_ds_table, classes, transform=transform_val)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=model_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=model_batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=model_batch_size, shuffle=True)

In [15]:
wandb.init(project="cinic")

net = Net(len(classes)).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)

trained_model = train_model(train_loader, val_loader, net, criterion, optimizer, device, number_of_epochs, batches_to_train)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 152MB/s]
Epoch 1: 100%|██████████| 256/256 [00:17<00:00, 14.87it/s, Loss=1.42]


Epoch 1/50 | Avg Loss: 1.6279 | Time: 17.2249 sec


Epoch 2: 100%|██████████| 256/256 [00:17<00:00, 14.70it/s, Loss=1.12]


Epoch 2/50 | Avg Loss: 1.2631 | Time: 17.4203 sec


Epoch 3: 100%|██████████| 256/256 [00:16<00:00, 15.91it/s, Loss=0.899]


Epoch 3/50 | Avg Loss: 1.1635 | Time: 16.0954 sec


Epoch 4: 100%|██████████| 256/256 [00:15<00:00, 16.45it/s, Loss=1.15]


Epoch 4/50 | Avg Loss: 1.0783 | Time: 15.5664 sec


Epoch 5: 100%|██████████| 256/256 [00:16<00:00, 15.94it/s, Loss=1.1]


Epoch 5/50 | Avg Loss: 1.0439 | Time: 16.0744 sec


Epoch 6: 100%|██████████| 256/256 [00:15<00:00, 16.41it/s, Loss=0.916]


Epoch 6/50 | Avg Loss: 0.9950 | Time: 15.6104 sec


Epoch 7: 100%|██████████| 256/256 [00:16<00:00, 15.94it/s, Loss=1.13]


Epoch 7/50 | Avg Loss: 0.9595 | Time: 16.0680 sec


Epoch 8: 100%|██████████| 256/256 [00:16<00:00, 15.52it/s, Loss=0.907]


Epoch 8/50 | Avg Loss: 0.9356 | Time: 16.5057 sec


Epoch 9: 100%|██████████| 256/256 [00:15<00:00, 16.29it/s, Loss=0.842]


Epoch 9/50 | Avg Loss: 0.9075 | Time: 15.7251 sec


Epoch 10: 100%|██████████| 256/256 [00:15<00:00, 16.56it/s, Loss=0.767]


Epoch 10/50 | Avg Loss: 0.8924 | Time: 15.4662 sec


Epoch 11: 100%|██████████| 256/256 [00:15<00:00, 16.72it/s, Loss=0.916]


Epoch 11/50 | Avg Loss: 0.8635 | Time: 15.3165 sec


Epoch 12: 100%|██████████| 256/256 [00:15<00:00, 16.52it/s, Loss=0.988]


Epoch 12/50 | Avg Loss: 0.8510 | Time: 15.5075 sec


Epoch 13: 100%|██████████| 256/256 [00:15<00:00, 16.29it/s, Loss=0.735]


Epoch 13/50 | Avg Loss: 0.8338 | Time: 15.7251 sec


Epoch 14: 100%|██████████| 256/256 [00:16<00:00, 15.82it/s, Loss=0.813]


Epoch 14/50 | Avg Loss: 0.8213 | Time: 16.1966 sec


Epoch 15: 100%|██████████| 256/256 [00:15<00:00, 16.58it/s, Loss=0.686]


Epoch 15/50 | Avg Loss: 0.7982 | Time: 15.4514 sec


Epoch 16: 100%|██████████| 256/256 [00:15<00:00, 16.52it/s, Loss=0.826]


Epoch 16/50 | Avg Loss: 0.7822 | Time: 15.5051 sec


Epoch 17: 100%|██████████| 256/256 [00:15<00:00, 16.68it/s, Loss=0.852]


Epoch 17/50 | Avg Loss: 0.7720 | Time: 15.3532 sec


Epoch 18: 100%|██████████| 256/256 [00:15<00:00, 16.41it/s, Loss=0.839]


Epoch 18/50 | Avg Loss: 0.7562 | Time: 15.6068 sec


Epoch 19: 100%|██████████| 256/256 [00:16<00:00, 15.99it/s, Loss=0.687]


Epoch 19/50 | Avg Loss: 0.7459 | Time: 16.0246 sec


Epoch 20: 100%|██████████| 256/256 [00:16<00:00, 15.91it/s, Loss=0.841]


Epoch 20/50 | Avg Loss: 0.7116 | Time: 16.1007 sec


Epoch 21: 100%|██████████| 256/256 [00:15<00:00, 16.49it/s, Loss=0.697]


Epoch 21/50 | Avg Loss: 0.7126 | Time: 15.5290 sec


Epoch 22: 100%|██████████| 256/256 [00:15<00:00, 16.40it/s, Loss=0.727]


Epoch 22/50 | Avg Loss: 0.7104 | Time: 15.6198 sec


Epoch 23: 100%|██████████| 256/256 [00:15<00:00, 16.45it/s, Loss=0.63]


Epoch 23/50 | Avg Loss: 0.6885 | Time: 15.5733 sec


Epoch 24: 100%|██████████| 256/256 [00:15<00:00, 16.42it/s, Loss=0.568]


Epoch 24/50 | Avg Loss: 0.6788 | Time: 15.5968 sec


Epoch 25: 100%|██████████| 256/256 [00:16<00:00, 15.92it/s, Loss=0.576]


Epoch 25/50 | Avg Loss: 0.6698 | Time: 16.0884 sec


Epoch 26: 100%|██████████| 256/256 [00:15<00:00, 16.11it/s, Loss=0.622]


Epoch 26/50 | Avg Loss: 0.6736 | Time: 15.8965 sec


Epoch 27: 100%|██████████| 256/256 [00:15<00:00, 16.48it/s, Loss=0.531]


Epoch 27/50 | Avg Loss: 0.6551 | Time: 15.5472 sec


Epoch 28: 100%|██████████| 256/256 [00:15<00:00, 16.67it/s, Loss=0.599]


Epoch 28/50 | Avg Loss: 0.6349 | Time: 15.3674 sec


Epoch 29: 100%|██████████| 256/256 [00:15<00:00, 16.46it/s, Loss=0.781]


Epoch 29/50 | Avg Loss: 0.6415 | Time: 15.5601 sec


Epoch 30: 100%|██████████| 256/256 [00:15<00:00, 16.45it/s, Loss=0.478]


Epoch 30/50 | Avg Loss: 0.6220 | Time: 15.5733 sec


Epoch 31: 100%|██████████| 256/256 [00:16<00:00, 15.97it/s, Loss=0.568]


Epoch 31/50 | Avg Loss: 0.6186 | Time: 16.0451 sec


Epoch 32: 100%|██████████| 256/256 [00:15<00:00, 16.14it/s, Loss=0.589]


Epoch 32/50 | Avg Loss: 0.6073 | Time: 15.8692 sec


Epoch 33: 100%|██████████| 256/256 [00:15<00:00, 16.20it/s, Loss=0.473]


Epoch 33/50 | Avg Loss: 0.5920 | Time: 15.8131 sec


Epoch 34: 100%|██████████| 256/256 [00:15<00:00, 16.33it/s, Loss=0.493]


Epoch 34/50 | Avg Loss: 0.5873 | Time: 15.6820 sec


Epoch 35: 100%|██████████| 256/256 [00:15<00:00, 16.53it/s, Loss=0.805]


Epoch 35/50 | Avg Loss: 0.5717 | Time: 15.4927 sec


Epoch 36: 100%|██████████| 256/256 [00:15<00:00, 16.08it/s, Loss=0.507]


Epoch 36/50 | Avg Loss: 0.5633 | Time: 15.9360 sec


Epoch 37: 100%|██████████| 256/256 [00:16<00:00, 15.94it/s, Loss=0.534]


Epoch 37/50 | Avg Loss: 0.5508 | Time: 16.0752 sec


Epoch 38: 100%|██████████| 256/256 [00:15<00:00, 16.52it/s, Loss=0.635]


Epoch 38/50 | Avg Loss: 0.5446 | Time: 15.5040 sec


Epoch 39: 100%|██████████| 256/256 [00:15<00:00, 16.18it/s, Loss=0.654]


Epoch 39/50 | Avg Loss: 0.5367 | Time: 15.8333 sec


Epoch 40: 100%|██████████| 256/256 [00:15<00:00, 16.43it/s, Loss=0.707]


Epoch 40/50 | Avg Loss: 0.5360 | Time: 15.5939 sec


Epoch 41: 100%|██████████| 256/256 [00:15<00:00, 16.06it/s, Loss=0.538]


Epoch 41/50 | Avg Loss: 0.5348 | Time: 15.9457 sec


Epoch 42: 100%|██████████| 256/256 [00:16<00:00, 15.51it/s, Loss=0.449]


Epoch 42/50 | Avg Loss: 0.5194 | Time: 16.5133 sec


Epoch 43: 100%|██████████| 256/256 [00:15<00:00, 16.36it/s, Loss=0.491]


Epoch 43/50 | Avg Loss: 0.5121 | Time: 15.6591 sec


Epoch 44: 100%|██████████| 256/256 [00:15<00:00, 16.32it/s, Loss=0.495]


Epoch 44/50 | Avg Loss: 0.5129 | Time: 15.7013 sec


Epoch 45: 100%|██████████| 256/256 [00:15<00:00, 16.43it/s, Loss=0.419]


Epoch 45/50 | Avg Loss: 0.4898 | Time: 15.5894 sec


Epoch 46: 100%|██████████| 256/256 [00:15<00:00, 16.43it/s, Loss=0.748]


Epoch 46/50 | Avg Loss: 0.4845 | Time: 15.5904 sec


Epoch 47: 100%|██████████| 256/256 [00:16<00:00, 15.88it/s, Loss=0.506]


Epoch 47/50 | Avg Loss: 0.4817 | Time: 16.1347 sec


Epoch 48: 100%|██████████| 256/256 [00:16<00:00, 15.71it/s, Loss=0.284]


Epoch 48/50 | Avg Loss: 0.4673 | Time: 16.3104 sec


Epoch 49: 100%|██████████| 256/256 [00:16<00:00, 15.81it/s, Loss=0.588]


Epoch 49/50 | Avg Loss: 0.4621 | Time: 16.2003 sec


Epoch 50: 100%|██████████| 256/256 [00:15<00:00, 16.13it/s, Loss=0.682]


Epoch 50/50 | Avg Loss: 0.4576 | Time: 15.8811 sec
Total Training Time: 13.2076 mins
Validation Accuracy: 70.49%
Finished Training


In [17]:
PATH = '/cinic/cinic_resnet.pth'
torch.save(trained_model.state_dict(), PATH)

In [18]:
def test_model(test_loader, model, device):
    model.eval()
    correct_test = 0
    total_test = 0

    with torch.no_grad():
        for data in test_loader:
            images_test, labels_test = data[0].to(device), data[1].to(device)
            outputs_test = model(images_test)
            _, predicted_test = torch.max(outputs_test.data, 1)
            total_test += labels_test.size(0)
            correct_test += (predicted_test == labels_test).sum().item()

    test_accuracy = 100 * correct_test / total_test
    print(f'Test Accuracy: {test_accuracy:.2f}%')

test_model(test_loader, trained_model, device)

Test Accuracy: 70.32%
