In [151]:
import torch
import torch.nn as nn

In [266]:
class MyLayer(nn.Module):
    def __init__(self, input_size, output_size, functions) -> None:
        super().__init__()
        self.functions = functions
        self.n_functions = len(functions)
        self.weights = nn.Parameter(torch.ones(input_size, self.n_functions))

        self.output = None

    def forward(self, x):

        self.output = []

        x = torch.matmul(x, self.weights)

        for i in range(self.n_functions):
            self.output.append(self.functions[i](x[i]))
        
        return torch.stack(self.output)

In [267]:
class MyNetwork(nn.Module):
    def __init__(self, input_size, output_size, functions) -> None:
        super().__init__()
        self.layer = MyLayer(input_size, output_size, functions)
    
    def forward(self, x):
        return self.layer(x)
    
    def get_weights(self):
        return self.layer.weights

In [268]:
# Generate data and train the network
functions = [torch.sqrt, torch.sin]
net = MyNetwork(10, 2, functions)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.MSELoss()
for i in range(1000):
    optimizer.zero_grad()
    y_hat = net(torch.ones(10))
    loss = criterion(y_hat, torch.ones(2))
    loss.backward()
    optimizer.step()

tensor([ 3.1623, -0.5440], grad_fn=<StackBackward0>)
tensor([ 3.1569, -0.4311], grad_fn=<StackBackward0>)
tensor([ 3.1515, -0.3113], grad_fn=<StackBackward0>)
tensor([ 3.1460, -0.1908], grad_fn=<StackBackward0>)
tensor([ 3.1406, -0.0750], grad_fn=<StackBackward0>)
tensor([3.1352, 0.0321], grad_fn=<StackBackward0>)
tensor([3.1297, 0.1285], grad_fn=<StackBackward0>)
tensor([3.1243, 0.2137], grad_fn=<StackBackward0>)
tensor([3.1189, 0.2880], grad_fn=<StackBackward0>)
tensor([3.1134, 0.3526], grad_fn=<StackBackward0>)
tensor([3.1079, 0.4086], grad_fn=<StackBackward0>)
tensor([3.1025, 0.4572], grad_fn=<StackBackward0>)
tensor([3.0970, 0.4996], grad_fn=<StackBackward0>)
tensor([3.0916, 0.5367], grad_fn=<StackBackward0>)
tensor([3.0861, 0.5692], grad_fn=<StackBackward0>)
tensor([3.0806, 0.5980], grad_fn=<StackBackward0>)
tensor([3.0751, 0.6235], grad_fn=<StackBackward0>)
tensor([3.0696, 0.6463], grad_fn=<StackBackward0>)
tensor([3.0641, 0.6666], grad_fn=<StackBackward0>)
tensor([3.0586, 0.684

In [269]:
print(list(net.parameters()))

[Parameter containing:
tensor([[0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955],
        [0.1000, 0.7955]], requires_grad=True)]
