In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import time
from torch.optim.lr_scheduler import StepLR
import sympy as sp

In [2]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [4]:
loaded_data = torch.load('hold_data.pth')

x_values = loaded_data['x_values']
y_values = loaded_data['y_values']
derivatives = loaded_data['derivatives']
hessians = loaded_data['hessians']
params = loaded_data['param_values']
formulas = loaded_data['formulas']
symbols = loaded_data['symbols']
num_params = loaded_data['num_params']


  loaded_data = torch.load('hold_data.pth')


In [None]:
class Multi_Func_Channels(nn.Module):
    def __init__(self, functions, num_params, x_data, input_channels, device):
        super().__init__()
        self.device = device
        self.formulas = functions
        self.x_data = x_data
        self.input_channels = input_channels
        self.num_params = num_params
        self.max_params = max(num_params)
        self.symbols = symbols
        self.params = sum(self.num_params)
        self.epsilon = 1e-4

        self.hidden_x1 = nn.Sequential(
            nn.Conv1d(in_channels=self.input_channels, out_channels=8, kernel_size=7),
            nn.SELU(),
            nn.Conv1d(in_channels=8, out_channels=6, kernel_size=7),
            nn.SELU(),
            nn.Conv1d(in_channels=6, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(64)
        )

        self.hidden_xfc = nn.Sequential(
            nn.Linear(256, 64),
            nn.SELU(),
            nn.Linear(64, 32),
            nn.SELU(),
            nn.Linear(32, 20),
            nn.SELU(),
        )

        self.hidden_x2 = nn.Sequential(
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=2, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(16),
            nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(8),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(4),
        )

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(28, 128),
            nn.SELU(),
            nn.Linear(128, 64),
            nn.SELU(),
            nn.Linear(64, self.params),
        )

    def evaluate(self, params, index):
        symbols = self.symbols[index]
        formula = self.formulas[index]
        x = self.x_data
        var_values = {symbols[j]: params[:, j] for j in range(len(symbols)-1)}
        eval_func = sp.lambdify(symbols, formula, modules="numpy")
        results = []
        for xi in x:
            var_values[symbols[-1]] = xi
            np_values = {str(sym): var_values[sym].detach().cpu().numpy() for sym in symbols}
            results.append(eval_func(**np_values))
        tensor_results = [torch.tensor(r, device=device) for r in results]
        return torch.stack(tensor_results, dim=1)
    
    def derivative(self, params, index):
        derivatives = torch.zeros(params[0], params[1], self.max_params)
        params_n = params.clone().detach().requires_grad_(True)
        for p in range(len(symbols[index])):
            plus = params_n.clone()
            minus = params_n.clone()
            plus[:,p] += self.epsilon
            forward_values = self.evaluate_function(plus, index)
            minus[:, p] -= self.epsilon
            backward_values = self.evaluate_function(minus, index)
            derivatives[:, :, p] = (forward_values - backward_values) / (2 * self.epsilon)

    def hessian(self, params, index):
        hessians = torch.zeros(params[0], params[1], self.max_params, self.max_params)
        params_f = params.clone().detach().requires_grad_(True)
        for j in range(len(symbols[f])):
            for k in range(len(symbols[f])):
                plus_plus = params_f.clone()
                plus_minus = params_f.clone()
                minus_plus = params_f.clone()
                minus_minus = params_f.clone()

                plus_plus[:, j] += self.epsilon
                plus_plus[:, k] += self.epsilon

                plus_minus[:, j] += self.epsilon
                plus_minus[:, k] -= self.epsilon

                minus_plus[:, j] -= self.epsilon
                minus_plus[:, k] += self.epsilon

                minus_minus[:, j] -= self.epsilon
                minus_minus[:, k] -= self.epsilon

                forward_forward = self.evaluate_function(plus_plus, index)
                forward_backward = self.evaluate_function(plus_minus, index)
                backward_forward = self.evaluate_function(minus_plus, index)
                backward_backward = self.evaluate_function(minus_minus,index)
                hessians[:, :, j, k] = (forward_forward - forward_backward - backward_forward + backward_backward) / (4 * epsilon **2)


    def forward(self, x, n=-1):
        target = x.squeeze(dim=2)
        x = torch.swapaxes(x, 1, 2).to(self.device)
        x = self.hidden_x1(x)
        xfc = torch.reshape(x, (n, 256))
        xfc = self.hidden_xfc(xfc)

        x = torch.reshape(x, (n, 2, 128))
        x = self.hidden_x2(x)
        cnn_flat = self.flatten_layer(x)
        encoded = torch.cat((cnn_flat, xfc), 1)
        embedding = self.hidden_embedding(encoded)

        loss_func = nn.MSELoss()
        start_index = 0
        losses = []
        outputs = []
        
        for f in range(len(self.functions[0])):
            output = self.functions[0][f](
                embedding[:, start_index:start_index+self.functions[1][f]], 
                self.x_data, 
                device=self.device
            ).to(device)
            outputs.append(output)
            loss = loss_func(output, target)
            losses.append(loss)
            start_index += self.functions[1][f]        
        best_index = torch.argmin(torch.tensor(losses))
        best_func = self.functions[0][best_index]
        best_loss, best_out = losses[best_index], outputs[best_index]

        return best_out, best_loss, best_func, outputs, losses