In [1]:
import torch
import math

# fix the random seed
torch.manual_seed(55272025)

dtype = torch.float
w = 8

## Rosenbrock

In [2]:
# Rosenbrock with equality constraint
def f(x1, x2):
    return w * torch.abs(x1 ** 2 - x2) + (1 - x1) ** 2

def penalty(x1, x2):
    return torch.abs(math.sqrt(2) * x1 - 1) + torch.abs(2 * x2 - 1)

def phi1(x1, x2, mu):
    return f(x1, x2) + mu * penalty(x1, x2)

In [3]:
def backtracking_line_search(t_rho, c, x_eps, mu_rho, mu_eps):
    x1 = torch.randn(1, 1, dtype=dtype, requires_grad=True)
    x2 = torch.randn(1, 1, dtype=dtype, requires_grad=True)
    mu = torch.tensor([1.], dtype=dtype)

    for iteration in range(1000):
        print("Iter", iteration)

        # Find an approximate minimizer xk of phi1(x; µk), starting at xk^s
        for epoch in range(100):
            phi1_x_mu = phi1(x1, x2, mu)
            phi1_x_mu.backward()

            # if epoch % 100 == 99:
            print("Epoch", epoch)
            print("Objective + penalty:", phi1_x_mu.item())
            print("Vars:", x1, x2, mu)
            print("Grads:", torch.norm(x1.grad), torch.norm(x2.grad))

            with torch.no_grad():
                if torch.norm(x1.grad) ** 2 + torch.norm(x2.grad) ** 2 <= x_eps:
                    break

                t = 1
                while t >= 1e-8 and phi1(x1 - t * x1.grad, x2 - t * x2.grad, mu) - phi1_x_mu >= -c * t * (torch.norm(x1.grad) ** 2 + torch.norm(x2.grad) ** 2):
                    t *= t_rho
                print("t:", t)
                if t < 1e-8:
                    break

                x1 -= t * x1.grad
                x1.grad = None
                x2 -= t * x2.grad
                x2.grad = None

        h = penalty(x1, x2)
        print("Penalty", h)
        if h < 1e-5:  # if h(xk ) ≤ τ
            break

        # Choose new penalty parameter µk+1 > µk ;
        if mu * h > mu_eps:
            mu *= mu_rho

        # Choose new starting point (stay as optimal x1, x2)

        print()

    return x1, x2

In [4]:
backtracking_line_search(t_rho=0.8, c=0, x_eps=1e-5, mu_rho=1.1, mu_eps=1e-5)

Iter 0
Epoch 0
Objective + penalty: 9.709888458251953
Vars: tensor([[-1.0277]], requires_grad=True) tensor([[0.7176]], requires_grad=True) tensor([1.])
Grads: tensor(21.9136) tensor(6.)
t: 0.10737418240000006
Epoch 1
Objective + penalty: 5.858539581298828
Vars: tensor([[1.3252]], requires_grad=True) tensor([[1.3618]], requires_grad=True) tensor([1.])
Grads: tensor(23.2680) tensor(6.)
t: 0.011529215046068483
Epoch 2
Objective + penalty: 4.8707075119018555
Vars: tensor([[1.0570]], requires_grad=True) tensor([[1.4310]], requires_grad=True) tensor([1.])
Grads: tensor(15.3831) tensor(10.)
t: 0.011529215046068483
Epoch 3
Objective + penalty: 4.094383239746094
Vars: tensor([[1.2343]], requires_grad=True) tensor([[1.3157]], requires_grad=True) tensor([1.])
Grads: tensor(21.6317) tensor(6.)
t: 0.00737869762948383
Epoch 4
Objective + penalty: 3.8853907585144043
Vars: tensor([[1.0747]], requires_grad=True) tensor([[1.3600]], requires_grad=True) tensor([1.])
Grads: tensor(15.6315) tensor(10.)
t: 0

(tensor([[0.7071]], requires_grad=True),
 tensor([[0.5000]], requires_grad=True))

## Sphere Manifold

In [5]:
device = torch.device('cpu')
torch.manual_seed(0)
n = 300
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
dtype = torch.double
A = torch.randn((n,n)).to(device=device, dtype=dtype)
A = .5*(A+A.T)

In [6]:
def f(x):
    return -x.T @ A @ x

def penalty(x):
    return torch.abs(x.T @ x - 1)

def phi1(x, mu):
    return f(x) + mu * penalty(x)

