In [None]:
from vgg16 import vgg16
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm

In [None]:
class PruneConfig():
    def __init__(self):
        self.batch_size = 64
        self.epochs = 100
        self.lr = 0.01
        self.cuda = False
        self.seed = 42
        self.log_rate = 10
        self.log_file = "log.txt"
        self.sensitivity = 2
        self.debug = True
cfg = PruneConfig()

In [None]:
torch.manual_seed(cfg.seed)

In [None]:
if cfg.cuda:
    print("Using CUDA")
    torch.cuda.manual_seed(cfg.seed)
else:
    print("No CUDA")
kwargs = {'num_workers': 5, 'pin_memory': True} if cfg.cuda else {}

### Use Toy MNIST Data
**Pad to 224x244x1 since VGG16 originally takes in images of those size, so essentially this is just really really bad toy data**

In [None]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True,
                                                         transform=transforms.Compose([
                                                             transforms.Pad(98),
                                                             transforms.ToTensor(),
                                                             transforms.Normalize((0.1307,),(0.3081,))
                                                         ])),
                                          batch_size=cfg.batch_size,
                                          shuffle=True,
                                          **kwargs)

test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, download=True,
                                                         transform=transforms.Compose([
                                                             transforms.ToTensor(),
                                                             transforms.Normalize((0.1307,),(0.3081,))
                                                         ])),
                                          batch_size=cfg.batch_size,
                                          shuffle=False,
                                          **kwargs)

In [None]:
dev = torch.device("cuda" if cfg.cuda else 'cpu')
model = vgg16(pretrained=True, mask=True, debug=cfg.debug, in_channels=3).to(dev)

In [None]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0.0001)
optim_state_dict = optimizer.state_dict()
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train(epochs):
    model.train()
    tmp_loss = []
    for epoch_i in range(epochs):
        loss = 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for batch_i, (x_in, y_in) in pbar:
            x_in = torch.cat([x_in, x_in, x_in], dim=1)
            x_in, y_in = x_in.to(dev), y_in.to(dev)
            optimizer.zero_grad()
            output = model(x_in)
            loss = criterion(output, y_in)
            curr_loss = loss.item() / cfg.batch_size
            loss.backward()
            tmp_loss.append(curr_loss)
            
            # zero out pruned connections
            for name, p in model.named_parameters():
                if "mask" in name:
                    continue
                tensor = p.data.cpu().numpy()
                grad_tensor = p.grad.data.cpu().numpy()
                grad_tensor = np.where(tensor==0, 0, grad_tensor)
                p.grad.data = torch.from_numpy(grad_tensor).to(dev)
            optimizer.step()
            if batch_i % cfg.log_rate == 0:
                done = batch_i * len(x_in)
                percentage = 100. * batch_i / len(train_loader)
                avg_loss = sum(tmp_loss)/len(tmp_loss)
                tmp_loss = []
                pbar.set_description(f"Train Epoch: {epoch_i} [{done:5}/{len(train_loader.dataset)} ({percentage:3.0f}%)] Loss: {avg_loss:.6f}")


In [None]:
train(cfg.epochs)