In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import time
from torch.optim.lr_scheduler import StepLR
import sympy as sp

In [3]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [4]:
def loss_func(self, output, target):
        target_range = torch.max(target, dim=-1, keepdim=True)[0] - torch.min(target, dim=-1, keepdim=True)[0]
        target_range = torch.clamp(target_range, min=1e-6).squeeze(-1)
        loss = ((target - output)**2)/target_range
        l2_reg = sum(p.pow(2.0).sum() for p in self.parameters())
        l1_reg = sum(p.abs().sum() for p in self.parameters())
        total_loss = loss + 0.01*l2_reg + 0.01*l1_reg
        return total_loss

def model_evaluate_function(self, params, symbols, formula):
    var_values = {symbols[j]: params[:, j] for j in range(len(symbols))}
    lambda_func = sp.lambdify(symbols, formula, modules=['numpy'])
    np_values = {str(sym): var_values[sym].detach().cpu().numpy() for sym in symbols}
    with np.errstate(all='ignore'): #need to fix this at some point
        evaluated = lambda_func(**np_values)
    evaluated = np.nan_to_num(evaluated, nan=0.0)
    return torch.tensor(evaluated, dtype=torch.float32, requires_grad=True)

In [5]:
class Multi_Func(nn.Module):
    def __init__(self, functions, num_params, symbols, input_channels, device):
        super().__init__()
        self.device = device
        self.functions = functions
        self.num_params = num_params
        self.symbols = symbols
        self.input_channels = input_channels
        self.params = sum(self.num_params)

        self.hidden_x1 = nn.Sequential(
            nn.Conv1d(in_channels=self.input_channels, out_channels=8, kernel_size=1),
            nn.SELU(),
            nn.Conv1d(in_channels=8, out_channels=6, kernel_size=1),
            nn.SELU(),
            nn.Conv1d(in_channels=6, out_channels=4, kernel_size=1),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(64)
        )

        self.hidden_xfc = nn.Sequential(
            nn.Linear(256, 64),
            nn.SELU(),
            nn.Linear(64, 32),
            nn.SELU(),
            nn.Linear(32, 20),
            nn.SELU(),
        )

        self.hidden_x2 = nn.Sequential(
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=2, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(16),
            nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(8),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(4),
        )

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(28, 128),
            nn.SELU(),
            nn.Linear(128, 64),
            nn.SELU(),
            nn.Linear(64, self.params),
        )

    def loss_func(self, output, target):
        target_range = torch.max(target, dim=-1, keepdim=True)[0] - torch.min(target, dim=-1, keepdim=True)[0]
        target_range = torch.clamp(target_range, min=1e-6).squeeze(-1)
        #print(f"target_range: {target_range}")
        loss = ((target - output)**2)/target_range
        l2_reg = sum(p.pow(2.0).sum() for p in self.parameters())
        l1_reg = sum(p.abs().sum() for p in self.parameters())
        total_loss = loss + 0.01*l2_reg + 0.01*l1_reg
        return total_loss

    def evaluate_function(self, params, symbols, formula):
        var_values = {symbols[j]: params[:, j] for j in range(len(symbols))}
        lambda_func = sp.lambdify(symbols, formula, modules=['numpy'])
        np_values = {str(sym): var_values[sym].detach().cpu().numpy() for sym in symbols}
        with np.errstate(all='ignore'): #need to fix this at some point
            evaluated = lambda_func(**np_values)
        evaluated = np.nan_to_num(evaluated, nan=0.0)
        return torch.tensor(evaluated, dtype=torch.float32, requires_grad=True)

    def forward(self, x, n=-1):
        x = x.unsqueeze(1).unsqueeze(2)
        out = self.hidden_x1(x)
        xfc = torch.reshape(out, (n, 256))
        xfc = self.hidden_xfc(xfc)

        out = torch.reshape(out, (n, 2, 128))
        out = self.hidden_x2(out)
        cnn_flat = self.flatten_layer(out)
        encoded = torch.cat((cnn_flat, xfc), 1)
        embedding = self.hidden_embedding(encoded)
        
        start_index = 0
        losses = []
        outputs = []
        #print(self.params)
        #print(f"whole embedding: {embedding.shape}")
        
        for f in range(len(self.functions)):
            #print(f"embedding: {embedding[:, start_index:start_index+self.num_params[f]].shape}")
            output = self.evaluate_function(
                embedding[:, start_index:start_index+self.num_params[f]],
                self.symbols[f],
                self.functions[f]
            ).to(device)
            outputs.append(output)
            loss = self.loss_func(output, x[:,0,0])
            losses.append(loss)
            start_index += self.num_params[f]  
        
        #print(f"loss len: {len(losses)}")      
        '''best_index = torch.argmin(torch.tensor(losses))
        best_func = self.functions[best_index]
        best_loss, best_out = losses[best_index], outputs[best_index]'''

        stacked_outputs = torch.stack(outputs)
        stacked_losses = torch.stack(losses)
        #print(f"stacked_outputs shape: {stacked_outputs.shape}")
        #print(f"stacked_losses shape: {stacked_losses.shape}")
        best_loss, best_indexes = torch.min(stacked_losses, dim=0) 
        #print(f"best_loss shape: {best_loss.shape}")
        #print(f"best_indexes shape: {best_indexes.shape}")
        #print(f"best indexes 0: {best_indexes[0]}")
        best_out = stacked_outputs[best_indexes, -1]
        best_func = [self.functions[idx] for idx in best_indexes]
        #print(f"best_func for 0: {best_func[0]}")

        return best_out, best_loss, best_func, outputs, losses