In [7]:
def backtracking_line_search(t_rho, c, x_eps, mu_rho, mu_eps):
    x = torch.randn(n, 1, dtype=dtype, requires_grad=True)
    with torch.no_grad():
        x /= torch.norm(x)
    mu = torch.tensor([100.], dtype=dtype)

    for iteration in range(1000):
        print("Iter", iteration)

        # Find an approximate minimizer xk of phi1(x; µk), starting at xk^s
        for epoch in range(1000):
            phi1_x_mu = phi1(x, mu)
            phi1_x_mu.backward()

            # if epoch % 100 == 99:
            # print("Epoch", epoch)
            # print("Objective + penalty:", phi1_x_mu.item())
            # print("Vars:", x, mu)
            # print("Grads:", torch.norm(x.grad))

            with torch.no_grad():
                if torch.norm(x.grad) ** 2 <= x_eps:
                    break

                t = 1
                while t >= 1e-8 and phi1(x - t * x.grad, mu) - phi1_x_mu >= -c * t * torch.norm(x.grad) ** 2:
                    t *= t_rho
                print("t:", t, torch.norm(x.grad))
                # if t < 1e-8:
                #     break

                x -= t * x.grad
                x.grad = None

        h = penalty(x)
        print("Objective", phi1_x_mu)
        print("Penalty", h)
        if h < 1e-5:  # if h(xk ) ≤ τ
            break

        # Choose new penalty parameter µk+1 > µk ;
        if mu * h > mu_eps:
            mu *= mu_rho

        # Choose new starting point (stay as optimal x1, x2)

        print()

    return x

In [8]:
x = backtracking_line_search(t_rho=0.8, c=0, x_eps=1e-5, mu_rho=1.1, mu_eps=1e-5)

