# 10-714 Final Project: Pre-training with NEEDLE
NEEDLE is a self-contained deep learning framework that supports training and inference with a variety of modern neural networks. Throughout this course, we have implemented several different neural network architectures, including feed-forward, convolutional, and recurrent networks. Another key function of this framework is the ability to use pretrained network weights. In this project, we propose to implement the following two features:

1. Save and load weights with NEEDLE
2. Port external models to a NEEDLE model. In order to support as many frameworks as possible, instead of implementing individual converters for each framework, we will implement a conversion pipeline that takes an ONNX model as input and converts it to a NEEDLE model.

## 1 Save/Load weights with NEEDLE
In this section, we refer to "model" as any architecture that inherit the `ndl.nn.Module` class. The first feature of NEEDLE is the ability to save and load model weights. This process is split into two steps: saving and loading.

To save the weights, we simply iterate through all the parameters in the model, create signature for them, and save them to a file. 

Loading the weights is a bit more complicated, as we need to ensure that the target model has the same architecture as the model from which the weights were saved. We need to compare the signature of the models, and device of the models. If both match, we load the weights corresondingly. Otherwise, we raise an error.

Once the signatures have been verified, we can load the weights from the file and assign them to the corresponding parameters in the target model. This ensures that the weights are applied correctly, allowing the model to continue making predictions using the loaded weights.

Now we explain the implementation step by step.

### Description
#### Save weights
1. Create a signature for the parameters and the model (i.e., the ndl.nn.Module object). The state_dict() function creates a signature for each parameter and stores them in a dictionary that maps the signature to its corresponding value. The id() function creates a signature for the model that describes its architecture. Note that instead of using the name of each parameter, we manually name the parameters with a counter. This avoids the problem of two models with the same architecture having different names for their parameters.
    ```python
    def state_dict(self, prefix=""):
        """Return a dictionary of the module's parameters."""
        # Create an empty dictionary to store the state of the model
        state_dict = {}
        # Create a counter for each module's class name
        count = Counter()
        
        # Iterate through all the attributes of the module
        for k, v in self.__dict__.items():
            # If the attribute is a parameter, add it to the dictionary
            if isinstance(v, Parameter):
                name = k
                state_dict[prefix + name] = v
            # If the attribute is a module, increment the counter and add the module and its parameters to the dictionary
            elif isinstance(v, Module):
                count[v.__class__.__name__] += 1
                name = v.__class__.__name__ + '-' + str(count[v.__class__.__name__])
                state_dict.update(v.state_dict(prefix + name + "."))
        return state_dict
    ```

    ```python
    def id(self):
        """construct a json string that uniquely identifies the module."""
        # Create a counter for each module's class name
        count = Counter()

        # Create a dictionary to store the signatures of the module's parameters
        signatures = dict()
        
        # Add the class name of the module to the dictionary
        signatures[self.__class__.__name__] = dict()
        
        # Set the current dictionary to the inner dictionary that was just created
        signatures = signatures[self.__class__.__name__]
        
        # Iterate through the attributes of the module
        for k, v in self.__dict__.items():
            # If the attribute is a module, increment the counter and add its signature to the dictionary
            if isinstance(v, Module):
                count[v.__class__.__name__] += 1
                name = v.__class__.__name__ + '-' + str(count[v.__class__.__name__])
                signatures[name] = v.id()
            # If the attribute is a parameter, add its shape to the dictionary
            elif isinstance(v, Parameter):
                name = k
                signatures[name] = v.shape
            # If the attribute is a list or tuple, iterate through its elements
            # and add the signatures of the modules and parameters to the dictionary
            elif isinstance(v, (list, tuple)):
                for i, x in enumerate(v):
                    if isinstance(x, Module):
                        count[x.__class__.__name__] += 1
                        name = x.__class__.__name__ + '-' + str(count[x.__class__.__name__])
                        signatures[name] = x.id()
                    elif isinstance(x, Parameter):
                        name = k
                        signatures[name] = x.shape
        # Return the dictionary containing the signatures of the module's parameters
        return signatures
    ```

