In [10]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging

# Setup paths
PROJECT_ROOT = os.path.join(os.path.dirname(os.path.dirname(os.getcwd())), 'filter_sparsity')
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)


import tensorflow as tf
import torch
from torchvision import datasets, transforms
from oto_modelling.pytorch_lenet5 import LeNet5
from only_train_once.only_train_once import OTO

In [12]:
def get_loaders(batch_size, test_batch_size): 
    train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('./data.fashionMNIST', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        #transforms.RandomCrop(32),
                        #transforms.RandomHorizontalFlip(),
                        #transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ])),
    batch_size=batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('./data.fashionMNIST', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])),
    batch_size=test_batch_size, shuffle=True)
    
    return train_loader, test_loader

def check_accuracy(model, testloader, two_input=False):
    correct1 = 0
    correct5 = 0
    total = 0
    model = model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            if two_input:
                y_pred = model.forward(X, X)
            else:
                y_pred = model.forward(X)
            total += y.size(0)

            prec1, prec5 = accuracy_topk(y_pred.data, y, topk=(1, 5))
            
            correct1 += prec1.item()
            correct5 += prec5.item()

    model = model.train()
    accuracy1 = correct1 / total
    accuracy5 = correct5 / total
    return accuracy1, accuracy5


In [13]:
def fit_model(batch_size, test_batch_size): 
    train_loader, test_loader = get_loaders(batch_size, test_batch_size)
    
    model = LeNet5()
    dummy_input = torch.rand(1, 1, 32, 32)
    oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())
    
    optimizer = oto.hesso(
        variant='sgd', 
        lr=0.1, 
        weight_decay=1e-4,
        target_group_sparsity=0.7,
        start_pruning_step=10 * len(train_loader), 
        pruning_periods=10,
        pruning_steps=10 * len(train_loader)
    )
    
    max_epoch = 100
    model.cuda()
    criterion = torch.nn.CrossEntropyLoss()
    # Every 50 epochs, decay lr by 10.0
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 

    for epoch in range(max_epoch):
        f_avg_val = 0.0
        model.train()
        lr_scheduler.step()
        for X, y in train_loader:
            X = X.cuda()
            y = y.cuda()
            y_pred = model.forward(X)
            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            f_avg_val += f
            optimizer.step()
        group_sparsity, param_norm, _ = optimizer.compute_group_sparsity_param_norm()
        norm_important, norm_redundant, num_grps_important, num_grps_redundant = optimizer.compute_norm_groups()
        # TODO: change accuracy metric or check that is does the correct thing
        accuracy1, accuracy5 = check_accuracy(model, test_loader)
        f_avg_val = f_avg_val.cpu().item() / len(train_loader)
        
        print("Ep: {ep}, loss: {f:.2f}, norm_all:{param_norm:.2f}, grp_sparsity: {gs:.2f}, acc1: {acc1:.4f}, norm_import: {norm_import:.2f}, norm_redund: {norm_redund:.2f}, num_grp_import: {num_grps_import}, num_grp_redund: {num_grps_redund}"\
            .format(ep=epoch, f=f_avg_val, param_norm=param_norm, gs=group_sparsity, acc1=accuracy1,\
            norm_import=norm_important, norm_redund=norm_redundant, num_grps_import=num_grps_important, num_grps_redund=num_grps_redundant
            ))