In [None]:
import copy
import numpy as np
from torch.optim.optimizer import Optimizer
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import math
import time
import torch.nn.functional as F
from torch.utils.data import DataLoader
from NN_arch import LSTM_sMNIST, LeNet, FCP, Autoencoder
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 50, bias=False)
        self.fc2 = nn.Linear(50, 50, bias=False)
        self.fc3 = nn.Linear(50, 10, bias=False)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.softplus(self.fc1(x))
        x = F.softplus(self.fc2(x))
        return self.fc3(x)
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
class LBFGS_Brent(Optimizer):
    """L-BFGS optimizer with Brent's method for line search.
    
    Args:
        params (iterable): iterable of parameters to optimize
        lr (float, optional): learning rate (not used, kept for compatibility)
        max_history (int, optional): maximum number of correction pairs to store (default: 10)
        line_search_budget (int, optional): maximum function evaluations for line search (default: 6)
        tolerance (float, optional): convergence tolerance for line search (default: 1e-4)
    """
    
    def __init__(self, params, lr=1.0, max_history=10, line_search_budget=6, tolerance=1e-4):
        defaults = dict(lr=lr, max_history=max_history, 
                       line_search_budget=line_search_budget, 
                       tolerance=tolerance)
        super(LBFGS_Brent, self).__init__(params, defaults)
        
        # Initialize storage for s and y vectors
        self.state['s_list'] = []
        self.state['y_list'] = []
        self.state['prev_grad'] = None
        self.state['initialized'] = False
        
    def _gather_flat_params(self):
        """Gather all parameters into a single flat tensor."""
        views = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    view = p.data.new(p.data.numel()).zero_()
                else:
                    view = p.data.view(-1)
                views.append(view)
        return torch.cat(views, 0)
    
    def _gather_flat_grad(self):
        """Gather all gradients into a single flat tensor."""
        views = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    view = p.data.new(p.data.numel()).zero_()
                else:
                    view = p.grad.data.view(-1)
                views.append(view)
        return torch.cat(views, 0)
    
    def _set_flat_params(self, flat_params):
        """Set all parameters from a single flat tensor."""
        offset = 0
        for group in self.param_groups:
            for p in group['params']:
                numel = p.numel()
                p.data.copy_(flat_params[offset:offset + numel].view_as(p.data))
                offset += numel
    
    def _clone_params(self):
        """Clone current parameters."""
        return [p.clone(memory_format=torch.contiguous_format) for group in self.param_groups for p in group['params']]
    
    def _set_params(self, params_data):
        """Set parameters from a list of tensors."""
        for p, pdata in zip([p for group in self.param_groups for p in group['params']], params_data):
            p.copy_(pdata)
    
    def _brent_line_search(self, closure, weights, grad, cur_loss, budget):
        """Brent's method for line search.
        
        Args:
            closure: A closure that reevaluates the model and returns the loss
            weights: Current flat parameter vector
            grad: Search direction (note: we minimize in direction -grad)
            cur_loss: Current loss value
            budget: Maximum number of function evaluations
            
        Returns:
            step_size: Optimal step size
            new_loss: Loss at optimal step size
        """
        tolerance = self.param_groups[0]['tolerance']
        count_eval = 0
        k = 0
        a, u, b = 0, None, None
        f_a, f_u, f_b = cur_loss, None, None
        
        mag_grad = torch.norm(grad, p=2).item()
        normalized_grad = grad / mag_grad
        
        # Initial step
        new_step = 0.2 * 2**k * mag_grad
        new_weights = weights - new_step * normalized_grad
        self._set_flat_params(new_weights)
        
        with torch.no_grad():
            new_loss = closure()
        count_eval += 1
        
        # Bounding phase
        if new_loss < f_a:
            k += 1
            u, f_u = new_step, new_loss
            
            while f_b is None:
                k += 1
                new_step = 0.2 * 2**k * mag_grad
                new_weights = weights - new_step * normalized_grad
                self._set_flat_params(new_weights)
                
                with torch.no_grad():
                    new_loss = closure()
                count_eval += 1
                
                if new_loss < f_u:
                    a, f_a = u, f_u
                    u, f_u = new_step, new_loss
                    if count_eval == budget:
                        return u, f_u
                elif new_loss > f_u:
                    b, f_b = new_step, new_loss
                    if count_eval == budget:
                        return u, f_u
                else:
                    u = (a + new_step) / 2
                    new_weights = weights - u * normalized_grad
                    self._set_flat_params(new_weights)
                    with torch.no_grad():
                        new_loss = closure()
                    return u, new_loss
                    
        elif new_loss > f_a:
            b, f_b = new_step, new_loss
            
            while f_u is None:
                if abs(b - a) < tolerance or abs(f_a - f_b) < tolerance:
                    return b, f_b
                    
                new_step = a + 0.382 * (b - a)
                new_weights = weights - new_step * normalized_grad
                self._set_flat_params(new_weights)
                
                with torch.no_grad():
                    new_loss = closure()
                count_eval += 1
                
                if new_loss < f_a:
                    u, f_u = new_step, new_loss
                    if count_eval == budget:
                        return u, f_u
                elif new_loss > f_a:
                    b, f_b = new_step, new_loss
                    if count_eval == budget:
                        return b, f_b
                else:
                    u = (a + new_step) / 2
                    new_weights = weights - u * normalized_grad
                    self._set_flat_params(new_weights)
                    with torch.no_grad():
                        new_loss = closure()
                    return u, new_loss
        else:
            u = (a + new_step) / 2
            new_weights = weights - u * normalized_grad
            self._set_flat_params(new_weights)
            with torch.no_grad():
                new_loss = closure()
            return u, new_loss
        
        # Golden section search phase
        x, f_x = u, f_u
        
        while count_eval != budget:
            if abs(b - a) < tolerance or abs(f_a - f_b) < tolerance:
                return x, f_x
                
            d_golden = (b - a) * 0.382
            if (b - x) > (x - a):
                u = a + d_golden
            else:
                u = b - d_golden
                
            new_weights = weights - u * normalized_grad
            self._set_flat_params(new_weights)
            
            with torch.no_grad():
                f_u = closure()
            count_eval += 1
            
            if f_u < f_x:
                if u >= x:
                    a, f_a = x, f_x
                elif u < x:
                    b, f_b = x, f_x
                x, f_x = u, f_u
            elif f_u > f_x:
                if u < x:
                    a, f_a = u, f_u
                elif u > x:
                    b, f_b = u, f_u
            else:
                return x, f_x
                
        return x, f_x
    
    @torch.no_grad()
    def step(self, closure):
        """Performs a single optimization step.
        
        Args:
            closure (callable): A closure that reevaluates the model and returns the loss.
                               Must call backward() to compute gradients.
        
        Returns:
            loss: The loss value after the step
        """
        assert closure is not None, "L-BFGS requires a closure function"
        
        # Evaluate loss and compute gradients
        with torch.enable_grad():
            loss = closure()
        
        # Get hyperparameters
        max_history = self.param_groups[0]['max_history']
        budget = self.param_groups[0]['line_search_budget']
        
        # Gather flat gradients and weights
        flat_grad = self._gather_flat_grad()
        flat_weights = self._gather_flat_params()
        
        # Store backup of current parameters
        params_backup = self._clone_params()
        
        # Update s and y lists
        if self.state['initialized']:
            y_k = flat_grad - self.state['prev_grad']
            s_k = self.state['s_k']
            
            # Maintain history size
            if len(self.state['s_list']) == max_history:
                self.state['s_list'].pop(0)
                self.state['y_list'].pop(0)
            
            self.state['s_list'].append(s_k)
            self.state['y_list'].append(y_k)
        
        # Two-loop recursion
        q = flat_grad.clone()
        alpha = []
        rho_array = []
        reset = False
        
        for i in range(len(self.state['s_list']) - 1, -1, -1):
            s_i = self.state['s_list'][i]
            y_i = self.state['y_list'][i]
            
            dot_product = torch.dot(y_i, s_i)
            rho_i = 1.0 / (dot_product + 1e-6)
            
            # Check for numerical issues
            mag_den = torch.norm(y_i).item() * torch.norm(s_i).item()
            if mag_den == 0.0:
                reset = True
            else:
                rho_array.append(dot_product.item() / mag_den)
            
            alpha_i = rho_i * torch.dot(s_i, q)
            alpha.append(alpha_i)
            q = q - alpha_i * y_i
        
        # Initial Hessian approximation
        if len(self.state['s_list']) > 0:
            gamma = torch.dot(self.state['s_list'][-1], self.state['y_list'][-1]) / \
                    torch.dot(self.state['y_list'][-1], self.state['y_list'][-1])
        else:
            gamma = 1.0
        
        r = gamma * q
        
        # Second loop
        for i in range(len(self.state['s_list'])):
            s_i = self.state['s_list'][i]
            y_i = self.state['y_list'][i]
            rho_i = 1.0 / (torch.dot(y_i, s_i) + 1e-6)
            beta_i = rho_i * torch.dot(y_i, r)
            r = r + s_i * (alpha[len(self.state['s_list']) - 1 - i] - beta_i)
        
        # Check for numerical issues
        if torch.isnan(r).any() or (len(rho_array) > 0 and any(rho < 1e-4 for rho in rho_array)) or reset:
            r = flat_grad
            self.state['s_list'] = []
            self.state['y_list'] = []
            print("Warning: L-BFGS reset due to numerical issues")
        
        # Line search using Brent's method
        mag_r = torch.norm(r, p=2).item()
        
        # Create a closure that just evaluates loss (no backward)
        def line_search_closure():
            with torch.enable_grad():
                loss_val = closure()
            return loss_val.item()
        
        step_size, new_loss = self._brent_line_search(
            line_search_closure, 
            flat_weights, 
            r, 
            loss.item(), 
            budget
        )
        
        # Update parameters
        new_weights = flat_weights - step_size * r / mag_r
        self._set_flat_params(new_weights)
        
        # Store for next iteration
        self.state['s_k'] = -step_size * r / mag_r
        self.state['prev_grad'] = flat_grad.clone()
        self.state['initialized'] = True
        
        return torch.tensor(new_loss)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_array=["FCP","LN","AE","LSTM"]
