In [2]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('./PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL

In [12]:
device = torch.device('cpu')

sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 100


double_precision = torch.double

class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        pass

    def forward(self, x):
        x = torch.reshape(x,(batch_size,sequence_length,input_size))
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
        out, hidden = self.rnn(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out[:, -1, :])
        return out

torch.manual_seed(0)

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

train_data = datasets.MNIST(
    root = './data/mnist',
    train = True,
    transform = ToTensor(),
    download = True,
)

loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=1),
}

inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

In [11]:
def user_fn(model,inputs,labels):
    # objective function
    logits = model(inputs)
    criterion = nn.CrossEntropyLoss()
    f = criterion(logits, labels)

    A = list(model.parameters())[1]

    # inequality constraint
    ci = None

    # equality constraint
    # special orthogonal group

    # ce = pygransoStruct()
    ce = None

    c1_vec = (A.T @ A
              - torch.eye(hidden_size)
              .to(device=device, dtype=double_precision)
             ).reshape(1,-1)

    # ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints
    constraint_violation = torch.linalg.vector_norm(c1_vec, 1) # TODO: normalize to account for dimension
    mu = 10 # TODO: properly maintain/update mu
    f += mu*constraint_violation

    # ce.c2 = torch.det(A) - 1

    # ce = None

    return [f,ci,ce]

comb_fn = lambda model : user_fn(model,inputs,labels)

In [13]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 1e-3
opts.viol_eq_tol = 1e-4
# opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 50
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True

opts.mu0 = 1

In [14]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))

Initial acc = 10.00%


In [15]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))



[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
[0m[33m║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
[0m[33m║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
[0m[33m║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
[0m[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
[0m══════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                          ║ 
Version 1.2.0                                                                                 ║ 
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang               ║ 
════════════════════════════════════════════════════════════════════════════

In [16]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Finl acc = {:.2f}%".format((100 * correct/len(inputs))))

Finl acc = 79.00%