In [6]:
class Encoder(nn.Module):
    def __init__(self, functions, num_params, symbols, input_channels, device):
        super().__init__()
        self.device = device
        self.functions = functions
        self.num_params = num_params
        self.symbols = symbols
        self.input_channels = input_channels
        self.params = sum(self.num_params)
        self.sequence_length=96
        self.input_channel=2
        self.cov1d_size=128

        self.cov1d = nn.Conv1d(self.input_channel, self.cov1d_size, 3, stride=1, padding=1)
        self.flattened_size = self.cov1d_size * self.sequence_length
        self.dense = nn.Linear(self.flattened_size, self.num_params)
        
        self.selu_1 = nn.SELU()

    def forward(self, x):
        print(f"out initial: {x.shape}")
        out = self.cov1d(x)
        print(f"out after cov1d: {out.shape}")
        out = out.reshape(out.size(0), -1)
        print(f"out before dense: {out.shape}")
        out = self.dense(out)
        out = self.selu_1(out)
        
        start_index = 0
        losses = []
        outputs = []
        
        for f in range(len(self.functions)):
            output = model_evaluate_function(
                out[:, start_index:start_index+self.num_params[f]],
                self.symbols[f],
                self.functions[f]
            ).to(device)
            outputs.append(output)
            loss = loss_func(output, x[:,0,0])
            losses.append(loss)
            start_index += self.num_params[f]  
        
        stacked_outputs = torch.stack(outputs)
        stacked_losses = torch.stack(losses)
        best_loss, best_indexes = torch.min(stacked_losses, dim=0) 
        best_out = stacked_outputs[best_indexes, -1]
        best_func = [self.functions[idx] for idx in best_indexes]

        return best_out, best_loss, best_func, outputs, losses