Iter 0
t: 9.046256971665371e-09 tensor(200.7635, dtype=torch.float64)
t: 1.7668470647783922e-08 tensor(202.1281, dtype=torch.float64)
t: 1.7668470647783922e-08 tensor(200.7638, dtype=torch.float64)
t: 1.7668470647783922e-08 tensor(202.1281, dtype=torch.float64)
t: 1.7668470647783922e-08 tensor(200.7638, dtype=torch.float64)
t: 1.7668470647783922e-08 tensor(202.1282, dtype=torch.float64)
t: 1.7668470647783922e-08 tensor(200.7637, dtype=torch.float64)
t: 1.4134776518227139e-08 tensor(202.1282, dtype=torch.float64)
t: 9.046256971665371e-09 tensor(200.7635, dtype=torch.float64)
t: 9.046256971665371e-09 tensor(202.1284, dtype=torch.float64)
t: 1.1307821214581712e-08 tensor(200.7635, dtype=torch.float64)
t: 1.1307821214581712e-08 tensor(202.1284, dtype=torch.float64)
t: 1.1307821214581712e-08 tensor(200.7635, dtype=torch.float64)
t: 1.1307821214581712e-08 tensor(202.1284, dtype=torch.float64)
t: 1.1307821214581712e-08 tensor(200.7635, dtype=torch.float64)
t: 9.046256971665371e-09 tensor(202.

In [9]:
f(x)

tensor([[-0.3491]], dtype=torch.float64, grad_fn=<MmBackward0>)

## Sphere Manifold with Exact Penalty and PyGranSO

In [10]:
import time
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct

In [11]:
def exact_penalty_with_pygranso(mu_rho, mu_eps):
    x = torch.randn(n, 1, dtype=dtype)
    # with torch.no_grad():
    #     x /= torch.norm(x)
    mu = torch.tensor([10.], dtype=dtype)

    for iteration in range(1000):
        print("Iter", iteration)
        
        # PyGRANSO
        device = torch.device('cpu')
        
        # variables and corresponding dimensions.
        var_in = {"x": [n, 1]}
        
        def comb_fn(X_struct):
            x = X_struct.x
            
            # objective function
            phi1_x_mu = phi1(x, mu)
        
            # inequality constraint, matrix form
            ci = None
        
            # equality constraint 
            ce = None
        
            return [phi1_x_mu,ci,ce]
        
        opts = pygransoStruct()
        # option for switching QP solver. We only have osqp as the only qp solver in current version. Default is osqp
        # opts.QPsolver = 'osqp'
        
        # set an intial point
        # All the user-provided data (vector/matrix/tensor) must be in torch tensor format. 
        # As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
        # Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
        opts.x0 = x
        opts.torch_device = device
        opts.print_level = 0
        
        start = time.time()
        soln = pygranso(var_spec = var_in,combined_fn = comb_fn, user_opts = opts)
        end = time.time()
        print("Total Wall Time: {}s".format(end - start))
        x = soln.final.x
        
        # Exact penalty update
        
        h = penalty(x)
        print("Objective", f(x))
        print("Penalty", h)
        if h < 1e-5:  # if h(xk ) ≤ τ
            break

        # Choose new penalty parameter µk+1 > µk ;
        if mu * h > mu_eps:
            mu *= mu_rho

        # Choose new starting point (stay as optimal x1, x2)

        print()

    return x

In [12]:
x = exact_penalty_with_pygranso(mu_rho=1.1, mu_eps=1e-5)

Iter 0
Total Wall Time: 0.02765798568725586s
Objective tensor([[-3854.7148]], dtype=torch.float64)
Penalty tensor([[251.5613]], dtype=torch.float64)

Iter 1
Total Wall Time: 0.006993770599365234s
Objective tensor([[-3854.7148]], dtype=torch.float64)
Penalty tensor([[251.5613]], dtype=torch.float64)

Iter 2
Total Wall Time: 0.00447392463684082s
Objective tensor([[-3854.7148]], dtype=torch.float64)
Penalty tensor([[251.5613]], dtype=torch.float64)

Iter 3
Total Wall Time: 0.012392997741699219s
Objective tensor([[-44344.1748]], dtype=torch.float64)
Penalty tensor([[2970.1424]], dtype=torch.float64)

Iter 4
Total Wall Time: 0.04290604591369629s
Objective tensor([[-239787.0472]], dtype=torch.float64)
Penalty tensor([[11985.8505]], dtype=torch.float64)

Iter 5
Total Wall Time: 0.013826131820678711s
Objective tensor([[-427883.4310]], dtype=torch.float64)
Penalty tensor([[20964.0460]], dtype=torch.float64)

Iter 6
Total Wall Time: 0.028512954711914062s
Objective tensor([[-877667.0923]], dtype=

An interesting note: starting with `mu = 1` instead of `mu = 100` yields:
```
Iter 33
Total Wall Time: 0.006292819976806641s
Objective tensor([[-1.1599e+50]], dtype=torch.float64)
Penalty tensor([[4.9290e+48]], dtype=torch.float64)

Iter 34
Total Wall Time: 0.9548580646514893s
Objective tensor([[-24.2859]], dtype=torch.float64)
Penalty tensor([[4.8850e-14]], dtype=torch.float64)
```
which results in large objectives before the mu value is sufficient.

It also seems that starting with large `mu` may lead to a slightly worse objective (e.g. `mu=100` vs `mu=10000`). This may be because PyGRANSO is better at finding global optima to `f` if the penalty term contributes too much. See https://www.youtube.com/watch?v=7Dvz2HbSyM8 and the textbook about ill conditioning.

## Orthogonal RNN

In [13]:
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL

In [19]:
device = torch.device('cpu')

torch.set_default_device(device)
torch.set_default_dtype(dtype)

sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 100


double_precision = torch.double

class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        pass

    def forward(self, x):
        x = torch.reshape(x,(batch_size,sequence_length,input_size))
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
        out, hidden = self.rnn(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out[:, -1, :])
        return out

torch.manual_seed(0)

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

train_data, test_data = torch.utils.data.random_split(datasets.MNIST(
    root = './examples/data/mnist',
    train = True,
    transform = ToTensor(),
    download = True,
), [0.6, 0.4])

loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=1),
    'test' : torch.utils.data.DataLoader(test_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=1)
}

inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

In [20]:
def f(model, inputs, labels):
    logits = model(inputs)
    loss_fn = nn.CrossEntropyLoss()
    return loss_fn(logits, labels)

def penalty(model):
    A = list(model.parameters())[1]
    
    # print(A)
    
    return torch.norm(A.T @ A - torch.eye(hidden_size), p=1)

def phi1(model, mu):
    return f(model, inputs, labels) + mu * penalty(model)

