In [None]:
import torch 
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt
import numpy as np

from torch.utils import data
from PIL import Image
import io
import lance
import wandb

### Custom Image Dataset Class to load the images from Lance Dataset

In [None]:
class CustomImageDataset(data.Dataset):
    def __init__(self, df, classes, transform=None):
        self.df = df
        self.classes = classes
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_data = self.df.iloc[idx]['image']
        img = Image.open(io.BytesIO(img_data))

        # Convert grayscale images to RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        label = self.classes.index(self.df.iloc[idx]['category'])
        return img, label

### Loading into Pandas

In [None]:
def loading_into_pandas(uri):
    ds = lance.dataset(uri)

    # Accumulate data from batches into a list
    data = []
    for batch in tqdm(ds.to_batches(columns=["image", "filename", "category", "data_type"], batch_size=10), desc="Loading batches"):
        tbl = batch.to_pandas()
        data.append(tbl)

    # Concatenate all DataFrames into a single DataFrame
    df = pd.concat(data, ignore_index=True)
    print("Pandas DataFrame is ready")
    print("Total Rows: ", df.shape[0])
    return df

### Training Function

In [None]:
def train_model(train_loader, val_loader, model, criterion, optimizer, device, num_epochs=10):
    model.to(device)  # Move model to the specified device
    for epoch in range(num_epochs):
        start_epoch_time = time.time()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)  # Move data to the specified device

            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            wandb.log({"Loss": loss.item()})

            # print statistics
            running_loss += loss.item()
            if i % 64 == 63:  # Print every 64 mini-batches (batch size)
                print(f'[{epoch + 1}, {i + 1:2d}] loss: {running_loss / 64:.3f}')
                running_loss = 0.0

        # Validation
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for data_val in val_loader:
                images_val, labels_val = data_val[0].to(device), data_val[1].to(device)
                outputs_val = model(images_val)
                _, predicted_val = torch.max(outputs_val.data, 1)
                total_val += labels_val.size(0)
                correct_val += (predicted_val == labels_val).sum().item()

        val_accuracy = 100 * correct_val / total_val
        print('Validation accuracy of the network: %.2f %%' % val_accuracy)
        wandb.log({"Validation Accuracy": val_accuracy})

        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time
        wandb.log({"Epoch Duration": epoch_duration})
        
    print('Finished Training')

### Main Block

In [None]:
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse','ship', 'truck')

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
wandb.init(project="cinic-10")

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS Device:", device)
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)    

# Load datasets
training_df = loading_into_pandas('cinic/cinic_train.lance')
testing_df = loading_into_pandas('cinic/cinic_test.lance')
validation_df = loading_into_pandas('cinic/cinic_val.lance')

# Create datasets
train_dataset = CustomImageDataset(training_df, classes, transform=transform)
test_dataset = CustomImageDataset(testing_df, classes, transform=transform)
val_dataset = CustomImageDataset(validation_df, classes, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=True)

# Define the neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # 3 input channels, 6 output channels, 5x5 kernel
        self.pool = nn.MaxPool2d(2, 2) # 2x2 pooling
        self.conv2 = nn.Conv2d(6, 16, 5) # 6 input channels, 16 output channels, 5x5 kernel
        self.fc1 = nn.Linear(16 * 5 * 5, 120) 
        self.fc2 = nn.Linear(120, 84) 
        self.fc3 = nn.Linear(84, 10) # There are 10 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Instantiate the model
net = Net()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Train the model
train_model(train_loader, val_loader, net, criterion, optimizer, device, num_epochs=10)

PATH = './cinic_net_lance.pth'
torch.save(net.state_dict(), PATH)

net = Net()
net.load_state_dict(torch.load(PATH))

correct = 0
total = 0

# Testing
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print('Accuracy of the network on the test images: %.2f %%' % accuracy)
wandb.log({"Test Accuracy": accuracy})