In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


# Set the dimensions of the vectors and matrices

In [None]:
n = 3  # Dimension of vector p
m = 4  # Dimension of vector v
l = 5  # Dimension of output J(p)v

batch_size = 200  # Batch size for training

# Generate training data

In [None]:
num_samples = 2000
p = torch.randn(num_samples, n)
v = torch.randn(num_samples, m)

# Define the known mapping J'(p) as a function depending nonlinearly on p
def J_prime(p):
    # Replace with your own implementation of J_prime
    # This is just an example
    matrix = torch.zeros(l, m)
    for i in range(l):
        for j in range(m):
            matrix[i, j] = p[i % n] * p[j % n]
    return matrix

def f(p, v):
    return torch.matmul(J_prime(p), v.unsqueeze(-1)).squeeze(-1)

# Generate fpv using J'(p) and v
fpv = torch.stack([f(p[i], v[i]) for i in range(num_samples)], dim=0)
fpv = fpv.unsqueeze(1)  # Add the extra dimension

In [None]:
# Define the architecture of the neural network
class FunctionApproximator(nn.Module):
    def __init__(self):
        super(FunctionApproximator, self).__init__()
        self.create_mat = nn.Sequential(
            nn.Linear(n, n),
            nn.ReLU(),
            nn.Linear(n, l * m),
            nn.ReLU(),
        )
        for module in self.create_mat.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, p, v):
        output = self.create_mat(p)

        # Reshape the output into the desired matrix shape
        self.matrix = output.view(-1, l, m)

        # Multiply the matrix by v
        output = torch.bmm(self.matrix, v.unsqueeze(-1)).squeeze(-1)

        return output

model = FunctionApproximator()

# Define the loss function
criterion = nn.MSELoss()

# Define the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Convert data to batches
num_batches = num_samples // batch_size
p_batches = torch.split(p[:num_batches * batch_size], batch_size)
v_batches = torch.split(v[:num_batches * batch_size], batch_size)
fpv_batches = torch.split(fpv[:num_batches * batch_size], batch_size)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in range(num_batches):
        # Forward pass
        outputs = model(p_batches[batch], v_batches[batch])

        # Compute the loss
        loss = criterion(outputs, fpv_batches[batch])
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print progress
    if (epoch + 1) % 100 == 0:
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}")

Epoch 100/1000, Loss: 6.786809778213501
Epoch 200/1000, Loss: 6.786109733581543
Epoch 300/1000, Loss: 6.785980892181397
Epoch 400/1000, Loss: 6.785937738418579
Epoch 500/1000, Loss: 6.785919952392578
Epoch 600/1000, Loss: 6.785910415649414
Epoch 700/1000, Loss: 6.785904598236084
Epoch 800/1000, Loss: 6.785901021957398
Epoch 900/1000, Loss: 6.7858963966369625
Epoch 1000/1000, Loss: 6.785892152786255


In [None]:
# Evaluation
with torch.no_grad():
    # Generate test data
    test_p = torch.randn(1, n)
    test_v = torch.randn(1, m)

    # Apply the trained model to compute J(p)v
    Jpv = model(test_p, test_v)

    # Print the result
    print("Approximated J(p)v:")
    print(Jpv.squeeze())

    print("Real f(p,v):")
    print(f(test_p[0],test_v[0]))

    # Extract the internal matrix produced for specific p
    print(model.matrix)



Approximated J(p)v:
tensor([0., 0., 0., 0., 0.])
Real f(p,v):
tensor([-0.1103,  0.0423,  0.0523, -0.1103,  0.0423])
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])
