In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from functools import reduce

import matplotlib.pyplot as plt
import numpy as np
import time
from torch.optim.lr_scheduler import StepLR, LambdaLR
import sympy as sp

In [3]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [4]:
class Function_Selector(nn.Module):
    def __init__(self, functions, num_params, symbols, x_data, device):
        super().__init__()
        self.device = device
        self.functions = functions
        self.x_data = x_data.to(self.device).requires_grad_(True)
        self.num_params = num_params
        self.max_params = max(num_params)
        self.total_params = sum(self.num_params)
        self.symbols = symbols

        self.hidden_x1 = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=8, kernel_size=7, padding=3),
            nn.LayerNorm([8, 100]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=8, out_channels=6, kernel_size=7, padding=3),
            nn.LayerNorm([6, 100]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=6, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 100]),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool1d(64)
        )

        self.hidden_xfc = nn.Sequential(
            nn.Linear(256, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Linear(32, 20),
            nn.LayerNorm(20),
            nn.LeakyReLU(),
        )

        self.hidden_x2 = nn.Sequential(
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),           
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),
            nn.LeakyReLU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5, padding=2),
            nn.LayerNorm([4, 32]),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool1d(16),
            nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool1d(8),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool1d(4),
        )

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(28, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, self.total_params),
        )

    def sympy_to_torch(self, expr, symbols):
        torch_funcs = {
            sp.Add: lambda *args: reduce(torch.add, args),
            sp.Mul: lambda *args: reduce(torch.mul, args),
            sp.Pow: torch.pow,
            sp.sin: torch.sin,
            sp.cos: torch.cos,
        }

        def torch_func(*args):
            def _eval(ex):
                if isinstance(ex, sp.Symbol):
                    return args[symbols.index(ex)]
                elif isinstance(ex, sp.Number):
                    return torch.full_like(args[0], float(ex))
                elif isinstance(ex, sp.Expr):
                    op = type(ex)
                    if op in torch_funcs:
                        return torch_funcs[op](*[_eval(arg) for arg in ex.args])
                    else:
                        raise ValueError(f"Unsupported operation: {op}")
                else:
                    raise ValueError(f"Unsupported type: {type(ex)}")
            
            return _eval(expr)

        return torch_func

    def evaluate(self, params, index):
        symbols = self.symbols[index]
        formula = self.functions[index]
        x = self.x_data
        torch_func = self.sympy_to_torch(formula, symbols)
        var_values = [params[:, j] for j in range(len(symbols)-1)] + [x.unsqueeze(1)]
        results = torch_func(*var_values)
        return results.swapaxes(0, 1)

    def forward(self, inputs):
        inputs = inputs.requires_grad_(True)
        outs = inputs.unsqueeze(1).to(self.device)
        outs = self.hidden_x1(outs)
        xfc = torch.reshape(outs, (-1, 256))
        xfc = self.hidden_xfc(xfc)

        outs = self.hidden_x2(outs)
        cnn_flat = self.flatten_layer(outs)
        encoded = torch.cat((cnn_flat, xfc), 1)
        embedding = self.hidden_embedding(encoded)

        start_index = 0
        losses = []
        outputs = []
        all_params = []
        
        for f in range(len(self.functions)):
            params = embedding[:, start_index:start_index+self.num_params[f]]
            all_params.append(F.pad(params, (0, self.max_params-self.num_params[f])))
            output = self.evaluate(params, f).to(self.device)
            outputs.append(output)
            loss = torch.mean(((inputs - output) ** 2), dim=1)
            losses.append(loss)
            start_index += self.num_params[f]        
        stacked_losses = torch.stack(losses).to(self.device)
        stacked_preds = torch.stack(outputs).to(self.device)
        weights = F.softmax(-stacked_losses, dim=0)
        best_out = torch.sum(weights.unsqueeze(2) * stacked_preds, dim=0)
        best_loss = torch.sum(weights * stacked_losses, dim=0)        
        best_func = weights.t()
        best_params = torch.sum(weights.unsqueeze(2) * torch.stack(all_params), dim=0)
        return best_out, best_loss, best_func, weights, best_params, outputs, losses, all_params