In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import time
from torch.optim.lr_scheduler import StepLR
import sympy as sp

In [3]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [44]:
loaded_data = torch.load('hold_data.pth')

x_values = loaded_data['x_values'].to(device)
y_values = loaded_data['y_values'].to(device)
derivatives = loaded_data['derivatives'].to(device)
params = loaded_data['param_values'].to(device)
functions = loaded_data['formulas']
symbols = loaded_data['symbols']
num_params = loaded_data['num_params'].to(device)
hessians = torch.load('hold_other.pth')['hessians'].to(device)


  loaded_data = torch.load('hold_data.pth')
  hessians = torch.load('hold_other.pth')['hessians'].to(device)


In [45]:
print(f"x_values: {x_values.shape}")
print(f"y_values: {y_values.shape}")
print(f"derivatives: {derivatives.shape}")
print(f"hessians: {hessians.shape}")
print(f"param_values: {params.shape}")
print(f"formulas: {len(functions)}")
print(f"symbols: {len(symbols)}")
print(f"num_params: {num_params.shape}")

x_values: torch.Size([100])
y_values: torch.Size([10000, 100])
derivatives: torch.Size([10000, 100, 5])
hessians: torch.Size([10000, 100, 25])
param_values: torch.Size([10000, 5])
formulas: 10
symbols: 10
num_params: torch.Size([10])


In [73]:
h = hessians.flatten(1,2)
d = derivatives.flatten(1,2)
d = F.pad(d, (0,h.shape[1]-d.shape[1]))
y = F.pad(y_values, (0,h.shape[1]-y_values.shape[1]))
full_data = torch.stack([h,d,y], dim=2)
dataloader = DataLoader(full_data, batch_size=50, shuffle=True)

In [76]:
print(num_params)

tensor([4, 2, 2, 4, 3, 2, 2, 4, 3, 2], device='cuda:4')


