In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from tqdm import tqdm

In [None]:
batch_size = 50
test_batch_size = 100
epochs = 10
lr = 0.01
lam = 1
no_cuda = False
seed = 42
log_interval = 10
log = 'log.txt'
sensitivity = 2
momentum = 0.5

In [None]:
class Net(nn.Module):
    def __init__(self, mask=False):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

In [None]:
def prox_L1(x, l):
    return torch.sign(x) * torch.max(x.abs() - l, torch.zeros_like(x))

In [None]:
x = torch.tensor([-1.,0.05,0])
prox_L1(x, 0.)

In [None]:
def prox_Fabian_numpy(vbar, wbar, l):
    def f(v, w, vbar, wbar, l):
        return 0.5*np.linalg.norm(v - vbar)**2 + 0.5*np.linalg.norm(w - wbar)**2 + l*np.sum(np.tensordot(v, w, axes=0))
    signs_v = np.sign(vbar)
    signs_w = np.sign(wbar)
    vbar = np.abs(vbar)
    wbar = np.abs(wbar)
    sorting_permutation_v = np.argsort(vbar)
    vbar = vbar[sorting_permutation_v]
    sorting_permutation_w = np.argsort(wbar)
    wbar = wbar[sorting_permutation_w]

    t = 1 + 1 / l**2 # sparsity threshold
    f_best = f(0,0, wbar, vbar, l)
    v_best = 0
    w_best = 0
    sum_w = 0
    w = np.zeros(len(wbar))
    for s_w in range(1, len(wbar)+1):
        sum_w += wbar[-s_w]
        sum_v = 0
        s_v_max = min(len(vbar), t / s_w)
        v = np.zeros(len(vbar))
        for s_v in range(1, s_v_max+1):
            sum_v += vbar[-s_v]
            v[-1] = ((1 - l**2 * s_v * s_w) * vbar[-1] + l**2 * s_w * sum_v - l * sum_w) / (1 - l**2 * s_v * s_w)
            v[-s_v:-1] = vbar[-s_v:-1] - vbar[-1] + v[-1]
            w[-s_w:] = wbar[-s_w:] - l * np.sum(v)
            if w[-s_w] and v[-s_v] >= 0:
                fval = f(v,w,vbar,wbar,l)
                if fval < f_best:
                    v_best = v
                    w_best = w
                    f_best = fval
            else:
                #print("Infeasible")
                break
        if s_v == 1:
            break
        print("Final s_v: {}".format(s_v))
    print("Final s_w: {}".format(s_w)) 
    print("f_best: {}".format(f_best))
    if f(0, wbar, vbar, wbar, l) < f_best:
        v_best = np.zeros(len(vbar))
        w_best = wbar
        f_best = f(0, wbar, vbar, wbar, l)
    if f(vbar, 0, vbar, wbar, l) < f_best:
        v_best = vbar
        w_best = np.zeros(len(wbar))
        f_best = f(vbar, 0, vbar, wbar, l)
    print("f_best: {}".format(f_best))
    inverse_permutation_v = np.arange(len(sorting_permutation_v))[np.argsort(sorting_permutation_v)]
    inverse_permutation_w = np.arange(len(sorting_permutation_w))[np.argsort(sorting_permutation_w)]
    v_best = signs_v * v_best[inverse_permutation_v]
    w_best = signs_w * w_best[inverse_permutation_w]
    return v_best, w_best, sorting_permutation_v, sorting_permutation_w

In [None]:
def prox_Fabian_1d(vbar, wbar, l): # tensor version
    def f(v, w, vbar, wbar, l):
        return 0.5*torch.norm(v - vbar)**2 + 0.5*torch.norm(w - wbar)**2 + l*torch.sum(torch.tensordot(v, w, dims=0))
    signs_v = vbar.sign()
    signs_w = wbar.sign()
    vbar = vbar.abs()
    wbar = wbar.abs()
    sorting_permutation_v = torch.argsort(vbar)
    vbar = vbar[sorting_permutation_v]
    sorting_permutation_w = torch.argsort(wbar)
    wbar = wbar[sorting_permutation_w]

    t = 1 + 1 / l**2 # sparsity threshold
    f_best = 0.5*torch.norm(vbar)**2 + 0.5*torch.norm(wbar)**2
    v_best = torch.zeros(len(wbar))
    w_best = torch.zeros(len(wbar))
    sum_v = 0
    sum_w = 0
    for s_w in range(1, len(wbar)+1):
        #sum_w += wbar[-s_w] # = torch.sum(wbar[-s_w:])
        w = torch.zeros(len(wbar))
        sum_v = 0
        if False: #s_w % 100 == 0:
            print(s_w)
        for s_v in range(1, len(vbar)+1):
            if t >= s_v*s_w:
                #sum_v += vbar[-s_v] # = torch.sum(vbar[-s_v:])
                v = torch.zeros(len(vbar))
                #v[-1] = ((1 - l**2 * s_v * s_w) * vbar[-1] + l**2 * s_w * sum_v - l * sum_w) / (1 - l**2 * s_v * s_w)
                #v[-s_v:-1] = vbar[-s_v:-1] - vbar[-1] + v[-1]
                #w[-s_w:] = wbar[-s_w:] - l * torch.sum(v)
                if w[-s_w] >= 0: # and (v >= 0).all():
                    fval = 0 # f(v,w,vbar,wbar,l)
                    if fval < f_best:
                        #v_best = v
                        #w_best = w
                        f_best = fval
            else:
                continue
    if 0.5*torch.norm(vbar)**2 < f_best:
        v_best = torch.zeros(len(vbar))
        w_best = wbar
        f_best = 0.5*torch.norm(vbar)**2
    if 0.5*torch.norm(wbar)**2 < f_best:
        v_best = vbar
        w_best = torch.zeros(len(wbar))
        f_best = 0.5*torch.norm(wbar)**2

    v_best = signs_v * v_best[sorting_permutation_v]
    w_best = signs_w * w_best[sorting_permutation_w]
    return v_best, w_best