In [134]:
class Multi_Func_ThreeChannels(nn.Module):
    def __init__(self, functions, num_params, symbols, device):
        super().__init__()
        self.device = device
        self.formulas = functions
        self.num_params = num_params
        self.symbols = symbols
        self.params = sum(self.num_params)

        self.hidden_x1 = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=8, kernel_size=1),
            nn.SELU(),
            nn.Conv1d(in_channels=8, out_channels=6, kernel_size=1),
            nn.SELU(),
            nn.Conv1d(in_channels=6, out_channels=4, kernel_size=1),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(64)
        )

        self.hidden_xfc = nn.Sequential(
            nn.Linear(256, 64),
            nn.SELU(),
            nn.Linear(64, 32),
            nn.SELU(),
            nn.Linear(32, 20),
            nn.SELU(),
        )

        self.hidden_x2 = nn.Sequential(
            nn.MaxPool1d(kernel_size=4),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(16),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(8),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(4),
        )

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(28, 128),
            nn.SELU(),
            nn.Linear(128, 64),
            nn.SELU(),
            nn.Linear(64, self.params),
        )

    def loss_func(self, outputs, targets):
        losses = []
        for output, target in zip(outputs, targets):
            loss = (torch.abs(target - output) ** 2)
            losses.append(loss.mean(dim=1))

        total_losses = torch.stack(losses).mean(dim=1)
        l2_reg = sum(p.pow(2.0).sum() for p in self.parameters())
        l1_reg = sum(p.abs().sum() for p in self.parameters())
        total_losses += 0.01 * l2_reg + 0.01 * l1_reg
        return total_losses

    def evaluate_function(self, params, symbols, formula):
        var_values = {symbols[j]: params[:, j] for j in range(len(symbols))}
        lambda_func = sp.lambdify(symbols, formula, modules=['numpy'])
        np_values = {str(sym): var_values[sym].detach().cpu().numpy() for sym in symbols}
        with np.errstate(all='ignore'):
            evaluated = lambda_func(**np_values)
        evaluated = np.nan_to_num(evaluated, nan=10000)
        return torch.tensor(evaluated, dtype=torch.float32, requires_grad=True)

    def compute_derivative(self, params, index, epsilon=1e-6):
        batch_size = params.shape[0]
        num_params = params.shape[1]
        gradients = torch.zeros(batch_size, num_params)
        param_tensor = params.clone().detach().requires_grad_(True)

        for j in range(num_params):
            perturbed_params_pos = param_tensor.clone().to(self.device)
            perturbed_params_neg = param_tensor.clone().to(self.device)
            perturbed_params_pos[:, j] += epsilon
            forward_value = self.evaluate_function(perturbed_params_pos, self.symbols[index], self.formulas[index])
            perturbed_params_neg[:, j] -= epsilon
            backward_value = self.evaluate_function(perturbed_params_neg, self.symbols[index], self.formulas[index])
            gradients[:, j] = (forward_value - backward_value) / (2 * epsilon)
        return gradients

    def compute_hessians(self, params, index, epsilon=1e-6):
        batch_size = params.shape[0]
        num_params = params.shape[1]
        hessians = torch.zeros((batch_size, num_params, num_params))
        param_tensor = params.clone().detach().requires_grad_(True)
        for j in range(num_params):
            for k in range(num_params):
                perturbed_params = param_tensor.clone().to(self.device)
                perturbed_params[:, j] += epsilon
                perturbed_params[:, k] += epsilon
                f_plus_plus = self.evaluate_function(perturbed_params, self.symbols[index], self.formulas[index])
                perturbed_params[:, k] -= 2 * epsilon
                f_plus_minus = self.evaluate_function(perturbed_params, self.symbols[index], self.formulas[index])
                perturbed_params[:, j] -= 2 * epsilon
                perturbed_params[:, k] += 2 * epsilon
                f_minus_plus = self.evaluate_function(perturbed_params, self.symbols[index], self.formulas[index])
                perturbed_params[:, k] -= 2 * epsilon
                f_minus_minus = self.evaluate_function(perturbed_params, self.symbols[index], self.formulas[index])
                hessians[:, j, k] = (f_plus_plus - f_plus_minus - f_minus_plus + f_minus_minus) / (4 * epsilon**2)
        return hessians

    def forward(self, x):
        out = self.hidden_x1(x.swapaxes(1, 2))
        xfc = torch.reshape(out, (-1, 256))
        xfc = self.hidden_xfc(xfc)

        out = torch.reshape(out, (-1, 2, 128))
        out = self.hidden_x2(out)
        cnn_flat = self.flatten_layer(out)
        encoded = torch.cat((cnn_flat, xfc), 1)
        embedding = self.hidden_embedding(encoded)
        
        start_index = 0
        para = embedding[:, self.num_params[0]].to(self.device)
        losses = []
        outputs = []
        preds = []

        for f in range(len(self.formulas)):
            para = embedding[:, start_index:start_index+self.num_params[f]]
            start_index += self.num_params[f]
            output = self.evaluate_function(para, self.symbols[f], self.formulas[f]).to(self.device)
            outputs.append(output)
            der = self.compute_derivative(para, f).to(self.device)
            hess = self.compute_hessians(para, f).to(self.device)
            
            hess = torch.flatten(hess, start_dim=1)
            output = output.unsqueeze(1)
            output = F.pad(output, (0, hess.size(1) - output.size(1)))
            der = F.pad(der, (0, hess.size(-1) - der.size(-1)))

            pred = torch.stack([output, der, hess], dim=2).to(self.device)
            pred = F.pad(pred, (0, 0, 0, (x.size(1) - pred.size(1))))
            preds.append(pred)

            loss = self.loss_func(pred, x)
            losses.append(loss)  

        stacked_outputs = torch.stack(outputs).to(self.device)
        stacked_losses = torch.stack(losses).to(self.device)
        stacked_preds = torch.stack(preds).to(self.device)
        best_loss, best_indexes = torch.min(stacked_losses, dim=0)
        best_out = stacked_outputs[best_indexes, -1]
        best_func = [self.formulas[idx] for idx in best_indexes]

        return best_out, best_loss, best_func, stacked_outputs, stacked_losses, stacked_preds


