In [1]:
# Physics modules
from metric import metric
from hybrid_eos import hybrid_eos

# Numpy and matplotlib
import numpy as np 
import matplotlib.pyplot as plt 

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init 
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import matplotlib.pyplot as plt

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
class c2p_NN(nn.Module):
    """
    This class defines a neural network model for the conservative-to-primitive variable transformation.
    The network consists of several fully connected layers with ReLU activation functions and batch normalization.
    The input to the network is a tensor of conservative variables, and the output is a tensor of primitive variables.
    """
    def __init__(self, hidden_layers=3, neurons=50):
        super(c2p_NN, self).__init__()
        self.activation = nn.Tanh 
        input_dim = 3 
        output_dim = 1 
        layers = []
        layers.append(nn.BatchNorm1d(input_dim))
        layers.append(nn.Linear(input_dim, neurons))  # Input layer
        layers.append(self.activation())  # Activation
        for _ in range(hidden_layers):
            layers.append(nn.Linear(neurons, neurons))  # Hidden layers
            layers.append(nn.BatchNorm1d(neurons)) # Normalization
            layers.append(self.activation())  # Activation
            
        layers.append(nn.Linear(neurons, output_dim))  # Output layer
        # Last relu because z >= 0 
        layers.append(nn.ReLU()                     )  # Activation
        self.network = nn.Sequential(*layers)
        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """ Initialize weights using Kaiming initialization """
        if isinstance(module, nn.Linear):
            #init.kaiming_normal_(module.weight, nonlinearity="relu")
            init.xavier_normal(module.weight)
            init.zeros_(module.bias)
    
    def forward(self, x):
        return self.network(x)


In [8]:
# Minkowski metric 
eta = metric(
torch.eye(3,device=device), torch.zeros(3,device=device), torch.ones(1,device=device)
)
# Gamma = 2 EOS with ideal gas thermal contrib 
eos = hybrid_eos(100,2,1.8)

In [9]:
def setup_initial_state_random(metric,eos,N,device,lrhomin=-12,lrhomax=-2.8,ltempmin=-1,ltempmax=2.3,Wmin=1,Wmax=2):
    # Get W, rho and T 
    W = Wmin + (Wmax-Wmin) * torch.rand(N,device=device)
    rho = 10**( lrhomin + (lrhomax-lrhomin) * torch.rand(N,device=device) )
    T = 10**( ltempmin + (ltempmax-ltempmin) * torch.rand(N,device=device) )
    # Call EOS to get press and eps 
    press,eps = eos.press_eps__temp_rho(T,rho)
    # Compute z 
    Z = torch.sqrt(1 - 1/W**2) * W 
    
    # Compute conserved vars 
    sqrtg = metric.sqrtg 
    u0 = W / sqrtg 
    dens = sqrtg * W * rho 
    
    rho0_h = rho * ( 1 + eps ) + press 
    g4uptt = - 1/metric.alp**2 
    Tuptt = rho0_h * u0**2 + press * g4uptt 
    tau = metric.alp**2 * sqrtg * Tuptt - dens 
    
    S = torch.sqrt((W**2-1)) * rho0_h * W
    # Assemble output 
    C = torch.cat((dens.view(-1,1)/metric.sqrtg,tau.view(-1,1)/dens.view(-1,1),S.view(-1,1)/dens.view(-1,1)),dim=1)
    return C, Z.view(-1,1)

def setup_initial_state_meshgrid(metric,eos,N,device,lrhomin=-12,lrhomax=-2.8,ltempmin=-1,ltempmax=2.3,Wmin=1,Wmax=2):
    # Get W, rho and T 
    W = torch.linspace(Wmin,Wmax,N,device=device)
    rho = 10**( torch.linspace(lrhomin,lrhomax,N,device=device) )
    T = 10**( torch.linspace(ltempmin,ltempmax,N,device=device) )
    W, rho, T = torch.meshgrid(W,rho,T, indexing='ij')
    
    W = W.flatten() 
    rho = rho.flatten()
    T = T.flatten() 
    
    # Call EOS to get press and eps 
    press,eps = eos.press_eps__temp_rho(T,rho)
    # Compute z 
    Z = torch.sqrt(1 - 1/W**2) * W 
    
    # Compute conserved vars 
    sqrtg = metric.sqrtg 
    u0 = W / sqrtg 
    dens = sqrtg * W * rho 
    
    rho0_h = rho * ( 1 + eps ) + press 
    g4uptt = - 1/metric.alp**2 
    Tuptt = rho0_h * u0**2 + press * g4uptt 
    tau = metric.alp**2 * sqrtg * Tuptt - dens 
    
    S = torch.sqrt((W**2-1)) * rho0_h * W
    # Assemble output 
    C = torch.cat((dens.view(-1,1)/metric.sqrtg,tau.view(-1,1)/dens.view(-1,1),S.view(-1,1)/dens.view(-1,1)),dim=1)
    return C, Z.view(-1,1)
    