In [None]:
def prox_Fabian(vbar, wbar, l):
    v = torch.zeros(vbar.shape)
    w = torch.zeros(wbar.shape)
    h = vbar.shape[1]
    for i in range(h):
        print(i)
        #v[:,i], w[i,:] = prox_Fabian_1d(vbar[:,i], wbar[i,:], l)
        v, w = prox_Fabian_numpy(np.array(vbar[:,i]), np.array(wbar[i,:]), l)
    return v, w

In [None]:
from torch.optim import Optimizer


class SGD_L1(Optimizer):
    r"""Implements stochastic gradient descent with L1 regularization.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    """

    def __init__(self, params, lr, lam=0):
        self.lr = lr
        self.lam = lam

        defaults = dict(lr=lr, lam=lam)
        super(SGD_L1, self).__init__(params, defaults)

    #def __setstate__(self, state):
    #    super(SGD, self).__setstate__(state)
    #    for group in self.param_groups:
    #        group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                
                p.data = prox_L1(p.data -self.lr*d_p, self.lr*self.lam)

        return loss


In [None]:
class SGD_Fabian_pen(Optimizer):
    r"""Implements stochastic gradient descent with Fabian's regularization.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    """

    def __init__(self, params, lr, lam=0):
        self.lr = lr
        self.lam = lam

        defaults = dict(lr=lr, lam=lam)
        super(SGD_Fabian_pen, self).__init__(params, defaults)

    #def __setstate__(self, state):
    #    super(SGD, self).__setstate__(state)
    #    for group in self.param_groups:
    #        group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                p.data = p.data - self.lr*d_p
            v, w = prox_Fabian(optimizer.param_groups[0]['params'][2].data, optimizer.param_groups[0]['params'][0].data, self.lr*self.lam)
            optimizer.param_groups[0]['params'][2].data = v
            optimizer.param_groups[0]['params'][0].data = w

        return loss



In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
# Control Seed
torch.manual_seed(seed)

# Select Device
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else 'cpu')
if use_cuda:
    print("Using CUDA!")
    torch.cuda.manual_seed(seed)
else:
    print('Not using CUDA!!!')

# Loader
kwargs = {'num_workers': 5, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=False, **kwargs)


# Define which model to use
model = Net().to(device)

print(model)

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0)
#optimizer = SGD_L1(model.parameters(), lr=lr, lam=lam)
#optimizer = SGD_Fabian_pen(model.parameters(), lr=lr, lam=lam)

for epoch in range(1, epochs + 1):
    print("Number of non-zero elements in W1: {}".format(len(torch.nonzero(list(model.fc1.parameters())[0].data))))
    print("Number of non-zero elements in v: {}".format(len(torch.nonzero(list(model.fc1.parameters())[1].data))))
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

In [None]:
i = 2
lam = 0.2
v,w,p_v,p_w = prox_Fabian_numpy(np.array(optimizer.param_groups[0]['params'][2].data[:,i]), np.array(optimizer.param_groups[0]['params'][0].data[i,:]), lam*lr)
print(v)
print(w)

In [None]:
lam = 0.001
model = Net().to(device)
#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
optimizer = SGD_L1(model.parameters(), lr=lr, lam=lam)

for epoch in range(1, epochs + 1):
    print("Number of non-zero elements in W1: {}".format(len(torch.nonzero(list(model.fc1.parameters())[0].data))))
    print("Number of non-zero elements in v: {}".format(len(torch.nonzero(list(model.fc1.parameters())[1].data))))
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

In [None]:
for i in range(300):
    if i not in list(model.fc1.parameters())[1].data.nonzero():
        if float(list(model.fc1.parameters())[0].data[i,:].abs().sum()) > 1e-8:
            print(float(list(model.fc1.parameters())[0].data[i,:].abs().sum()))
            