In [21]:
def exact_penalty_with_pygranso(mu_rho, mu_eps):
    mu = torch.tensor([1.], dtype=dtype)

    for iteration in range(1000):
        print("Iter", iteration)
        
        # PyGRANSO
        device = torch.device('cpu')
        
        def comb_fn(model):
            # objective function
            phi1_x_mu = phi1(model, mu)
        
            # inequality constraint, matrix form
            ci = None
        
            # equality constraint 
            ce = None
        
            return [phi1_x_mu,ci,ce]
        
        opts = pygransoStruct()
        # option for switching QP solver. We only have osqp as the only qp solver in current version. Default is osqp
        # opts.QPsolver = 'osqp'
        
        # set an intial point
        # All the user-provided data (vector/matrix/tensor) must be in torch tensor format. 
        # As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
        # Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
        nvar = getNvarTorch(model.parameters())
        opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
        opts.torch_device = device
        opts.opt_tol = 1e-5
        opts.viol_eq_tol = 1e-5
        opts.print_level = 1
        opts.print_frequency = 50
        
        # opts.maxit = 1000 yields 80% acc but seems far from reaching stationarity
        #  900 ║  - │   -   ║  0.76603445858 ║   -  │   -  ║ QN │     6 │ 0.031250 ║     1 │ 0.246480   ║ 
        #  950 ║  - │   -   ║  0.70312720877 ║   -  │   -  ║ QN │     6 │ 0.031250 ║     1 │ 0.471985   ║ 
        # ═════╬════════════╬════════════════╬═════════════╬═══════════════════════╬════════════════════╣
        #      ║ Penalty Fn ║                ║  Violation  ║ <--- Line Search ---> ║ <- Stationarity -> ║ 
        # Iter ║ Mu │ Value ║    Objective   ║ Ineq │  Eq  ║ SD │ Evals │     t    ║ Grads │    Value   ║ 
        # ═════╬════════════╬════════════════╬═════════════╬═══════════════════════╬════════════════════╣
        # 1000 ║  - │   -   ║  0.62800096715 ║   -  │   -  ║ QN │     5 │ 0.062500 ║     1 │ 0.430504   ║ 
        opts.maxit = 3000
        
        start = time.time()
        soln = pygranso(var_spec = model,combined_fn = comb_fn, user_opts = opts)
        end = time.time()
        print("Total Wall Time: {}s".format(end - start))
        torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
        
        # Exact penalty update
        
        h = penalty(model)
        print("Objective:", f(model, inputs, labels))
        print("Penalty:", h)
        if h < 1e-3:  # if h(xk ) ≤ τ
            break

        # Choose new penalty parameter µk+1 > µk ;
        if mu * h > mu_eps:
            mu *= mu_rho

        # Choose new starting point (stay as optimal x1, x2)

        print()

    return model

In [22]:
model = exact_penalty_with_pygranso(mu_rho=1.1, mu_eps=1e-5)

Iter 0


[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
[0m[33m║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
[0m[33m║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
[0m[33m║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
[0m[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
[0m══════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                          ║ 
Version 1.2.0                                                                                 ║ 
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang               ║ 
═════════════════════════════════════════════════════════════════════

In [23]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))

Final acc = 100.00%


In [25]:
model.eval()

inputs, labels = next(iter(loaders['test']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))

Final acc = 25.00%


## Orthogonal RNN with PyGRANSO

In [27]:
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

RNN(
  (rnn): RNN(28, 30, batch_first=True)
  (fc): Linear(in_features=30, out_features=10, bias=True)
)

In [28]:
def user_fn(model,inputs,labels):
    # objective function    
    logits = model(inputs)
    criterion = nn.CrossEntropyLoss()
    f = criterion(logits, labels)

    A = list(model.parameters())[1]

    # inequality constraint
    ci = None

    # equality constraint 
    # special orthogonal group
    
    ce = pygransoStruct()

    c1_vec = (A.T @ A 
              - torch.eye(hidden_size)
              .to(device=device, dtype=double_precision)
             ).reshape(1,-1)
    
    ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints
    # ce.c2 = torch.det(A) - 1

    # ce = None

    return [f,ci,ce]

comb_fn = lambda model : user_fn(model,inputs,labels)

In [29]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 1e-3
opts.viol_eq_tol = 1e-4
# opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 50
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True

opts.mu0 = 1

In [30]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))  

Initial acc = 6.00%


In [31]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))



[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
[0m[33m║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
[0m[33m║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
[0m[33m║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
[0m[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║ 
Version 1.2.0                                                                                                    ║ 
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║ 


In [32]:
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))  
print("final feasibility = {}".format(soln.final.tve))

Final acc = 100.00%
final feasibility = 4.1614077503424446e-05


In [33]:
model.eval()

inputs, labels = next(iter(loaders['test']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))

Final acc = 43.00%
