<a href="https://colab.research.google.com/github/cedamusk/Astrophysics/blob/main/PINNs_Implementation_for_planetary_motion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
!pip install torch numpy matplotlib seaborn



In [38]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import seaborn as sns

In [39]:
torch.manual_seed(42)
np.random.seed(42)

In [40]:
class ImprovedPINN(nn.Module):
  def __init__(self):
    super().__init__()
    #Wider network with better initialization
    self.network=nn.Sequential(
        nn.Linear(1, 128),
        nn.Tanh(),
        nn.Linear(128, 128),
        nn.Tanh(),
        nn.Linear(128, 128),
        nn.Tanh(),
        nn.Linear(128, 64),
        nn.Tanh(),
        nn.Linear(64, 2)
    )

    #Xavier initialization
    for layer in self.network:
      if isinstance(layer, nn.Linear):
        nn.init.xavier_normal_(layer.weight)
        nn.init.zeros_(layer.bias)

  def forward(self, t):
    return self.network(t)

  def compute_derivatives(self, t):
    t.requires_grad_(True)
    xy=self.forward(t)

    #Compute first derivaticves for both x and y components
    dxy_dt=torch.autograd.grad(
        xy, t,
        grad_outputs=torch.ones_like(xy),
        create_graph=True,
        allow_unused=True
    )[0]


    #Reshape derivatives to match the expected dimensions
    dxy_dt=dxy_dt.reshape(-1, 2)

    d2xy_dt2_x=torch.autograd.grad(
        dxy_dt[:, 0], t,
        grad_outputs=torch.ones_like(dxy_dt[:, 0]),
        create_graph=True,
        allow_unused=True
    )[0]

    d2xy_dt2_y=torch.autograd.grad(
        dxy_dt[:, 1], t,
        grad_outputs=torch.ones_like(dxy_dt[:, 1]),
        create_graph=True,
        allow_unused=True
    )

    #Combine x and y components of second derivatives
    #Check if d2xy_dt2_x or d2xy_dt2_y is a tuple (None, Tensor)
    #and extract the Tensor elements if necessary

  # FIX: Access the gradient value using index 1 if it's a tuple
  #and has the necessary gradient at index 1, otherwise use the original values
    d2xy_dt2=d2xy_dt2_x[1] if isinstance(d2xy_dt2_x, tuple) else d2xy_dt2_x
    d2xy_dt2_y=d2xy_dt2_y[1] if isinstance(d2xy_dt2_y, tuple) else d2xy_dt2_y

    d2xy_dt2=torch.stack([d2xy_dt2_x, d2xy_dt2_y], dim=1)

    return xy, dxy_dt, d2xy_dt2

In [41]:
def generate_orbital_data(n_points=1000, noise_level=0.005): #Reduced noise
  """Generate cleaner orbital data"""
  t=np.linspace(0, 10, n_points)

  #Orbital parameters (more stable orbit)
  r=1.0
  omega=2*np.pi/5

  #True solution
  x=r*np.cos(omega*t)
  y=r*np.sin(omega*t)

  #Add reduced noise
  x+=noise_level*np.random.randn(n_points)
  y+=noise_level*np.random.randn(n_points)

  return t, x, y

In [42]:
def physics_loss(model, t, normalize=True):
  """Improved physics loss with normalization"""
  xy, dxy_dt, d2xy_dt2=model.compute_derivatives(t)

  #Gravitational parameter
  k=4*np.pi**2

  #Position vectors
  r=torch.sqrt(xy[:, 0]**2 +xy[:, 1]**2)

  #physics residuals
  residual_x=d2xy_dt2[:, 0]+k*xy[:, 0]/ (r**3)
  residual_y=d2xy_dt2[:, 1]+k*xy[:, 1]/ (r**3)

  if normalize:
    #Normalize residuals by the magnitude of terms
    scale=torch.mean(torch.abs(k*xy/ (r**3).unsqueeze(1)))
    residual_x=residual_x/ scale
    residual_y=residual_y /scale

  return torch.mean(residual_x**2+residual_y**2)



In [43]:
def train_pinn(model, t_data, xy_data, n_epochs=10000):
   #More epochs
   """Improved training process"""
   optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
   scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=100, factor=0.5)

   t_torch=torch.FloatTensor(t_data).reshape(-1, 1)
   xy_torch=torch.FloatTensor(xy_data)

   #Normalize data
   xy_scale=torch.max(torch.abs(xy_torch))
   xy_torch=xy_torch/ xy_scale

   best_loss=float('inf')
   best_state=None
   losses=[]

   for epoch in range(n_epochs):
    optimizer.zero_grad()

    xy_pred=model(t_torch)
    data_loss=torch.mean((xy_pred-xy_torch)**2)
    phys_loss=physics_loss(model, t_torch)

    #Dynamic weighting of losses
    physics_weight=0.01*(1-np.exp(-epoch/1000)) #Gradually increase physics weight
    total_loss=data_loss+physics_weight* phys_loss

    total_loss.backward()
    optimizer.step()
    scheduler.step(total_loss)

    losses.append([total_loss.item(), data_loss.item(), phys_loss.item()])

    #save best model
    if total_loss.item() < best_loss:
      best_loss=total_loss.item()
      best_state=model.state_dict().copy()


    if (epoch+1)% 500==0:
      print(f'Epoch [{epoch+1}/{n_epochs}],'
            f'Loss: {total_loss.item():.4f},'
            f'Data Loss: {data_loss.item():.4f},'
            f'Physics Loss: {phys_loss.item():.4f}')


  #Restore best model
    model.load_state_dict(best_state)
    return np.array(losses)



In [44]:
def main():
  #Generate data
  t_data, x_data, y_data=generate_orbital_data()
  xy_data=np.stack([x_data, y_data], axis=1)

  #Create and train the model
  model=ImprovedPINN()
  losses=train_pinn(model, t_data, xy_data)

  plot_results(model, t_data, xy_data, losses)

def plot_results(model, t_data, xy_data, losses):

  #plot loss
  plt.figure(figsize=(10, 4))
  plt.semilogy(losses[:, 0], label='Total Loss')
  plt.semilogy(losses[:, 1], label='Data Loss')
  plt.semilogy(losses[:, 2], label='Physics Loss')
  plt.xlabel('Epoch')
  plt.ylabel("Loss (log Scale)")
  plt.legend()
  plt.grid(True)
  plt.show()

  #Evaluate predictions
  model.eval()
  with torch.no_grad():
    t_torch=torch.FloatTensor(t_data).reshape(-1,1)
    xy_pred=model(t_torch).numpy()

  #plot results
  plt.figure(figsize=(12, 5))
  plt.subplot(121)
  plt.plot(xy_data[:, 0], xy_data[:, 1], 'b.', label="Data")
  plt.plot(xy_pred[:, 0], xy_pred[:, 1], 'r-', label='PINN')
  plt.plot(0, 0, 'y*', markersize=15)
  plt.xlabel('X Position')
  plt.ylabel('Y Position')
  plt.legend()
  plt.axis('equal')
  plt.grid(True)

  plt.subplot(122)
  plt.plot(t_data, xy_data[:, 0], 'b.', label='Data X')
  plt.plot(t_data, xy_pred[:, 0], 'r-', label='PINN X')
  plt.xlabel('Time')
  plt.ylabel('Position')
  plt.legend()
  plt.grid(True)
  plt.show()

In [45]:
if __name__=="__main__":
  main()

IndexError: tuple index out of range