# torchID approximate search
By: Daniel Redder

## Here is the same stuff as in `test_simple.ipynb`

In [1]:
import torch 
from torchID import tag_model, identify_tensors, find_leaves
from simple_model import SimpleModel

### Preperation

Lets load in some simple sleep data from kaggle: https://www.kaggle.com/code/tanshihjen/eda-timeseries-fitbitsleepscoredata/input

In [2]:
import pandas as pd
df = pd.read_csv('sleep_score_data_fitbit.csv')[['overall_score', 'revitalization_score', 'deep_sleep_in_minutes', 'resting_heart_rate','restlessness']]
print(df.head())

print(df.values.shape)
dataset = torch.tensor(df.values, dtype=torch.float32)


   overall_score  revitalization_score  deep_sleep_in_minutes  \
0             83                    83                    104   
1             87                    87                    114   
2             84                    84                     99   
3             81                    81                     73   
4             76                    76                     64   

   resting_heart_rate  restlessness  
0                  63      0.068100  
1                  63      0.053283  
2                  64      0.051408  
3                  65      0.046679  
4                  65      0.076923  
(291, 5)


Now we initialize our model, optimizer, and loss

In [3]:
model = SimpleModel()

data_lab = dataset[:, -1].unsqueeze(1)
data_lab = torch.nn.functional.normalize(data_lab, p=2, dim=1, eps=1e-12, out=None)
model.train() #?this sets the requires_grads in some nn.modules to True 

#?but just to be really sure
for n,m in model.named_parameters(): m.requires_grad = True

optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = torch.nn.MSELoss()

### Finding Leaves

So we want to know which variables are leaves in this model, so to make this interesting lets slightly modify the output with additional leaves

In [4]:
EPOCHS = 20


#?note that this variable is not in the optimizer, or the model so we cannot locate it through the optimizer's parameter lists or the model's named_parameters
delta = torch.FloatTensor([-1.4])
delta.requires_grad = True

for epoch in range(EPOCHS):

    #? get our model's predicted output
    y_pred = model(dataset[:, :-1])

    Y_pred = y_pred + delta

    #? calculate the loss
    loss = loss_fn(Y_pred, data_lab)
    print(loss.item())

    #?typical backpropogation commands
    optim.zero_grad()
    loss.backward()
    optim.step()



5.4803242683410645
5.054694175720215
4.308633804321289
3.3790018558502197
2.4075212478637695
1.5183870792388916
0.8027776479721069
0.31126001477241516
0.05371517688035965
0.00536576472222805
0.1168447807431221
0.32603558897972107
0.5695803761482239
0.7924039363861084
0.9542089700698853
1.0325487852096558
1.0226702690124512
0.9347599148750305
0.7894994616508484
0.6129088997840881


And the loss went down, but we still have this leaf `delta` with a lingering .grad value that is slowly accumulating

In [5]:
print(delta.grad)

tensor([-9.4057])


This won't impact model's predictive performance, but it will impact memory usage, and especially with larger models `(specifically we had this happening in a diffusion meta-learning case)` it can make learning **intractable**

So, lets take the same loop and use **torchID** to find this tensor

This is comprised of 3 functions:

`find_leaves( grad_fn )`: this loops over `grad_fn.next_functions` to find `grad_fn` objects which have `.variable` attributes. It also records the path of `grad_fn's` taken to get to it **(used in our approximate search method)**

`tag_model( torch.nn.Module )`: this is the straightforward solution it checks `model.named_parameters()`, and assigns `leaf.nmm = name` from `named_parameters` 

`identify_tensors( [tensor] )`: this approach finds tensors which are not parameters, but are leaves in the computational graph. This works by searching `sys.modules` i.e. it **searches all references defined in all loaded modules**. We explain why we choose to do it this way on the ReadME. To mitigate the overhead this includes we have a `limited_system_search` parameter which will look for whether a specific variable exists in each module before checking it for the tensor. 

i.e. 

```py 
-- other package --
MyTestVar = ...

-- other other package --
(does not contain MyTestVar)

-- main --
identify_tensors(grad_fn, limited_system_search = "MyTestVar")
```
will only check in "other package"

## Handling Overwritten Leaves

A common case in torch packages is where a variable is overwritten by itself modified: `var = var+1`, now we cannot do `var+=1` in pytorch, but it does allow `var=var+1` which causes a problem for this package. 

Because tensors are immutable this change removes a reference to the original object `var`, and replaces it with a new `var* = var+1`. **This is a common spot for computational graph problems, so it is important we can do this.**

To solve this we use a approximation approach where we find the nearest `grad_fn` "before" (in terms of backprop) the variable is overwritten, and we use that as the name. If the `identify_tensors` function encounters any of these cases it will print a warning.

To do this, give the `outputDepthList` obj returned from `find_leaves` to `identify_tensors` as `approximate_search`. If a approximate is found you will see `approximate. ...` in the `nmm` attribute of the leaf with a measure of its seperation from the original leaf. 

In [6]:

#?we only need one forward pass to build the computational graph
EPOCHS = 1

for epoch in range(EPOCHS):

    #? get our model's predicted output
    y_pred = model(dataset[:, :-1])

    #!lets introduce some problems
    delta = delta / y_pred

    Y_pred = y_pred + delta

    #? calculate the loss
    loss = loss_fn(Y_pred, data_lab)
    print(loss.item())

    #?first we recursively find all the leaves of the computational graph this works by using the .variable bijective reference to the computational graph i.e. grad_fn.variable <-> tensor.grad_fn
    #?leaves is a list of all the leaves of the computational graph (as tensors), paths is a list of lists of the paths to each leaf (grad_fn's)
    leaves, paths = find_leaves(loss.grad_fn)


    #?now we identify the tensors that are in the computational graph
    #first we tag the model parameters so that they are handled seperately
    tag_model(model)

    #we then identify all leaves that are not model parameters
    identify_tensors(leaves, approximate_search=paths)

    for leaf in leaves: print(leaf.shape, leaf.grad_fn, "  ",leaf.nmm)


2.5555741786956787
[31mTensor object not found, using approximate search for the following objects: [0m
[34mindex: 2, shape: torch.Size([1]), dtype: torch.float32[0m
torch.Size([1]) None    ['SimpleModel_l1.bias', '__main__.m', '__main__.m', '__mp_main__.m', '__mp_main__.m']
torch.Size([1, 4]) None    ['SimpleModel_l1.weight']
torch.Size([1]) None    ['approximated.2seperated.__main__.delta']
torch.Size([1]) None    ['SimpleModel_l1.bias', '__main__.m', '__main__.m', '__mp_main__.m', '__mp_main__.m']
torch.Size([1, 4]) None    ['SimpleModel_l1.weight']
