# Load Data

In [None]:
# Load MNIST
from torchvision import datasets
from torchvision.transforms import ToTensor 

train_data = datasets.MNIST(
    root = 'data',
    train = True, 
    transform = ToTensor(),
    download = True,
)

test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = ToTensor()
)

In [None]:
from torch.utils.data import DataLoader
loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1),
    
    'test'  : torch.utils.data.DataLoader(test_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1),
}
loaders

# Initialize Model

In [None]:
# define a feed forward network
import torch
import torch.nn as nn

class fc1(nn.Module):

    def __init__(self, num_classes=100):
        super(fc1, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28*28, 300),
            nn.ReLU(inplace=True),
            nn.Linear(300, 100),
            nn.ReLU(inplace=True),
            nn.Linear(100, num_classes),
        )

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    

In [None]:
# init weights
import torch.nn.init as init

def init_weights(module) :
    if isinstance(module, nn.Linear) :
        init.xavier_normal_(module.weight.data)
        init.normal_(module.bias.data)
    else :
        ValueError

# init feed forward
model = fc1()
model.apply(init_weights)
# saved config
initial_dict = copy.deepcopy(model.state_dict())
# add optimizer
# optimizer = optim.Adam(model.parameters(), lr = 0.01, weight_decay=1e-4) 
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum=0.9)
# def loss fn
loss_func = nn.CrossEntropyLoss()  

# Define Masking and Pruning

In [None]:
# define masking function
def make_mask(model) :
    """Initializes a mask for the given model. 
    
    Parameters
    ----------
    model: extends(nn.Module)
        The model to create the mask for.
    
    Returns
    -------
    list[Tensor]
        The mask.
    """
    mask = [None] * sum(1 for name, param in model.named_parameters() if "weight" in name)
    print(mask)
    layer = 0
    for name, param in model.named_parameters():
        if 'weight' in name :
            tensor = param.data
            mask[layer] = torch.ones_like(tensor)
            layer += 1

    return mask

mask = make_mask(model)

In [None]:
def prune_percentile(percent: float, mask: list[torch.Tensor]) -> list[torch.Tensor] :
    """Prunes mask based on percentile.

    Iterates over the given mask, settings all weights to zero in each tensor's layer
    if it is under the given percentile value.

    Paramaters
    ----------
    percent: float
        Value from 0 to 1
    mask: list[torch.Tensor]
        List of tensors that make up the mask

    Returns 
    -------
    list[torch.Tensor]
        List of tensors that make up the pruned mask
    """
    layer = 0
    for name, param in model.named_parameters() :
        if 'weight' in name:
            tensor = param.data
            torch_nonzero = torch.nonzero(tensor, as_tuple=True)
            alive = tensor[torch_nonzero]
            percentile_value = torch.quantile(abs(alive), percent).item()
            new_mask = torch.from_numpy(np.where(abs(tensor) < percentile_value, 0, mask[layer]))
            mask[layer] = new_mask
            layer += 1
    return mask
            

In [None]:
def total_nodes(model) :
    total = 0
    for name, param in model.named_parameters() :
        if "weight" in name :
            total += torch.count_nonzero(param.data)
    return total
original_nodes = total_nodes(model)
print("Total nodes:", original_nodes)

#### Utilities for pruning, masking

In [None]:
# resets the mask
def reset_mask(mask):
    """Resets the given mask.
        
    Sets the given mask to have all ones.

    Parameters
    ----------
    mask: list[Tensor]
        The mask to reset
    """
    for step in range(len(mask)) :
        new_mask = torch.ones_like(mask[step])
        mask[step] = new_mask

In [None]:
# reset to original
def reset_to_original_init(model, mask, initial_dict) :
    """Resets the given model to the originally intialized paramaters.
    
    Parameters
    ----------
    model: extends(nn.Module)
        The model to reset
    mask: list[torch.Tensor]
        The mask to reset
    initial_dict: Dict[str, Any]
        Dictionary values containing values to reset to
    """
    layer = 0
    for name, param in model.named_parameters() :
        if "weight" in name :
            param.data = initial_dict[name] * mask[layer]
            layer += 1
        if "bias" in name :
            param.data = initial_dict[name]

In [None]:
# full reset to init
def full_reset(model, mask, initial_dict) :
    """Fully resets the given model and mask.

    Parameters
    ----------
    model: nn.Module
        The model to reset
    mask: list[torch.Tensor]
        The mask to reset
    initial_dict: Dict[str, Any]
        Dictionary values containing values to reset to
    """    
    reset_mask(mask)
    reset_to_original_init(model, mask, initial_dict)