2. Save the `state_dict` to a `.npy` file. We also save the `device` of the model.


    ```python
    def save_state_dict(module, filename):
        """
        Save the state dict of a module to a file.

        Args:
            module (ndl.nn.Module): The module to save.
            filename (str): The file to save to.
        """
        # Create a dictionary to store the state dict and signature of the module
        save_dict = {}
        save_dict['state_dict'] = module.state_dict()

        # Convert the state dict to numpy arrays and store them in the dictionary
        for k, v in save_dict['state_dict'].items():
            save_dict['state_dict'][k] = v.numpy().astype(np.float32)
        # Store the signature of the module in the dictionary
        save_dict['signature'] = str(module.id())

        # Store the device of the module in the dictionary
        save_dict['device'] = str(module.device)

        # Save the dictionary to the specified file using numpy
        np.save(filename, save_dict)
        print("Saved state dict to file: {}".format(filename))
    ```

#### Load weights
Compare the signature of the target model and source model. Then, load the `state_dict` into the model by matching the parameter signatures

```python
def load_state_dict(module, filename):
    """
    Load the state dict from a file.

    Args:
        module: (ndl.nn.Module): The module to load the state dict to.
        filename (str): The file to load state dict from.
    """
    # Load the dictionary containing the state dict and signature of the module from the file
    save_dict = np.load(filename, allow_pickle=True).item()

    # Check if the signature of the loaded module matches the saved state dict
    assert save_dict['signature'] == str(module.id()), "Module signature does not match saved state dict"

    # Check if the device of the loaded module matches the saved state dict
    assert save_dict['device'] == str(module.device), "Module device does not match saved state dict"

    # Load the state dict to the module
    module.load_state_dict(save_dict['state_dict'])

    # Print a message to indicate that the state dict has been loaded from the file
    print("Loaded state dict from file: {}".format(filename))
```

and `module.load_state_dict()` function:

```python
def load_state_dict(self, state_dict):
    """
    Iterate over the module's parameters and submodules and load the values from the  state_dict.
    
    Args:
        state_dict (dict): A dictionary containing the values to load.
    """
    # Get a dictionary of the module's parameters
    this = self.state_dict()
    # Iterate through the values in the state_dict
    for k, v in state_dict.items():
        # If the value is a parameter of the module, set its value to the value from the state_dict
        if k in this:
            this[k].data = Tensor(NDArray(v, device=self.device), device=self.device, requires_grad=True)
        else:
            # If the value is not a parameter of the module, raise an error
            raise KeyError("State dict does not contain key: " + k)
```

#### Model Summary
As a side product of the above functions, we can also create a model summary function that prints out the model's architecture and the shape of its parameters. This is useful for debugging and understanding the model's architecture. The summary is in the format of a json string. 

```python
def summary(self):
    """Print a summary of the module."""
    id = self.id()
    pp = pprint.PrettyPrinter(indent=2)
    pp.pprint(id)
    return json.dumps(id)
```

### Demo
####  Save/Load weights
In this section, we demonstrate NEEDLE's save/load weights functionality with live code blocks. Here we use a ResNet9 model we developed during HW4, but in theory any model is supported as the save/load are implemented for the base `ndl.nn.Module` class.

We start by importing the necessary libraries and defining the ResNet9 model.

In [2]:
from apps.models import ResNet9
import numpy as np 
import needle as ndl
from needle.autograd import Tensor

model = ResNet9(device=ndl.cpu())


<module 'needle.backend_ndarray' from './python/needle/backend_ndarray/__init__.py'>


The saving of the model's weights is done by calling the `ndl.save_state_dict()` function. We use the numpy object array as a medium of saved data. Note that both the model weights and device information are saved.

In [3]:
ndl.save_state_dict(model, 'weights.npy')

Saved state dict to file: weights.npy


To demonstrate the loading of the model's weights, we create a new model and load the saved weights into it. We then compare the weights of the two models to show that they are the same.

In [4]:
new_model = ResNet9(device=ndl.cpu())

Before loading the weights, `model` and `new_model` have different weights, thus different output. Let's test that out.

