## Notebook to understand NN differentiation

In [None]:
import torch
import torch.nn as nn

# Define the sub-networks for your architecture
class PositionEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(PositionEncoder, self).__init__()
        self.linear = nn.Linear(input_dim, embedding_dim)

    def forward(self, x):
        return self.linear(x)

class PhysicsEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(PhysicsEncoder, self).__init__()
        self.linear = nn.Linear(input_dim, embedding_dim)

    def forward(self, x):
        return self.linear(x)

class FinalNet(nn.Module):
    def __init__(self, embedding_dim, output_dim):
        super(FinalNet, self).__init__()
        self.linear = nn.Linear(embedding_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

# Full network combining the sub-networks
class MyNetwork(nn.Module):
    def __init__(self, pos_dim, physics_dim, embedding_dim, output_dim):
        super(MyNetwork, self).__init__()
        self.pos_encoder = PositionEncoder(pos_dim, embedding_dim)
        self.physics_encoder = PhysicsEncoder(physics_dim, embedding_dim)
        self.final_net = FinalNet(embedding_dim, output_dim)

    def forward(self, pos, physics_params):
        pos_embedding = self.pos_encoder(pos)
        physics_embedding = self.physics_encoder(physics_params)
        combined_embedding = pos_embedding * physics_embedding
        output = self.final_net(combined_embedding)
        return output

# --- Jacobian Calculation ---

# Instantiate your trained network
pos_dim = 2
physics_dim = 5
embedding_dim = 64
output_dim = 1
model = MyNetwork(pos_dim, physics_dim, embedding_dim, output_dim)
# model.load_state_dict(torch.load('your_model.pth')) # Load your trained weights
model.eval() # Set the model to evaluation mode

# Example input data
# Create a batch of position tensors
pos_input = torch.randn(10, pos_dim, requires_grad=True)
physics_input = torch.randn(10, physics_dim)

# Define a function that takes only the variable we want to differentiate with respect to
def model_forward_for_jacobian(p):
    return model(p, physics_input)

# Compute the Jacobian of the output with respect to the position input
jacobian_matrix = torch.autograd.functional.jacobian(model_forward_for_jacobian, pos_input)

# The resulting jacobian_matrix will have the shape (batch_size, output_dim, batch_size, input_dim)
# We are interested in the diagonal elements of the batch dimensions
jacobian_per_sample = torch.diagonal(jacobian_matrix, dim1=0, dim2=2).permute(2, 0, 1)


print("Shape of the Jacobian matrix per sample:", jacobian_per_sample.shape)
# For a single output, this will give you a tensor of shape (batch_size, output_dim, input_dim)
# where each [i, 0, 0] is d(output_i)/dr_i and [i, 0, 1] is d(output_i)/dz_i

# To get the derivatives for the first sample in the batch:
# d(output)/dr and d(output)/dz
derivatives_first_sample = jacobian_per_sample[0]
print("Derivatives for the first sample (d(output)/dr, d(output)/dz):", derivatives_first_sample)

### Test with double derivative, to backprop as a PINN loss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# --- Assume the same MyNetwork class definition as before ---
class PositionEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim): super().__init__(); self.linear = nn.Linear(input_dim, embedding_dim)
    def forward(self, x): return self.linear(x)

class PhysicsEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim): super().__init__(); self.linear = nn.Linear(input_dim, embedding_dim)
    def forward(self, x): return self.linear(x)

class FinalNet(nn.Module):
    def __init__(self, embedding_dim, output_dim): super().__init__(); self.linear = nn.Linear(embedding_dim, output_dim)
    def forward(self, x): return self.linear(x)

class MyNetwork(nn.Module):
    def __init__(self, pos_dim, physics_dim, embedding_dim, output_dim):
        super().__init__()
        self.pos_encoder = PositionEncoder(pos_dim, embedding_dim)
        self.physics_encoder = PhysicsEncoder(physics_dim, embedding_dim)
        self.final_net = FinalNet(embedding_dim, output_dim)

    def forward(self, pos, physics_params):
        pos_embedding = self.pos_encoder(pos)
        physics_embedding = self.physics_encoder(physics_params)
        combined_embedding = pos_embedding * physics_embedding
        output = self.final_net(combined_embedding)
        return output

# --- Setup for Training ---
model = MyNetwork(pos_dim=2, physics_dim=5, embedding_dim=64, output_dim=1)
model.train() # Set to training mode

optimizer = optim.Adam(model.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()

# The weight for the derivative loss term
lambda_deriv = 1.0

# --- SIMULATED TRAINING BATCH ---
# In your real code, this would come from your DataLoader
pos_input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
physics_input = torch.randn(2, 5)

# Your dataset must now provide these "true" target values
true_function_values = torch.tensor([[0.5], [-0.2]])
true_derivatives = torch.tensor([[-1.5, 0.8], [0.1, -0.4]]) # True (∂f/∂r, ∂f/∂z)

# --- A SINGLE TRAINING STEP ---

# Step 1: Tell PyTorch to track gradients for the input position
pos_input.requires_grad = True

# Step 2: Clear old gradients from the optimizer
optimizer.zero_grad()

# Step 3: Forward pass to get the predicted function value
predicted_output = model(pos_input, physics_input)

# Step 4: Calculate the loss on the function value
loss_f = mse_loss(predicted_output, true_function_values)

# Step 5: Calculate the network's derivative
# We use torch.autograd.grad which is a more fundamental way to do this
# and shows the need for create_graph=True clearly.
predicted_derivatives = torch.autograd.grad(
    outputs=predicted_output,
    inputs=pos_input,
    grad_outputs=torch.ones_like(predicted_output),
    create_graph=True # <-- THIS IS THE CRITICAL PART!
)[0]

# Step 6: Calculate the loss on the derivatives
loss_deriv = mse_loss(predicted_derivatives, true_derivatives)

# Step 7: Combine the losses
total_loss = loss_f + lambda_deriv * loss_deriv

# Step 8: Backpropagate the total loss to compute gradients for the network's weights
total_loss.backward()

# Step 9: Update the network weights
optimizer.step()

# --- End of Training Step ---

print(f"Loss on function value: {loss_f.item():.4f}")
print(f"Loss on derivative value: {loss_deriv.item():.4f}")
print(f"Total Combined Loss: {total_loss.item():.4f}")