In [165]:
import torch.nn as nn
import torch 
import os

dir = ...
os.chdir(dir)

%reload_ext autoreload
%autoreload 2
from model import DigitClassification, LoRaParametrisation, LoRa_model, enable_disable_lora
from datasets import Datasets
from trainer import Trainer
from paths import *
import os
import copy

import warnings



In [166]:
save_dir = create_results_dir(dir_root = dir)
sub_dirs = ["checkpoints", "results"]
dic_pth = {}
for sub_ in sub_dirs:
    s_path = os.path.join(save_dir, sub_)
    dic_pth[sub_] = s_path
    os.makedirs(os.path.join(save_dir, sub_), exist_ok = True)
    



In [167]:
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


### **Load training as testing datasets**

In [168]:
warnings.filterwarnings("ignore")
dataset = Datasets(dataset_name = "MNIST")
# train set
train_loader = dataset.get_dataloader(batch_size = 16)
# test set
test_loader = dataset.get_dataloader(batch_size = 16, train_status = False)

In [169]:
iterator = iter(train_loader)
images, labels = next(iterator)

images.shape

torch.Size([16, 1, 28, 28])

### **Load Model**

In [170]:
im_size = 28
classes = 10
model = DigitClassification(input_size = im_size*im_size, output_size = classes)

### **Trainer module**

In [171]:
lr= 1e-3
save_ckeckpoint = 100 
path_checkpoint = dic_pth["checkpoints"]
trainer = Trainer(model, train_loader, test_loader, lr = lr, device = device, path_checkpoint=dic_pth["checkpoints"], save_ckeckpoint = save_ckeckpoint)

### **Training**

In [None]:
trainer.training_loop(epochs = 1)

###  **Copy weights**

In [173]:
original_weights = {}

for name, param in model.named_parameters():
    original_weights[name] = param.data.clone().detach()
    

# Make a deep copy of the model
L_model = copy.deepcopy(model)

In [174]:
# count parameters in the model
def count_parameters(model):
    tparas = 0
    for layer in model.layers:
        tparas += layer.weight.nelement() + layer.bias.nelement()
    return tparas

print(f"Total trainable parameters of the model: {count_parameters(L_model)}") #type: ignore

Total trainable parameters of the model: 2807010


In [175]:
for layer in L_model.layers:
    print(layer.weight.nelement(), layer.bias.nelement())

784000 1000
2000000 2000
20000 10


### **LoRa model**

We first load the specific dataset on which we would like to fine tuning our model. By setting `exclude_tgs =9` in `Dataset`, we are specifying that the fine tuning is on the digit `9`.

In [176]:
warnings.filterwarnings("ignore")
dataset_9 = Datasets(dataset_name = "MNIST", exclude_tgs=9)
# train set
train_loader_9 = dataset_9.get_dataloader(batch_size = 16)


Here we construct the LoRa model. This consists of freezing all the parameters of the orignial model while creating for each layer, 2 low rank matrices whose parameters will be learned shortly.

In [177]:
L_model = LoRa_model(L_model, rank = 2, device = device)

Number of Layers frozen: 6
Layer 1 -- Original Weights: 784000  Bias: 1000 ---- Lora_A weight: 2000 + Lora_B weight: 1568
Layer 2 -- Original Weights: 2000000  Bias: 2000 ---- Lora_A weight: 4000 + Lora_B weight: 2000
Layer 3 -- Original Weights: 20000  Bias: 10 ---- Lora_A weight: 20 + Lora_B weight: 4000
Total trainable parameters of the model: 2807010 (non-LoRa) vs 13588 (LoRa) Ratio: 0.48% of the original model


### **Trainer for LoRa model**

In [178]:
lr= 1e-3
save_ckeckpoint = 100 
path_checkpoint = dic_pth["checkpoints"]
trainer_LoRa = Trainer(L_model, train_loader_9, test_loader, lr = lr, device = device, path_checkpoint=dic_pth["checkpoints"], save_ckeckpoint = save_ckeckpoint)


Training of the model. 

In [None]:

trainer_LoRa.training_loop(epochs = 1)

In [187]:
trainer_LoRa.test_epoch()

(4.200552679777146, 0.5629)

In [185]:
for ii, w in enumerate(trainer_LoRa.wrong_counts):
    print(f"Misclassified {w} times for digit {ii}")

Misclassified 975 times for digit 0
Misclassified 10 times for digit 1
Misclassified 556 times for digit 2
Misclassified 454 times for digit 3
Misclassified 627 times for digit 4
Misclassified 451 times for digit 5
Misclassified 140 times for digit 6
Misclassified 598 times for digit 7
Misclassified 559 times for digit 8
Misclassified 1 times for digit 9


In [186]:
enable_disable_lora(L_model, enabled=True)

bias
parametrizations.weight.original
parametrizations.weight.0.lora_A
parametrizations.weight.0.lora_B
bias
parametrizations.weight.original
parametrizations.weight.0.lora_A
parametrizations.weight.0.lora_B
bias
parametrizations.weight.original
parametrizations.weight.0.lora_A
parametrizations.weight.0.lora_B


In [183]:
L_model

DigitClassification(
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (layers): ModuleList(
    (0): ParametrizedLinear(
      in_features=784, out_features=1000, bias=True
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): LoRaParametrisation()
        )
      )
    )
    (1): ParametrizedLinear(
      in_features=1000, out_features=2000, bias=True
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): LoRaParametrisation()
        )
      )
    )
    (2): ParametrizedLinear(
      in_features=2000, out_features=10, bias=True
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): LoRaParametrisation()
        )
      )
    )
  )
)

aa = 

In [140]:
model

DigitClassification(
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): Linear(in_features=1000, out_features=2000, bias=True)
    (2): Linear(in_features=2000, out_features=10, bias=True)
  )
)