In [5]:
x = Tensor(np.random.randn(1, 3, 32, 32).astype(np.float32), device=ndl.cpu())

assert np.allclose(model(x).numpy(), new_model(x).numpy()), 'As expected, the two models have different outputs'

AssertionError: As expected, the two models have different outputs

Now we load the weights we just saved into `new_model`. We can see that the weights of `model` and `new_model` are the same.

In [7]:
ndl.load_state_dict(new_model, 'weights.npy')

assert np.allclose(model(x).numpy(), new_model(x).numpy()), 'The two models have different outputs'

print('Loading Success! The models now have same outputs')

Loaded state dict from file: weights.npy
Loading Success! The models now have same outputs


As expected, they now have the same output.

But things could get complicated if we have various models of different architecture. In these cases we want to prevent the model from loading weights that are not compatible with its architecture. We have set up a check to ensure that the model's signature matches the signature of the saved weights. If the signatures do not match, an error will be raised. We also prevent the model from loading weights that are saved on a different device.

Here we define a `ResNet9_2` model, which is very similar to `ResNet9` but has different shapes in some layers. We then try to load the weights of `ResNet9` into `ResNet9_2`.

In [9]:
from apps.models import ResNet9_2

differnt_model = ResNet9_2(device=ndl.cpu())

ndl.load_state_dict(differnt_model, 'weights.npy')

AssertionError: Module signature does not match saved state dict

And similarly, we test loading to a different device.

In [10]:
cuda_model = ResNet9(device=ndl.cuda())

ndl.load_state_dict(cuda_model, 'weights.npy')

AssertionError: Module device does not match saved state dict

The last functionality we want to demonstrate in this section is model summary, which gives us an overview fo the model's architecture and the shape of its parameters. Similar functionality is also available in PyTorch and Tensorflow, which can be really handy for debugging.

For demo purpose, we examine the summary of `ResNet9` and `ResNet9_2` models. And understand why the loading should fail.

In [12]:
_ = model.summary()

