In [1]:
import torch 
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn as nn
import tqdm
_=torch.manual_seed(0)

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=10, shuffle=True)

mnist_test = dsets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=10, shuffle=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
class unoptimized_model(nn.Module):
    def __init__(self, hidden_size_1 = 1000, hidden_size_2 = 2000):
        super(unoptimized_model, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
net = unoptimized_model().to(device)

In [4]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
    
        loss_sum = 0
        niters = 0
        
        data_iter = tqdm.tqdm(train_loader, desc=f'Epoch {epoch +1}')
        if total_iterations_limit is not None:
            data_iter.total = total_iterations_limit
        
        for data in data_iter:
            niters += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / niters
            data_iter.set_postfix(loss = avg_loss)
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
                   
        
        
train(train_loader, net, epochs = 1)     

Epoch 1:   0%|          | 0/6000 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 6000/6000 [00:48<00:00, 124.61it/s, loss=0.237]


In [5]:
ogweight = {}
for name,param in net.named_parameters():
    ogweight[name] = param.clone().detach()

In [6]:
def test():
    cor = 0
    ttl = 0
    
    wrong_counts = [0 for i in range(10)]
    
    with torch.no_grad():
        for data in tqdm.tqdm(test_loader, desc = 'Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    cor += 1
                else:
                    wrong_counts[y[idx]] += 1
                ttl += 1
    print(f'Accuracy: {round(cor/ttl, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for {i}:{wrong_counts[i]}')
test()

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 188.90it/s]

Accuracy: 0.959
wrong counts for 0:11
wrong counts for 1:11
wrong counts for 2:21
wrong counts for 3:76
wrong counts for 4:25
wrong counts for 5:46
wrong counts for 6:31
wrong counts for 7:33
wrong counts for 8:24
wrong counts for 9:137





In [7]:
total_params_og = 0
for idx,layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_params_og += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {idx+1} : {layer.weight.shape} weights and {layer.bias.shape} biases')
    
print(f'Total params: {total_params_og:,}')

Layer 1 : torch.Size([1000, 784]) weights and torch.Size([1000]) biases
Layer 2 : torch.Size([2000, 1000]) weights and torch.Size([2000]) biases
Layer 3 : torch.Size([10, 2000]) weights and torch.Size([10]) biases
Total params: 2,807,010


In [8]:

class LoRaParam(nn.Module):
    '''features_in : k
       features_out : d
       rank : r
    '''
    def __init__(self, features_in, features_out, rank = 1, alpha = 1, device = 'cpu'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        
        nn.init.normal_(self.lora_A, mean = 0, std = 1)
        
        
        self.scale = alpha/rank
        self.enabled = True
        
        
    def forward(self, ogweights):
        if self.enabled:
            #   W+(B*A)*scale
            return ogweights + torch.matmul(self.lora_B, self.lora_A).view(ogweights.shape)*self.scale
        else:
            return ogweights

In [9]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):

    
    features_in, features_out = layer.weight.shape
    return LoRaParam(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )


In [10]:

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [11]:
total_params_lora = 0
total_params_nonlora = 0
for idx,layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_params_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_params_nonlora += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {idx+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
)
    
assert total_params_og == total_params_nonlora
print(f'Total params og: {total_params_nonlora:,}')
print(f'Total params og+lora: {total_params_nonlora+total_params_lora:,}')
print(f'Total params lora: {total_params_lora:,}')
param_incr = (total_params_lora/total_params_nonlora)*100
print(f'Parameter increase: {param_incr:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total params og: 2,807,010
Total params og+lora: 2,813,804
Total params lora: 6,794
Parameter increase: 0.242%


In [12]:

for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing  {name}')
        param.requires_grad = False


mnist_trainset = dsets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing  linear1.bias
Freezing  linear1.parametrizations.weight.original
Freezing  linear2.bias
Freezing  linear2.parametrizations.weight.original
Freezing  linear3.bias
Freezing  linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:01<00:00, 89.25it/s, loss=0.0601] 


In [13]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == ogweight['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == ogweight['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == ogweight['linear3.weight'])


In [None]:
enable_disable_lora(enabled=True)
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, ogweight['linear1.weight'])

In [14]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 163.97it/s]

Accuracy: 0.902
wrong counts for 0:13
wrong counts for 1:17
wrong counts for 2:27
wrong counts for 3:218
wrong counts for 4:162
wrong counts for 5:83
wrong counts for 6:29
wrong counts for 7:290
wrong counts for 8:130
wrong counts for 9:10





In [15]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 203.40it/s]

Accuracy: 0.959
wrong counts for 0:11
wrong counts for 1:11
wrong counts for 2:21
wrong counts for 3:76
wrong counts for 4:25
wrong counts for 5:46
wrong counts for 6:31
wrong counts for 7:33
wrong counts for 8:24
wrong counts for 9:137



