# Deep Learning Homework \#05
### Deep Learning Course $\in$ DSSC @ UniTS (Spring 2021)  

#### Submitted by [Emanuele Ballarin](mailto:emanuele@ballarin.cc)  

### Preliminaries:

#### Imports:

We start off by importing all the libraries, modules, classes and functions we are going to use *today*...

In [None]:
# System interaction
import os

# Typing
from torch import Tensor

# Tensor computation and ANNs
import torch        # Backward compatibility
import torch as th  # Forward compatibility

# Scripted easers
from scripts import mnist, train_utils, architectures, train
from scripts.torch_utils import use_gpu_if_possible
from scripts.train_utils import accuracy, AverageMeter

#### Pruning utility functions and related utilities

Taken from the provided *Jupyter* notebook, and slightly adapted.

In [None]:
def magnitude_pruning(model, pruning_rate, layers_to_prune=["1", "4", "7", "10"], init_mask=None):

    # Handle base case
    if init_mask is None:
        init_mask = [1]*len(model.named_parameters())   # Identity mask
    # else:
    #     init mask = init_mask

    params_to_prune = [pars[1] for pars in model.named_parameters() if any([l in pars[0] for l in layers_to_prune])]
    flat = torch.cat([pars.abs().flatten() for pars in params_to_prune], dim=0)

    flat = flat.sort()[0]
    position = int(pruning_rate * flat.shape[0])
    thresh = flat[position]

    mask = []
    maskidx = 0
    for pars in model.named_parameters():
        if any([l in pars[0] for l in layers_to_prune]) and init_mask[maskidx] == 1:
            m = torch.where(pars[1].abs() >= thresh, 1, 0)
            mask.append(m)
            pars[1].data *= m
        else:
            mask.append(torch.ones_like(pars[1]))
        
        maskidx += 1

    return mask

#### *Training with pruning* routine

Taken from the provided *Jupyter* notebook, and slightly adapted.

In [None]:
def train_epoch(model, dataloader, loss_fn, optimizer, loss_meter, performance_meter, performance, device, mask, layers_to_prune, params_type_to_prune):

    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)

        optimizer.zero_grad() 

        y_hat = model(X)

        loss = loss_fn(y_hat, y)

        loss.backward()

        if mask is not None:
            for (name, param), m in zip(model.named_parameters(), mask):
                if any([l in name for l in layers_to_prune]):
                    param.grad *= m

        optimizer.step()

        acc = performance(y_hat, y)

        loss_meter.update(val=loss.item(), n=X.shape[0])
        performance_meter.update(val=acc, n=X.shape[0])

In [None]:
def train_model(model, dataloader, loss_fn, optimizer, num_epochs, checkpoint_loc=None, checkpoint_name="checkpoint.pt", performance=accuracy, lr_scheduler=None, device=None, mask=None, params_type_to_prune=["weight", "bias"]):
    if checkpoint_loc is not None:
        os.makedirs(checkpoint_loc, exist_ok=True)

    if device is None:
        device = use_gpu_if_possible()
    
    model = model.to(device)
    model.train()

    for epoch in range(num_epochs):
        loss_meter = AverageMeter()
        performance_meter = AverageMeter()

        print(f"Epoch {epoch+1} --- learning rate {optimizer.param_groups[0]['lr']:.5f}")

        train_epoch(model, dataloader, loss_fn, optimizer, loss_meter, performance_meter, performance, device, mask, layers_to_prune, params_type_to_prune)

        print(f"Epoch {epoch+1} completed. Loss - total: {loss_meter.sum} - average: {loss_meter.avg}; Performance: {performance_meter.avg}")

        if checkpoint_name is not None and checkpoint_loc is not None:
            checkpoint_dict = {
                "parameters": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch
            }
            torch.save(checkpoint_dict, os.path.join(checkpoint_loc, checkpoint_name))
        
        if lr_scheduler is not None:
            lr_scheduler.step()

    return loss_meter.sum, performance_meter.avg