transform = transforms.Compose([transforms.ToTensor()])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=len(trainset), shuffle=True)
test_loader = DataLoader(testset, batch_size=len(testset), shuffle=True)
for images, labels in train_loader:
    images_train, labels_train = images.to(device), labels.to(device)
    break
for images, labels in test_loader:
    images_test, labels_test = images.to(device), labels.to(device)
    break

if model_name == "LN" :
    net = LeNet().to(device)
elif model_name == "FCP" :
    net = FCP().to(device)
elif model_name == "AE" :
    net = Autoencoder().to(device)
elif model_name == "LSTM" :
    net=LSTM_sMNIST().to(device)
if model_name == "AE":
    criterion = nn.MSELoss()
else:
    criterion= nn.CrossEntropyLoss()


optimizer = LBFGS_Brent(net.parameters(), max_history=10, line_search_budget=6)
net.train()

def closure():
    optimizer.zero_grad()
    output = net(images_train)
    loss = criterion(output, labels_train)
    loss.backward()
    return loss
def closure_AE():
    optimizer.zero_grad()
    output = net(images_train)
    loss = criterion(output, images_train)
    loss.backward()
    return loss
for epoch in range(100):
    if model_name == "AE":
        loss = optimizer.step(closure_AE)
    else:
        loss= optimizer.step(closure)
    print(f"Epoch {epoch}, Loss: {loss.item():.6f}")
