In [1]:
import torch
import torch.nn as nn

In [2]:
class MyLayer(nn.Module):
    def __init__(self, input_size, output_size, functions) -> None:
        super().__init__()
        self.functions = functions
        self.n_functions = len(functions)
        self.weights = nn.Parameter(torch.ones(input_size, self.n_functions))

        self.output = None

    def forward(self, x):

        self.output = []

        x = torch.matmul(x, self.weights)

        for i in range(self.n_functions):
            self.output.append(self.functions[i](x[i]))
        
        return torch.stack(self.output)

In [3]:
class MyNetwork(nn.Module):
    def __init__(self, input_size, output_size, functions) -> None:
        super().__init__()
        self.layer = MyLayer(input_size, output_size, functions)
    
    def forward(self, x):
        return self.layer(x)
    
    def get_weights(self):
        return self.layer.weights

In [4]:
# Generate data and train the network
functions = [torch.sqrt, torch.sin]
net = MyNetwork(10, 2, functions)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.MSELoss()
for i in range(1000):
    optimizer.zero_grad()
    y_hat = net(torch.ones(10))
    loss = criterion(y_hat, torch.ones(2))
    loss.backward()
    optimizer.step()

In [5]:
print(list(net.parameters()))

[Parameter containing:
tensor([[0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955]], requires_grad=True)]
