In [1]:
from models.networks import DemoNet
import peft
import torch
from utils import get_dataset_torch, AverageMeter
from torchmetrics.classification import MulticlassAccuracy
from torchsummary import summary
from tqdm import tqdm
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader, valoader   = get_dataset_torch('FMNIST', 'data', 32)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:04<00:00, 5329890.27it/s]


Extracting data\FashionMNIST\raw\train-images-idx3-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 155700.97it/s]


Extracting data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:04<00:00, 933231.12it/s] 


Extracting data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<?, ?it/s]


Extracting data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw



In [3]:
lr = 1e-3
epochs = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ACCURACY = MulticlassAccuracy(num_classes=10).to(device)

In [4]:
def train(model, optimizer, criterion, trainloader, valoader, epochs):
    for epoch in range(epochs):
        model.train()
        data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')
        train_accuracy = AverageMeter()
        train_loss = AverageMeter()
        for _, data in data_loop_train:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            accuracy = ACCURACY(outputs, labels)

            train_accuracy.update(accuracy.item(), inputs.size(0))
            train_loss.update(loss.item(), inputs.size(0))
            data_loop_train.set_description(f'Epoch {epoch+1}/{epochs}')
            data_loop_train.set_postfix(loss=train_loss.avg, accuracy=train_accuracy.avg)

        with torch.no_grad():
            model.eval()
            data_loop_val = tqdm(enumerate(valoader), total=len(valoader), colour='green')
            val_accuracy = AverageMeter()
            val_loss = AverageMeter()
            for _, data in data_loop_val:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                accuracy = ACCURACY(outputs, labels)
                loss = criterion(outputs, labels)
                val_loss.update(loss.item(), inputs.size(0))
                val_accuracy.update(accuracy.item(), inputs.size(0))
                data_loop_val.set_description(f'Epoch {epoch+1}/{epochs}')
                data_loop_val.set_postfix(loss=val_loss.avg, accuracy=accuracy.item())

## **Train without lora**

In [5]:
model = DemoNet().to(device)
print(summary(model, (1, 28, 28)))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 100, 28, 28]           1,000
            Conv2d-2          [-1, 100, 14, 14]          90,100
            Conv2d-3              [-1, 1, 7, 7]             901
            Linear-4                   [-1, 10]             500
Total params: 92,501
Trainable params: 92,501
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.75
Params size (MB): 0.35
Estimated Total Size (MB): 1.10
----------------------------------------------------------------
None


In [6]:
train(model, optimizer, criterion, trainloader, valoader, epochs)

