# Introduction

Deep Learning is a highly experimental field, and as a result, there can be many satisfactory (though usually non-optimal) solutions for any given problem. It often happens that neural network architectures are disproportionately large for the complexity of the task. It turns out that we can "slim down" these models with only a small loss in prediction accuracy.

Pruning a network involves removing individual weights or entire neurons. There are many advantages to this method:

    Reducing the size of the network.

    Speeding up inference.

    Counteracting overfitting.

    Improving results.

To effectively reduce the network's size, we must zero out a sufficient number of elements in its weight matrices. By doing so, we can better compress the model in memory. However, just zeroing out the weights is not enough to speed up inference. It is also necessary to implement and effectively utilize sparse matrix computations. Another pruning method can be the removal of entire neurons, which reduces the actual size of the weight matrices.

In this task, we will focus only on zeroing out weights within the model. You cannot change the network's architecture (for example, by removing a neuron or an entire hidden layer). We will consider this problem using a regression example.


### **Task** 📝

Your goal is to implement the function `your_pruning_algorithm(model: torch.nn.Module) -> pruned_model: torch.nn.Module`. This function will take the model implemented below as input and return its **pruned** version.

The objective is to have the highest possible number of zeroed-out model parameters (**weights and biases**), while maintaining the lowest possible **mean squared error (MSE)** for its predictions.

You'll find a designated cell in the notebook below for your function. The cells you need to modify will be very clearly marked!

---

### **Evaluation**

You will be evaluated based on the result of the following function (the higher the value, the better):

$$
\mathrm{score}(s, \epsilon) = \begin{cases}
    0 & \text{if } \epsilon > 1000 \\
    \left(1 - \frac{\epsilon}{1000}\right)^{1.5} \cdot s^{1.5} & \text{otherwise}
\end{cases}
$$

where:
* **$s$** is the number of zeroed parameters divided by the total number of parameters in the model (**sparsity**).
* **$\epsilon$** is the **mean squared error (MSE)** on the test set.

This scoring criterion and all the functions mentioned above are already implemented for you.

### **Constraints** 📜

* Your function must return the model in a **maximum of 5 minutes** when run on Google Colab with a GPU.
* The weights file must be saved using the `save_parameters` function with the name `model_parameters.pkl`.
* You **cannot change the model's architecture**. It must have exactly:
    * An input layer of size 128
    * A hidden layer of size 1024
    * A Sigmoid activation function
    * An output layer of size 10

---

### **Submission Files** 📁

1.  This notebook.
2.  The model's parameters (weights), saved using the `save_parameters` function. Do not change the name of the generated file: `model_parameters.pkl`.

---

### **Evaluation** 📈

The weights file you provide will be the basis for your grade. However, you must also submit a working notebook that generates this `model_parameters.pkl` weights file in **under 5 minutes** after running all cells with the `FINAL_EVALUATION_MODE` flag set to `True` (timed on Google Colab with GPU access).

For this task, you can earn between **0 and 1.5 points**.
* If your score is below **0.085**, you will receive 0 points.
* If your score is above **0.95**, you will receive the maximum of 1.5 points.
* Between these two thresholds, your points will increase linearly with your score.

In [None]:
FINAL_EVALUATION_MODE = False  # During the evaluation, we will set this flag to True.

In [None]:
######################### DO NOT CHANGE THIS CELL ##########################
import copy
import pickle
import numpy as np
from IPython.display import clear_output
from tqdm.auto import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import SGD

np.random.seed(0)
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

### LOAD DATA

In [None]:
######################### DO NOT CHANGE THIS CELL ###########################

# Function to load training and validation data as np.array
def load_data_from_file(x_train_path, y_train_path, x_valid_path, y_valid_path):
    X_train = np.load(x_train_path)
    y_train = np.load(y_train_path)

    X_valid = np.load(x_valid_path)
    y_valid = np.load(y_valid_path)

    return X_train, y_train, X_valid, y_valid

# Dataset class
class InMemDataset(Dataset):
    def __init__(self, xs, ys, device='cpu'):
        super().__init__()
        self.dataset = []
        for i in tqdm(range(len(xs))):
            self.dataset.append((torch.tensor(xs[i]).to(device).float(), torch.tensor(ys[i]).to(device).float() ))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

In [None]:
######################### DO NOT CHANGE THIS CELL ###########################

