## Pre-trained model

Create a **linear model** with 5 input features and 1 output feature

$y = m(x) = w \, x + b = \sum_{i=1}^{5} w_i \, x_i + b$

where: 

* $x = [x_1, x_2, x_3, x_4, x_5]$ is the input feature vector, and

* $w = [w_1, w_2, w_3, w_4, w_5]$ and $b$ are the model parameters (weight and bias)


In [1]:
import torch
from torch import nn

In [2]:
model = nn.Linear(5,1)  
print(model)

Linear(in_features=5, out_features=1, bias=True)


Model parameters (*weights* and *bias*) at initialization

In [3]:
for n, p in model.named_parameters():
    print(f"{n}: {p}\n")

weight: Parameter containing:
tensor([[0.0034, 0.4410, 0.4370, 0.4056, 0.3660]], requires_grad=True)

bias: Parameter containing:
tensor([-0.1786], requires_grad=True)



We can test the model by computing its output $y = m(x)$, for $x=(1,1,1,1,1)$.

In [4]:
ones = torch.tensor([1.,1.,1.,1.,1.])

In [5]:
with torch.no_grad():
    y = model(ones)
y.item()

1.4744436740875244

In this case, for $x=(1,1,1,1,1)$, $y = m(x) = w_1 + w_2 + w_3 + w_4 + w_5 + b$. Let's check this:

In [6]:
(model.weight.data.sum() + model.bias.data).item()

1.4744436740875244

## Create SpaRTA adapter

We use the `SpaRTA` class to add a SpaRTA adapter to the model. 

By choosing a $sparsity \approx 0.0$, the SpaRTA adapter will mark all the model parameters $(w_1, ...,w_5)$ and $b$ as trainable 

In [7]:
from peft_sparta import SpaRTA

In [8]:
model = SpaRTA(model, 1e-9)

Let's have a look into the internals of the created SpaRTA adapter (at initialization)

* Non-trainable **indices** for each of the pre-trained model parameters:

In [9]:
for n, p in model.indices._buffers.items():
    print(f"{n}:\n {p}\n")

weight:
 tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4]], dtype=torch.int32)

bias:
 tensor([[0]], dtype=torch.int32)



The **indices** (non-trainable) keep track of which specific model parameters are selected as trainable and therefore can be updated during training. Note here how all the model parameters have been selected as trainable  

* Trainable **deltas** (adapter parameters) for each of the pre-trained model parameters:  

In [10]:
for n, p in model.deltas._parameters.items(): # model.named_parameters()
    print(f"{n}: {p}\n")

weight: Parameter containing:
tensor([0., 0., 0., 0., 0.], requires_grad=True)

bias: Parameter containing:
tensor([0.], requires_grad=True)



The **deltas** (trainable) represent changes to the original model parameters on the given indices, and they will be added to the original model parameter values (that are kept frozen) at the beginning of each forward pass. Note how they are initialized at zero. Therefore, at initialization (before any training takes place), the adapted model has not changed and coincides with the original pre-trained model. We can test this by evaluating the model on $x=(1,1,1,1,1)$.   

In [11]:
with torch.no_grad():
    y = model(ones)
y.item()

1.4744436740875244

* Now, the parameters of the pre-trained base model are frozen (non-trainable)

In [12]:
# original (frozen) parameters of our pre-trained linear model
for n, p in model.model.named_parameters():
    print(f"{n}: {p}")

weight: Parameter containing:
tensor([[0.0034, 0.4410, 0.4370, 0.4056, 0.3660]])
bias: Parameter containing:
tensor([-0.1786])


* We also keep a copy of the original model parameters for the selected indices, so we can recover the model original parameter values (unmerging the deltas) after modifying them.

Let's see all the trainable and non-trainable parameters of our SpaRTA adapter:

In [13]:
# all trainable and non-trainable params (adapter state dict)
for n, p in model.state_dict().items():
    print(f"{n}:\n {p}\n")

model.weight:
 tensor([[0.0034, 0.4410, 0.4370, 0.4056, 0.3660]])

model.bias:
 tensor([-0.1786])

indices.weight:
 tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4]], dtype=torch.int32)

indices.bias:
 tensor([[0]], dtype=torch.int32)

deltas.weight:
 tensor([0., 0., 0., 0., 0.])

deltas.bias:
 tensor([0.])

original_chosen_params.weight:
 tensor([0.0034, 0.4410, 0.4370, 0.4056, 0.3660])

original_chosen_params.bias:
 tensor([-0.1786])



In [14]:
print('Pre-trained base model params:')
for n, _ in model.model.named_parameters():
    print(' -', n)

print('\nSpaRTA trainable (deltas) params:')
for n, _ in model.named_parameters():
    print(' -', n)

print('\nSpaRTA non-trainable (indices, original_chosen_params) params:')
for n, _ in model.named_buffers():
    print(' -', n)

