In [2]:
import torch
import torch.nn as nn
from torchsummary import summary
import torch.nn.utils.parametrize as parametrize
from tqdm import tqdm

In [3]:
import data_loader
from mnist_model import MNIST_MODEL
import config
import train
import test
import model_utils
from lora_model import LoRA_Parameterization

In [4]:
_ = torch.manual_seed(42)

#### Data Loader

In [5]:
train_dataset = data_loader.load_mnist_dataset(train=True)
train_data_loader = data_loader.data_loader(dataset=train_dataset)

test_dataset = data_loader.load_mnist_dataset(train=False)
test_data_loader = data_loader.data_loader(dataset=test_dataset)

#### Model Initialization

In [6]:
model = MNIST_MODEL()
summary(model, input_size=(1,28,28))
model = model.to(config.DEVICE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1000]         785,000
              ReLU-2                 [-1, 1000]               0
            Linear-3                 [-1, 2000]       2,002,000
              ReLU-4                 [-1, 2000]               0
            Linear-5                   [-1, 10]          20,010
              ReLU-6                   [-1, 10]               0
Total params: 2,807,010
Trainable params: 2,807,010
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 10.71
Estimated Total Size (MB): 10.76
----------------------------------------------------------------


#### Model Training

In [7]:
train.train(model=model, data_loader=train_data_loader, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:48<00:00, 124.20it/s, loss=2.07]


##### Clone Model Weights

In [8]:
weights_dict = model_utils.extract_weights(model=model)

#### Model Testing

In [9]:
test.test(model=model, data_loader=test_data_loader)

Testing: 100%|██████████| 1000/1000 [00:08<00:00, 112.49it/s]

Accuracy : 0.21
Wrong Counts for the digit 0 : 1
Wrong Counts for the digit 1 : 13
Wrong Counts for the digit 2 : 1032
Wrong Counts for the digit 3 : 1010
Wrong Counts for the digit 4 : 982
Wrong Counts for the digit 5 : 892
Wrong Counts for the digit 6 : 958
Wrong Counts for the digit 7 : 1028
Wrong Counts for the digit 8 : 974
Wrong Counts for the digit 9 : 1009





#### The Number of parameters present in the model, before loading the LoRA

In [10]:
original_parameters_count = model_utils.count_total_parameters(model_layers=[model.linear_layer_1, model.linear_layer_2, model.linear_layer_3])
print(f'Total number of parameters: {original_parameters_count:,}')

Layer 1: W (Matrix): torch.Size([1000, 784]) + Bias : torch.Size([1000])
Layer 2: W (Matrix): torch.Size([2000, 1000]) + Bias : torch.Size([2000])
Layer 3: W (Matrix): torch.Size([10, 2000]) + Bias : torch.Size([10])
Total number of parameters: 2,807,010


#### Linear Layer Parameterization

In [11]:
def linear_layer_parameterization(layer, device):
    d, k = layer.weight.shape
    return LoRA_Parameterization(d=d, k=k, rank=config.RANK, alpha=config.ALPHA)

##### Registering the Parameterization

In [12]:
parametrize.register_parametrization(
    model.linear_layer_1, "weight", linear_layer_parameterization(model.linear_layer_1, device=config.DEVICE)
)
parametrize.register_parametrization(
    model.linear_layer_2, "weight", linear_layer_parameterization(model.linear_layer_2, device=config.DEVICE)
)
parametrize.register_parametrization(
    model.linear_layer_3, "weight", linear_layer_parameterization(model.linear_layer_3, device=config.DEVICE)
)

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA_Parameterization()
    )
  )
)

In [13]:
def enable_or_disable_lora(enabled=True):
    for layer in [model.linear_layer_1, model.linear_layer_2, model.linear_layer_3]:
        layer.parametrizations["weight"][0].enabled = enabled

#### Freeze Model Parameters

In [14]:
for name, param in model.named_parameters():
    if 'LoRA' not in name:
        print(f'Freezing the Non-LoRA parameter {name}')
        param.requires_grad = False

Freezing the Non-LoRA parameter linear_layer_1.bias
Freezing the Non-LoRA parameter linear_layer_1.parametrizations.weight.original
Freezing the Non-LoRA parameter linear_layer_2.bias
Freezing the Non-LoRA parameter linear_layer_2.parametrizations.weight.original
Freezing the Non-LoRA parameter linear_layer_3.bias
Freezing the Non-LoRA parameter linear_layer_3.parametrizations.weight.original


#### Prepare Domain Specific dataset for Fine-tuning the Model

In [15]:
version_2_dataset = data_loader.load_mnist_dataset(train=True)
dataset_with_index_2 = version_2_dataset.targets == 9
version_2_dataset.data = version_2_dataset.data[dataset_with_index_2]
version_2_dataset.targets = version_2_dataset.targets[dataset_with_index_2]
version_2_train_loader = data_loader.data_loader(dataset=version_2_dataset)

#### Model Fine-tuning

In [16]:
def fine_tune(train_loader, model, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(config.DEVICE)
            y = y.to(config.DEVICE)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

fine_tune(train_loader=version_2_train_loader, model=model, epochs=1, total_iterations_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 100.53it/s, loss=2.31]


#### Test the domain specific fine-tuned model

In [17]:
enable_or_disable_lora(enabled=True)
test.test(data_loader=test_data_loader, model=model)

Testing: 100%|██████████| 1000/1000 [00:11<00:00, 86.79it/s]

Accuracy : 0.21
Wrong Counts for the digit 0 : 1
Wrong Counts for the digit 1 : 13
Wrong Counts for the digit 2 : 1032
Wrong Counts for the digit 3 : 1010
Wrong Counts for the digit 4 : 982
Wrong Counts for the digit 5 : 892
Wrong Counts for the digit 6 : 958
Wrong Counts for the digit 7 : 1028
Wrong Counts for the digit 8 : 974
Wrong Counts for the digit 9 : 1009





In [None]:
model