# Let's load the data and create dataloaders
X_train, y_train, X_valid, y_valid = load_data_from_file(
    "train_data/X_train.npy",
    "train_data/y_train.npy",
    "valid_data/X_valid.npy",
    "valid_data/y_valid.npy",
)

batch_size = 128
_train = InMemDataset(X_train, y_train, device)
_valid = InMemDataset(X_valid, y_valid, device)

loaders = {
    "train" : DataLoader(_train, batch_size=batch_size, shuffle=True),
    "valid" : DataLoader(_valid, batch_size=batch_size, shuffle=False),
}

In [None]:
######################### DO NOT CHANGE THIS CELL ###########################

# The complete criterion defined in the task description.
def score(mse_loss, sparsity, mse_weight=1.5, sparsity_weight=1.5):
    
    if type(mse_loss) == np.ndarray:
        mse_loss[mse_loss > 1000] = 1000
    else:
        if mse_loss > 1000:
            mse_loss = 1000
            
    score = (1 - mse_loss / 1000) ** mse_weight * sparsity**sparsity_weight
    return score

# Calculates the model's sparsity (ratio of zeroed parameters).
def get_sparsity(model):
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if "weight" in name or "bias" in name:
            total_params += param.numel()
            zero_params += torch.sum(param == 0).item()
            
    sparsity = zero_params / total_params
    return sparsity

# Computes the Mean Squared Error (MSE).
def compute_error(model, data_loader):
    model.eval()
    
    losses = 0
    num_of_el = 0
    with torch.no_grad():
        for x, y in data_loader:
            outputs = model(x)
            num_of_el += x.shape[0] * y.shape[1]
            losses += model.loss(outputs, y, reduction="sum")
            
    return losses / num_of_el

# Scales the raw score to points.
def points(score):
    def scale(x, lower=0.085, upper=0.95, max_points=1.5):
        scaled = min(max(x, lower), upper)
        return (scaled - lower) / (upper - lower) * max_points
    return scale(score)

In [None]:
######################### DO NOT CHANGE THIS CELL ###########################

