In [None]:
What\'' s task arithmetic?

This first snippet comes from the TaLoS pruner.

A Pruner is a class that takes in a model a testing dataset and finds the parameters that change the less (so are less relevant for the task we are doing).

This knowledge usually is used to cut the dimension of the network.

TaLoS on the other hand wants to find the parameters responsible for identifiying a particular class.

This is fondamental to do task arithmetic. For example, if we fine-tune a model to recognize dogs, with TaLoS we can find the parameters that are important to do this task. We can do the same with cats.

In the end we can get a model that recognizes cats and dogs by simply injecting in a vanilla model, the parameters we gathered with TaLoS for dogs and cats.

This is the first code snippet the TA gave us. It's inside the score funtion of the TaLoS pruner. Its job is to compute the diagonal elements of the Fisher Information matrix.

What does it mean? Every time the model is given a picture the backpropagation changes our parameters. TaLoS calculates the gradient of this change, squares it(to have a positive modulus) and adds it to previous square gradients.

So in the end for every parameter we will have the squared sum of his gradients. The ones with smaller modulus will be the less significat parameters for our tasks, while the ones with largest modules are the parameters that tuned and "learned" to do the task we requested.

In [None]:
for _ in range(self.R):
                #Here we get a tensor with values of the last layer of the model after it was given a bunch of picures(our input to learn a task)
                logits = model(input)
                #Here We do a softmax of this tensor, so that every value becomes a probability and their sum adds to 1 (e.g {[90, 10]} becomes [0.9, 0.1])
                outdx = torch.distributions.Categorical(logits=logits)
                #.sample() will spin a roulette with every class having his softmax probability to come out (for [0.9, 0.1], if we run on it .sample() 10 times, we will get
                # [0] 9 times and [1] 1 time. Unsqueeze just changes the shape of the tensor and detach tells Pytorch to stop traking the gradients
                .sample().unsqueeze(1).detach()
                #Now, outdx is an array of indexes. For every picture we know what class .sample() chose for us.
                # .gather(dim,idx) for the chosen dimension (0 = columns, 1 rows) will take the corrispective index. idx is a vector matching the dim we gave.
                # In our example, logits = [90,10], outx = [1], samples = logits.gather(0,[1]) ===> sample = [90], if outx = [0] ====> sample = [10]
                samples = logits.gather(1, outdx)

                for idx in range(data.size(0)):
                    #We don't track the model gradient (we don't need to compute the backpropagation, we already did model(input))
                    model.zero_grad()
                    #This computes the derivative of that specific output with respect to every weight in the network.
                    # sample[idx] ,as we explained before, is the the output value of the network for a specific picture
                    torch.autograd.backward(samples[idx], retain_graph=True)
                    # masked_parameters(model) returns all the parameters responsible for defining Linear connections or Convolutional connections, ignoring the others
                    for m, p in masked_parameters(model):
                        if p.requires_grad and hasattr(p, 'grad') and p.grad is not None:
                            # Here we compute, step by step, the diagonal elements of the Fisher Information matrix by summing the squares of the gradients
                            self.scores[id(p)] += torch.clone(p.grad.data.pow(2)).detach().cpu()


Rememeber the previous snippet was inside the TaLoS score() function in which we compute Fisher's diagonal Matrix.

We can now understand the second snippet the TA gave to us.

In [None]:
#Sparsity is just the fraction of the network we want to retain. Sparsity 0.5, we retain half of the network and prune the other, less significant, half
if sparsity < 1.0:
        #In the complete code of TaLoS they try to prune for multiple rounds I think for research porupese
        for round in range(ROUNDS):
            # Only in the finale round sparse = sparsity
            sparse = sparsity**((round + 1) / ROUNDS)
            print('[+] Target sparsity:', sparse)
            #We compute the Fisher diagonal Matrix as shown before
            pruner.score(model, None, data_loader, args.device, N_PRUNING_BATCHES)
            #Inside the mask() function there is a Switch to choose differents ways to compute the threshold for pruning/masking or not a parameter
            mode = 'global_copy'
            #pruner.mask() takes the score of our model and based on them and the mode chooses which parameters to mask.
            # From a practical point of view, mask() changes the self.masked_parameters which is an iterator that returns (mask, param) where mask equals 1 or 0, while param is the parameter
            #Where does the pruner gets the model scores? Remember the pruner is wrapped around the model itself(the model is an __init__ parameter)

            pruner.mask(sparse, mode)

After seeing how to find the most significant parameters for a task, we can implement the SparseSGDM that we were asked.

The optimizer takes already normally the parameters as input. Every parameter has his mask saved in  self.state[parameter]['mask']

The only addition we make to a normal SGDM is in point 2. MASKING GRADIENT.

With this Optimizer we will be able to train/fine tune only some parts of our model. This parts can be identified with TaLoS or other pruning tools.

In [None]:
import torch
from torch.optim import Optimizer

class SparseSGDM(Optimizer):
    r"""
    Implements Stochastic Gradient Descent with Momentum (SGDM)
    where the gradient mask is passed directly into the constructor.
    """

    def __init__(self, params, masks=None, lr=1e-3, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        """
        Args:
            params (iterable): iterable of parameters to optimize (usually model.parameters())
            masks (iterable, optional): iterable of (mask, param) tuples.
                                        (in case we use the TaLoS masked_parameters(model))
            lr (float): learning rate
            momentum (float, optional): momentum factor (default: 0)
            weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
            dampening (float, optional): dampening for momentum (default: 0)
            nesterov (bool, optional): enables Nesterov momentum (default: False)
        """
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)

        # 1. Initialize the Base Optimizer with the parameters
        super(SparseSGDM, self).__init__(params, defaults)

        # 2. Automatically register masks if provided in constructor
        if masks is not None:
            self._register_masks(masks)

    def _register_masks(self, masked_parameters):
        """
        Internal helper to populate self.state with masks.
        """
        for mask, param in masked_parameters:
            # Initialize state if it implies lazy initialization
            if param not in self.state:
                self.state[param] = {}

            # Store mask in the specific parameter's state dictionary
            # We ensure device consistency (move mask to CPU/GPU of the param)
            self.state[param]['mask'] = mask.to(param.device)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step with Masking."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                # the name d_p resembles "delta p" of the gradient
                d_p = p.grad

                # Retrieve mask from state
                state = self.state[p]
                mask = state.get('mask')

                # 1. Apply Weight Decay (Before masking)
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)

                # 2. MASKING GRADIENT: Prevent update from current batch
                if mask is not None:
                    d_p.mul_(mask)

                # 3. Momentum Logic
                if momentum != 0:
                    if 'momentum_buffer' not in state:
                        buf = state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

                    # 4. MASKING MOMENTUM: Prevent drift from history
                    if mask is not None:
                        buf.mul_(mask)

                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                # 5. Update Weight
                p.add_(d_p, alpha=-group['lr'])

        return loss