Notebook for OTO: https://colab.research.google.com/drive/1Q6zuORrGQkyIp1IWYSiWcHCv7kZ5OA8u?usp=sharing

In [13]:
!pip install only_train_once



In [14]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging
import tensorflow as tf
import torch
from torchvision import datasets, transforms
from only_train_once import OTO

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.functional as F

import numpy as np
import pandas as pd

In [28]:
class LeNet5BN(nn.Module):

    def __init__(self):
        super(LeNet5BN, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None),
            nn.AvgPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400,120),  #in_features = 16 x5x5
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10),
            nn.Softmax()

        )

    def forward(self,x):
        a1=self.feature_extractor(x)
        #print(a1.shape)
        a1 = torch.flatten(a1,1)
        a2=self.classifier(a1)
        return a2


class LeNet5(nn.Module):

    def __init__(self):
        super(LeNet5, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            #nn.BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None),
            nn.AvgPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            #nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400,120),  #in_features = 16 x5x5
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10),
            nn.Softmax()

        )

    def forward(self,x):
        a1=self.feature_extractor(x)
        #print(a1.shape)
        a1 = torch.flatten(a1,1)
        a2=self.classifier(a1)
        return a2

In [16]:
def get_loaders(batch_size, test_batch_size):
    train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('./data.fashionMNIST', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        #transforms.RandomCrop(32),
                        #transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ])),
    batch_size=batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('./data.fashionMNIST', train=False, transform=transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])),
    batch_size=test_batch_size, shuffle=True)

    return train_loader, test_loader

def accuracy_topk(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).view(-1).float().sum(0, keepdim=True)
        res.append(correct_k)
    return res


def check_accuracy(model, testloader, two_input=False):
    correct1 = 0
    correct5 = 0
    total = 0
    model = model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            if two_input:
                y_pred = model.forward(X, X)
            else:
                y_pred = model.forward(X)
            total += y.size(0)

            prec1, prec5 = accuracy_topk(y_pred.data, y, topk=(1, 5))

            correct1 += prec1.item()
            correct5 += prec5.item()

    model = model.train()
    accuracy1 = correct1 / total
    accuracy5 = correct5 / total
    return accuracy1, accuracy5


In [24]:
import copy

def fit_model(batch_size, test_batch_size, ts, bn = True):
    train_loader, test_loader = get_loaders(batch_size, test_batch_size)

    if bn:
      model = LeNet5BN()
    else:
      model = LeNet5()
    dummy_input = torch.rand(1, 1, 32, 32)
    input_shape = (batch_size, 1, 32, 32)
    oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())

    optimizer = oto.hesso(
        variant='sgd',
        lr=0.15,
        first_momentum = 0.9,
        weight_decay=0,
        target_group_sparsity=ts,
        start_pruning_step=0,
        pruning_periods=1,
        pruning_steps=1
    )

    max_epoch = 50
    model.cuda()
    criterion = torch.nn.CrossEntropyLoss()
    # Every 50 epochs, decay lr by 10.0
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 1000)
    patience = 10
    counter = 0
    best_val_loss = float('inf')
    best_model_state = None

    for epoch in range(max_epoch):
        f_avg_val = 0.0
        model.train()
        lr_scheduler.step()
        for X, y in train_loader:
            X = X.cuda()
            y = y.cuda()
            y_pred = model.forward(X)
            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            f_avg_val += f
            optimizer.step()
        group_sparsity, param_norm, _ = optimizer.compute_group_sparsity_param_norm()
        norm_important, norm_redundant, num_grps_important, num_grps_redundant = optimizer.compute_norm_groups()
        accuracy1, accuracy5 = check_accuracy(model, test_loader)
        f_avg_val = f_avg_val.cpu().item() / len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X, y in test_loader:
                X = X.cuda()
                y = y.cuda()
                y_pred = model.forward(X)
                val_loss += criterion(y_pred, y).item()

        val_loss = val_loss / len(test_loader)

        if val_loss < best_val_loss:
          best_val_loss = val_loss
          best_model_state = copy.deepcopy(model.state_dict())
          counter = 0
        else:
            counter += 1

        print("Ep: {ep}, loss: {f:.2f}, norm_all:{param_norm:.2f}, grp_sparsity: {gs:.2f}, acc1: {acc1:.4f}, norm_import: {norm_import:.2f}, norm_redund: {norm_redund:.2f}, num_grp_import: {num_grps_import}, num_grp_redund: {num_grps_redund}"\
            .format(ep=epoch, f=f_avg_val, param_norm=param_norm, gs=group_sparsity, acc1=accuracy1,\
            norm_import=norm_important, norm_redund=norm_redundant, num_grps_import=num_grps_important, num_grps_redund=num_grps_redundant
            ))

        if counter > patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)


    baseline_flops = oto.compute_flops()['total']
    # oto = OTO(torch.load(ckpt_path), dummy_input)
    oto.construct_subnet(out_dir='./cache')
    pruned_flops = oto.compute_flops()['total']
    return accuracy1, baseline_flops, pruned_flops

In [26]:
lambda_min = 0.1
lambda_max = 1
lambda_seq_len = 11
lambda_seq = np.linspace(lambda_max, lambda_min, lambda_seq_len)
lambda_seq = np.concatenate([lambda_seq, [0]])

In [27]:
results = []

for ts in lambda_seq:
  print(ts)
  acc, baseline_flops, remaining_flops = fit_model(256, 256, ts)
  res = {
      'ts': ts,
      'acc' : acc,
      'baseline_flops' : baseline_flops,
      'remaining_flops' : remaining_flops
  }
  results.append(res)



1
OTO graph constructor
graph build
Setup HESSO
Target redundant groups per period:  [225]
Ep: 0, loss: 2.30, norm_all:3.50, grp_sparsity: 0.98, acc1: 0.1000, norm_import: 3.50, norm_redund: 0.00, num_grp_import: 4, num_grp_redund: 222
Ep: 1, loss: 2.30, norm_all:3.50, grp_sparsity: 0.98, acc1: 0.1000, norm_import: 3.50, norm_redund: 0.00, num_grp_import: 4, num_grp_redund: 222
Ep: 2, loss: 2.30, norm_all:3.51, grp_sparsity: 0.98, acc1: 0.1000, norm_import: 3.51, norm_redund: 0.00, num_grp_import: 4, num_grp_redund: 222
Ep: 3, loss: 2.26, norm_all:9.40, grp_sparsity: 0.98, acc1: 0.2224, norm_import: 9.40, norm_redund: 0.00, num_grp_import: 4, num_grp_redund: 222
Ep: 4, loss: 2.19, norm_all:10.83, grp_sparsity: 0.98, acc1: 0.2640, norm_import: 10.83, norm_redund: 0.00, num_grp_import: 4, num_grp_redund: 222
Parameter containing:
tensor([[[[-0.8504, -0.6980, -0.4923, -0.5247, -0.5386],
          [-0.6806, -0.7957, -0.8564, -0.5661, -0.4919],
          [-0.8030, -0.7573, -0.7812, -0.6017,

  non_zero_weights = count_nonzero_weights(torch.load(oto.compressed_model_path))


In [None]:
res = pd.DataFrame(results)

In [None]:
print(res)

In [None]:
res.to_csv('results_oto.csv')

In [None]:
from google.colab import files

files.download('results_oto.csv')