In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
import math
import time
import torch.nn.functional as F
from torch.utils.data import DataLoader
from NN_arch import LSTM_sMNIST, LeNet, FCP, Autoencoder
class FourierPathNN(nn.Module):
    def __init__(self, x1, x2, num_terms=25):
        super(FourierPathNN, self).__init__()
        self.num_terms = num_terms
        self.register_buffer('x1', x1.flatten())  # Store as a constant buffer
        self.register_buffer('x2', x2.flatten())  # Store as a constant buffer
        
        # Fourier coefficients for sine terms (b_n terms)
        self.b = nn.Parameter(torch.zeros(num_terms, x1.numel()))  # Learnable sine coefficients

    def forward(self, t_values):
        """
        Generate weights array of t values
        """
        t_values = t_values.view(-1, 1)
        # Linear interpolation for the base term a_0
        a_0 = (1 - t_values) * self.x1 + t_values * self.x2

        sine_terms = torch.zeros_like(a_0)
        for n in range(1, self.num_terms + 1):
            sine_terms += self.b[n-1].view(1, -1) * torch.sin(n * torch.pi * t_values)
        weights = a_0 + sine_terms  # Combine a_0 and sine terms
        return weights.view(-1, *self.x1.shape)

def loss_fn_simplified(weights,model,model_name, device, trainloader,b_gradients,num_terms,t_values):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """
    Compute the loss by injecting weights into the CNN model.
    """
    set_seed(42)
    model.train()
    total_loss = 0.0
    if model_name == "AE": #MSE Loss for Autoencoder, cross entropy for other.
        criterion=nn.MSELoss()
    else:
        criterion = nn.CrossEntropyLoss()
    Data_array=np.zeros(len(weights))
    Loss_sum=0.0
    #print(len(weights))
    smoothness_loss = 0.0
    Length=0.0
    
    for j in range(len(weights) - 1): # 
        smoothness_loss += torch.norm(weights[j+1] - weights[j])**2
        Length+=torch.norm(weights[j+1] - weights[j]).item()
    Loss_sum=0.0
    for i in range(len(weights)): #loop through 
        
        #weights_fc1, weights_fc2, weights_fc3 = load_and_reshape_weights(weights[i])
        index=0
        for param in model.parameters():
            param.data = weights[i][index:index + param.numel()].reshape(param.shape)
            index += param.numel()
        model.zero_grad()
        outputs = model(images)
        if model_name == "AE":
            loss = criterion(outputs, images)
        else:
            loss = criterion(outputs, labels)
                
        total_loss+=loss
        loss.backward() #compute gradient 
        flattened_gradients = torch.tensor([]).to(device)
        for param in model.parameters():
            flattened_gradients = torch.cat((flattened_gradients, param.grad.flatten()))
        for n in range(num_terms):
            b_gradients[n]+=flattened_gradients*torch.sin((n+1)*np.pi*t_values[i])
        Data_array[i]=loss.item()       
    return total_loss,Data_array,Length,b_gradients,0.001*smoothness_loss



def train_fourier_nn_simplified(path_model,model, model_name, images, labels, num_terms, num_steps=100, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.Adam(path_model.parameters(), lr=lr) #Can also be SGD or rmsprop
    t_values = torch.linspace(0, 1, 51).unsqueeze(1).to(device)  # 50 points between t= 0 and 1
    Min_Loss=None
    for step in range(num_steps):
        optimizer.zero_grad()
        path_weights = path_model(t_values)  # Generate weights with time t along the opt path.
        b_gradients=torch.zeros_like(path_model.b)
        loss,Data_array,L,b_gradients,smoothness_loss = loss_fn_simplified(path_weights,model,model_name, device, trainloader,b_gradients,num_terms,t_values)
        smoothness_loss.backward()
        path_model.b.grad.add_(b_gradients)
        optimizer.step()
        if Min_Loss is None or Min_Loss > loss:
            Min_array=Data_array
            Min_Loss=loss
            Min_Length=L
    print(f"Loss: {loss:.6f}","Length:", L)
    return Min_array,Min_Length

def main_simplified(x1,x2,model_name,data_set):
    # Dataset Preparation
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_name == "AE":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.view(-1))]) 
    else: 
        transform = transforms.Compose([transforms.ToTensor()])
    
    if data_set == "Test":
        trainset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        trainloader = DataLoader(trainset, batch_size=len(trainset), shuffle=True)
    else: 
        trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        trainloader = DataLoader(trainset, batch_size=len(trainset), shuffle=True)
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        break
    num_terms=10 #10 terms in truncated Fourier series
    
    #Generate initial path
    fourier_nn = FourierPathNN(x1, x2, num_terms=num_terms).to(device)

    # Landscape used for pathfinding
    if model_name == "LN" :
        model = LeNet()
    elif model_name == "FCP" :
        model = FCP()
    elif model_name == "AE" :
        model = Autoencoder()
    elif model_name == "LSTM" :
        model=LSTM_sMNIST()
    # Train the FourierPathNN
    Min_path,L=train_fourier_nn_simplified(fourier_nn,model,model_name,images, labels,num_terms)
    return Min_path,L

model_name_array=["FCP","LN","AE","LSTM"] #Select which architecture
data_set_array=["Test","Train"]
model_name=model_name_array
if model_name == "FCP":
    x_array=np.load("FC_BFGS_Training_best48_weights.npy")
if model_name == "LN":
    x_array=np.load("LN_BFGS_Training_best48_weights.npy")
if model_name == "AE":
    x_array=np.load("AE_BFGS_Training_best48_weights.npy")
if model_name == "LSTM":
    x_array=np.load("LSTM_BFGS_Training_best48_weights.npy")

#pick start and end points 
x1=x_array[1]
x2=x_array[2]
main_simplified(x1,x2,model_name,data_set_array[1])