In [1]:
import torch
import torch.nn as nn
from torchsummary import summary
import torch.nn.utils.parametrize as parametrize

In [2]:
import data_loader
from mnist_model import MNIST_MODEL
import config
import train
import test
import model_utils
import lora_parameterization
from fine_tune import fine_tune

In [3]:
_ = torch.manual_seed(42)

In [4]:
train_dataset = data_loader.load_mnist_dataset(train=True)
train_data_loader = data_loader.data_loader(dataset=train_dataset)

test_dataset = data_loader.load_mnist_dataset(train=False)
test_data_loader = data_loader.data_loader(dataset=test_dataset)

In [5]:
model = MNIST_MODEL()
summary(model, input_size=(1,28,28))
model = model.to(config.DEVICE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1000]         785,000
              ReLU-2                 [-1, 1000]               0
            Linear-3                 [-1, 2000]       2,002,000
              ReLU-4                 [-1, 2000]               0
            Linear-5                   [-1, 10]          20,010
              ReLU-6                   [-1, 10]               0
Total params: 2,807,010
Trainable params: 2,807,010
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 10.71
Estimated Total Size (MB): 10.76
----------------------------------------------------------------


#### Train the model

In [6]:
train.train(model=model, data_loader=train_data_loader, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:50<00:00, 118.15it/s, loss=2.07]


#### The Trained model's original weights are preserved, this is represented by `W` matrix

In [7]:
weights_dict = model_utils.extract_weights(model=model)

#### Test the model

In [8]:
test.test(model=model, data_loader=test_data_loader)

Testing: 100%|██████████| 1000/1000 [00:10<00:00, 97.61it/s]

Accuracy : 0.21
Wrong Counts for the digit 0 : 1
Wrong Counts for the digit 1 : 13
Wrong Counts for the digit 2 : 1032
Wrong Counts for the digit 3 : 1010
Wrong Counts for the digit 4 : 982
Wrong Counts for the digit 5 : 892
Wrong Counts for the digit 6 : 958
Wrong Counts for the digit 7 : 1028
Wrong Counts for the digit 8 : 974
Wrong Counts for the digit 9 : 1009





#### Let's look at the parameters for each layer in the original network, before applying LoRA matrices

In [9]:
original_parameters_count = model_utils.count_total_parameters(
    model_layers=[model.linear_layer_1, model.linear_layer_2, model.linear_layer_3])
print(f'Total number of parameters: {original_parameters_count:,}')

Layer 1: W (Matrix): torch.Size([1000, 784]) + Bias : torch.Size([1000])
Layer 2: W (Matrix): torch.Size([2000, 1000]) + Bias : torch.Size([2000])
Layer 3: W (Matrix): torch.Size([10, 2000]) + Bias : torch.Size([10])
Total number of parameters: 2,807,010


#### Adding the LoRA parameterization to the trained Model

In [10]:
parametrize.register_parametrization(
    model.linear_layer_1, "weight", lora_parameterization.layer_parameterization(layer=model.linear_layer_1)
)

parametrize.register_parametrization(
    model.linear_layer_2, "weight", lora_parameterization.layer_parameterization(layer=model.linear_layer_2)
)

parametrize.register_parametrization(
    model.linear_layer_3, "weight", lora_parameterization.layer_parameterization(layer=model.linear_layer_3)
)

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA_Parameterization()
    )
  )
)

##### Let's visualize the Number of Parameters added by LoRA

In [11]:
lora_parameters_count = 0
non_lora_parameters_count = 0

for index, layer in enumerate([model.linear_layer_1, model.linear_layer_2, model.linear_layer_3]):
    lora_parameters_count += layer.parametrizations["weight"][0].LoRA_Matrix_A.nelement() + layer.parametrizations["weight"][0].LoRA_Matrix_B.nelement()
    non_lora_parameters_count += layer.weight.nelement() + layer.bias.nelement()

    print(f'Layer - {index+1}: W (weight matrix) : {layer.weight.shape} + B (bias): {layer.bias.shape} + LoRA A Matrix: {layer.parametrizations["weight"][0].LoRA_Matrix_A.shape}+ LoRA B Matrix: {layer.parametrizations["weight"][0].LoRA_Matrix_B.shape}')

Layer - 1: W (weight matrix) : torch.Size([1000, 784]) + B (bias): torch.Size([1000]) + LoRA A Matrix: torch.Size([1, 784])+ LoRA B Matrix: torch.Size([1000, 1])
Layer - 2: W (weight matrix) : torch.Size([2000, 1000]) + B (bias): torch.Size([2000]) + LoRA A Matrix: torch.Size([1, 1000])+ LoRA B Matrix: torch.Size([2000, 1])
Layer - 3: W (weight matrix) : torch.Size([10, 2000]) + B (bias): torch.Size([10]) + LoRA A Matrix: torch.Size([1, 2000])+ LoRA B Matrix: torch.Size([10, 1])


##### Validate whether the parameter counts were distrubed to the original trained model or not

In [12]:
assert non_lora_parameters_count == original_parameters_count
print(f'Original Parameters count : {original_parameters_count}')
print(f'Total Parameters after registering LoRA: Original Parameters Count + LoRA Parameters Count = {original_parameters_count + lora_parameters_count}')
print(f'Parameters introduced by LoRA : {lora_parameters_count}')