In [74]:
class Multi_Func_Channels(nn.Module):
    def __init__(self, functions, num_params, x_data, input_channels, device):
        super().__init__()
        self.device = device
        self.functions = functions
        self.x_data = x_data
        self.input_channels = input_channels
        self.num_params = num_params
        self.max_params = max(num_params)
        self.total_params = sum(self.num_params)
        self.symbols = symbols
        self.epsilon = 1e-4

        self.hidden_x1 = nn.Sequential(
            nn.Conv1d(in_channels=self.input_channels, out_channels=8, kernel_size=7),
            nn.SELU(),
            nn.Conv1d(in_channels=8, out_channels=6, kernel_size=7),
            nn.SELU(),
            nn.Conv1d(in_channels=6, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(64)
        )

        self.hidden_xfc = nn.Sequential(
            nn.Linear(256, 64),
            nn.SELU(),
            nn.Linear(64, 32),
            nn.SELU(),
            nn.Linear(32, 20),
            nn.SELU(),
        )

        self.hidden_x2 = nn.Sequential(
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=2, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.Conv1d(in_channels=4, out_channels=4, kernel_size=5),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(16),
            nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(8),
            nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3),
            nn.SELU(),
            nn.AdaptiveAvgPool1d(4),
        )

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(28, 128),
            nn.SELU(),
            nn.Linear(128, 64),
            nn.SELU(),
            nn.Linear(64, self.total_params),
        )

    def evaluate(self, params, index):
        symbols = self.symbols[index]
        formula = self.functions[index]
        x = self.x_data
        var_values = {symbols[j]: params[:, j] for j in range(len(symbols)-1)}
        eval_func = sp.lambdify(symbols, formula, modules="numpy")
        results = []
        for xi in x:
            var_values[symbols[-1]] = xi
            np_values = {str(sym): var_values[sym].detach().cpu().numpy() for sym in symbols}
            results.append(eval_func(**np_values))
        tensor_results = [torch.tensor(r, device=device) for r in results]
        return torch.stack(tensor_results, dim=1)
    
    def derivative(self, params, index):
        derivatives = torch.zeros((params.shape[0], params.shape[1], self.max_params))
        print(f"derivatives: {derivatives.shape}")
        params_n = params.clone().detach().requires_grad_(True)
        for p in range(len(symbols[index])):
            plus = params_n.clone()
            minus = params_n.clone()
            plus[:,p] += self.epsilon
            forward_values = self.evaluate(plus, index)
            minus[:, p] -= self.epsilon
            backward_values = self.evaluate(minus, index)
            print(f"forward: {forward_values.shape}")
            derivatives[:, :, p] = (forward_values - backward_values) / (2 * self.epsilon)

    def hessian(self, params, index):
        hessians = torch.zeros(params[0], params[1], self.max_params, self.max_params)
        params_f = params.clone().detach().requires_grad_(True)
        for j in range(len(symbols[index])):
            for k in range(len(symbols[index])):
                plus_plus = params_f.clone()
                plus_minus = params_f.clone()
                minus_plus = params_f.clone()
                minus_minus = params_f.clone()

                plus_plus[:, j] += self.epsilon
                plus_plus[:, k] += self.epsilon

                plus_minus[:, j] += self.epsilon
                plus_minus[:, k] -= self.epsilon

                minus_plus[:, j] -= self.epsilon
                minus_plus[:, k] += self.epsilon

                minus_minus[:, j] -= self.epsilon
                minus_minus[:, k] -= self.epsilon

                forward_forward = self.evaluate(plus_plus, index)
                forward_backward = self.evaluate(plus_minus, index)
                backward_forward = self.evaluate(minus_plus, index)
                backward_backward = self.evaluate(minus_minus,index)
                hessians[:, :, j, k] = (forward_forward - forward_backward - backward_forward + backward_backward) / (4 * epsilon **2)


    def forward(self, inputs):
        target = inputs.squeeze(dim=2)
        outs = torch.swapaxes(inputs, 1, 2).to(self.device)
        outs = self.hidden_x1(outs)
        xfc = torch.reshape(outs, (-1, 256))
        xfc = self.hidden_xfc(xfc)

        outs = torch.reshape(outs, (-1, 2, 128))
        outs = self.hidden_x2(outs)
        cnn_flat = self.flatten_layer(outs)
        encoded = torch.cat((cnn_flat, xfc), 1)
        embedding = self.hidden_embedding(encoded)

        loss_func = nn.MSELoss()
        start_index = 0
        losses = []
        outputs = []

        for f in range(len(self.functions)):
            params = embedding[:, start_index:start_index+self.num_params[f]]
            print(f"params: {params.shape}")
            y_vals = self.evaluate(params, f)
            print(f"y_vals: {y_vals.shape}")
            d_vals = self.derivative(params, f)
            print(f"d_vals: {d_vals.shape}")
            h_vals = self.hessian(params, f)
            print(f"h_vals: {h_vals.shape}")
            # PAUSED HERE NEED TO FIGURE OUT DATA STACKING
            '''outputs.append(output)
            loss = loss_func(output, target)
            losses.append(loss)
            start_index += self.functions[1][f]        
        best_index = torch.argmin(torch.tensor(losses))
        best_func = self.functions[0][best_index]
        best_loss, best_out = losses[best_index], outputs[best_index]

        return best_out, best_loss, best_func, outputs, losses'''

In [75]:
model = Multi_Func_Channels(functions=functions, num_params=num_params, x_data=x_values, input_channels=3, device=device).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
loss_func = nn.MSELoss()

for batch in dataloader:
    print(model(batch))
    break

params: torch.Size([50, 4])
y_vals: torch.Size([50, 100])
derivatives: torch.Size([50, 4, 4])
forward: torch.Size([50, 100])


RuntimeError: The expanded size of the tensor (4) must match the existing size (100) at non-singleton dimension 1.  Target sizes: [50, 4].  Tensor sizes: [50, 100]