In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import numpy as np
import time
from torch.optim.lr_scheduler import StepLR, LambdaLR
import sympy as sp

In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [6]:
class Multi_Func_Channels(nn.Module):
    def __init__(self, functions, num_params, x_data, input_channels, device):
        super().__init__()
        self.device = device
        self.functions = functions
        self.x_data = x_data.to(self.device)
        self.input_channels = input_channels
        self.num_params = num_params
        self.max_params = max(num_params)
        self.total_params = sum(self.num_params)
        self.symbols = symbols
        self.epsilon = 1e-4

        self.hidden_x1 = nn.Sequential(
            nn.Conv1d(in_channels=self.input_channels, out_channels=8, kernel_size=7),
            nn.SELU(),
            nn.Conv1d(in_channels=8, out_channels=6, kernel_size=7),
            nn.SELU(),
            nn.Conv1d(in_channels=6, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(64)
        )

        self.hidden_xfc = nn.Sequential(
            nn.Linear(256, 64),
            nn.SELU(),
            nn.Linear(64, 32),
            nn.SELU(),
            nn.Linear(32, 20),
            nn.SELU(),
        )

        self.hidden_x2 = nn.Sequential(
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=2, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(16),
            nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(8),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(4),
        )

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(28, 128),
            nn.SELU(),
            nn.Linear(128, 64),
            nn.SELU(),
            nn.Linear(64, self.total_params),
        )

    def evaluate(self, params, index):
        symbols = self.symbols[index]
        formula = self.functions[index]
        x = self.x_data
        var_values = {str(symbols[j]): params[:, j] for j in range(len(symbols)-1)}
        eval_func = sp.lambdify(symbols, formula, modules="numpy")
        #results = []
        #for xi in x:
        var_values[str(symbols[-1])] = x.unsqueeze(1)
            #np_values = {str(sym): var_values[sym].detach().cpu().numpy() for sym in symbols}
        results = eval_func(**var_values)
        results = torch.nan_to_num(results, 0)
        #results.append(eval_func(**var_values))
        #tensor_results = [torch.tensor(r, device=device) for r in results]
        return results.swapaxes(0,1)
    
    def loss_func(self, outputs, targets):
        losses = []
        outputs = outputs.permute(2,0,1)
        targets = targets.permute(2,0,1)
        for output, target in zip(outputs, targets):
            loss = torch.mean(((target - output) ** 2), dim=1)
            losses.append(loss)
        total_losses = torch.mean(torch.stack(losses), dim=0)
        #l2_reg = sum(p.pow(2.0).sum() for p in self.parameters())
        #l1_reg = sum(p.abs().sum() for p in self.parameters())
        #total_losses += 0.01 * l2_reg + 0.01 * l1_reg
        return total_losses

    def forward(self, inputs):
        target = inputs.squeeze(dim=2)
        outs = torch.swapaxes(inputs, 1, 2).to(self.device)
        outs = self.hidden_x1(outs)
        xfc = torch.reshape(outs, (-1, 256))
        xfc = self.hidden_xfc(xfc)

        outs = torch.reshape(outs, (-1, 2, 128))
        outs = self.hidden_x2(outs)
        cnn_flat = self.flatten_layer(outs)
        encoded = torch.cat((cnn_flat, xfc), 1)
        embedding = self.hidden_embedding(encoded)

        start_index = 0
        losses = []
        outputs = []
        preds = []
        pred_params = []

        for f in range(len(self.functions)):
            params = embedding[:, start_index:start_index+self.num_params[f]]
            pred_params.append(params)
            y_vals = self.evaluate(params, f).to(self.device)
            d_vals = self.derivative(params, f).to(self.device)
            output = torch.stack([d_vals,y_vals], dim=2).to(self.device)
            outputs.append(output)
            preds.append(y_vals)
            loss = self.loss_func(output, target)
            losses.append(loss)
            start_index += self.num_params[f]  
        stacked_losses = torch.stack(losses).to(self.device)
        stacked_preds = torch.stack(preds).to(self.device)
        best_loss, best_indexes = torch.min(stacked_losses, dim=0)
        best_out = stacked_preds[best_indexes, -1]
        best_func = [self.functions[idx] for idx in best_indexes]
        best_params = []
        for index, value in enumerate(best_indexes):
            best_params.append(pred_params[value][index])
        return best_out, best_loss, best_func, best_indexes, best_params, stacked_preds, stacked_losses, pred_params

In [4]:
loaded_data = torch.load('hold_data.pth')

x_values = loaded_data['x_values'].to(device)
y_values = loaded_data['y_values'].to(device)
derivatives = loaded_data['derivatives'].to(device)
params = loaded_data['param_values'].to(device)
functions = loaded_data['formulas']
symbols = loaded_data['symbols']
num_params = loaded_data['num_params'].to(device)
function_labels = loaded_data['function_labels'].to(device)

  loaded_data = torch.load('hold_data.pth')


In [5]:
print(f"x_values: {x_values.shape}")
print(f"y_values: {y_values.shape}")
print(f"derivatives: {derivatives.shape}")
#print(f"hessians: {hessians.shape}")
print(f"param_values: {params.shape}")
print(f"formulas: {len(functions)}")
print(f"symbols: {len(symbols)}")
print(f"num_params: {num_params.shape}")
print(f"function_labels: {function_labels.shape}")

x_values: torch.Size([100])
y_values: torch.Size([10000, 100])
derivatives: torch.Size([10000, 100, 5])
param_values: torch.Size([10000, 5])
formulas: 10
symbols: 10
num_params: torch.Size([10])
function_labels: torch.Size([10000])


In [None]:
model = Multi_Func(functions, x_values, 2, device).to(device)