parameters_incremented_by_lora = (lora_parameters_count / non_lora_parameters_count) * 100
print(f'Parameters incremented by LoRA : {parameters_incremented_by_lora:.3f}%')


Original Parameters count : 2807010
Total Parameters after registering LoRA: Original Parameters Count + LoRA Parameters Count = 2813804
Parameters introduced by LoRA : 6794
Parameters incremented by LoRA : 0.242%


#### Now, lets work on improving the accuracy for the digit `2`. To do that, lets freeze all the Parameters in the original model and finetune only the parameters introduced by LoRA. The finetuning is done only on the digit `2`

In [13]:
model_utils.freeze_model_parameters(model=model, unfreeze_layer="LoRA")

Freezing Non LoRA layer parameter linear_layer_1.bias
Freezing Non LoRA layer parameter linear_layer_1.parametrizations.weight.original
Unfreeze Layer found :  linear_layer_1.parametrizations.weight.0.LoRA_Matrix_A
Unfreeze Layer found :  linear_layer_1.parametrizations.weight.0.LoRA_Matrix_B
Freezing Non LoRA layer parameter linear_layer_2.bias
Freezing Non LoRA layer parameter linear_layer_2.parametrizations.weight.original
Unfreeze Layer found :  linear_layer_2.parametrizations.weight.0.LoRA_Matrix_A
Unfreeze Layer found :  linear_layer_2.parametrizations.weight.0.LoRA_Matrix_B
Freezing Non LoRA layer parameter linear_layer_3.bias
Freezing Non LoRA layer parameter linear_layer_3.parametrizations.weight.original
Unfreeze Layer found :  linear_layer_3.parametrizations.weight.0.LoRA_Matrix_A
Unfreeze Layer found :  linear_layer_3.parametrizations.weight.0.LoRA_Matrix_B


In [14]:
train_dataset = data_loader.load_mnist_dataset(train=True)
indices_digit_2 = train_dataset.targets == 2
train_dataset.data = train_dataset.data[indices_digit_2]
train_dataset.targets = train_dataset.targets[indices_digit_2]
train_data_loader = data_loader.data_loader(dataset=train_dataset)

#### Finetuning the model with the updated parameters

In [15]:
fine_tune(epochs=1, model=model, total_iterations_limit=100, train_loader=train_data_loader)

Epoch 1:  99%|█████████▉| 99/100 [00:01<00:00, 62.66it/s, loss=2.32] 


In [16]:
assert torch.all(model.linear_layer_1.parametrizations.weight.original == weights_dict['linear_layer_1.weight'])
assert torch.all(model.linear_layer_2.parametrizations.weight.original == weights_dict['linear_layer_2.weight'])
assert torch.all(model.linear_layer_3.parametrizations.weight.original == weights_dict['linear_layer_3.weight'])

lora_parameterization.enable_or_disable_lora(enabled=True, model_layers=[model.linear_layer_1, model.linear_layer_2, model.linear_layer_3])
assert torch.equal(model.linear_layer_1.weight, model.linear_layer_1.parametrizations.weight.original + 
                   (model.linear_layer_1.parametrizations.weight[0].LoRA_Matrix_B 
                    @ model.linear_layer_1.parametrizations.weight[0].LoRA_Matrix_A) 
                   * model.linear_layer_1.parametrizations.weight[0].scale)

lora_parameterization.enable_or_disable_lora(enabled=False, model_layers=[model.linear_layer_1, model.linear_layer_2, model.linear_layer_3])
assert torch.equal(model.linear_layer_1.weight, weights_dict['linear_layer_1.weight'])

In [17]:
model

MNIST_MODEL(
  (linear_layer_1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRA_Parameterization()
      )
    )
  )
  (linear_layer_2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRA_Parameterization()
      )
    )
  )
  (linear_layer_3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRA_Parameterization()
      )
    )
  )
  (relu): ReLU()
)

In [18]:
lora_parameterization.enable_or_disable_lora(enabled=True, 
                                             model_layers=[model.linear_layer_1, model.linear_layer_2, model.linear_layer_3])
test.test(model=model, data_loader=test_data_loader)

Testing: 100%|██████████| 1000/1000 [00:11<00:00, 86.23it/s]

Accuracy : 0.21
Wrong Counts for the digit 0 : 1
Wrong Counts for the digit 1 : 13
Wrong Counts for the digit 2 : 1032
Wrong Counts for the digit 3 : 1010
Wrong Counts for the digit 4 : 982
Wrong Counts for the digit 5 : 892
Wrong Counts for the digit 6 : 958
Wrong Counts for the digit 7 : 1028
Wrong Counts for the digit 8 : 974
Wrong Counts for the digit 9 : 1009





In [19]:
lora_parameterization.enable_or_disable_lora(enabled=False, 
                                             model_layers=[model.linear_layer_1, model.linear_layer_2, model.linear_layer_3])
test.test(model=model, data_loader=test_data_loader)

Testing: 100%|██████████| 1000/1000 [00:10<00:00, 95.50it/s]

Accuracy : 0.21
Wrong Counts for the digit 0 : 1
Wrong Counts for the digit 1 : 13
Wrong Counts for the digit 2 : 1032
Wrong Counts for the digit 3 : 1010
Wrong Counts for the digit 4 : 982
Wrong Counts for the digit 5 : 892
Wrong Counts for the digit 6 : 958
Wrong Counts for the digit 7 : 1028
Wrong Counts for the digit 8 : 974
Wrong Counts for the digit 9 : 1009



