In [1]:
from models.networks import DemoNet
import peft
import torch
from utils import get_dataset_torch, AverageMeter
from torchmetrics.classification import MulticlassAccuracy
from torchsummary import summary
from tqdm import tqdm
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader, valoader   = get_dataset_torch('FMNIST', 'data', 32)

In [3]:
lr = 1e-3
epochs = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ACCURACY = MulticlassAccuracy(num_classes=10).to(device)

In [4]:
def train(model, optimizer, criterion, trainloader, valoader, epochs):
    for epoch in range(epochs):
        model.train()
        data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')
        train_accuracy = AverageMeter()
        train_loss = AverageMeter()
        for _, data in data_loop_train:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            accuracy = ACCURACY(outputs, labels)

            train_accuracy.update(accuracy.item(), inputs.size(0))
            train_loss.update(loss.item(), inputs.size(0))
            data_loop_train.set_description(f'Epoch {epoch+1}/{epochs}')
            data_loop_train.set_postfix(loss=train_loss.avg, accuracy=train_accuracy.avg)

        with torch.no_grad():
            model.eval()
            data_loop_val = tqdm(enumerate(valoader), total=len(valoader), colour='green')
            val_accuracy = AverageMeter()
            val_loss = AverageMeter()
            for _, data in data_loop_val:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                accuracy = ACCURACY(outputs, labels)
                loss = criterion(outputs, labels)
                val_loss.update(loss.item(), inputs.size(0))
                val_accuracy.update(accuracy.item(), inputs.size(0))
                data_loop_val.set_description(f'Epoch {epoch+1}/{epochs}')
                data_loop_val.set_postfix(loss=val_loss.avg, accuracy=accuracy.item())

## **Train without lora**

In [5]:
model = DemoNet().to(device)
print(summary(model, (1, 28, 28)))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 100, 28, 28]           1,000
            Conv2d-2          [-1, 100, 14, 14]          90,100
            Conv2d-3              [-1, 1, 7, 7]             901
            Linear-4                   [-1, 10]             500
Total params: 92,501
Trainable params: 92,501
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.75
Params size (MB): 0.35
Estimated Total Size (MB): 1.10
----------------------------------------------------------------
None


In [6]:
train(model, optimizer, criterion, trainloader, valoader, epochs)

Epoch 1/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 80.42it/s, accuracy=0.781, loss=0.569]
Epoch 1/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 108.04it/s, accuracy=0.812, loss=0.397]
Epoch 2/5: 100%|[31m██████████[0m| 1688/1688 [00:19<00:00, 86.04it/s, accuracy=0.852, loss=0.382]
Epoch 2/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 102.02it/s, accuracy=0.398, loss=0.349]
Epoch 3/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 80.64it/s, accuracy=0.868, loss=0.337]
Epoch 3/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 105.98it/s, accuracy=0.952, loss=0.341]
Epoch 4/5: 100%|[31m██████████[0m| 1688/1688 [00:21<00:00, 78.27it/s, accuracy=0.88, loss=0.308] 
Epoch 4/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 108.36it/s, accuracy=0.714, loss=0.321]
Epoch 5/5: 100%|[31m██████████[0m| 1688/1688 [00:19<00:00, 85.37it/s, accuracy=0.888, loss=0.287]
Epoch 5/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 104.45it/s, accuracy=1, loss=0.31]     


## **Train with lora**

In [7]:
[(n, type(m)) for n, m in DemoNet().named_modules()]

[('', networks.DemoNet),
 ('conv1', torch.nn.modules.conv.Conv2d),
 ('conv2', torch.nn.modules.conv.Conv2d),
 ('conv3', torch.nn.modules.conv.Conv2d),
 ('linear', torch.nn.modules.linear.Linear)]

In [8]:
config = peft.LoraConfig(
    r=1,
    target_modules=["conv1", "linear"],
)

In [9]:
model = DemoNet().to(device)
model_copy = copy.deepcopy(model)
peft_model = peft.get_peft_model(model, config)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
peft_model.print_trainable_parameters() # previous wights + A, B matrices

trainable params: 168 || all params: 92,669 || trainable%: 0.1813


In [10]:
train(peft_model, optimizer, criterion, trainloader, valoader, epochs)

Epoch 1/5: 100%|[31m██████████[0m| 1688/1688 [00:22<00:00, 76.60it/s, accuracy=0.297, loss=1.72]
Epoch 1/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 99.35it/s, accuracy=0.308, loss=1.4]  
Epoch 2/5: 100%|[31m██████████[0m| 1688/1688 [00:21<00:00, 76.79it/s, accuracy=0.516, loss=1.28]
Epoch 2/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 97.70it/s, accuracy=0.537, loss=1.23]
Epoch 3/5: 100%|[31m██████████[0m| 1688/1688 [00:22<00:00, 75.44it/s, accuracy=0.546, loss=1.19]
Epoch 3/5: 100%|[32m██████████[0m| 188/188 [00:02<00:00, 91.26it/s, accuracy=0.317, loss=1.18]
Epoch 4/5: 100%|[31m██████████[0m| 1688/1688 [00:23<00:00, 72.50it/s, accuracy=0.562, loss=1.14]
Epoch 4/5: 100%|[32m██████████[0m| 188/188 [00:02<00:00, 91.47it/s, accuracy=0.567, loss=1.14] 
Epoch 5/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 80.64it/s, accuracy=0.585, loss=1.1]
Epoch 5/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 97.71it/s, accuracy=0.444, loss=1.11]


## **Check params**

In [11]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" not in name:
        continue

    print(f"New parameter {name:<30} | {param.numel():>5} parameters | updated")

New parameter model.conv1.lora_A.default.weight |     9 parameters | updated
New parameter model.conv1.lora_B.default.weight |   100 parameters | updated
New parameter model.linear.lora_A.default.weight |    49 parameters | updated
New parameter model.linear.lora_B.default.weight |    10 parameters | updated


In [12]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" in name:
        continue

    print(f"Parameter {name:<30} | {param.numel():>5} parameters | not updated")

Parameter model.conv1.base_layer.weight  |   900 parameters | not updated
Parameter model.conv1.base_layer.bias    |   100 parameters | not updated
Parameter model.conv2.weight             | 90000 parameters | not updated
Parameter model.conv2.bias               |   100 parameters | not updated
Parameter model.conv3.weight             |   900 parameters | not updated
Parameter model.conv3.bias               |     1 parameters | not updated
Parameter model.linear.base_layer.weight |   490 parameters | not updated
Parameter model.linear.base_layer.bias   |    10 parameters | not updated
