In [11]:
from networks import DemoNet
import peft
import torch
from utils import get_dataset_torch, AverageMeter
from torchmetrics.classification import MulticlassAccuracy
from torchsummary import summary
from tqdm import tqdm
import copy

In [2]:
channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader, valoader   = get_dataset_torch('MNIST', 'data', 32)

In [3]:
lr = 1e-3
epochs = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ACCURACY = MulticlassAccuracy(num_classes=10).to(device)

In [4]:
def train(model, optimizer, criterion, trainloader, valoader, epochs):
    for epoch in range(epochs):
        model.train()
        data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')
        train_accuracy = AverageMeter()
        train_loss = AverageMeter()
        for _, data in data_loop_train:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            accuracy = ACCURACY(outputs, labels)

            train_accuracy.update(accuracy.item(), inputs.size(0))
            train_loss.update(loss.item(), inputs.size(0))
            data_loop_train.set_description(f'Epoch {epoch+1}/{epochs}')
            data_loop_train.set_postfix(loss=train_loss.avg, accuracy=train_accuracy.avg)

        with torch.no_grad():
            model.eval()
            data_loop_val = tqdm(enumerate(valoader), total=len(valoader), colour='green')
            val_accuracy = AverageMeter()
            val_loss = AverageMeter()
            for _, data in data_loop_val:
                inputs, labels = data
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = model(inputs)
                accuracy = ACCURACY(outputs, labels)
                loss = criterion(outputs, labels)
                val_loss.update(loss.item(), inputs.size(0))
                val_accuracy.update(accuracy.item(), inputs.size(0))
                data_loop_val.set_description(f'Epoch {epoch+1}/{epochs}')
                data_loop_val.set_postfix(loss=val_loss.avg, accuracy=accuracy.item())

    torch.save(model.state_dict(), 'model_without_lora.pth')

## **Train without lora**

In [13]:
model = DemoNet().to(device)
print(summary(model, (1, 28, 28)))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 2, 28, 28]              20
            Linear-2                   [-1, 10]           3,930
Total params: 3,950
Trainable params: 3,950
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.02
Estimated Total Size (MB): 0.03
----------------------------------------------------------------
None


In [6]:
train(model, optimizer, criterion, trainloader, valoader, epochs)

Epoch 1/5: 100%|[31m██████████[0m| 1688/1688 [00:19<00:00, 88.43it/s, accuracy=0.867, loss=0.454]
Epoch 1/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 107.74it/s, accuracy=0.738, loss=0.264]
Epoch 2/5: 100%|[31m██████████[0m| 1688/1688 [00:19<00:00, 85.84it/s, accuracy=0.924, loss=0.228]
Epoch 2/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 105.22it/s, accuracy=0.817, loss=0.213]
Epoch 3/5: 100%|[31m██████████[0m| 1688/1688 [00:19<00:00, 85.48it/s, accuracy=0.935, loss=0.196]
Epoch 3/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 103.21it/s, accuracy=0.9, loss=0.195]  
Epoch 4/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 83.51it/s, accuracy=0.941, loss=0.179]
Epoch 4/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 100.71it/s, accuracy=1, loss=0.186]    
Epoch 5/5: 100%|[31m██████████[0m| 1688/1688 [00:18<00:00, 90.19it/s, accuracy=0.942, loss=0.171]
Epoch 5/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 104.01it/s, accuracy=0.812, loss=0.172]


## **Train with lora**

In [8]:
[(n, type(m)) for n, m in DemoNet().named_modules()]

[('', networks.DemoNet),
 ('conv1', torch.nn.modules.conv.Conv2d),
 ('linear', torch.nn.modules.linear.Linear)]

In [9]:
config = peft.LoraConfig(
    r=8,
    target_modules=["conv1", "linear"],
)

In [15]:
model = DemoNet().to(device)
model_copy = copy.deepcopy(model)
peft_model = peft.get_peft_model(model, config)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
peft_model.print_trainable_parameters()

trainable params: 3,304 || all params: 7,254 || trainable%: 45.5473


In [16]:
train(peft_model, optimizer, criterion, trainloader, valoader, epochs)

Epoch 1/5: 100%|[31m██████████[0m| 1688/1688 [00:19<00:00, 85.57it/s, accuracy=0.807, loss=0.58] 
Epoch 1/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 111.40it/s, accuracy=1, loss=0.367]    
Epoch 2/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 81.79it/s, accuracy=0.89, loss=0.336] 
Epoch 2/5: 100%|[32m██████████[0m| 188/188 [00:02<00:00, 89.94it/s, accuracy=0.889, loss=0.32] 
Epoch 3/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 81.15it/s, accuracy=0.901, loss=0.312]
Epoch 3/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 97.93it/s, accuracy=0.906, loss=0.319] 
Epoch 4/5: 100%|[31m██████████[0m| 1688/1688 [00:21<00:00, 79.66it/s, accuracy=0.901, loss=0.303]
Epoch 4/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 100.56it/s, accuracy=1, loss=0.309]    
Epoch 5/5: 100%|[31m██████████[0m| 1688/1688 [00:20<00:00, 84.05it/s, accuracy=0.905, loss=0.296]
Epoch 5/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 94.77it/s, accuracy=1, loss=0.302]    


## **Check params**

In [24]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" not in name:
        continue

    print(f"New parameter {name:<30} | {param.numel():>5} parameters | updated")

New parameter model.conv1.lora_A.default.weight |    72 parameters | updated
New parameter model.conv1.lora_B.default.weight |    16 parameters | updated
New parameter model.linear.lora_A.default.weight |  3136 parameters | updated
New parameter model.linear.lora_B.default.weight |    80 parameters | updated


In [25]:
params_before = dict(model_copy.named_parameters())
for name, param in peft_model.base_model.named_parameters():
    if "lora" in name:
        continue

    print(f"Parameter {name:<30} | {param.numel():>5} parameters | not updated")

Parameter model.conv1.base_layer.weight  |    18 parameters | not updated
Parameter model.conv1.base_layer.bias    |     2 parameters | not updated
Parameter model.linear.base_layer.weight |  3920 parameters | not updated
Parameter model.linear.base_layer.bias   |    10 parameters | not updated