def sanity_check(Z,C, metric, eos):
    t,q,r = torch.split(C,[1,1,1], dim=1)
    htilde = h__z(Z,C,eos)
    
    return torch.mean((Z - r/htilde)**2)

def W__z(z):
    return torch.sqrt(1 + z**2)

def rho__z(z,C):
    return C[:,0].view(-1,1) / W__z(z)

def eps__z(z,C):
    q = C[:,1].view(-1,1)
    r = C[:,2].view(-1,1)
    W = W__z(z)
    return W * q - z * r + z**2/(1+W)

def a__z(z,C,eos):
    eps = eps__z(z,C)
    rho = rho__z(z,C)
    press = eos.press__eps_rho(eps,rho)
    return press/(rho*(1+eps))

def h__z(z,C,eos):
    eps = eps__z(z,C)
    a = a__z(z,C,eos)
    return (1 + eps)*(1+a)

In [11]:
C, Z = setup_initial_state_meshgrid(eta,eos,20,device)
err = sanity_check(Z,C,eta,eos)

print(err)

tensor(1.8528e-13, device='cuda:0')


In [12]:
def compute_loss(model, C, eos, metric):
    '''
    Eq (C3) of https://arxiv.org/pdf/1306.4953.pdf
    '''
    Z_pred = model(C)
    htilde = h__z(Z_pred,C,eos)
    return F.mse_loss(Z_pred,C[:,2].view(-1,1)/htilde)

In [13]:
def train_c2p_model(model,optimizer,scheduler,num_epochs,C,eos,metric):
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        
        loss = compute_loss(model,C,eos,metric)
        loss.backward()
        optimizer.step()
        
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(loss)
        else:
            scheduler.step()
        
        # Print progress
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}, LR: {optimizer.param_groups[0]['lr']:.6e}")

In [14]:
neurons = 50 
layers  = 4 

net = c2p_NN(hidden_layers=layers,neurons=neurons).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=500)

train_c2p_model(net,optimizer,scheduler,10000,C,eos,eta)

# Do one more step with LBFGS
optimizer = torch.optim.LBFGS(net.parameters(), lr=1.0, max_iter=500)

def closure():
    optimizer.zero_grad()
    loss = compute_loss(net, C, eos, eta)
    loss.backward()
    return loss

optimizer.step(closure)

loss_value = closure()  # Compute the loss
print(f"Final loss: {loss_value.item()}")  

torch.save(net.state_dict(), f"model_L{layers}N{neurons}.pt")

  return F.linear(input, self.weight, self.bias)


Epoch 0, Loss: 0.168570, LR: 1.000000e-03
Epoch 1000, Loss: 0.002842, LR: 1.000000e-03
Epoch 2000, Loss: 0.002714, LR: 1.000000e-03
Epoch 3000, Loss: 0.002533, LR: 1.000000e-03
Epoch 4000, Loss: 0.002506, LR: 1.000000e-03
Epoch 5000, Loss: 0.002234, LR: 1.000000e-03
Epoch 6000, Loss: 0.002238, LR: 1.000000e-03
Epoch 7000, Loss: 0.002077, LR: 1.000000e-03
Epoch 8000, Loss: 0.001993, LR: 1.000000e-03
Epoch 9000, Loss: 0.001962, LR: 5.000000e-04
Final loss: 0.0003262293175794184


In [None]:
C_test, Z_test = setup_initial_state_random(eta,eos,1000)
Z_pred = net(C_test)
print("Sanity check: ",sanity_check(Z_test,C_test,eta,eos))
print("Network error: ", F.mse_loss(Z_pred,Z_test) )

Sanity check:  tensor(2.1997e-13)
Network error:  tensor(0.0071, grad_fn=<MseLossBackward0>)


In [None]:
x = torch.linspace(1,2,100)
y = torch.linspace(1,2,100)
z = torch.linspace(1,2,100)
X,Y,Z = torch.meshgrid(x,y,z)