In [1]:
from networks import DemoNet
import peft
import torch
from utils import get_dataset_torch, AverageMeter
from torchmetrics.classification import MulticlassAccuracy
from torchsummary import summary
from tqdm import tqdm
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader, valoader   = get_dataset_torch('FMNIST', 'data', 32)

In [None]:
lr = 1e-3
epochs = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ACCURACY = MulticlassAccuracy(num_classes=10).to(device)

In [None]:
def train(model, optimizer, criterion, trainloader, valoader, epochs):
    for epoch in range(epochs):
        model.train()
        data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')
        train_accuracy = AverageMeter()
        train_loss = AverageMeter()
        for _, data in data_loop_train:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            accuracy = ACCURACY(outputs, labels)

            train_accuracy.update(accuracy.item(), inputs.size(0))
            train_loss.update(loss.item(), inputs.size(0))
            data_loop_train.set_description(f'Epoch {epoch+1}/{epochs}')
            data_loop_train.set_postfix(loss=train_loss.avg, accuracy=train_accuracy.avg)

        with torch.no_grad():
            model.eval()
            data_loop_val = tqdm(enumerate(valoader), total=len(valoader), colour='green')
            val_accuracy = AverageMeter()
            val_loss = AverageMeter()
            for _, data in data_loop_val:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                accuracy = ACCURACY(outputs, labels)
                loss = criterion(outputs, labels)
                val_loss.update(loss.item(), inputs.size(0))
                val_accuracy.update(accuracy.item(), inputs.size(0))
                data_loop_val.set_description(f'Epoch {epoch+1}/{epochs}')
                data_loop_val.set_postfix(loss=val_loss.avg, accuracy=accuracy.item())

## **Train without lora**

In [None]:
model = DemoNet().to(device)
print(summary(model, (1, 28, 28)))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
train(model, optimizer, criterion, trainloader, valoader, epochs)

## **Train with lora**

In [None]:
[(n, type(m)) for n, m in DemoNet().named_modules()]

In [None]:
config = peft.LoraConfig(
    r=1,
    target_modules=["conv1", "linear"],
)

In [None]:
model = DemoNet().to(device)
model_copy = copy.deepcopy(model)
peft_model = peft.get_peft_model(model, config)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
peft_model.print_trainable_parameters()

In [None]:
train(peft_model, optimizer, criterion, trainloader, valoader, epochs)

## **Check params**

In [None]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" not in name:
        continue

    print(f"New parameter {name:<30} | {param.numel():>5} parameters | updated")

In [None]:
for name, param in peft_model.base_model.named_parameters():
    if "lora" in name:
        continue

    print(f"Parameter {name:<30} | {param.numel():>5} parameters | not updated")