# Orthogonal RNN

Train Orthogonal RNN for MNIST classification based on [this Paper](https://arxiv.org/pdf/1901.08428.pdf)

NOTE: this example is still under development. 

## Problem Description

For each element in the input sequence, each layer computes the following function:
$$h_t=\tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_hh)$$

where $h_{t}$ is the hidden state at time $t$, and $h_{t-1}$ is the hidden state of the previous layer at time $t-1$ or the initial hidden state at time $o$. 

For each layer, we have the orthogonal constraint:
$$ W_{hh}^T W_{hh} = I $$

## Modules Importing
Import all necessary modules and add PyGRANSO src folder to system path. 

In [3]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('.')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct 
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL

In [4]:
import torch

# fix the random seed
torch.manual_seed(55272025)

w = 8

## Data Initialization 
Specify torch device, neural network architecture, and generate data.

NOTE: please specify path for downloading data.

Use GPU for this problem. If no cuda device available, please set *device = torch.device('cpu')*

In [5]:
device = torch.device('cpu')

sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 1024


double_precision = torch.double

class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        pass
    
    def forward(self, x):
        x = torch.reshape(x,(batch_size,sequence_length,input_size))
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
        out, hidden = self.rnn(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out[:, -1, :])
        return out
    
torch.manual_seed(0)

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

train_data = datasets.MNIST(
    root = './examples/data/mnist',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
) 
test_data = datasets.MNIST(
    root = './examples/data/mnist',
    train = False,                         
    transform = ToTensor(), 
    download = True,            
) 

loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                        batch_size=batch_size, 
                                        shuffle=True, 
                                        num_workers=1),
    'test' : torch.utils.data.DataLoader(test_data, 
                                        batch_size=batch_size, 
                                        shuffle=True, 
                                        num_workers=1),
}

inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

## Function Set-Up

Encode the optimization variables, and objective and constraint functions.

Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.

In [6]:
def user_fn(model,inputs,labels):
    # objective function    
    logits = model(inputs)
    criterion = nn.CrossEntropyLoss()
    f = criterion(logits, labels)

    A = list(model.parameters())[1]

    # inequality constraint
    ci = None

    # equality constraint 
    # special orthogonal group
    
    ce = pygransoStruct()

    c1_vec = (A.T @ A 
              - torch.eye(hidden_size)
              .to(device=device, dtype=double_precision)
             ).reshape(1,-1)
    
    ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints
    # ce.c2 = torch.det(A) - 1

    # ce = None

    return [f,ci,ce]

comb_fn = lambda model : user_fn(model,inputs,labels)

## User Options
Specify user-defined options for PyGRANSO 

In [7]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 1e-3
opts.viol_eq_tol = 1e-4
# opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 50
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True

opts.mu0 = 1

## Initial Test 
Check initial accuracy of the RNN model

In [8]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))  

Initial acc = 8.79%


## Main Algorithm

In [10]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))



