In [2]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.parametrize as parametrize

In [None]:
# Define hyperparameters for training
batch_size = 64
learning_rate = 1e-3
num_epochs = 10

# Define transformations for training and testing datasets
transform_train = transforms.Compose([
    transforms.Resize((32, 32)), # Resize images to 32x32
    transforms.ToTensor()
])

transform_val = transforms.Compose([
    transforms.Resize((32, 32)), # Resize images to 32x32
    transforms.ToTensor()
])

# Load Fashion MNIST dataset
train_dataset = datasets.FashionMNIST(root='./data', download=True, train=True, transform=transform_train)
val_dataset = datasets.FashionMNIST(root='./data', download=True, train=False, transform=transform_train)



# Create data loaders for training and validation sets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
print("Batch size:", batch_size)


Training dataset size: 60000
Validation dataset size: 10000
Batch size: 64


In [90]:
# label map
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

In [4]:
# Define hyperparameters for training
batch_size = 64
learning_rate = 1e-3
num_epochs = 10
momentum = 0.9
weight_decay = 5e-4

# Define the CNN model
class FashionMNISTCNN(nn.Module):
    def __init__(self, h1, h2):
        super(FashionMNISTCNN, self).__init__()
        self.linear1 = nn.Linear(32*32, h1)
        self.linear2 = nn.Linear(h1, h2)
        self.linear3 = nn.Linear(h2, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x_s = x.shape
        x = x.reshape(x_s[0], x_s[1], x_s[-2]*x_s[-1])
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [61]:
model = FashionMNISTCNN(h1=1000, h2=2000)
model.to("cuda")
num_params = sum([p.numel() for p in model.parameters()])
print(f"Num parameters in model - {num_params}")

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

# Train the model
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")
        
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs.squeeze(1)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {running_loss / (i + 1)}')

# Save the model
torch.save(model.state_dict(), 'model.pth')

Num parameters in model - 3047010
Epoch 1, Loss: 1.1668907664477952
Epoch 2, Loss: 0.6404037685917893
Epoch 3, Loss: 0.5464359337904814
Epoch 4, Loss: 0.5001209350282958
Epoch 5, Loss: 0.474462777868644
Epoch 6, Loss: 0.45565528541739814
Epoch 7, Loss: 0.4417266100486204
Epoch 8, Loss: 0.42736764938465316
Epoch 9, Loss: 0.4191315769831509
Epoch 10, Loss: 0.41034741289834226


In [91]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in val_loader:
            x, y = data
            x = x.to("cuda")
            y = y.to("cuda")
            output = model(x)
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Accuracy: 0.849
wrong counts for the digit 0: 185
wrong counts for the digit 1: 43
wrong counts for the digit 2: 220
wrong counts for the digit 3: 125
wrong counts for the digit 4: 253
wrong counts for the digit 5: 74
wrong counts for the digit 6: 440
wrong counts for the digit 7: 59
wrong counts for the digit 8: 49
wrong counts for the digit 9: 67


## Model Parameterization -- 
### What is it? Why it is necessary? How to do it?

Model reparameterization is a technique of transforming model's parameters for certain operations e.g. Regularization. To reparametrize the model parameters, we need to  instruct pytorch on how we want to transform parameters of certain layers. One way of doing this is to implement the reparameterization by hand and then manually transforming the parameters. Below is an example of reparameterization that enforces orthogonality of weights ->

In [6]:
# function to convert model's parameter into orthogonal matrix
def orthogonal_params(weights):
  q, r = torch.linalg.qr(weights)
  return q

# let's define a simple model
class SimpleModelManual(nn.Module):
    def __init__(self, in_features, out_features):
        super(SimpleModelManual, self).__init__()
        # Define an unconstrained parameter
        self.weight_raw = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))  # Optional bias

    def forward(self, x):
        # Apply the orthogonal reparameterization manually
        Q = orthogonal_params(self.weight_raw)  # QR decomposition to enforce orthogonality
        # Perform linear transformation with reparameterized weight
        return x @ Q.T + self.bias

In [7]:
# let's check if our reparameterization works

# Model parameters
in_features = 5
out_features = 5

# Create the model
model = SimpleModelManual(in_features, out_features)

# Input tensor
x = torch.randn(3, in_features)

# Forward pass
output = model(x)

# Check if the reparameterized weight matrix is orthogonal
Q, _ = torch.linalg.qr(model.weight_raw)  # Apply QR to get orthogonal weight
print("Reparameterized weight matrix (Q):\n", Q)
print("Is orthogonal (Q^T Q = I):\n", torch.allclose(Q.T @ Q, torch.eye(out_features), atol=1e-6))



Reparameterized weight matrix (Q):
 tensor([[-0.6216, -0.2397, -0.6985,  0.2387, -0.1059],
        [ 0.6800, -0.1291, -0.3929,  0.5752,  0.1889],
        [ 0.0138,  0.9426, -0.3252, -0.0023, -0.0750],
        [ 0.3062, -0.1319, -0.1589, -0.2023, -0.9070],
        [ 0.2392, -0.1415, -0.4761, -0.7558,  0.3533]],
       grad_fn=<LinalgQrBackward0>)
Is orthogonal (Q^T Q = I):
 True


In manual implementation of reparameterization, we had to reimplement the linear transformation which is readially available in pytorch as nn.Linear(). It will become tedious to say the least or even impossible to reimplement several layers for such manual implementatins. 