{ 'ConvBN-1': { 'BatchNorm2d-1': {'bias': (16,), 'weight': (16,)},
                'Conv-1': {'bias': (16,), 'weight': (7, 7, 3, 16)},
                'ReLU-1': {}},
  'ConvBN-2': { 'BatchNorm2d-1': {'bias': (32,), 'weight': (32,)},
                'Conv-1': {'bias': (32,), 'weight': (3, 3, 16, 32)},
                'ReLU-1': {}},
  'ConvBN-3': { 'BatchNorm2d-1': {'bias': (64,), 'weight': (64,)},
                'Conv-1': {'bias': (64,), 'weight': (3, 3, 32, 64)},
                'ReLU-1': {}},
  'ConvBN-4': { 'BatchNorm2d-1': {'bias': (128,), 'weight': (128,)},
                'Conv-1': {'bias': (128,), 'weight': (3, 3, 64, 128)},
                'ReLU-1': {}},
  'Flatten-1': {},
  'Linear-1': {'bias': (1, 128), 'weight': (128, 128)},
  'Linear-2': {'bias': (1, 10), 'weight': (128, 10)},
  'ReLU-1': {},
  'Residual-1': { 'Sequential-1': { 'ConvBN-1': { 'BatchNorm2d-1': { 'bias': ( 32,),
                                                                     'weight': ( 32,)},
           

In [14]:
_ = differnt_model.summary()

{ 'ConvBN-1': { 'BatchNorm2d-1': {'bias': (16,), 'weight': (16,)},
                'Conv-1': {'bias': (16,), 'weight': (7, 7, 3, 16)},
                'ReLU-1': {}},
  'ConvBN-2': { 'BatchNorm2d-1': {'bias': (32,), 'weight': (32,)},
                'Conv-1': {'bias': (32,), 'weight': (3, 3, 16, 32)},
                'ReLU-1': {}},
  'ConvBN-3': { 'BatchNorm2d-1': {'bias': (64,), 'weight': (64,)},
                'Conv-1': {'bias': (64,), 'weight': (3, 3, 32, 64)},
                'ReLU-1': {}},
  'ConvBN-4': { 'BatchNorm2d-1': {'bias': (64,), 'weight': (64,)},
                'Conv-1': {'bias': (64,), 'weight': (3, 3, 64, 64)},
                'ReLU-1': {}},
  'Flatten-1': {},
  'Linear-1': {'bias': (1, 64), 'weight': (64, 64)},
  'Linear-2': {'bias': (1, 10), 'weight': (64, 10)},
  'ReLU-1': {},
  'Residual-1': { 'Sequential-1': { 'ConvBN-1': { 'BatchNorm2d-1': { 'bias': ( 32,),
                                                                     'weight': ( 32,)},
                   

If we look at `Residual-2.Sequential-1.ConvBN-1.BatchNorm2d-1`, we coudl see that  `ResNet9` has 128 channels in the first layer, while `ResNet9_2` has 64 channels. This is why the loading should fail.

## 2 Port ONNX models to NEEDLE

Due to the lack of training optimization (like operator optimization, distributed training), it's often hard to train a large scale model for multiple epoch on Needle directly. In order to enable Needle to run inference on large scale model, we support Needle to load pre-trained model from other ML framework like PyTorch, Tensorflow, Caffe etc. 

Instead of building seperate model loader for each ML framework individually, we only implement the convert from ONNX to Needle. ONNX is an widely use open format supporting model transfer between multiple ML framework like PyTorch, Tensorflow, Caffe etc, thus could be a good intermediate to bridge the convertion between Needle and other common ML framework.

In order to load ONNX model to Needle. We need to  1). Load the protocol buffer stored in onnx model file and use a parser to extract each module, their input / output variable names, and weights. 2). Using the intermediate variable, we can construct a graph with those module, and use topological sort to iterate through the graph to transfer each of the intermediate variable to Needle.nn modules. 3). We could construct the target needle model from needle modules dynamically.

We will introduce each step in order.

### Description
#### Parse ONNX module to Dictionary
There's a specific rule of how ONNX store model. Specifically, onnx.node store all the operation graph between each module(like input, output, attribute value, and name of weight and bias), and onnx.initializer store all the vale of weight and bias corredsponding to the name in onnx.node. 


Thus, likewise, we construct some object to store the value parse from ONNX model. For which Onnx node is similar to what onnx.node is doing, which record all the connection between different data, while onnx data is similar to onnx.data, which store all the actual value of data.

Moreover, we will construct a object inheritance from onnx node for each operator in the onnx.node, please refer to onnx_dict.py for detail.

```python
class OnnxNode:
    def __init__(self, att_dict) -> None:
        self.name = att_dict['name']
        self.inputs = att_dict['inputs']
        self.indegree = len(self.inputs)
        for input_name in self.inputs:
            if "data" in input_name:
                self.indegree -= 1
        self.output: str = att_dict['output']


class OnnxData:
    def __init__(self, **kwargs) -> None:
        self.name = kwargs['name']
        self.dtype = kwargs['dtype']
        self.category = "Initializer"   
        self.data: np.array = kwargs['data']
        self.dims: list =kwargs['dims']


class ConvOpNode(OnnxNode):
    def __init__(self, att_dict) -> None:
        super().__init__(att_dict)
        
        # attribures
        self.dilations = att_dict["dilations"]
        self.group = att_dict["group"]
        self.kernel_shape = att_dict["kernel_shape"]
        self.padding = att_dict["pads"]
        self.strides = att_dict["strides"]
        
        # data field
        self.X_name = att_dict["X_name"]
        self.Y_name = att_dict["Y_name"]
        self.W_name = att_dict["W_name"]
        self.W: OnnxData = att_dict["W_data"]
        self.use_bias = False
        if "B_name" in att_dict:
            self.use_bias = True
            self.B_name = att_dict["B_name"]
            self.B: OnnxData = att_dict["B_data"]

        self.out_channels, self.in_channels, _, _ = self.W.dims
        
class BatchNorm2DNode(OnnxNode):
    def __init__(self, att_dict) -> None:
        super().__init__(att_dict)

        # attribures
        self.eps = att_dict["epsilon"]
        self.momentum = att_dict["momentum"]
        self.spatial = att_dict["spatial"]

        # data field
        self.X_name = att_dict["X_name"]
        self.gamma_name = att_dict["gamma_name"]
        self.gamma: OnnxData = att_dict["gamma_data"]
        self.beta_name = att_dict["beta_name"]
        self.beta: OnnxData = att_dict["beta_data"]
        self.running_mean_name = att_dict["running_mean_name"]
        self.running_mean: OnnxData = att_dict["running_mean_data"]
        self.running_var_name = att_dict["running_var_name"]
        self.running_var: OnnxData = att_dict["running_var_data"]
        self.Y_name = att_dict["Y_name"]

        self.dim = self.gamma.dims[0]
```

### demo
In this small demo, we create a Residual block consisting of Conv, Batchnorm, ReLU and residual connection, as we can see the Needle model from ONNX produced the same output as the PyTorch model.

## Appendix 
In this section, we breifly discuss how to convert a PyTorch model to ONNX model. Although our focus of this project is to convert ONNX model to Needle, ONNX are mostly used for moving models  between different tools and frameworks for training, optimizing, and deploying them instead of building a model from scratch. Thus, for the sake of completeness, we will breifly discuss how to build a PyTorch model and convert it to ONNX model. Here, we demonstrate with two examples, one is a pretrained ResNet model offered by torchvision, and the other is a simple RNN model we build from scratch. Note that this section heavily relies on the official PyTorch documentation.

### Using Pretrained Model from torchvision
Here we demonstrate how to use a pretrained ResNet18 model from torchvision. The model is pretrained on the ImageNet dataset.

```python
# import the resnet18 model from PyTorch's torchvision module
from torchvision.models import resnet18 
import torch

# create a resnet18 model and load it with pre-trained weights
model = resnet18(pretrained=True) 

# Specify the input and output names of the onnx model
input_names = ['data'] 
output_names = ['output']

# create a dummy input tensor used to trace the model
dummy_input = torch.randn(1, 3, 224, 224, device='cpu') 

# export the model to ONNX format
torch.onnx.export(model, dummy_input, 'resnet18.onnx', verbose=True, input_names=input_names, output_names=output_names) 
```

### Build a model from scratch
The process of building a model from scratch is similar to the process of importing a pretrained model from torchvision. We define a simple RNN model with specified embedding size, output size, hidden size and number of layers
The model is composed of an embedding layer, an RNN layer and a fully connected layer, similar to what we built in hw4

```python
import torch
import torch.nn as nn


class SequenceModel(nn.Module):
    def __init__(self, embedding_size, output_size, hidden_size, num_layers=1, device='cpu', dtype=torch.float32):
        self.embedding_size = embedding_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.device = device
        self.dtype = dtype

        super().__init__()
        self.embedding = nn.Embedding(output_size, embedding_size)
        self.rnn = nn.RNN(embedding_size, hidden_size, num_layers=num_layers, bidirectional=False, dropout=0)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        seq_len, batch_size = x.shape
        x = self.embedding(x)
        x, h = self.rnn(x)

        x = x.view(seq_len * batch_size, -1)
        x = self.fc(x)

        return x

# Instantiate the model we just defined
model = SequenceModel(embedding_size=10, output_size=16, hidden_size=8, device='cpu')

# Create a dummy input tensor used to trace the model
x = torch.randint(0, 16, (5, 1), dtype=torch.long)

# Specify the input and output names of the onnx model
input_names = ['data']
output_names = ['output']

# Export the model to ONNX format
torch.onnx.export(model, x, 'rnn.onnx', verbose=True, input_names=input_names, output_names=output_names)
```

We have verfied with multiple models that the exported ONNX model produce the same output as the original PyTorch model, which ensures that our corresponding Needle model will produce the same output as the original PyTorch model given it has the same output as the ONNX model. Since not a main focus of the project, the code for exporting ONNX model is not part of the submission.