In [2]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader, Subset, random_split

# Define the transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Download the dataset

n_classes = 10
batch_size = 32
lora_rank = 0  # when 0, only train the last layer 

trainset = torchvision.datasets.Food101(root='../data', split='train', download=True, transform=transform)
testset = torchvision.datasets.Food101(root='../data', split='test', download=True, transform=transform)

# Select top k most frequent classes
# classes, counts = np.unique(trainset._labels, return_counts=True)
# idx = np.argsort(counts)[-n_classes:]
# target_classes = classes[idx]
# print('target_classes', target_classes)

target_classes = list(range(n_classes))



# Filter the dataset to include only samples from the target classes
def filter_classes(dataset, target_classes):
    targets = np.array(dataset._labels)
    mask = np.isin(targets, target_classes)
    indices = np.where(mask)[0]
    return Subset(dataset, indices)

trainset_filtered = filter_classes(trainset, target_classes)
testset_filtered = filter_classes(testset, target_classes)

# TODO select only a couple of classes? 

# Create a smaller subset for training
# subset_indices = torch.randperm(len(dataset)) # [:100]  # Using 100 samples for training
# trainset_subset = Subset(dataset, subset_indices)

trainloader = DataLoader(trainset_filtered, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset_filtered, batch_size=batch_size, shuffle=False, num_workers=2)


print(f'Training Samples - Train {len(trainloader) * batch_size} Test {len(testloader) * batch_size}')
print(f'Batch size: {batch_size}')

Training Samples - Train 7520 Test 2528
Batch size: 32


In [3]:
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, ResNet18_Weights
import math

class LoRALayer(nn.Module):
    def __init__(self, base_layer, rank=4):
        super(LoRALayer, self).__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.lora_A = nn.Parameter(torch.randn(base_layer.weight.size(0), rank))
        self.lora_B = nn.Parameter(torch.randn(rank, base_layer.weight.size(1)))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5))

    def forward(self, x):
        lora_weight = torch.matmul(self.lora_A, self.lora_B)
        new_weight = self.base_layer.weight + lora_weight
        return nn.functional.linear(x, new_weight, self.base_layer.bias)


def apply_lora(model, rank):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            lora_layer = LoRALayer(module, rank)
            setattr(model, name, lora_layer)
            lora_layer.lora_A.requires_grad = True
            lora_layer.lora_B.requires_grad = True
    return model
    

def print_trainable_parameters(model): 
    # Calculate total number of parameters and trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f'Total number of parameters: {total_params}')
    print(f'Number of trainable parameters: {trainable_params}')


# net = resnet18(pretrained=True)
net = resnet18(weights=ResNet18_Weights.DEFAULT)


# Freeze all parameters
for param in net.parameters():
    param.requires_grad = False

if lora_rank > 0: 
    net = apply_lora(net, rank=lora_rank)


net.fc = nn.Linear(512, n_classes) 
net.fc.requires_grad = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device', device)
net.to(device)


print_trainable_parameters(net)

print('done')


device cuda:0
Total number of parameters: 11181642
Number of trainable parameters: 5130
done


In [None]:
import os 
from os.path import join

lr = 0.01 #  0.001 

criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)


print('Training ...')


best_valid_loss = 9999
model_dir = './models'
os.makedirs(model_dir, exist_ok=True)
export_path = join(model_dir, f'model_lorarank{lora_rank}.pth')

for epoch in range(21): 
    # --- Compute test loss 
    if epoch % 5 == 0: 
        net.eval()
        test_losses = []
        correct, total = 0, 0

        with torch.no_grad():
            for i, (images, labels) in enumerate(testloader):
                if i == 200:
                    break
                images, labels = images.to(device), labels.to(device)
                logits = net(images) 
                
                test_loss = criterion(logits, labels).item()
                test_losses.append(test_loss)
 
                _, predicted = torch.max(logits.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
        test_loss = np.mean(test_losses)
        test_accuracy = correct / total
        print(f'Epoch {epoch + 1}, Test loss: {test_loss:.3f}, Test accuracy: {test_accuracy:.3f}')

        if test_loss < best_valid_loss:
            best_valid_loss = test_loss
            if epoch > 0: 
                torch.save(net.state_dict(), export_path)
                print('Model exported')
        
        net.train()

    # --- Train the model
    
    train_losses = []
    
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
 
    train_loss = np.mean(train_losses)
    print(f'[{epoch + 1}, {i + 1}] loss: {train_loss:.3f}')

print('Finished Training')

Training ...
Epoch 1, Test loss: 2.647, Test accuracy: 0.100
[1, 235] loss: 1.282
[2, 235] loss: 1.107
[3, 235] loss: 1.078
[4, 235] loss: 1.040
[5, 235] loss: 1.041
Epoch 6, Test loss: 1.023, Test accuracy: 0.716
Model exported
[6, 235] loss: 0.998
[7, 235] loss: 0.978
[8, 235] loss: 0.936
[9, 235] loss: 1.042
[10, 235] loss: 1.015
Epoch 11, Test loss: 1.037, Test accuracy: 0.728