[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
[0m[33m║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
[0m[33m║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
[0m[33m║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
[0m[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║ 
Version 1.2.0                                                                                                    ║ 
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║ 


## Train Accuracy

In [11]:
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))  
print("final feasibility = {}".format(soln.final.tve))

Final acc = 100.00%
final feasibility = 0.0004830002235291609


## Test Accuracy

In [12]:
inputs, labels = next(iter(loaders['test']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final test acc = {:.2f}%".format((100 * correct/len(inputs))))  

Final test acc = 74.61%


# Exact Penalty with Adam

In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn.functional import one_hot

In [10]:
inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)

In [11]:
val_dataloader = loaders['test']

In [12]:
def f(model, inputs, labels):
    logits = model(inputs)
    loss_fn = nn.CrossEntropyLoss()
    return loss_fn(logits, labels)

def penalty(model):
    A = list(model.parameters())[1]
    
    # print(A)
    
    return torch.norm(A.T @ A - torch.eye(hidden_size), p=1)

def phi1(model, mu):
    return f(model, inputs, labels) + mu * penalty(model)

In [13]:
def train_loop(model, mu, optimizer):
    model.train()
    
    loss = phi1(model, mu)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

def val_loop(dataloader, model):
    model.eval()
    # size = len(dataloader.dataset)
    size = batch_size
    correct = 0

    with torch.no_grad():
        # for inputs, labels in dataloader:
        inputs, labels = next(iter(loaders['test']))
        inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)
        logits = model(inputs)
        _, predicted = torch.max(logits.data, 1)
        correct += (predicted == labels).sum().item()
    
    correct /= size
    print(f"Error: \n Accuracy: {(100*correct):>0.1f}% \n")
    return 100*correct

### 100 inner epochs

In [33]:
def exact_penalty_with_adam(mu_rho, mu_eps):
    # res_model = res_loss = res_accuracy = res_label_smoothing = res_optimizer = None
    global model
    model.train()
    
    mu = torch.tensor([1.], dtype=double_precision)

    for iteration in range(1000):
        print("Iter", iteration)
        
        # Adam
        optimizer = torch.optim.Adam(model.parameters())
        
        # prev_accuracy = None
        for t in range(100):
            # print(f"Epoch {t+1}\n-------------------------------")
            train_loop(model, mu, optimizer)
            # val_accuracy = val_loop(val_dataloader, model)
            # if prev_accuracy is not None and val_accuracy > prev_accuracy:
            #     break
            # prev_accuracy = val_accuracy
        
        # Exact penalty update
        h = penalty(model)
        print("Objective:", f(model, inputs, labels))
        val_accuracy = val_loop(val_dataloader, model)
        print("Val accuracy:", val_accuracy)
        print("Penalty parameter:", mu)
        print("Penalty:", h)
        if h < 1e-3:  # if h(xk ) ≤ τ
            break

        # Choose new penalty parameter µk+1 > µk ;
        if mu * h > mu_eps:
            mu *= mu_rho

        # Choose new starting point (stay as optimal x1, x2)

        print()

Hits around 64% acc at iter ~20

In [34]:
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
exact_penalty_with_adam(mu_rho=1.1, mu_eps=1e-5)

Iter 0
Objective: tensor(2.2073, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 19.1% 

Val accuracy: 19.140625
Penalty parameter: tensor([1.], dtype=torch.float64)
Penalty: tensor(27.6102, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

Iter 1
Objective: tensor(2.0059, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 34.5% 

Val accuracy: 34.47265625
Penalty parameter: tensor([1.1000], dtype=torch.float64)
Penalty: tensor(21.4146, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

Iter 2
Objective: tensor(1.8096, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 34.6% 

Val accuracy: 34.5703125
Penalty parameter: tensor([1.2100], dtype=torch.float64)
Penalty: tensor(17.3110, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

Iter 3
Objective: tensor(1.6743, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 38.5% 

Val accuracy: 38.4765625
Penalty parameter: tensor([1.3310], dtype=tor

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jeffreyhu/Desktop/s25/csci-5527/pygranso/.venv/lib/python3.12/site-packages/torch/__init__.py", line 405, in <module>
    from torch._C import *  # noqa: F403
    ^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 463, in _lock_unlock_module
KeyboardInterrupt
Traceback (most recent call last):
  File "/Users/jeffreyhu/Desktop/s25/csci-5527/pygranso/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3579, in run_code
    exec(code_obj, self.user_

In [35]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))  

Final acc = 83.98%


### 1000 inner epochs

In [25]:
def exact_penalty_with_adam(mu_rho, mu_eps):
    # res_model = res_loss = res_accuracy = res_label_smoothing = res_optimizer = None
    global model
    model.train()
    
    mu = torch.tensor([1.], dtype=double_precision)

    for iteration in range(1000):
        print("Iter", iteration)
        
        # Adam
        optimizer = torch.optim.Adam(model.parameters())
        
        # prev_accuracy = None
        for t in range(1000):
            # print(f"Epoch {t+1}\n-------------------------------")
            train_loop(model, mu, optimizer)
            # val_accuracy = val_loop(val_dataloader, model)
            # if prev_accuracy is not None and val_accuracy > prev_accuracy:
            #     break
            # prev_accuracy = val_accuracy
            if t % 200 == 0:
                print("Epoch", t)
        
        # Exact penalty update
        h = penalty(model)
        print("Objective:", f(model, inputs, labels))
        val_accuracy = val_loop(val_dataloader, model)
        print("Val accuracy:", val_accuracy)
        print("Penalty parameter:", mu)
        print("Penalty:", h)
        if h < 1e-3:  # if h(xk ) ≤ τ
            break

        # Choose new penalty parameter µk+1 > µk ;
        if mu * h > mu_eps:
            mu *= mu_rho

        # Choose new starting point (stay as optimal x1, x2)

        print()

In [26]:
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
exact_penalty_with_adam(mu_rho=1.1, mu_eps=1e-5)

Iter 0
Epoch 0
Epoch 200
Epoch 400
Epoch 600
Epoch 800
Objective: tensor(1.1073, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 54.4% 

Val accuracy: 54.39453125
Penalty parameter: tensor([1.], dtype=torch.float64)
Penalty: tensor(2.1130, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

Iter 1
Epoch 0
Epoch 200
Epoch 400
Epoch 600
Epoch 800
Objective: tensor(0.6417, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 59.6% 

Val accuracy: 59.5703125
Penalty parameter: tensor([1.1000], dtype=torch.float64)
Penalty: tensor(0.5604, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

Iter 2
Epoch 0
Epoch 200
Epoch 400
Epoch 600
Epoch 800
Objective: tensor(0.4157, dtype=torch.float64, grad_fn=<NllLossBackward0>)
Error: 
 Accuracy: 61.6% 

Val accuracy: 61.62109375
Penalty parameter: tensor([1.2100], dtype=torch.float64)
Penalty: tensor(0.2139, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

Iter 3
Epoch 0
Epoch 200
Epoch 400
E

KeyboardInterrupt: 

In [29]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))  

Final acc = 100.00%


Adam overfits like crazy with many inner loops