# Let's define our network's architecture
class MLP(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(128, 1024),
            nn.Sigmoid(),
            nn.Linear(1024, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.layers(x)
        return logits

    def loss(self, input, target, reduction="mean"):
        mse_loss = nn.MSELoss(reduction=reduction)
        return mse_loss(input, target)

In [None]:
######################### DO NOT CHANGE THIS CELL ###########################

# Initializing network weights
def init_weights(m):
    ''' Initialize the weights in the module m.'''
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)

# Function to save the model's weights to a file - remember that your weights
# file must be named: model_parameters.pkl
def save_parameters(model, file_name="model_parameters.pkl", to_file=True):

    params_to_save = {}
    for name, param in model.named_parameters():
        params_to_save[name] = param.to("cpu")
    
    if not to_file:
        return params_to_save
    
    with open(f"{file_name}", "wb") as f:
        pickle.dump(params_to_save, f)

# Function to load the model's weights from a file
def load_parameters(model, file_name="model_parameters.pkl", from_file=True, params=None):

    if from_file:
        with open(f"{file_name}", "rb") as f:
            params_to_load = pickle.load(f)
    else:
        params_to_load = params
        
    for name, param in model.named_parameters():
        with torch.no_grad():
            param[...] = params_to_load[name].to(device)

In [None]:
######################### DO NOT CHANGE THIS CELL ###########################

# Function for training the model
def train_model(model: nn.Module,
              data_loaders: dict[str, DataLoader],
              num_epochs: int,
              optimizer: torch.optim.Optimizer,
              verbose: bool = True
              ) -> tuple[torch.Tensor, float]:
    """Function to train a model.

    Args:
        model (torch.nn.Module): The neural network to be trained.
        data_loaders (dict[str, DataLoader]): A dictionary containing DataLoaders for the training and validation sets.
        num_epochs (int): The number of epochs for training.
        optimizer (torch.optim.Optimizer): The optimizer used for training the model.
        verbose (bool, optional): If True, shows the training progress.

    Returns:
        tuple[torch.Tensor, float]: A tuple containing the best set of model parameters
                                    found during training and the corresponding loss value
                                    on the validation set.
    """
    if FINAL_EVALUATION_MODE:
        verbose = False

    best_epoch = None
    best_params = None
    best_val_loss = np.inf

    for epoch in range(num_epochs):
        model.train()
        _iter = 1
        for inputs, targets in data_loaders['train']:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = model.loss(outputs, targets)
            loss.backward()
            optimizer.step()

            if verbose:
                if _iter % 10 == 0:
                    print(f"Minibatch {_iter:>6}    |  loss {loss.item():>5.2f}  |")

            _iter +=1

        val_loss = compute_error(model, data_loaders["valid"])

        if val_loss < best_val_loss:
            best_epoch = epoch
            best_val_loss = val_loss
            best_params = [copy.deepcopy(p.detach().cpu()) for p in model.parameters()]

        if verbose:
            clear_output(True)
            m = f"After epoch {epoch:>2} | valid loss: {val_loss:>5.2f}"
            print("{0}\n{1}\n{0}".format("-" * len(m), m))

    if best_params is not a None:
        if verbose:
            print(f"\nLoading best params on validation set in epoch {best_epoch} with loss {best_val_loss:.2f}")
        with torch.no_grad():
            for param, best_param in zip(model.parameters(), best_params):
                param[...] = best_param

    return best_params, best_val_loss

In [None]:
######################### DO NOT CHANGE THIS CELL ##########################
initial_model = MLP().to(device)
initial_model.apply(init_weights)

optimizer = SGD(
    initial_model.parameters(),
    lr = 0.01,
    momentum = 0.95,
    weight_decay = 0.001
)

best_params, best_val_loss = train_model(initial_model, loaders, num_epochs=100, optimizer=optimizer, verbose=True)

loss = compute_error(initial_model, loaders["valid"])
m = f"| Validation loss: {loss:>5.2f} |"
print("{0}\n{1}\n{0}".format("-" * len(m), m))

### Example Solution 💡

Below is a simple solution that is obviously not optimal. It's provided only to demonstrate how the entire notebook is intended to function.

In [None]:
def starter_pruning_algorithm(model):
    with torch.no_grad():
        model.layers[0].weight[:, 0:2] = 0
    return model

In [None]:
if not FINAL_EVALUATION_MODE:
    # Let's make a deep copy so we don't change the weights of the trained model
    model_to_prune = copy.deepcopy(initial_model)

    # Let's prune the weights with the example solution
    model_to_prune = starter_pruning_algorithm(model_to_prune)

    # Saving the model's parameters (we've changed the filename here, 
    # you should save yours as "model_parameters.pkl")
    save_parameters(model_to_prune, "starter_model_parameters.pkl")

    # Now let's see how to load the previously saved weights into a newly created model
    new_model = MLP().to(device)
    loss = compute_error(new_model, loaders["valid"])
    print(f"The new model has a loss of {loss:.3f}")

    # Loading the model's parameters
    load_parameters(new_model, "starter_model_parameters.pkl")
    loss = compute_error(new_model, loaders["valid"])
    print(f"After loading the parameters, the model has a loss of {loss:.3f}")

    mse = compute_error(new_model, loaders["valid"])
    sparsity = get_sparsity(new_model)

    print(f"Model MSE: {mse:.3f} Sparsity: {sparsity:.3f}")
    model_score = score(mse, sparsity)
    print(f"Your model's score is {model_score:.3f}!")
    print(f"Your solution gets {points(model_score):.3f}/1.5 points!")

### Your Solution 🚀

This section is the only place where you can change the code!

In [None]:
def your_pruning_algorithm(model):
    # TODO
    pruned_model = starter_pruning_algorithm(model)
    
    # Saving the model's parameters 
    save_parameters(pruned_model, "model_parameters.pkl")
    return pruned_model

model_to_prune = copy.deepcopy(initial_model)
your_pruning_algorithm(model_to_prune)

### Evaluation

The code below will be used to evaluate your solution. After you submit your solution, a function almost identical to the evaluate function below will be run on a test_data directory, which is available only to the graders.

Before submitting, make sure that the entire notebook runs from start to finish without errors (also with the FINAL_EVALUATION_MODE flag set to True), requires no user interaction, and saves the weights to the model_parameters.pkl file after executing the "Run All" command. Also, check that validation_script.py returns the expected result

In [None]:
def evaluate(X_test, y_test):
    """Validator"""
    test_model = MLP().to(device)
    load_parameters(test_model)

    batch_size = 128

    _test = InMemDataset(X_test, y_test, device)
    test_loader = DataLoader(_test, batch_size=batch_size, shuffle=False)

    mse = compute_error(test_model, test_loader)
    sparsity = get_sparsity(test_model)

    print(f"Model had error: {mse:.3f} and sparsity: {sparsity:.3f}")
    model_score = score(mse, sparsity)

    return model_score