In [6]:
import timm 
import torch
from lora_adapters import LoraConv2d, apply_adapter, mark_only_lora_as_trainable, lora_state_dict, undo_lora
from torch.optim import AdamW

In [7]:
model = timm.create_model('resnet50', pretrained=True).to('cuda')

In [8]:
optimizer = AdamW((param for param in model.parameters() if param.requires_grad), lr=1e-3)

In [9]:
inputs = torch.randn(1, 3, 224, 224).to('cuda')
targets = torch.randint(0, 1000, (1,)).to('cuda')

In [10]:
for _ in range(10):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = torch.nn.functional.cross_entropy(outputs, targets)
    loss.backward()
    optimizer.step()
    print(loss.item())

7.085061073303223
6.210545063018799
3.2762057781219482
0.6974539756774902
0.05648316815495491
0.01910405606031418
0.00746177276596427
0.0039986190386116505
0.002394310897216201
0.0014627005439251661


In [11]:
model_parameters = sum(p.numel() for p in model.parameters())    
model_grads = sum(p.grad.numel() for p in model.parameters() if p.requires_grad)    
optimizer_states = sum([sum(elem.numel() for elem in  p.values()) for p in optimizer.state.values()])

In [12]:
model = apply_adapter(model, LoraConv2d, rank=16)
# model = mark_only_lora_as_trainable(model, bias='lora_only')
optimizer = AdamW((param for param in model.parameters() if param.requires_grad), lr=1e-3)

In [13]:
for _ in range(10):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = torch.nn.functional.cross_entropy(outputs, targets)
    loss.backward()
    optimizer.step()
    print(loss.item())

0.000910344475414604
0.0002079985715681687
7.962863310240209e-05
3.8980677345534787e-05
2.1934269170742482e-05
1.3589766240329482e-05
9.417489309271332e-06
6.556489552167477e-06
4.887569048150908e-06
3.6954811548639555e-06


In [14]:
lora_model_parameters = sum(p.numel() for p in model.parameters())    
lora_model_grads = sum(p.grad.numel() for p in model.parameters() if p.requires_grad)    
lora_optimizer_states = sum([sum(elem.numel() for elem in  p.values()) for p in optimizer.state.values()])

In [15]:
print(f"Model parameters: {model_parameters} -> {lora_model_parameters} ratio: {lora_model_parameters/model_parameters:.2f}")
print(f"Model grads: {model_grads} -> {lora_model_grads} ratio: {lora_model_grads/model_grads:.2f}")
print(f"Optimizer states: {optimizer_states} -> {lora_optimizer_states} ratio: {lora_optimizer_states/optimizer_states:.2f}")

Model parameters: 25557032 -> 26828120 ratio: 1.05
Model grads: 25557032 -> 3373208 ratio: 0.13
Optimizer states: 51114225 -> 6746630 ratio: 0.13


In [16]:
output = model(inputs)

In [17]:
model = undo_lora(model)

In [18]:
normal_outputs = model(inputs)

In [19]:
torch.equal(output, normal_outputs)

True

In [20]:
model = apply_adapter(model, LoraConv2d, rank=16)

In [21]:
state_dict = model.state_dict()

In [22]:
state_dict = {k: v for k, v in state_dict.items() if 'lora_' in k}

In [23]:
state_dict.keys()

dict_keys(['conv1.lora_A', 'conv1.lora_B', 'layer1.0.conv1.lora_A', 'layer1.0.conv1.lora_B', 'layer1.0.conv2.lora_A', 'layer1.0.conv2.lora_B', 'layer1.0.conv3.lora_A', 'layer1.0.conv3.lora_B', 'layer1.0.downsample.0.lora_A', 'layer1.0.downsample.0.lora_B', 'layer1.1.conv1.lora_A', 'layer1.1.conv1.lora_B', 'layer1.1.conv2.lora_A', 'layer1.1.conv2.lora_B', 'layer1.1.conv3.lora_A', 'layer1.1.conv3.lora_B', 'layer1.2.conv1.lora_A', 'layer1.2.conv1.lora_B', 'layer1.2.conv2.lora_A', 'layer1.2.conv2.lora_B', 'layer1.2.conv3.lora_A', 'layer1.2.conv3.lora_B', 'layer2.0.conv1.lora_A', 'layer2.0.conv1.lora_B', 'layer2.0.conv2.lora_A', 'layer2.0.conv2.lora_B', 'layer2.0.conv3.lora_A', 'layer2.0.conv3.lora_B', 'layer2.0.downsample.0.lora_A', 'layer2.0.downsample.0.lora_B', 'layer2.1.conv1.lora_A', 'layer2.1.conv1.lora_B', 'layer2.1.conv2.lora_A', 'layer2.1.conv2.lora_B', 'layer2.1.conv3.lora_A', 'layer2.1.conv3.lora_B', 'layer2.2.conv1.lora_A', 'layer2.2.conv1.lora_B', 'layer2.2.conv2.lora_A', 'lay

In [24]:
lora_dict = {name:module for name, module in model.named_modules() if getattr(module, 'is_lora', False)}
lora_dict_weights = {name:module.weight for name, module in lora_dict.items()}
lora_dict_bias = {name:module.bias for name, module in lora_dict.items() if hasattr(module, 'bias') and module.bias is not None}

In [25]:
lora_state_dict(model, bias='lora_only').keys()

dict_keys(['conv1.weight', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.0.conv3.weight', 'layer1.0.downsample.0.weight', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer1.1.conv3.weight', 'layer1.2.conv1.weight', 'layer1.2.conv2.weight', 'layer1.2.conv3.weight', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.conv3.weight', 'layer2.0.downsample.0.weight', 'layer2.1.conv1.weight', 'layer2.1.conv2.weight', 'layer2.1.conv3.weight', 'layer2.2.conv1.weight', 'layer2.2.conv2.weight', 'layer2.2.conv3.weight', 'layer2.3.conv1.weight', 'layer2.3.conv2.weight', 'layer2.3.conv3.weight', 'layer3.0.conv1.weight', 'layer3.0.conv2.weight', 'layer3.0.conv3.weight', 'layer3.0.downsample.0.weight', 'layer3.1.conv1.weight', 'layer3.1.conv2.weight', 'layer3.1.conv3.weight', 'layer3.2.conv1.weight', 'layer3.2.conv2.weight', 'layer3.2.conv3.weight', 'layer3.3.conv1.weight', 'layer3.3.conv2.weight', 'layer3.3.conv3.weight', 'layer3.4.conv1.weight', 'layer3.4.conv2.weight', 'l