In [46]:
import torch
import torchvision.datasets as dataasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm

In [47]:
_=torch.manual_seed(42)

In [48]:
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081))])
##load mnist dataste
mnist_train=dataasets.MNIST(root='/data',train=True,download=True,transform=transforms)
train_loader=torch.utils.data.DataLoader(mnist_train,batch_size=10,shuffle=True)
mnist_test=dataasets.MNIST(root='/data',train=False,download=True,transform=transforms)
test_loader=torch.utils.data.DataLoader(mnist_test,batch_size=10,shuffle=True)

In [49]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [50]:
device

device(type='cuda', index=0)

In [51]:
class mnist(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1=nn.Linear(28*28,2000)
        self.linear2=nn.Linear(2000,1000)
        self.linear3=nn.Linear(1000,10)
        self.relu=nn.ReLU()
    def forward(self,image):
        x=image.view(-1,28*28)
        x=self.relu(self.linear1(x))
        x=self.relu(self.linear2(x))
        x=self.linear3(x)
        return x
model=mnist().to(device)



In [52]:
model

mnist(
  (linear1): Linear(in_features=784, out_features=2000, bias=True)
  (linear2): Linear(in_features=2000, out_features=1000, bias=True)
  (linear3): Linear(in_features=1000, out_features=10, bias=True)
  (relu): ReLU()
)

In [53]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # Move model to the device
def train(train_loader, model, epochs=2, total_iterations_limit=None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        # Wrap your data loader with tqdm to show progress
        with tqdm(train_loader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch {epoch+1}/{epochs}")
            
            for input_batch, target in tepoch:
                optimizer.zero_grad()

                # Send data to device
                x = input_batch.to(device)
                y = target.to(device)

                # Forward pass
                output = model(x.view(-1, 28*28))

                # Calculate loss
                curr_loss = loss_fn(output, y)
                total_loss += curr_loss.item()

                # Backward pass and optimize
                curr_loss.backward()
                optimizer.step()

                # Update progress bar description with current loss
                tepoch.set_postfix(loss=curr_loss.item())

        print(f'Epoch [{epoch+1}/{epoch}], Average Loss: {total_loss/len(train_loader):.4f}')
train(train_loader,model,epochs=2)

Epoch 1/2: 100%|██████████| 6000/6000 [01:26<00:00, 69.11batch/s, loss=0.147]   


Epoch [1/0], Average Loss: 0.2307


Epoch 2/2: 100%|██████████| 6000/6000 [01:28<00:00, 68.02batch/s, loss=0.0398]  

Epoch [2/1], Average Loss: 0.1248





In [54]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:09<00:00, 102.76it/s]

Accuracy: 0.962
wrong counts for the digit 0: 28
wrong counts for the digit 1: 10
wrong counts for the digit 2: 42
wrong counts for the digit 3: 19
wrong counts for the digit 4: 25
wrong counts for the digit 5: 70
wrong counts for the digit 6: 25
wrong counts for the digit 7: 106
wrong counts for the digit 8: 14
wrong counts for the digit 9: 44





In [55]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [56]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(total_parameters_original)
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

1570000
Layer 1: W: torch.Size([2000, 784]) + B: torch.Size([2000])
3571000
Layer 2: W: torch.Size([1000, 2000]) + B: torch.Size([1000])
3581010
Layer 3: W: torch.Size([10, 1000]) + B: torch.Size([10])
Total number of parameters: 3,581,010


In [57]:
class LoRaParameters(nn.Module):
    def __init__(self,feature_out,feature_in,alpha,rank=1,device='cpu'):
        super().__init__()
        self.LoraA=nn.Parameter(torch.zeros((rank,feature_out)).to(device))
        self.LoraB=nn.Parameter(torch.zeros((feature_in,rank)).to(device))
        nn.init.normal_(self.LoraA, mean=0, std=1)
        self.scale=alpha/rank
        self.enabled=True
    def forward(self,original_weights):
        if self.enabled:

            return original_weights + torch.matmul(self.LoraB, self.LoraA).view(original_weights.shape) * self.scale
        else:
            original_weights

        

In [58]:
model.linear1.weight.shape

torch.Size([2000, 784])

In [59]:
import torch.nn.utils.parametrize as parametrize
def linear_layer_parameterization(layer,device,rank=2,lora_alpha=1):
    features_in,features_out=layer.weight.shape
    print(features_in,features_out)
    return LoRaParameters(features_in,features_out,rank=rank,alpha=lora_alpha,device=device)
parametrize.register_parametrization(model.linear1,"weight",linear_layer_parameterization(model.linear1,device))
parametrize.register_parametrization(model.linear2,"weight",linear_layer_parameterization(model.linear2,device))
parametrize.register_parametrization(model.linear3,"weight",linear_layer_parameterization(model.linear3,device))

def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled


2000 784
1000 2000
10 1000


In [60]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].LoraA.nelement() + layer.parametrizations["weight"][0].LoraB.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].LoraA.shape} + Lora_B: {layer.parametrizations["weight"][0].LoraB.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([2000, 784]) + B: torch.Size([2000]) + Lora_A: torch.Size([2, 2000]) + Lora_B: torch.Size([784, 2])
Layer 2: W: torch.Size([1000, 2000]) + B: torch.Size([1000]) + Lora_A: torch.Size([2, 1000]) + Lora_B: torch.Size([2000, 2])
Layer 3: W: torch.Size([10, 1000]) + B: torch.Size([10]) + Lora_A: torch.Size([2, 10]) + Lora_B: torch.Size([1000, 2])
Total number of parameters (original): 3,581,010
Total number of parameters (original + LoRA): 3,594,598
Parameters introduced by LoRA: 13,588
Parameters incremment: 0.379%


In [61]:
model

mnist(
  (linear1): ParametrizedLinear(
    in_features=784, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRaParameters()
      )
    )
  )
  (linear2): ParametrizedLinear(
    in_features=2000, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRaParameters()
      )
    )
  )
  (linear3): ParametrizedLinear(
    in_features=1000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRaParameters()
      )
    )
  )
  (relu): ReLU()
)

In [62]:
for name, param in model.named_parameters():
    if 'Lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


In [63]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'Lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = dataasets.MNIST(root='./data', train=True, download=True, transform=transforms)
exclude_indices = mnist_trainset.targets == 7
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1/1: 100%|██████████| 627/627 [00:09<00:00, 66.81batch/s, loss=1.63e-5] 

Epoch [1/0], Average Loss: 0.0168





In [64]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:10<00:00, 91.24it/s] 

Accuracy: 0.893
wrong counts for the digit 0: 49
wrong counts for the digit 1: 72
wrong counts for the digit 2: 357
wrong counts for the digit 3: 44
wrong counts for the digit 4: 74
wrong counts for the digit 5: 80
wrong counts for the digit 6: 32
wrong counts for the digit 7: 0
wrong counts for the digit 8: 46
wrong counts for the digit 9: 318