In [95]:
loaded_data = torch.load('hold_data.pth')

results = loaded_data['results']
formulas = loaded_data['formulas']
symbols = loaded_data['symbols']
params = loaded_data['params']
num_params = loaded_data['num_params']
flat_data = loaded_data['flattened_data']

dataloader = DataLoader(flat_data, batch_size=1000, shuffle=True)

  loaded_data = torch.load('hold_data.pth')


In [70]:
r = np.random.randint(10000)
print(flat_data[r, :, 1])
print(flat_data[r, :, 2])

tensor([ 0.0168, -0.0379,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000])
tensor([-0.0405,  0.0218,  0.0218,  0.0067,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.00

In [13]:
def numerical_derivative(tensor, epsilon=1e-5):
    shape = tensor.shape
    flat_tensor = tensor.view(-1)
    grad = torch.zeros_like(flat_tensor)
    
    for i in range(flat_tensor.numel()):
        original_value = flat_tensor[i].item()
        
        flat_tensor[i] = original_value + epsilon
        perturbed_tensor = flat_tensor.view(shape)
        f_plus = perturbed_tensor
        
        flat_tensor[i] = original_value - epsilon
        perturbed_tensor = flat_tensor.view(shape)
        f_minus = perturbed_tensor
        
        grad[i] = (f_plus - f_minus).mean() / (2 * epsilon)
        flat_tensor[i] = original_value
    
    return grad.view(shape)

In [11]:
def training_loss_func(model, output, target):
    target_max = torch.max(target, dim=-1, keepdim=True)[0]
    target_min = torch.min(target, dim=-1, keepdim=True)[0]
    target_range = torch.clamp(target_max - target_min, min=1e-6).squeeze(-1)
    
    mse_loss = torch.mean((output - target) ** 2, dim=-1)
    normalized_loss = mse_loss / target_range
    
    l2_reg = sum(p.pow(2.0).sum() for p in model.parameters())
    l1_reg = sum(p.abs().sum() for p in model.parameters())
    
    output_derivative = numerical_derivative(output)
    target_derivative = numerical_derivative(target)
    
    derivative_diff = torch.mean((output_derivative - target_derivative) ** 2, dim=-1)
    
    total_loss = torch.mean(normalized_loss) + 0.01 * l2_reg + 0.01 * l1_reg + 0.1 * torch.mean(derivative_diff)
    return total_loss


In [137]:
model = Multi_Func_ThreeChannels(functions=formulas, num_params=num_params, symbols=symbols, device=device).to(device)

#train_loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

epochs = 50
for epoch in range(epochs):
    start_time = time.time()
    train_loss = 0.0
    total_num = 1
    model.train()
    
    for train_batch in dataloader:
        train_batch = train_batch.to(device)
        optimizer.zero_grad()
        best_out, best_loss, best_func, outputs, losses, preds = model(train_batch)
        loss = training_loss_func(model, best_out, train_batch[:, 0, 0]) #train_loss_func(best_out, train_batch[:,0,0])
        #print(f"best_func: {best_func[0]}")
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * best_out.shape[0]
        total_num += best_out.shape[0]
    scheduler.step()
    train_loss /= total_num
    print(f"epoch : {epoch}/{epochs}, loss = {train_loss:.8f}")
    print(f"--- {time.time() - start_time} seconds ---")

epoch : 0/50, loss = 10.31608098
--- 31.61694312095642 seconds ---
epoch : 1/50, loss = 4.49397094
--- 26.172062635421753 seconds ---
epoch : 2/50, loss = 3.22246590
--- 27.048153162002563 seconds ---
epoch : 3/50, loss = 2.59375951
--- 26.515844345092773 seconds ---


In [97]:
model.eval()
best_out, best_loss, best_func, stacked_outputs, stacked_losses, stacked_preds = model(flat_data[0:500].to(device))
r = np.random.randint(500)
print(best_out)
print(flat_data[r, 0, 0])


tensor([0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989, 0.3989,
        0.3989, 0.3989, 0.3989, 0.3989, 