Luckily, pytorch has made life easier once again, and such reparameterizations are made easy using `register_parametrization`. Below is an example of Pytorch's register_parametrization.

In [8]:
class OrthogonalParam(nn.Module):
    def forward(self, W):
        # Use QR decomposition to make W orthogonal
        Q, R = torch.linalg.qr(W)
        return Q


class SimpleModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        
        # Apply orthogonal parameterization to the weight of the linear layer
        parametrize.register_parametrization(self.linear, "weight", OrthogonalParam())

    def forward(self, x):
        return self.linear(x)


In [9]:
# Model parameters
in_features = 5
out_features = 5

# Create the model
model = SimpleModel(in_features, out_features)

# Input tensor
x = torch.randn(3, in_features)

# Forward pass
output = model(x)

# Verify orthogonality
W = model.linear.weight  # Get the parameterized weight
# W_original = model.linear.parametrizations.weight.original
print("Weight matrix:\n", W)
print("Is orthogonal (W^T W = I):\n", torch.allclose(W.T @ W, torch.eye(out_features), atol=1e-6))


Weight matrix:
 tensor([[-0.5994, -0.3149,  0.1313,  0.6633, -0.2904],
        [ 0.3432, -0.3821, -0.3499, -0.1396, -0.7710],
        [ 0.4510,  0.6129, -0.0295,  0.6162, -0.2012],
        [-0.5445,  0.6158, -0.1332, -0.3582, -0.4222],
        [-0.1521, -0.0084, -0.9175,  0.1804,  0.3202]],
       grad_fn=<LinalgQrBackward0>)
Is orthogonal (W^T W = I):
 True


See how easy it was to parameterize model when using register_parametrization. 

Now we need to apply same trick for LoRA.. Why?
Because we need to transform the model's parameters to a rank deficiant matrix. i.e. we need to transform the paramters for certain operations...

## LoRA implementation

In [57]:
class LoRA(nn.Module):
  def __init__(self, in_feat, out_feat, rank, alpha, device="cuda") -> None:
    super(LoRA, self).__init__()

    # as per paper, we need to break down ∆W into A and B matrices such that ∆W = BA. Additionally ∆W = BA should be zeros at the beginning 
    # The A matrix is initialized randomly from a normal distribution, and B is all zeros

    self.lora_b = nn.Parameter(torch.zeros(in_feat, rank)).to(device)
    self.lora_a = nn.Parameter(torch.zeros(rank, out_feat)).to(device)
    nn.init.normal_(self.lora_a, mean=0, std=1)

    # scaling the ∆W by using alpha -- section 4.1 in paper "We then scale ∆Wx by α/r , where α is a constant in r"
    self.scale = alpha / rank
    self.enabled = True

  def forward(self, original_w):
    delta_w = torch.matmul(self.lora_b, self.lora_a).view(original_w.shape) * self.scale
    return original_w 

now, let's add this reparametrization to out FashionMNIST network

In [83]:
def parametrize_model(curr_layer, device, rank=1, lora_alpha=1):
  """adds matirx A and B of LoRA to each liner layer
  """
  in_feat, out_feat = curr_layer.weight.shape
  return LoRA(in_feat, out_feat, rank, lora_alpha, device)

parametrize.register_parametrization(
    model.linear1, "weight", parametrize_model(model.linear1, "cuda")
)
parametrize.register_parametrization(
    model.linear2, "weight", parametrize_model(model.linear2, "cuda")
)
parametrize.register_parametrization(
    model.linear3, "weight", parametrize_model(model.linear3, "cuda")
)


def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled
  

In [88]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_a.nelement() + \
    layer.parametrizations["weight"][0].lora_b.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == num_params
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')


Total number of parameters (original): 3,047,010
Total number of parameters (original + LoRA): 3,054,044
Parameters introduced by LoRA: 7,034
Parameters incremment: 0.231%


Let's finetune our model on worst performing class i.e. cls_id = 6

In [92]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


In [93]:
# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
exclude_indices = mnist_trainset.targets == 6
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)


In [None]:
# Train the model
for epoch in range(2):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")
        
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs.squeeze(1)

        loss = criterion(outputs, labels)
        loss.requires_grad = True
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {running_loss / (i + 1)}')


Epoch 1, Loss: 0.9729501763556866
Epoch 2, Loss: 0.9728141014880323
Epoch 3, Loss: 0.973069206831303
Epoch 4, Loss: 0.9727423666639531
Epoch 5, Loss: 0.9722493982061426
Epoch 6, Loss: 0.9724092033315213
Epoch 7, Loss: 0.9729931601818572
Epoch 8, Loss: 0.9727127234986488
Epoch 9, Loss: 0.9728179035034585
Epoch 10, Loss: 0.9722128377315846


In [97]:
# testing fine-tuned model

enable_disable_lora(enabled=False)
test()

Accuracy: 0.849
wrong counts for the digit 0: 185
wrong counts for the digit 1: 43
wrong counts for the digit 2: 220
wrong counts for the digit 3: 125
wrong counts for the digit 4: 253
wrong counts for the digit 5: 74
wrong counts for the digit 6: 440
wrong counts for the digit 7: 59
wrong counts for the digit 8: 49
wrong counts for the digit 9: 67