Epoch 1/5: 100%|[31m██████████[0m| 1688/1688 [00:09<00:00, 171.71it/s, accuracy=0.787, loss=0.555]
Epoch 1/5: 100%|[32m██████████[0m| 188/188 [00:00<00:00, 221.13it/s, accuracy=0.75, loss=0.395] 
Epoch 2/5: 100%|[31m██████████[0m| 1688/1688 [00:17<00:00, 97.77it/s, accuracy=0.86, loss=0.367]  
Epoch 2/5: 100%|[32m██████████[0m| 188/188 [00:01<00:00, 165.91it/s, accuracy=0.889, loss=0.355]
Epoch 3/5: 100%|[31m██████████[0m| 1688/1688 [00:11<00:00, 143.25it/s, accuracy=0.873, loss=0.324]
Epoch 3/5: 100%|[32m██████████[0m| 188/188 [00:00<00:00, 207.11it/s, accuracy=0.952, loss=0.374]
Epoch 4/5: 100%|[31m██████████[0m| 1688/1688 [00:11<00:00, 149.49it/s, accuracy=0.884, loss=0.299]
Epoch 4/5: 100%|[32m██████████[0m| 188/188 [00:00<00:00, 218.35it/s, accuracy=0.833, loss=0.339]
Epoch 5/5: 100%|[31m██████████[0m| 1688/1688 [00:27<00:00, 61.45it/s, accuracy=0.893, loss=0.279] 
Epoch 5/5: 100%|[32m██████████[0m| 188/188 [00:05<00:00, 35.38it/s, accuracy=1, loss=0.331]    


## **Train with lora**

In [7]:
[(n, type(m)) for n, m in DemoNet().named_modules()]

[('', models.networks.DemoNet),
 ('conv1', torch.nn.modules.conv.Conv2d),
 ('conv2', torch.nn.modules.conv.Conv2d),
 ('conv3', torch.nn.modules.conv.Conv2d),
 ('linear', torch.nn.modules.linear.Linear)]

In [8]:
config = peft.LoraConfig(
    r=1,
    target_modules=["conv1", "linear"],
)

In [9]:
model = DemoNet().to(device)
model_copy = copy.deepcopy(model)

peft_model = peft.get_peft_model(model, config)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
peft_model.print_trainable_parameters() # previous wights + A, B matrices

trainable params: 168 || all params: 92,669 || trainable%: 0.1813


In [14]:
print(summary(peft_model, (1, 28, 28)))
print(summary(model_copy, (1, 28, 28)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 100, 28, 28]           1,000
          Identity-2            [-1, 1, 28, 28]               0
            Conv2d-3            [-1, 1, 28, 28]               9
            Conv2d-4          [-1, 100, 28, 28]             100
            Conv2d-5          [-1, 100, 28, 28]           1,000
            Conv2d-6          [-1, 100, 14, 14]          90,100
            Conv2d-7              [-1, 1, 7, 7]             901
            Linear-8                   [-1, 10]             500
          Identity-9                   [-1, 49]               0
           Linear-10                    [-1, 1]              49
           Linear-11                   [-1, 10]              10
           Linear-12                   [-1, 10]             500
          DemoNet-13                   [-1, 10]               0
Total params: 94,169
Trainable params: 

In [10]:
train(peft_model, optimizer, criterion, trainloader, valoader, epochs)

Epoch 1/5: 100%|[31m██████████[0m| 1688/1688 [01:03<00:00, 26.66it/s, accuracy=0.294, loss=1.58]
Epoch 1/5: 100%|[32m██████████[0m| 188/188 [00:05<00:00, 33.35it/s, accuracy=0.167, loss=1.47]
Epoch 2/5: 100%|[31m██████████[0m| 1688/1688 [01:02<00:00, 26.84it/s, accuracy=0.383, loss=1.44]
Epoch 2/5: 100%|[32m██████████[0m| 188/188 [00:05<00:00, 35.66it/s, accuracy=0.556, loss=1.38]
Epoch 3/5: 100%|[31m██████████[0m| 1688/1688 [01:04<00:00, 26.27it/s, accuracy=0.44, loss=1.37] 
Epoch 3/5: 100%|[32m██████████[0m| 188/188 [00:05<00:00, 33.72it/s, accuracy=0.407, loss=1.33]
Epoch 4/5: 100%|[31m██████████[0m| 1688/1688 [01:02<00:00, 27.04it/s, accuracy=0.466, loss=1.32]
Epoch 4/5: 100%|[32m██████████[0m| 188/188 [00:05<00:00, 32.03it/s, accuracy=0.407, loss=1.3] 
Epoch 5/5: 100%|[31m██████████[0m| 1688/1688 [01:00<00:00, 27.76it/s, accuracy=0.481, loss=1.29]
Epoch 5/5: 100%|[32m██████████[0m| 188/188 [00:05<00:00, 35.25it/s, accuracy=0.378, loss=1.29]


## **Check params**

In [11]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" not in name:
        continue

    print(f"New parameter {name:<30} | {param.numel():>5} parameters | updated")

New parameter model.conv1.lora_A.default.weight |     9 parameters | updated
New parameter model.conv1.lora_B.default.weight |   100 parameters | updated
New parameter model.linear.lora_A.default.weight |    49 parameters | updated
New parameter model.linear.lora_B.default.weight |    10 parameters | updated


In [12]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" in name:
        continue

    print(f"Parameter {name:<30} | {param.numel():>5} parameters | not updated")

Parameter model.conv1.base_layer.weight  |   900 parameters | not updated
Parameter model.conv1.base_layer.bias    |   100 parameters | not updated
Parameter model.conv2.weight             | 90000 parameters | not updated
Parameter model.conv2.bias               |   100 parameters | not updated
Parameter model.conv3.weight             |   900 parameters | not updated
Parameter model.conv3.bias               |     1 parameters | not updated
Parameter model.linear.base_layer.weight |   490 parameters | not updated
Parameter model.linear.base_layer.bias   |    10 parameters | not updated
