In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch import tensor
from matplotlib import pyplot as plt
from torchvision import datasets

In [None]:
dtype = torch.float
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using GPU: Metal Performance Shaders (MPS)")
else:
    device = torch.device('cpu')
    print("Using CPU")

# Tensor creation
x = tensor([1.0, 2.0, 3.0], device=device, dtype=dtype)
print(f"Tensor: {x}, Device: {x.device}")

In [None]:
# Is MPS even available? macOS 12.3+
print(torch.backends.mps.is_available())
# Was the current version of PyTorch built with MPS activated?
print(torch.backends.mps.is_built())

In [None]:
# Loading torch modules

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
# Define the path to your dataset in the Downloads folder
dataset_path = "/Users/jalajtrivedi/Downloads/inaturalist_12K"

# Define transformations (e.g., resizing, normalization)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Load training and validation datasets
train_dataset = datasets.ImageFolder(root=f"{dataset_path}/train", transform=transform)
val_dataset = datasets.ImageFolder(root=f"{dataset_path}/val", transform=transform)

In [None]:
print(len(train_dataset))
print(len(val_dataset))

In [None]:
print(train_dataset.classes)

In [None]:
print(train_dataset.samples[0])

In [None]:
image,label = train_dataset[0]
print(image.shape)
print(label)
print(train_dataset.classes[label])

In [None]:
# Create DataLoaders for batching
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,drop_last=False)
validation_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,drop_last=False)

In [None]:
print("No.of batches in train data:",len(train_loader))
print("No.of batches in val data:",len(validation_loader))

In [None]:
#Accessing 1 batch of training data
images,labels = next(iter(train_loader))
print('Size of batch',images.shape)
print('Size of labels',labels.shape)

In [None]:
x_t = iter(train_loader)

In [None]:
imgs, labels = next(x_t)

fig,ax = plt.subplots(1,4,figsize=(8,10))
i = 0
for i, (img, label) in enumerate(zip(imgs[:4], labels[:4])):
    img = img/2 +0.5
    img = torch.clamp(img, 0, 1)
    npimg = img.numpy()
    ax[i].imshow(np.transpose(npimg, (1, 2, 0)))
    ax[i].set_xlabel(train_dataset.classes[label])
    i += 1

In [None]:
#Finetunning a pretrained model(resnet50)
from torchvision import models
resnet50 = models.resnet50(pretrained=True)
num_classes = 10

In [None]:
resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)

In [None]:
#freeze the parameters of the pre-trained layers
for param in resnet50.parameters():
    param.requires_grad = False  # Freeze all layers

# for param in resnet50.layer2.parameters():
#     param.requires_grad = True  # Unfreeze layer2 onwards

# for param in resnet50.layer3.parameters():
#     param.requires_grad = True

for param in resnet50.layer4.parameters():
    param.requires_grad = True

# Move the model to the correct device
resnet50 = resnet50.to(device)

In [None]:
# For Training Data:

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(),lr=0.001,momentum=0.9)

num_epochs = 5

train_losses = []
best_val_loss = float('inf')  # Initialize with a very large value

# Train the model
for epoch in range(num_epochs):
    
    # Train the model on the training set
    resnet50.train()
    
    # Initialize the training loss accumulator to zero
    running_loss = 0.0
    running_correct = 0.0
    total_samples = 0.0
    
    for i, (inputs, labels) in enumerate(train_loader,0):
        # Prepare data and send it to the proper device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Clear the gradients of all optimized parameters
        optimizer.zero_grad()

        # Forward pass: obtain model predictions for the input data
        outputs = resnet50(inputs)

        # Compute the loss between the model predictions and the true labels
        loss = criterion(outputs, labels)
        running_loss += loss.item() * inputs.size(0)

        # Backward pass: compute gradients of the loss with respect to model parameters
        loss.backward()

        # Update model parameters using the computed gradients and the optimizer
        optimizer.step()

        # Accuracy
        class_correct = torch.argmax(outputs, axis=1) == labels
        running_correct += torch.count_nonzero(class_correct)
        total_samples += labels.size(0)

    # Metrics for the epoch
    epoch_loss = running_loss / total_samples
    epoch_accuracy = running_correct / total_samples

    print(f'Epoch {epoch + 1}/{num_epochs} - Training Loss: {epoch_loss:.3f}, Training Accuracy: {epoch_accuracy:.3f}')

In [None]:
# Move the model to the correct device
resnet50 = resnet50.to(device)

# Evaluate the model on the validation set
resnet50.eval()
val_loss = 0.0
val_correct = 0.0
total_samples = 0.0
with torch.no_grad():
    for inputs, labels in validation_loader:
        # Prepare data and send it to the proper device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass: obtain model predictions for the input data
        outputs = resnet50(inputs)

        # Compute the loss between the model predictions and the true labels
        loss = criterion(outputs, labels)

        # Update the validation loss
        val_loss += loss.item()* inputs.size(0)
        
        # Calculate how many images were correctly classified
        class_correct = torch.argmax(outputs, axis=1) == labels
        val_correct += torch.count_nonzero(class_correct)
        total_samples += labels.size(0)

# Calculate validation loss
val_loss /= total_samples

# Calculate validation accuracy
val_acc = val_correct / total_samples
# Print validation loss and accuracy
print(f"Validation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}")