Pre-trained base model params:
 - weight
 - bias

SpaRTA trainable (deltas) params:
 - deltas.weight
 - deltas.bias

SpaRTA non-trainable (indices, original_chosen_params) params:
 - indices.weight
 - indices.bias
 - original_chosen_params.weight
 - original_chosen_params.bias


Thus, for each of the original pre-trained model parameters, the SpaRTA adapter uses *indices* and *deltas* to represent which subset of scalar values within those parameter tensors can be modified (adapted) during training.

In [15]:
for param_name, param in model.model.named_parameters():
    print('* model param:', param_name, '\n')

    param_indices = model.indices[param_name]
    print("  - indices:", param_indices, '\n')

    param_deltas = model.deltas[param_name]
    print("  - deltas:", param_deltas, '\n')

* model param: weight 

  - indices: tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4]], dtype=torch.int32) 

  - deltas: Parameter containing:
tensor([0., 0., 0., 0., 0.], requires_grad=True) 

* model param: bias 

  - indices: tensor([[0]], dtype=torch.int32) 

  - deltas: Parameter containing:
tensor([0.], requires_grad=True) 



## Training the SpaRTA adapter

We now train (using standard Stochastic Gradient Descent) the SpaRTA adapter (specifically its deltas) to fit the following dataset of $(x,y)$ pairs, where the target $y$ is simply the sum of the input features in $x$. That is, the parametric linear model to be leart (the one that generated the dataset) is one with parameters $w_1=w_2=...=w_5=1$ for its wights, and $b=0$ for its bias.

#### Dataset

In [16]:
# dataset generated from a linear model with 
# w1 = w2 = w3 = w4 = w5 = 1; b = 0
x = torch.randn(50, 5)
y = x.sum(-1) # y = x1 + x2 + x3 + x4 + x5

#### Optimizer

In [17]:
import torch.optim as optim

In [18]:
optimizer = optim.SGD(model.parameters(), lr=0.01)

#### Training loop

We use the Mean Squared Error (MSE) loss to measure the discrepancy between the targets and the model predictions in the following training loop 

In [19]:
model.train()

for _ in range(40):

    y_pred = model(x).squeeze()
    loss = (y_pred - y).square().sum()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(), 
        max_norm=float('inf'))

    print(f'Training Loss: {loss.item():9.6f};  Gradient norm: {grad_norm.item():7.4f}')


Training Loss: 60.549877;  Gradient norm: 82.7082
Training Loss: 14.997356;  Gradient norm: 38.1405
Training Loss:  4.755784;  Gradient norm: 21.2599
Training Loss:  1.579162;  Gradient norm: 12.2680
Training Loss:  0.530923;  Gradient norm:  7.1481
Training Loss:  0.179143;  Gradient norm:  4.1777
Training Loss:  0.060528;  Gradient norm:  2.4450
Training Loss:  0.020467;  Gradient norm:  1.4324
Training Loss:  0.006926;  Gradient norm:  0.8399
Training Loss:  0.002345;  Gradient norm:  0.4930
Training Loss:  0.000795;  Gradient norm:  0.2897
Training Loss:  0.000270;  Gradient norm:  0.1704
Training Loss:  0.000092;  Gradient norm:  0.1003
Training Loss:  0.000031;  Gradient norm:  0.0591
Training Loss:  0.000011;  Gradient norm:  0.0349
Training Loss:  0.000004;  Gradient norm:  0.0206
Training Loss:  0.000001;  Gradient norm:  0.0122
Training Loss:  0.000000;  Gradient norm:  0.0072
Training Loss:  0.000000;  Gradient norm:  0.0043
Training Loss:  0.000000;  Gradient norm:  0.0025


Here is the gradient of the adapter's delta parameters (close enough to zero, indicating convergance)

In [20]:
for param in model.parameters():
    print(param.grad)

tensor([ 4.7546e-07,  2.5568e-06, -1.0386e-06,  5.9224e-07, -1.2583e-06])
tensor([-1.4901e-07])


#### Evaluation

Let's now check if the learnt model is actually the one that generated the data.

In [21]:
model.eval()

The final parameters (weights and bias) of the adapted model are the result of the original pre-trained model parameters (frozen) plus the learnt deltas. Calling the `eval` method on the adapter model will merge the deltas into the original pre-trained model parameters, if not already merged.

In [22]:
model.model.weight.data

tensor([[1., 1., 1., 1., 1.]])

In [23]:
model.model.bias.data

tensor([2.9802e-08])

Finally, we can also check this by evaluating the final adapted model in $x=(1,1,1,1,1)$. This should give us a value very close to $y = x_1 + x_2 + x_3 + x_4 + x_5 = 5.0$.

In [24]:
model(ones)

tensor([5.])