In [63]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import torch.nn.utils.parametrize as parametrize

#### Make torch deterministic

In [43]:
torch.manual_seed(42)

<torch._C.Generator at 0x10e3c6550>

### Load MNSIT dataset

In [44]:
BATCH_SIZE = 32
DATA_DIR = './data'
EPOCHS = 50
LEARNING_RATE = 0.001
NUM_CLASSES = 10
EARLY_STOPPING_PATIENCE = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [45]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transform)
validation_dataset = datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Creating a simple Neural Network model

In [46]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, NUM_CLASSES)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [47]:
model = MNISTClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [69]:
def validation(epoch, validation_loader):
    model.eval()
    validation_loss = 0
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in tqdm(validation_loader, desc=f"Validation Epoch {epoch+1}"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            validation_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    avg_loss = validation_loss / len(validation_loader)
    print(f"Epoch: {epoch+1} Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

In [70]:
def train_one_epoch(epoch, train_loader):
    model.train()
    epoch_loss = 0.0

    for batch_index, data in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{EPOCHS}")):
        inputs, labels = data
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        if (batch_index + 1) % 1000 == 0:
            print(f"Epoch: {epoch+1}, Batch: {batch_index+1}, Loss: {epoch_loss:.4f}")
    return epoch_loss

In [71]:
def train(train_loader, validation_loader):
    stop_counter = 0
    best_val_loss = float('inf')
    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(epoch, train_loader)
        val_loss, val_accuracy = validation(epoch, validation_loader)
        print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

        if val_loss < best_val_loss:
            stop_counter = 0
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Model saved with validation loss: {best_val_loss:.4f}")
        else:
            stop_counter += 1
            if stop_counter >= EARLY_STOPPING_PATIENCE:
                print("Early stopping triggered")
                break

### Train the MNIST classifier model

In [51]:
train()

Training Epoch 1/50:  54%|█████▍    | 1008/1875 [00:17<00:15, 56.82it/s]

Epoch: 1, Batch: 1000, Loss: 195.1587


Training Epoch 1/50: 100%|██████████| 1875/1875 [00:33<00:00, 56.51it/s]
Validation Epoch 1: 100%|██████████| 313/313 [00:01<00:00, 157.65it/s]


Epoch: 1 Validation Loss: 0.0442, Accuracy: 98.57%
Epoch 1/50, Train Loss: 246.8885, Validation Loss: 0.0442, Validation Accuracy: 98.57%
Model saved with validation loss: 0.0442


Training Epoch 2/50:  54%|█████▍    | 1008/1875 [00:18<00:15, 57.02it/s]

Epoch: 2, Batch: 1000, Loss: 42.6863


Training Epoch 2/50: 100%|██████████| 1875/1875 [00:33<00:00, 56.41it/s]
Validation Epoch 2: 100%|██████████| 313/313 [00:01<00:00, 158.43it/s]


Epoch: 2 Validation Loss: 0.0366, Accuracy: 98.75%
Epoch 2/50, Train Loss: 76.3925, Validation Loss: 0.0366, Validation Accuracy: 98.75%
Model saved with validation loss: 0.0366


Training Epoch 3/50:  54%|█████▍    | 1008/1875 [00:17<00:15, 56.75it/s]

Epoch: 3, Batch: 1000, Loss: 28.9526


Training Epoch 3/50: 100%|██████████| 1875/1875 [00:33<00:00, 55.97it/s]
Validation Epoch 3: 100%|██████████| 313/313 [00:02<00:00, 154.19it/s]


Epoch: 3 Validation Loss: 0.0319, Accuracy: 98.92%
Epoch 3/50, Train Loss: 53.0841, Validation Loss: 0.0319, Validation Accuracy: 98.92%
Model saved with validation loss: 0.0319


Training Epoch 4/50:  54%|█████▍    | 1008/1875 [00:18<00:15, 56.02it/s]

Epoch: 4, Batch: 1000, Loss: 18.8826


Training Epoch 4/50: 100%|██████████| 1875/1875 [00:33<00:00, 55.42it/s]
Validation Epoch 4: 100%|██████████| 313/313 [00:01<00:00, 158.47it/s]


Epoch: 4 Validation Loss: 0.0448, Accuracy: 98.64%
Epoch 4/50, Train Loss: 38.9755, Validation Loss: 0.0448, Validation Accuracy: 98.64%


Training Epoch 5/50:  54%|█████▍    | 1008/1875 [00:18<00:15, 55.44it/s]

Epoch: 5, Batch: 1000, Loss: 14.4855


Training Epoch 5/50: 100%|██████████| 1875/1875 [00:33<00:00, 55.41it/s]
Validation Epoch 5: 100%|██████████| 313/313 [00:02<00:00, 146.05it/s]


Epoch: 5 Validation Loss: 0.0326, Accuracy: 98.99%
Epoch 5/50, Train Loss: 29.1234, Validation Loss: 0.0326, Validation Accuracy: 98.99%


Training Epoch 6/50:  54%|█████▍    | 1008/1875 [00:17<00:15, 56.83it/s]

Epoch: 6, Batch: 1000, Loss: 11.7264


Training Epoch 6/50: 100%|██████████| 1875/1875 [00:33<00:00, 56.13it/s]
Validation Epoch 6: 100%|██████████| 313/313 [00:01<00:00, 160.40it/s]


Epoch: 6 Validation Loss: 0.0349, Accuracy: 98.96%
Epoch 6/50, Train Loss: 21.9345, Validation Loss: 0.0349, Validation Accuracy: 98.96%


Training Epoch 7/50:  54%|█████▍    | 1008/1875 [00:18<00:15, 56.70it/s]

Epoch: 7, Batch: 1000, Loss: 7.7661


Training Epoch 7/50: 100%|██████████| 1875/1875 [00:33<00:00, 55.66it/s]
Validation Epoch 7: 100%|██████████| 313/313 [00:02<00:00, 153.00it/s]


Epoch: 7 Validation Loss: 0.0393, Accuracy: 98.95%
Epoch 7/50, Train Loss: 19.7054, Validation Loss: 0.0393, Validation Accuracy: 98.95%


Training Epoch 8/50:  54%|█████▍    | 1010/1875 [00:19<00:15, 54.36it/s]

Epoch: 8, Batch: 1000, Loss: 6.9122


Training Epoch 8/50: 100%|██████████| 1875/1875 [00:34<00:00, 54.25it/s]
Validation Epoch 8: 100%|██████████| 313/313 [00:01<00:00, 160.22it/s]

Epoch: 8 Validation Loss: 0.0377, Accuracy: 99.08%
Epoch 8/50, Train Loss: 15.8468, Validation Loss: 0.0377, Validation Accuracy: 99.08%
Early stopping triggered





### Keep a copy of the trained model (compare after performing LoRA)

In [56]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

### Check performance of the trained model all digits

In [62]:
def test():
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()
    test_loss = 0
    correct, total = 0, 0
    digit_count = {i: 0 for i in range(NUM_CLASSES)}
    digit_correct = {i: 0 for i in range(NUM_CLASSES)}
    with torch.no_grad():
        for inputs, labels in tqdm(validation_loader, desc="Testing"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            for i in range(len(labels)):
                digit_count[labels[i].item()] += 1
                if predicted[i] == labels[i]:
                    digit_correct[labels[i].item()] += 1
    accuracy = 100 * correct / total
    avg_loss = test_loss / len(validation_loader)
    print(f"Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    for digit in range(NUM_CLASSES):
        if digit_count[digit] > 0:
            digit_accuracy = 100 * digit_correct[digit] / digit_count[digit]
        else:
            digit_accuracy = 0
        print(f"Digit {digit}: {digit_accuracy:.2f}% Incorrect: {digit_count[digit] - digit_correct[digit]} Total: {digit_count[digit]}")
    return avg_loss, accuracy

test()

Testing: 100%|██████████| 313/313 [00:02<00:00, 149.27it/s]

Test Loss: 0.0319, Accuracy: 98.92%
Digit 0: 99.59% Incorrect: 4 Total: 980
Digit 1: 99.65% Incorrect: 4 Total: 1135
Digit 2: 99.42% Incorrect: 6 Total: 1032
Digit 3: 99.70% Incorrect: 3 Total: 1010
Digit 4: 98.57% Incorrect: 14 Total: 982
Digit 5: 98.32% Incorrect: 15 Total: 892
Digit 6: 98.64% Incorrect: 13 Total: 958
Digit 7: 98.93% Incorrect: 11 Total: 1028
Digit 8: 98.05% Incorrect: 19 Total: 974
Digit 9: 98.12% Incorrect: 19 Total: 1009





(0.031945871058609375, 98.92)

### Number of parameters in the MNIST classifier model

In [60]:
total_params = 0
for index, layer in enumerate([model.conv1, model.conv2, model.fc1, model.fc2]):
    num_params = sum(p.numel() for p in layer.parameters())
    total_params += num_params
    print(f"Layer {index+1}: {num_params} parameters")
print(f"Total parameters in the model: {total_params}")

Layer 1: 320 parameters
Layer 2: 18496 parameters
Layer 3: 401536 parameters
Layer 4: 1290 parameters
Total parameters in the model: 421642


### LoRA Parameterization

##### Use a gaussian initialization for A and zero for B, so delta W = BA is zero at the begining of training. We then scale delta Wx by alpha/r, where alpha is a constant in r. When optimizing with Adam, tuning alpha is roughly the same as tuning the learning rate if we scale the initialization appropriately. As a result, we simply set alpha to the first r we try and do not tune it. This scaling helps to reduce the need to retune hyperparameters when we vary r.

In [66]:
class LoRAParameterization(nn.Module):
    def __init__(self, dim_input, dim_output, rank=1, alpha=1, device='cpu'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros(rank, dim_output).to(device))
        self.lora_B = nn.Parameter(torch.zeros(dim_input, rank).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + (self.lora_B @ self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

#### Create a LoRA parameterization instance for a given linear layer by inspecting its weight shape

In [67]:
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    lora = LoRAParameterization(features_in, features_out, rank=rank, alpha=lora_alpha, device=device)
    return lora

RANK = 4
parametrize.register_parametrization(
    model.fc1, "weight", linear_layer_parameterization(model.fc1, device=DEVICE, rank=RANK)
)

parametrize.register_parametrization(
    model.fc2, "weight", linear_layer_parameterization(model.fc2, device=DEVICE, rank=RANK)
)

def enable_disable_lora(enabled=True):
    for layer in [model.fc1, model.fc2]:
        layer.parametrizations["weight"][0].enabled = enabled

#### Number of parameters added by LoRA

In [68]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.fc1, model.fc2]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.numel() + layer.parametrizations["weight"][0].lora_B.numel()
    total_parameters_non_lora += layer.weight.numel() + layer.bias.numel()
    print(f"Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + A: {layer.parametrizations['weight'][0].lora_A.shape} + B: {layer.parametrizations['weight'][0].lora_B.shape}")

print(f"Total parameters in LoRA: {total_parameters_lora}")
print(f"Total parameters in non-LoRA: {total_parameters_non_lora}")

Layer 1: W: torch.Size([128, 3136]) + B: torch.Size([128]) + A: torch.Size([4, 3136]) + B: torch.Size([128, 4])
Layer 2: W: torch.Size([10, 128]) + B: torch.Size([10]) + A: torch.Size([4, 128]) + B: torch.Size([10, 4])
Total parameters in LoRA: 13608
Total parameters in non-LoRA: 402826


#### Freeze all linear layer parameters and fine tune the LoRA parameters

In [72]:
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f"Freezing non-lora parameters: {name}")
        param.requires_grad = False

Freezing non-lora parameters: conv1.weight
Freezing non-lora parameters: conv1.bias
Freezing non-lora parameters: conv2.weight
Freezing non-lora parameters: conv2.bias
Freezing non-lora parameters: fc1.bias
Freezing non-lora parameters: fc1.parametrizations.weight.original
Freezing non-lora parameters: fc2.bias
Freezing non-lora parameters: fc2.parametrizations.weight.original


#### Load MNIST data and retrain the least accurate digit again using LoRA params

In [74]:
mnist_lora_train = datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transform)
exclude_indices = mnist_lora_train.targets == 8
mnist_lora_train.data = train_dataset.data[exclude_indices]
mnist_lora_train.targets = train_dataset.targets[exclude_indices]
mnist_lora_train_loader = DataLoader(mnist_lora_train, batch_size=BATCH_SIZE, shuffle=True)
mnist_lora_validation = datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transform)
exclude_indices = mnist_lora_validation.targets == 8
mnist_lora_validation.data = validation_dataset.data[exclude_indices]
mnist_lora_validation.targets = validation_dataset.targets[exclude_indices]
mnist_lora_validation_loader = DataLoader(mnist_lora_validation, batch_size=BATCH_SIZE, shuffle=False)

In [75]:
train(mnist_lora_train_loader, mnist_lora_validation_loader)

Training Epoch 1/50: 100%|██████████| 183/183 [00:01<00:00, 114.10it/s]
Validation Epoch 1: 100%|██████████| 31/31 [00:00<00:00, 132.90it/s]


Epoch: 1 Validation Loss: 0.0596, Accuracy: 98.05%
Epoch 1/50, Train Loss: 5.3257, Validation Loss: 0.0596, Validation Accuracy: 98.05%
Model saved with validation loss: 0.0596


Training Epoch 2/50: 100%|██████████| 183/183 [00:01<00:00, 117.30it/s]
Validation Epoch 2: 100%|██████████| 31/31 [00:00<00:00, 127.96it/s]


Epoch: 2 Validation Loss: 0.0596, Accuracy: 98.05%
Epoch 2/50, Train Loss: 5.3175, Validation Loss: 0.0596, Validation Accuracy: 98.05%


Training Epoch 3/50: 100%|██████████| 183/183 [00:01<00:00, 113.52it/s]
Validation Epoch 3: 100%|██████████| 31/31 [00:00<00:00, 135.09it/s]


Epoch: 3 Validation Loss: 0.0596, Accuracy: 98.05%
Epoch 3/50, Train Loss: 5.3237, Validation Loss: 0.0596, Validation Accuracy: 98.05%


Training Epoch 4/50: 100%|██████████| 183/183 [00:01<00:00, 120.71it/s]
Validation Epoch 4: 100%|██████████| 31/31 [00:00<00:00, 138.48it/s]


Epoch: 4 Validation Loss: 0.0596, Accuracy: 98.05%
Epoch 4/50, Train Loss: 5.3211, Validation Loss: 0.0596, Validation Accuracy: 98.05%


Training Epoch 5/50: 100%|██████████| 183/183 [00:01<00:00, 119.82it/s]
Validation Epoch 5: 100%|██████████| 31/31 [00:00<00:00, 137.90it/s]


Epoch: 5 Validation Loss: 0.0596, Accuracy: 98.05%
Epoch 5/50, Train Loss: 5.3171, Validation Loss: 0.0596, Validation Accuracy: 98.05%


Training Epoch 6/50: 100%|██████████| 183/183 [00:01<00:00, 121.32it/s]
Validation Epoch 6: 100%|██████████| 31/31 [00:00<00:00, 131.21it/s]

Epoch: 6 Validation Loss: 0.0596, Accuracy: 98.05%
Epoch 6/50, Train Loss: 5.3196, Validation Loss: 0.0596, Validation Accuracy: 98.05%
Early stopping triggered





#### Verify fine-tuning didn't alter original weights but only the ones introduced by LoRA

In [77]:
# assert torch.all(model.fc1.parametrizations.weight.original == original_weights['fc1.weight'])
# assert torch.all(model.fc2.parametrizations.weight.original == original_weights['fc2.weight'])

enable_disable_lora(enabled=True)
assert torch.equal(model.fc1.weight, model.fc1.parametrizations.weight.original + (model.fc1.parametrizations.weight[0].lora_B @ model.fc1.parametrizations.weight[0].lora_A) * model.fc1.parametrizations.weight[0].scale)
enable_disable_lora(enabled=False)
assert torch.equal(model.fc1.weight, model.fc1.parametrizations.weight.original)

#### Test with LoRA enabled

In [78]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 313/313 [00:02<00:00, 136.09it/s]

Test Loss: 0.0319, Accuracy: 98.92%
Digit 0: 99.59% Incorrect: 4 Total: 980
Digit 1: 99.65% Incorrect: 4 Total: 1135
Digit 2: 99.42% Incorrect: 6 Total: 1032
Digit 3: 99.70% Incorrect: 3 Total: 1010
Digit 4: 98.57% Incorrect: 14 Total: 982
Digit 5: 98.32% Incorrect: 15 Total: 892
Digit 6: 98.64% Incorrect: 13 Total: 958
Digit 7: 98.93% Incorrect: 11 Total: 1028
Digit 8: 98.05% Incorrect: 19 Total: 974
Digit 9: 98.12% Incorrect: 19 Total: 1009





(0.031945871058609375, 98.92)

#### Test network with LoRA disabled (accuracy and error counts must be the same as the original network)

In [79]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 313/313 [00:02<00:00, 150.14it/s]

Test Loss: 0.0319, Accuracy: 98.92%
Digit 0: 99.59% Incorrect: 4 Total: 980
Digit 1: 99.65% Incorrect: 4 Total: 1135
Digit 2: 99.42% Incorrect: 6 Total: 1032
Digit 3: 99.70% Incorrect: 3 Total: 1010
Digit 4: 98.57% Incorrect: 14 Total: 982
Digit 5: 98.32% Incorrect: 15 Total: 892
Digit 6: 98.64% Incorrect: 13 Total: 958
Digit 7: 98.93% Incorrect: 11 Total: 1028
Digit 8: 98.05% Incorrect: 19 Total: 974
Digit 9: 98.12% Incorrect: 19 Total: 1009





(0.031945871058609375, 98.92)