<a href="https://colab.research.google.com/github/amulyagarimella/242finalproject/blob/main/Aim1_Toy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title PA2 - unstructured pruning

def _make_pair(x):
    if hasattr(x, '__len__'):
        return x
    else:
        return (x, x)

class SparseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                     padding=1):
        super(SparseConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = _make_pair(stride)
        self.padding = _make_pair(padding)

        # initialize weights of this layer
        self._weight = nn.Parameter(torch.randn([self.out_channels, self.in_channels,
                                                        self.kernel_size, self.kernel_size]))
        stdv = 1. / math.sqrt(in_channels)
        self._weight.data.uniform_(-stdv, stdv)
        # initialize mask
        # Since we are going to zero out the whole filter, the number of
        # elements in the mask is equal to the number of filters.
        self.register_buffer('_mask', torch.ones(out_channels))


    def forward(self, x):
        return F.conv2d(x, self.weight, stride=self.stride,
                        padding=self.padding)
    @property
    def weight(self):
        # check out https://pytorch.org/docs/stable/notes/broadcasting.html
        # to better understand the following line
        return self._mask[:,None,None,None] * self._weight

In [None]:
#@title Code Cell 3.4

# unstructurd - remove smallest WEIGHT in each layer
# technically train-prune-retrain allows for "how much of a training you do"
class SparseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                     padding=1):
        super(SparseConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = _make_pair(stride)
        self.padding = _make_pair(padding)

        # initialize weights of this layer
        self._weight = nn.Parameter(torch.randn([self.out_channels, self.in_channels,
                                                        self.kernel_size, self.kernel_size]))
        stdv = 1. / math.sqrt(in_channels)
        self._weight.data.uniform_(-stdv, stdv)
        # Since we are going to zero out the whole filter, the number of
        # elements in the mask is equal to the number of filters.
        self.register_buffer('_mask', torch.ones_like(self._weight))


    def forward(self, x):
        return F.conv2d(x, self.weight, stride=self.stride,
                        padding=self.padding)

    @property
    def weight(self):
        return self._mask * self._weight


def unnstructured_pruning(net, prune_percent):
    for i, layer in enumerate(get_sparse_conv2d_layers(net)):
        num_nonzero = layer._mask.sum().item()
        num_total = layer._mask.numel()
        num_prune = round(num_total * prune_percent)
        sparsity = 100.0 * (1 - (num_nonzero / num_total))
        print("Pruning: ", num_prune, num_total, prune_percent)
        print(f"Sparsity before pruning: {sparsity}")

        # We set elements in layer._mask to zero corresponding to the smallest magnitude
        abs_weight = torch.abs(layer._weight)
        indices = torch.nonzero(layer._mask.view(-1))
        # From current nonzero indices in the mask, find those corresponding to lowest-weight filters
        # We select from currently-nonzero elements so that we don't redundantly prune
        sorted_nonzero_indices = indices[torch.argsort(abs_weight.view(-1)[indices].view(-1))].view(-1)
        # Flatten the list of elements, then use unravel_index to get the original indices of the elements we want to prune
        idx_to_prune_flat = sorted_nonzero_indices[0:num_prune]
        idx_to_prune = torch.unravel_index(idx_to_prune_flat, layer._weight.size())
        print(torch.stack(idx_to_prune).shape)
        layer._mask.data[idx_to_prune] = 0

In [None]:
#@title PA2 - unstructured pruning - run

torch.manual_seed(43) # to give stable randomness

def get_sparse_conv2d_layers(net):
    '''
    Helper function which returns all SparseConv2d layers in the net.
    Use this below to implement layerwise pruning.
    '''
    sparse_conv_layers = []
    for layer in net.children():
        if isinstance(layer, SparseConv2d):
            sparse_conv_layers.append(layer)
        else:
            child_layers = get_sparse_conv2d_layers(layer)
            sparse_conv_layers.extend(child_layers)

    return sparse_conv_layers

device = 'cuda'
net = SparseConvNet()
net = net.to(device)

# Set these parameters based on PART 1.2
lr = 0.1 # best learning rate from 1.2
epochs = 100
milestones = list(range(25, epochs, 25)) # milestones from 1.2

# PART 3.3: Set this prune an additional 10% every 10 epochs, starting at
#           epoch 10, ending at epoch 50. By the end, you should achieve
#           50% sparsity for each convolution layer in the CNN. Current
#           paramaters indicate 10% pruning at the end of epoch 0.
prune_percentage = 0.1
prune_epochs = np.linspace(10,50,5)


def train_and_test_SparseConvNet_unstructured_pruning (lr = 0.1, milestones = [], epochs = 5, prune_percentage = 0, prune_epochs = []):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                    milestones=milestones,
                                                    gamma=0.1)

    train_loss_tracker, train_acc_tracker = [], []
    test_loss_tracker, test_acc_tracker = [], []

    print('Training for {} epochs, with learning rate {} and milestones {}'.
    format(epochs, lr, milestones))

    print('Unstructured pruning percentage {} and epochs {}'.format(prune_percentage, prune_epochs))

    start_time = time.time()
    for epoch in range(0, epochs):
        train(net=net, epoch=epoch, loader=trainloader, criterion=criterion, optimizer=optimizer, loss_tracker=train_loss_tracker, acc_tracker=train_acc_tracker)

        if epoch in prune_epochs:
            print('\nUnstructured pruning at epoch {}'.format(epoch))
            unnstructured_pruning(net, prune_percentage)
            # unstructured_pruning(net, prune_percentage)

        test(net=net, epoch=epoch, loader=testloader, criterion=criterion, loss_tracker=test_loss_tracker, acc_tracker=test_acc_tracker)
        scheduler.step()


    total_time = time.time() - start_time
    print('Total training time: {} seconds'.format(total_time))
    return train_loss_tracker, train_acc_tracker, test_loss_tracker, test_acc_tracker