# Imports

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils import data

import pandas as pd
from PIL import Image

import io
import lance
import wandb
import tqdm

import time
import warnings
warnings.simplefilter('ignore')

### Defining the Image Classes, Transformation function and other utilities

We are defining the different image classes that comes with the `cinic-10` and the transformation function that needs to be applied to the images.

In [None]:
# Define the image classes
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse','ship', 'truck')

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Initialize W&B
wandb.init(project="cinic-10-test")

# Determine the device to use (CPU or GPU)
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS Device:", device)
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

# Custom Image Dataset Class

We are going to use a custom Dataset class to load the images from the `cinic-10` Image Lance dataset. To know more about how we created a Lance image dataset, refer to `convert-any-image-dataset-to-lance.py` script in `converters` folder. 


Along with it, we are passing the adequate number of different classes and transformation function that needs to be applied to the images.

To make sure the cnn architecture remains constant for all kind of images, we are going to apply the `RGB transformation` to the various images to maintain the same color space with a default setting of 3 channels.

In [None]:
# Define the custom dataset class
class CustomImageDataset(data.Dataset):
    def __init__(self, df, classes, transform=None):
        self.df = df
        self.classes = classes
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_data = self.df.iloc[idx]['image']
        img = Image.open(io.BytesIO(img_data))

        # Convert grayscale images to RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        label = self.classes.index(self.df.iloc[idx]['label'])
        return img, label

# Model hyperparameters and Architecture

In [None]:
lr = 1e-3
momentum = 0.9
number_of_epochs = 5
train_dataset_path = "cinic/cinic_train.lance"
test_dataset_path = "cinic/cinic_test.lance"
validation_dataset_path = "cinic/cinic_val.lance"
model_batch_size = 64
dataframe_batch_size = 10

In [None]:
def loading_into_pandas(uri):
    ds = lance.dataset(uri)
    data = []
    for batch in tqdm.tqdm(ds.to_batches(batch_size=dataframe_batch_size), desc="Loading batches"):
        tbl = batch.to_pandas()
        data.append(tbl)
    df = pd.concat(data, ignore_index=True)
    print("Pandas DataFrame is ready")
    print("Total Rows: ", df.shape[0])
    return df

# Training Function

`train_model` is the standard training function that we are going to use to train our CNN model. We will pass the relevant dataloaders, model, loss function, optimizer, device and number of epochs to train the model.

In [None]:
# Define the training function
def train_model(train_loader, val_loader, model, criterion, optimizer, device, num_epochs=10):
    model.to(device)
    for epoch in range(num_epochs):
        start_epoch_time = time.time()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            wandb.log({"Loss": loss.item()})
            running_loss += loss.item()
            if i % model_batch_size == model_batch_size-1:
                print(f'[{epoch + 1}, {i + 1:2d}] loss: {running_loss / model_batch_size:.2f}')
                running_loss = 0.0

        # Validation
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for data_val in val_loader:
                images_val, labels_val = data_val[0].to(device), data_val[1].to(device)
                outputs_val = model(images_val)
                _, predicted_val = torch.max(outputs_val.data, 1)
                total_val += labels_val.size(0)
                correct_val += (predicted_val == labels_val).sum().item()

        val_accuracy = 100 * correct_val / total_val
        print('Validation accuracy of the network: %.2f %%' % val_accuracy)
        wandb.log({"Validation Accuracy": val_accuracy})

        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time
        wandb.log({"Epoch Duration": epoch_duration})

    print('Finished Training')

In [None]:
# Define the neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Load datasets
training_df = loading_into_pandas(train_dataset_path)
testing_df = loading_into_pandas(test_dataset_path)
validation_df = loading_into_pandas(validation_dataset_path)

# Create datasets
train_dataset = CustomImageDataset(training_df, classes, transform=transform)
test_dataset = CustomImageDataset(testing_df, classes, transform=transform)
val_dataset = CustomImageDataset(validation_df, classes, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=model_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=model_batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=model_batch_size, shuffle=True)

# Instantiate the model
net = Net()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)

# Train the model
train_model(train_loader, val_loader, net, criterion, optimizer, device, number_of_epochs)

# Save and load the model
PATH = './cinic_lance.pth'
torch.save(net.state_dict(), PATH)
net = Net()
net.load_state_dict(torch.load(PATH))

# Test the model
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print('Accuracy of the network on the test images: %.2f %%' % accuracy)
wandb.log({"Test Accuracy": accuracy})