# 1. Import packages

In [None]:
!pip3 install loralib

Collecting loralib
  Downloading loralib-0.1.2-py3-none-any.whl.metadata (15 kB)
Downloading loralib-0.1.2-py3-none-any.whl (10 kB)
Installing collected packages: loralib
Successfully installed loralib-0.1.2


In [None]:
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Data Loader

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

BATCH_SIZE = 256

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29731767.09it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# 3. LoRA

In [None]:
class LoRA_layer():
  def __init__ (
      self,
      r: int,
      lora_alpha: int,
      lora_dropout: float,
      merge_weights: bool,
  ):
    self.r = r
    self.lora_alpha = lora_alpha
    if lora_dropout > 0.:
      self.lora_dropout = nn.Dropout(p=lora_dropout)
    else:
      self.lora_dropout = lambda x: x
    self.merged = False
    self.merge_weights = merge_weights

In [None]:
class LoRA_Linear(nn.Linear, LoRA_layer):
  def __init__ (
      self,
      in_features: int,
      out_features: int,
      r: int = 0,
      lora_alpha: int = 1,
      lora_dropout: float = 0.,
      merge_weights: bool = True,
      **kwargs
  ):
    nn.Linear.__init__(self, in_features, out_features, **kwargs)
    LoRA_layer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)

    if r > 0:
      self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
      self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
      self.scaling = self.lora_alpha / self.r
      self.weight.requires_grad = False
    self.reset_parameters()

  def reset_parameters(self):
    nn.Linear.reset_parameters(self)
    if hasattr(self, 'lora_A'):
      nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
      nn.init.zeros_(self.lora_B)

  def train(self, mode: bool = True):
    nn.Linear.train(self, mode)

    # Training mode
    if mode:
      if self.merge_weights and self.merged:
        if self.r > 0:
          self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
        self.merged = False

    # Evaluation mode
    else:
      if self.merge_weights and not self.merged:
        if self.r > 0:
          self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
        self.merged = True

  def forward(self, x: torch.Tensor):
    # Evaluation mode
    if self.r > 0 and not self.merged:
      result = F.linear(x, self.weight, bias=self.bias)
      result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1)) @ self.lora_B.transpose(0, 1) * self.scaling
      return result
    else:
       return F.linear(x, self.weight, bias=self.bias)


# 4. Modeling

In [None]:
def trainable_params(model):
  total_params = 0
  for param in model.parameters():
    if param.requires_grad:
      total_params += param.numel()

  print(f"Trainable params: {total_params/1e6:.2f} M")

## 4.1. Load pre-trained VGG16 ImageNet1k

In [None]:
from torchvision.models import vgg16, VGG16_Weights

vgg16_model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

vgg16_classifier_ckpt = vgg16_model.classifier.state_dict()

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 107MB/s]


In [None]:
trainable_params(vgg16_model)

Trainable params: 138.36 M


## 4.2. Define classifier with LoRA

In [None]:
lora_classifier = nn.Sequential(
    LoRA_Linear(
        in_features=512*7*7,
        out_features=4096,
        bias=True,
        r=16),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5, inplace=False),
    LoRA_Linear(
        in_features=4096,
        out_features=4096,
        bias=True,
        r=16),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5, inplace=False),
    LoRA_Linear(
        in_features=4096,
        out_features=1000,
        bias=True,
        r=16),
)

## 4.3. Load pre-trained weights into LoRA classifier

In [None]:
lora_classifier.load_state_dict(vgg16_classifier_ckpt, strict=False)

_IncompatibleKeys(missing_keys=['0.lora_A', '0.lora_B', '3.lora_A', '3.lora_B', '6.lora_A', '6.lora_B'], unexpected_keys=[])

## 4.4. Add a layer with 10 output feature for CIFAR10 dataset classification

In [None]:
new_classifier = nn.Sequential(
    *lora_classifier,
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5, inplace=False),
    LoRA_Linear(in_features=1000, out_features=10, bias=True)
)

In [None]:
trainable_params(new_classifier)

Trainable params: 0.70 M


## 4.5. Wrap pre-trained VGG16 features with LoRA classifier

In [None]:
class CLS_model(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.features = vgg16_model.features.eval()
    for param in self.features.parameters():
      param.requires_grad = False
    self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
    self.classifier = new_classifier

  def forward(self, x):
    x = self.features(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x

In [None]:
model = CLS_model().to(device)

# 5. Training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [None]:
model.train()

start = time.time()
for epoch in range(10):
  running_loss = 0.0
  for i, (inputs, labels) in enumerate(trainloader, 0):
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad()

    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  print(f"Epoch [{epoch + 1}/{10}], Average Loss: {running_loss / len(trainloader): .4f}, GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB, GPU Memory Reserved: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB")


print(f"Training time: {time.time() - start:.2f}s")


Epoch [1/10], Average Loss:  1.5352, GPU Memory Allocated: 561.50 MB, GPU Memory Reserved: 1272.00 MB
Epoch [2/10], Average Loss:  1.2858, GPU Memory Allocated: 561.50 MB, GPU Memory Reserved: 1272.00 MB
Epoch [3/10], Average Loss:  1.2096, GPU Memory Allocated: 558.74 MB, GPU Memory Reserved: 1272.00 MB
Epoch [4/10], Average Loss:  1.1554, GPU Memory Allocated: 558.49 MB, GPU Memory Reserved: 1272.00 MB
Epoch [5/10], Average Loss:  1.1154, GPU Memory Allocated: 558.75 MB, GPU Memory Reserved: 1272.00 MB
Epoch [6/10], Average Loss:  1.0886, GPU Memory Allocated: 561.50 MB, GPU Memory Reserved: 1272.00 MB
Epoch [7/10], Average Loss:  1.0636, GPU Memory Allocated: 561.50 MB, GPU Memory Reserved: 1272.00 MB
Epoch [8/10], Average Loss:  1.0433, GPU Memory Allocated: 558.74 MB, GPU Memory Reserved: 1272.00 MB
Epoch [9/10], Average Loss:  1.0229, GPU Memory Allocated: 558.49 MB, GPU Memory Reserved: 1272.00 MB
Epoch [10/10], Average Loss:  1.0019, GPU Memory Allocated: 561.50 MB, GPU Memory 

# 6. Evaluate

In [None]:
correct = 0
total = 0
model.eval()
with torch.no_grad():
  for images, labels in testloader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Accuracy of the model on 10000 test images: {100 * correct / total:.2f}%")

Accuracy of the model on 10000 test images: 64.92%
