# Edit the circuit
To find the circuit (the important subnetwork), we need to "mask out" the other parts of the network. Differential masking uses trainable masks to do so. Following is one convenient way.

## Trainable masks 
This computation graph shows a forward procedure that passes the gradient through the mask without changing the model parameters. Here `w` is the model's parameter, `x` is the input, and `loss` is the output. We use an external "mask parameter", `m` that is trainable, to "cover up" those model parameters that are not necessary.  
(Might need vscode's Mermaid plugin to show the graph).  

```mermaid
graph LR;
	w -..-> new_w
	new_w --> Opmask((*))
	m --> Opmask
	Opmask --> Oploss((*))
	x --> Oploss
	Oploss --> Loss
```

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

In [2]:
# A minimal example that gets around the "variable that gradient computation requires is modified in-place" problem

w = torch.tensor(1., requires_grad=False)  # Weight of the model
m = torch.tensor(0.5, requires_grad=True)  # Mask
x = torch.tensor(3., requires_grad=True)  # Input tensor

# Step 1: Copy the model's weight to a (detached) temporary parameter
tmp_w = w.detach().clone()

# Step 2: Apply the mask there, where p has gradient
new_w = m * tmp_w

# Step 3: In-place copy the new weight back to the model
w.copy_(new_w)
print("w=", w)

# Now do the regular forward run of the model
loss = x * w
loss.backward()

# You can see the weight does not contain gradients; the gradients goes to the mask
print("loss=", loss)
print("m.grad:{}, w.grad:{}, x.grad: {}".format(m.grad, w.grad, x.grad))

w= tensor(0.5000, grad_fn=<CopyBackwards>)
loss= tensor(1.5000, grad_fn=<MulBackward0>)
m.grad:3.0, w.grad:None, x.grad: 0.5


  print("m.grad:{}, w.grad:{}, x.grad: {}".format(m.grad, w.grad, x.grad))


This approach allows us to set up an external module to mask a pre-trained DNN model (e.g., GPT2) without needing to rewrite the huggingface's model.

## Find a sparse circuit

A circuit consists of a sparse collection of parameters. We need to make most of the masks either 0 or 1.  
Section 3 of [Cao etal (2021)](https://arxiv.org/pdf/2104.03514.pdf) proposes a nice method to do so.  
We implemented that in `masked_model.py`  

## Apply edits on the circuit
Now that we have a `MaskedModel`, with the `self.masks` the names of the masked parameters.  
In this step, let's fine-tune the circuit. During the back-propagation procedure, if we see a weight is masked (i.e., `m=0`), we avoid it.  
TODO: Can this be done without rewriting the pytorch optimizer?  