In [None]:
full_reset(model, mask, initial_dict)

In [None]:
from torch.autograd import Variable

def train_prune(model, loaders, loss_func):
    """Trains the given model with pruning.
    
    Trains the given model on the given training set using the given loss function. Ignores the values of
    pruned weights by assigning zero to their gradients.
    
    Parameters
    ----------
    model: extends(nn.Module)
        The model to train with pruning.
    loaders: Dict[str, DataLoader]
        The dictionary containing the training data set.
    loss_func: LossFunction
        The loss function to use.
    """

    EPS = 1e-6
    size = len(loaders["train"].dataset)
    for batch_idx, (imgs, targets) in enumerate(loaders['train']) :
        optimizer.zero_grad()
        pred = model(imgs)
        train_loss = loss_func(pred, targets)
        train_loss.backward()

        for name, param in model.named_parameters() :
            if "weight" in name:
                # tensor = param.data.cpu().numpy()
                tensor = param.data
                # grad_tensor = param.grad.data.cpu().numpy()
                grad_tensor = param.grad.data
                # assign 0 to gradient of pruned weight
                # grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                grad_tensor = torch.where(tensor < EPS, 0, grad_tensor)
                # param.grad.data = torch.from_numpy(grad_tensor)
                param.grad.data = grad_tensor
        
        optimizer.step()

        if batch_idx % 100 == 0:
            loss, current = train_loss.item(), batch_idx * len(imgs)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def test(model, loaders, loss_func) :
    """Tests the given model. 
    
    Parameters
    ----------
    model: extends(nn.Module)
        The model to train with pruning.
    loaders: Dict[str, DataLoader]
        The dictionary containing the training data set.
    loss_func: LossFunction
        The loss function to use.
    """
    test_dataloader = loaders["test"]
    size = len(test_dataloader.dataset) 
    num_batches = len(test_dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad() :
        for X, y in test_dataloader :
            pred = model(X)
            test_loss += loss_func(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size 
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# Baseline Test

In [None]:
test(model, loaders, loss_func)
nodes = total_nodes(model)
# print(f"Accuracy: {acc:.3f}")
print(f"Number of nodes: {nodes}")

# Iterative Pruning with Resetting

1. Randomly initialize a neural network f(x; m  θ) where θ = θ0 and m = 1|θ|
is a mask.
2. Train the network for j iterations, reaching parameters m  θj .
3. Prune s% of the parameters, creating an updated mask m0 where Pm0 = (Pm − s)%.
4. Reset the weights of the remaining portion of the network to their values in θ0. That is, let
θ = θ0.
5. Let m = m0
and repeat steps 2 through 4 until a sufficiently pruned network has been
obtained.

In [None]:
epochs = 5
prune_percent = 0.5
iterations = 6

def iterative_prune_train(model, mask, loss_func, iterations, percent) :
    for epoch in range(epochs) :
        print("epoch:", epoch)
        ### Train
        model.train()
        for t in range(iterations) :
            print(f"Iteration {t+1}\n-------------------------------")
            train_prune(model, loaders, loss_func)
            test(model, loaders, loss_func)
        ### Prune
        mask = prune_percentile(percent, mask)
        ### Reset
        reset_to_original_init(model, mask, initial_dict) 
        print(f"\n--- Pruning Level [{epoch}/{epochs}]: ---")

iterative_prune_train(model, mask, loss_func, iterations, prune_percent)

### Train, Test
Train and test the fully pruned model, and evaluate the accuracy.

In [None]:
epochs = 10
model.train()
for t in range(epochs) :
    print(f"Epoch {t+1}\n-------------------------------")
    train_prune(model, loaders, loss_func)
    test(model, loaders, loss_func)
print("Done.")

# Pruning Results

In [None]:
# loss_func = nn.CrossEntropyLoss()
# full_reset(model, mask, initial_dict)
test(model, loaders, loss_func)
nodes = total_nodes(model)
# print(f"Accuracy: {acc:.3f}")
print(f"Number of nodes: {nodes}")
print(f"Percent of nodes left: {(nodes / original_nodes):.3f}")

# Normal Training, Testing

Fully reset the model along with the mask, and train normally without any pruning.

In [None]:
full_reset(model, mask, initial_dict)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
epochs = 10
model.train()
for t in range(epochs) :
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(loaders["train"], model, loss_func, optimizer)
    test(model, loaders, loss_func)
print("Done.")