In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

# Example of a simple model with 2x2x2 layers
class SimpleModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(2, 2, bias=True)
    self.sigmoid = nn.Sigmoid()
    self.linear2 = nn.Linear(2, 2, bias=True)

    self.linear1.weight = torch.nn.Parameter(torch.tensor([[0.2, 0.4], 
                                                           [-0.3, -0.5]]))
    self.linear1.bias = torch.nn.Parameter(torch.tensor([0.1, -0.2]))
    self.linear2.weight = torch.nn.Parameter(torch.tensor([[0.1, -0.6],
                                                           [0.3, -0.2]]))
    self.linear2.bias = torch.nn.Parameter(torch.tensor([0.05, -0.1]))

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.linear1(x)
    x = self.sigmoid(x)
    x = self.linear2(x)
    x = self.sigmoid(x)
    return x

# Regularization term based on gradient uniformity (variance across neurons)
def gradient_uniformity_penalty(model: nn.Module):
  loss = torch.tensor(0.0, requires_grad=True)
  for param in model.parameters():
    loss = loss + param.grad.var(dim=-1).mean()
  return loss

# Instantiate the model
model = SimpleModel()

# Example input and target
input_tensor = torch.tensor([[0.5, 0.1]],)
target = torch.tensor([[0.7, 0.3]]) # The expected output size

# Example loss (e.g., Mean Squared Error)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Forward pass
output = model(input_tensor)

# Compute the original loss
loss: torch.Tensor = criterion(output, target)

# Backward pass to compute gradients
loss.backward(retain_graph=True)

# Compute the gradient penalty based on gradient variance
grad_penalty = gradient_uniformity_penalty(model)

# Add the gradient penalty to the original loss
lambda_penalty = 0.1
total_loss: torch.Tensor = loss + lambda_penalty * grad_penalty

# Backpropagation
total_loss.backward()
optimizer.step()

# Clear the gradients
optimizer.zero_grad()
