In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


In [2]:
class DEMNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4):
        super(DEMNet, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.Tanh()]
        for _ in range(num_layers - 1):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.Tanh()]
        layers += [nn.Linear(hidden_dim, output_dim)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [3]:
def energy_functional(u_pred, x):
    # Assuming a 1D example of linear elasticity, we can define a simple strain energy
    # as an example; this function must be adapted to your specific application.
    u_x = torch.autograd.grad(u_pred, x, grad_outputs=torch.ones_like(u_pred), create_graph=True)[0]
    energy = torch.sum(0.5 * (u_x ** 2))  # Strain energy example (for linear elasticity)
    return energy



In [4]:
def train(model, optimizer, x, epochs=1000):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        u_pred = model(x)
        energy = energy_functional(u_pred, x)
        energy.backward()
        optimizer.step()
        
        if epoch % 100 == 0:
            print(f'Epoch {epoch}: Energy = {energy.item()}')

In [5]:
x = torch.linspace(0, 1, 100, requires_grad=True).view(-1, 1)  # 1D spatial domain

# Initialize the model, optimizer, and start training
input_dim = 1
hidden_dim = 20
output_dim = 1
model = DEMNet(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train(model, optimizer, x)

Epoch 0: Energy = 0.06570553779602051
Epoch 100: Energy = 4.560104571282864e-06
Epoch 200: Energy = 1.842128881435201e-06
Epoch 300: Energy = 7.893568181316368e-07
Epoch 400: Energy = 3.622546103088098e-07
Epoch 500: Energy = 2.2807618904607807e-07
Epoch 600: Energy = 1.9321160493745992e-07
Epoch 700: Energy = 1.8379013511093945e-07
Epoch 800: Energy = 1.7925103179550206e-07
Epoch 900: Energy = 1.7527065665490227e-07
