In [346]:
import numpy as np
import matplotlib.pyplot as plt
import math
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch.autograd import gradcheck
torch.manual_seed(42)
import random
random.seed(0)
np.random.seed(0)

In [347]:
class HarmonicOscillatorDataset(Dataset):
    def __init__(self, file_path):
        # Load the dataset
        dataset = np.load(file_path)[:, ::1000, :]

        # Convert the dataset to PyTorch tensors
        self.p_values = torch.tensor(dataset[:, :, 0], dtype=torch.float32)
        self.q_values = torch.tensor(dataset[:, :, 1], dtype=torch.float32)
        self.h_values = torch.tensor(dataset[:, :, 2], dtype=torch.float32)

        # Ensure consistent length for all tensors
        assert len(self.p_values) == len(self.q_values) == len(self.h_values)
        self.length = len(self.p_values)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.p_values[idx], self.q_values[idx], self.h_values[idx]

In [348]:
class MLP_General_Hamilt(nn.Module):
    def __init__(self, n_input, n_hidden):
        super(MLP_General_Hamilt, self).__init__()
        self.linear1 = nn.Linear(2*n_input, n_hidden)
        self.linear2 = nn.Linear(n_hidden, n_hidden)
        self.linear3 = nn.Linear(n_hidden, n_hidden)
        self.linear4 = nn.Linear(n_hidden, n_hidden)
        self.linear5 = nn.Linear(n_hidden, n_hidden)
        self.linear6 = nn.Linear(n_hidden, 1)

    def forward(self, p, q):
        pq = torch.cat((p, q), 1)
        h = self.linear1(pq)
        h = h.sigmoid_()
        h = self.linear2(h)
        h = h.sigmoid_()
        h = self.linear3(h)
        h = h.sigmoid_()
        h = self.linear4(h)
        h = h.sigmoid_()
        h = self.linear5(h)
        h = h.sigmoid_()
        h = self.linear6(h)
        return h

class MLP2H_Separable_Hamilt(nn.Module):
    def __init__(self, n_hidden, input_size):
        super(MLP2H_Separable_Hamilt, self).__init__()
        self.linear_K1 = nn.Linear(input_size, n_hidden)
        self.linear_K1B = nn.Linear(n_hidden, n_hidden)
        self.linear_K2 = nn.Linear(n_hidden, 1)
        self.linear_P1 = nn.Linear(input_size, n_hidden)
        self.linear_P1B = nn.Linear(n_hidden, n_hidden)
        self.linear_P2 = nn.Linear(n_hidden, 1)

    def kinetic_energy(self, p):
        h_pre = self.linear_K1(p)
        h = h_pre.tanh_()
        # h = h_pre.sigmoid()
        h_pre_B = self.linear_K1B(h)
        h_B = h_pre_B.tanh_()
        # h_B = h_pre_B.sigmoid()
        return self.linear_K2(h_B)

    def potential_energy(self, q):
        h_pre = self.linear_P1(q)
        h = h_pre.tanh_()
        h_pre_B = self.linear_P1B(h)
        h_B = h_pre_B.tanh_()
        return self.linear_P2(h_B)

    def forward(self, p, q):
        return self.kinetic_energy(p) + self.potential_energy(q)

In [360]:
class BackProp_HNN(nn.Module):
    def __init__(self, f, T, dt, dim, integrator, iter):
        super(BackProp_HNN, self).__init__()
        self.f = f
        self.dt = dt
        self.dim = dim
        self.integrator = integrator
        self.T = T
        self.num_steps = int(1/dt)
        self.iter = iter
        
    def forward(self, p0, q0):
        trajectories = torch.zeros((p0.shape[0], self.T, self.dim * 2)).to(device)
        trajectories[:, 0, 0] = p0
        trajectories[:, 0, 1] = q0
        p = p0
        q = q0
        for timestep in range(1, T):
            for _ in range(self.num_steps):
                if self.integrator == "euler":
                    x = self.euler_step(torch.stack([p, q], dim = 1))
                elif self.integrator == "rk2":
                    x = self.rk2_step(torch.stack([p, q], dim = 1))
                elif self.integrator == "sv":
                    x = self.sv_step(x = torch.stack([p, q], dim = 1), iterations = iter)
                elif self.integrator == "pc":
                    x = self.pc_step(x = torch.stack([p, q], dim = 1), iterations = iter)
                # print(x.shape)
                p = x[:, 0]
                q = x[:, 1]
            trajectories[:, timestep, 0] = p
            trajectories[:, timestep, 1] = q
        return trajectories
    
    def euler_step(self, x):
        return x + self.dt * self.dynamics_fn(x)
    
    def leapfrog_step(self, x):
        p_ = x[:, 0] + 0.5 * self.dt * self.dynamics_fn(x)[:, 0]
        q = x[:, 1] + self.dt * self.dynamics_fn(torch.stack([p_, x[:, 1]], dim = 1))[:, 1]
        p = p_ + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_, q], dim = 1))[:, 0]
        return torch.stack([p, q], dim = 1)

    def rk2_step(self, x):
        p0 = x[:,0]
        q0 = x[:,1]
        out = self.dynamics_fn(x)
        p1 = out[:, 0]
        q1 = out[:, 1]

        out = self.dynamics_fn(torch.stack([p0 + 0.5 * self.dt * p1, q0 + 0.5 * self.dt * q1], dim = 1))
        p2 = out[:, 0]
        q2 = out[:, 1]

        p = p0 + self.dt * (p1 + p2)/2
        q = q0 + self.dt * (q1 + q2)/2
        return torch.stack([p, q], dim = 1)


    def sv_step(self, x, x_init = None, iterations = 1):
        p0 = x[:, 0]
        q0 = x[:, 1]
        if x_init == None:
            p_half = p0
        else:
            p_half = (x_init[:, 0] + p0)/2
        for _ in range(iterations):
            p_half = p0 + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q0], dim = 1))[:, 0]
        q1 = q0 + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q0], dim = 1))[:, 1]
        if x_init == None:
            q2 = q0
        else:
            q2 = (x_init[:, 1] + q0)/2
        for _ in range(iterations):
            q2 = q1 + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q2], dim = 1))[:, 1]
        p1 = p_half + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q2], dim = 1))[:, 0]
        return torch.stack([p1, q2], dim = 1)

    def pc_step(self, x, iterations = 1):
        out = self.rk2_step(x)
        out = self.sv_step(x, out, iterations)
        return out
            
    def dynamics_fn(self, x):
        p = x[:, 0]
        q = x[:, 1]
        p.requires_grad_(True)
        q.requires_grad_(True)
        # h = self.f(p, q)
        h = self.f(p.unsqueeze(-1), q.unsqueeze(-1))
        # print("dynamics_fn", h.shape)
        grad_p, = grad(h.sum(), p, create_graph=True, allow_unused=True)
        grad_q, = grad(h.sum(), q, create_graph=True, allow_unused=True)
        return torch.stack([-grad_q, grad_p], dim = 1)

In [None]:
class BackProp_HNN(nn.Module):
    def __init__(self, f, T, dt, dim, integrator, iter):
        super(BackProp_HNN, self).__init__()
        self.f = f
        self.dt = dt
        self.dim = dim
        self.integrator = integrator
        self.T = T
        self.num_steps = int(1/dt)
        self.iter = iter

    def forward(self, p0, q0):
        trajectories = torch.zeros((p0.shape[0], self.T, self.dim * 2)).to(device)
        trajectories[:, 0, 0] = p0
        trajectories[:, 0, 1] = q0
        p = p0
        q = q0
        for timestep in range(1, T):
            for _ in range(self.num_steps):
                if self.integrator == "euler":
                    x = self.euler_step(torch.stack([p, q], dim = 1))
                elif self.integrator == "rk2":
                    x = self.rk2_step(torch.stack([p, q], dim = 1))
                elif self.integrator == "sv":
                    x = self.sv_step(torch.stack([p, q], dim = 1), x_init = None, iterations = iter)
                elif self.integrator == "pc":
                    x = self.rk2_step(torch.stack([p, q], dim = 1))
                    x = self.sv_step(torch.stack([p, q], dim = 1), x_init = x, iterations = iter)
                    # x = self.pc_step(torch.stack([p, q], dim = 1), iterations = iter)
                # print(x.shape)
                p = x[:, 0]
                q = x[:, 1]
            trajectories[:, timestep, 0] = p
            trajectories[:, timestep, 1] = q
        return trajectories

    def euler_step(self, x):
        return x + self.dt * self.dynamics_fn(x)

    def rk2_step(self, x):
        p0 = x[:,0]
        q0 = x[:,1]
        out = self.dynamics_fn(x)
        p1 = out[:, 0]
        q1 = out[:, 1]

        out = self.dynamics_fn(torch.stack([p0 + 0.5 * self.dt * p1, q0 + 0.5 * self.dt * q1], dim = 1))
        p2 = out[:, 0]
        q2 = out[:, 1]

        p = p0 + self.dt * (p1 + p2)/2
        q = q0 + self.dt * (q1 + q2)/2
        return torch.stack([p, q], dim = 1)


    def sv_step(self, x, x_init = None, iterations = 1):
        p0 = x[:, 0]
        q0 = x[:, 1]
        if x_init == None:
            p_half = p0
        else:
            p_half = (x_init[:, 0] + p0)/2
        for _ in range(iterations):
            p_half = p0 + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q0], dim = 1))[:, 0]
        q1 = q0 + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q0], dim = 1))[:, 1]
        if x_init == None:
            q2 = q1
        else:
            q2 = (x_init[:, 1] + q1)/2
        for _ in range(iterations):
            q2 = q1 + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q2], dim = 1))[:, 1]
        p1 = p_half + 0.5 * self.dt * self.dynamics_fn(torch.stack([p_half, q2], dim = 1))[:, 0]
        return torch.stack([p1, q2], dim = 1)

    # def pc_step(self, x, iterations = 1):
        # out = self.rk2_step(x)
        # out = self.sv_step(x, out, iterations)
        # return self.rk2_step(x)

    def dynamics_fn(self, x):
        p = x[:, 0]
        q = x[:, 1]
        p.requires_grad_(True)
        q.requires_grad_(True)
        # h = self.f(p, q)
        h = self.f(p.unsqueeze(-1), q.unsqueeze(-1))
        # print("dynamics_fn", h.shape)
        grad_p, = grad(h.sum(), p, create_graph=True, allow_unused=True)
        grad_q, = grad(h.sum(), q, create_graph=True, allow_unused=True)
        return torch.stack([-grad_q, grad_p], dim = 1)

In [None]:
for integrator in ["euler", "rk2", "sv", "pc"]:

    f1 = MLP_General_Hamilt(n_input = 1, n_hidden = 64)
    # f1 = MLP2H_Separable_Hamilt(n_hidden=256, input_size=1).to(device).double()
    T = 2
    dt = 0.1
    dim = 1
    num_epochs = 100
    iter = 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create the model, criterion, optimizer, and data loaders
    model = BackProp_HNN(f1, T, dt, dim, integrator, iter).to(device).double()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-1)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.8)
    train_mass_spring = HarmonicOscillatorDataset("./data/mass_spring/train.npy")
    data_loader = DataLoader(train_mass_spring, batch_size=32, shuffle=False)
    val_mass_spring = HarmonicOscillatorDataset("./data/mass_spring/val.npy")
    val_loader = DataLoader(val_mass_spring, batch_size=32, shuffle=False)

    train_loss = []
    val_loss = []

    best_val_loss = float('inf')
    best_model_path = integrator+str(dt)+'_best_model.pt'

    with open(integrator+"_loss_"+str(dt)+".txt", "w") as loss_file:
        for epoch in tqdm(range(num_epochs)):
            model.train()
            loss_epoch = []
            for batch in data_loader:
                optimizer.zero_grad()
                p_batch, q_batch, _ = batch
                p0_batch = p_batch[:, 0].to(device).double()
                q0_batch = q_batch[:, 1].to(device).double()
                simulated_trajectory = model(p0_batch, q0_batch)
                simulated_trajectory = simulated_trajectory.view(simulated_trajectory.size(0), -1)
                trajectory = torch.stack([p_batch, q_batch], axis=2).to(device)
                trajectory = trajectory.view(trajectory.size(0), -1).to(device)
                loss = criterion(trajectory, simulated_trajectory)
                loss.backward()
                optimizer.step()
                scheduler.step()
                loss_epoch.append(loss.item())
            avg_train_loss = sum(loss_epoch)/len(loss_epoch)
            train_loss.append(avg_train_loss)

            model.eval()  # Set model to evaluation mode
            val_loss_epoch = []
            for batch in val_loader:
                optimizer.zero_grad()
                p_batch, q_batch, _ = batch
                p0_batch = p_batch[:, 0].to(device).double()
                q0_batch = q_batch[:, 1].to(device).double()
                simulated_trajectory = model(p0_batch, q0_batch)
                simulated_trajectory = simulated_trajectory.view(simulated_trajectory.size(0), -1)
                trajectory = torch.stack([p_batch, q_batch], axis=2).to(device)
                trajectory = trajectory.view(trajectory.size(0), -1).to(device)
                loss = criterion(trajectory, simulated_trajectory)
                val_loss_epoch.append(loss.item())
            avg_val_loss = sum(val_loss_epoch)/len(val_loss_epoch)
            val_loss.append(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), best_model_path)

            loss_file.write(f"{avg_train_loss}, {avg_val_loss}\n")

In [None]:
class Hamiltonian_Adjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, p0, q0, f, T, dt, back_dt, dim, integrator, back_integrator, iter, *adjoint_params):
        with torch.no_grad():
            trajectories = torch.zeros((p0.shape[0], T, dim * 2))
            trajectories[:, 0, 0] = p0
            trajectories[:, 0, 1] = q0
            p = p0
            q = q0
            num_steps = int(1/dt)
            if integrator == "euler":
                integrate_fn = Hamiltonian_Adjoint.euler_step
            elif integrator == "rk2":
                integrate_fn = Hamiltonian_Adjoint.rk2_step
            elif integrator == "sv":
                integrate_fn = Hamiltonian_Adjoint.sv_step                
            # elif integrator == "rk2":
            #     integrate_fn = Hamiltonian_Adjoint.rk2
            # elif integrate_fn == "sv":
            #     integrate_fn = Hamiltonian_Adjoint.sv
            dynamics = Hamiltonian_Adjoint.dynamics_fn
            for timestep in range(1, T):
                for _ in range(num_steps):
                    x = integrate_fn([torch.stack([p, q], dim = 1)], dynamics, f, dt)[0]
                    # if integrator == "predictor_corrector":
                    #     x = integrate_fn([torch.stack([p, q], dim = 1)], dynamics, None, f, dt)[0]
                    # else:
                    #     x = integrate_fn([torch.stack([p, q], dim = 1)], dynamics, f, dt)[0]
                    p = x[:, 0]
                    q = x[:, 1]
                trajectories[:, timestep, 0] = p
                trajectories[:, timestep, 1] = q
            ctx.save_for_backward(trajectories, p0, q0, *adjoint_params)
            ctx.T = T
            ctx.dt = dt
            ctx.dim = dim
            ctx.back_integrator = back_integrator
            ctx.iter = iter
            ctx.back_dt = back_dt
            ctx.f = f
            # print("trajectories p", trajectories[0, : , 0])
            # return trajectories[:, :, 0], trajectories[:, :, 1]
            return trajectories

    @staticmethod
    def backward(ctx, dldz):
        with torch.no_grad():
            dldp = dldz[:, :, 0]
            dldq = dldz[:, :, 1]
            # print("dldp", dldp.shape)
            # print("dldq", dldq.shape)
            trajectories, p0, q0, *adjoint_params = ctx.saved_tensors
            T = ctx.T
            dt = ctx.dt
            dim = ctx.dim
            back_integrator = ctx.back_integrator
            back_dt = ctx.back_dt
            iter = ctx.iter
            if back_integrator == "euler":
                integrate_fn = Hamiltonian_Adjoint.euler_step
            elif back_integrator == "rk2":
                integrate_fn = Hamiltonian_Adjoint.rk2_step
            elif back_integrator == "sv":
                integrate_fn = Hamiltonian_Adjoint.sv_step
            # elif back_integrator == "second_order":
            #     integrate_fn = Hamiltonian_Adjoint.heuns_method
            # elif back_integrator == "predictor_corrector":
            #     integrate_fn = Hamiltonian_Adjoint.predictor_corrector
            num_steps = int(1/back_dt)
            h = T/num_steps
            f = ctx.f

            def augmented_dynamics(y_aug, Func):
                #get p and q at current time step
                p = y_aug[0]
                q = y_aug[1]
                with torch.enable_grad():
                    p = p.detach().requires_grad_(True).double()
                    q = q.detach().requires_grad_(True).double()
                    H = Func(p.unsqueeze(-1), q.unsqueeze(-1))
                    # following three lines are necessary to bypass a bug in pytorch's autograd function, and do not contribute the actual algorithm whatsoever
                    _p = torch.as_strided(p, (), ())  # noqa
                    _q = torch.as_strided(q, (), ())  # noqa
                    _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params)
                    #partial gradient of Hamiltonian with respect to p and q
                    dhdp, = torch.autograd.grad(H.sum(), p, allow_unused = True, retain_graph = True, create_graph = True)
                    dhdq, = torch.autograd.grad(H.sum(), q, allow_unused = True, retain_graph = True, create_graph = True)
                    #partial gradient of Hamiltonian with p and q, the vector jacobian product being -adjoint_q, -lambda_q*dh/dpdq
                    dhdpdq_1 = torch.autograd.grad(dhdp, q, -y_aug[3], allow_unused=True, retain_graph=True)
                    # print("dhdpdq_1", dhdpdq_1[0])
                    #partial gradient of Hamiltonian with p and q, the vector jacobian product being adjoint_p, lambda_p*dh/dpdq
                    dhdpdq_2 = torch.autograd.grad(dhdp, q, y_aug[2], allow_unused=True, retain_graph=True)
                    # print("dhdpdq_2", dhdpdq_2[0])
                    #double partial gradients with respect to p and q adjusted by respective jacobian products as in the formula
                    #-lambda_q*dh/dpdp
                    dhdpdp = torch.autograd.grad(dhdp, p, -y_aug[3], allow_unused=True, retain_graph=True)
                    # print("dhdpdp", dhdpdp[0][:5])
                    #lambda_p*dh/dqdq
                    dhdqdq = torch.autograd.grad(dhdq, q, y_aug[2], allow_unused=True, retain_graph=True)
                    # print("dhdqdq", dhdqdq[0][:5])
                    #partial gradients with respect to parameters, p and q adjusted with respective adjoints as given in the formula
                    # -lambda_q*dh/dpdtheta
                    dhdpdw = torch.autograd.grad(dhdp, adjoint_params, y_aug[3], allow_unused=True, retain_graph=True)
                    #lambda_p*dh/dqdtheta
                    dhdqdw = torch.autograd.grad(dhdq, adjoint_params, -y_aug[2], allow_unused=True, retain_graph=True)
                    dhdpdq_1 = [torch.zeros_like(p) if param is None else param for param in dhdpdq_1]
                    dhdpdq_2 = [torch.zeros_like(q) if param is None else param for param in dhdpdq_2]
                    #setting gradients zero for parameters which may not directly contribute to Hamiltonian calculation at different time steps
                    dhdpdp = torch.zeros_like(p) if dhdpdp is None else dhdpdp
                    dhdqdq = torch.zeros_like(q) if dhdqdq is None else dhdqdq
                    dhdpdw = [torch.zeros_like(param) if vjp_param is None else vjp_param
                              for param, vjp_param in zip(adjoint_params, dhdpdw)]
                    dhdqdw = [torch.zeros_like(param) if vjp_param is None else vjp_param
                              for param, vjp_param in zip(adjoint_params, dhdqdw)]
                    #final gradient calculation (lambda_p*dh/dqdtheta - lambda_q*dh/dpdtheta)
                    dw = list(dhdp_param + dhdq_param for dhdp_param, dhdq_param in zip(dhdpdw, dhdqdw))
                    f_p = -dhdq
                    f_q = dhdp
                    #lambda_p = -lambda_q*dh/dpdq + lambda_p*dh/dqdq
                    adjoint_p = dhdpdq_2[0] + dhdpdp[0]
                    #lambda_q = -lambda_q*dh/dpdp + lambda_p*dh/dpdq
                    adjoint_q = dhdqdq[0] + dhdpdq_1[0]
                    # return [f_p, f_q, adjoint_p, adjoint_q]
                    return [f_p, f_q, adjoint_p, adjoint_q, *dw]
            adj_p = torch.zeros(dldp.shape[0]).to(dldp)
            adj_q = torch.zeros(dldq.shape[0]).to(dldq)
            adj_params = [torch.zeros_like(param).to(dldp) for param in adjoint_params]
            final_aug = None
            for i in range(T-1, 0, -1):
                p_next = trajectories[:, i, 0]
                q_next = trajectories[:, i, 1]
                adj_p += dldp[:, i]
                adj_q += dldq[:, i]
                aug_state = [p_next, q_next, adj_p, adj_q]
                aug_state.extend([torch.zeros_like(param).to(dldp) for param in adjoint_params])
                for _ in range(num_steps):
                    aug_state = integrate_fn(aug_state, augmented_dynamics, f, back_dt)
                    # if back_integrator == 'predictor_corrector':
                    #     aug_state = integrate_fn(aug_state, augmented_dynamics, f, back_dt, iter)
                    # else:
                    #     aug_state = integrate_fn(aug_state, augmented_dynamics, f, back_dt)
                    if not final_aug:
                        final_aug = aug_state[4:]
                        aug_state = aug_state[:4]
                        aug_state.extend([torch.zeros_like(param).to(dldp) for param in adjoint_params])
                    else:
                        adj_params = [val1 + val2 for val1, val2 in zip(adj_params, aug_state[4:])]
                        aug_state = aug_state[:4]
                        aug_state.extend([torch.zeros_like(param).to(dldp) for param in adjoint_params])
                adj_p = aug_state[2]
                adj_q = aug_state[3]
                # print("param", aug_state[4][:5])
                # adj_params = [val1 + val2 for val1, val2 in zip(adj_params, aug_state[4:])]

            adj_p += dldp[:, 0]
            adj_q += dldq[:, 0]
            p_next = trajectories[:, 0, 0]
            q_next = trajectories[:, 0, 1]
            aug_state = [p_next, q_next, adj_p, adj_q]
            aug_state.extend([torch.zeros_like(param).to(dldp) for param in adjoint_params])
            return_aug = augmented_dynamics(aug_state, f)
            avg = [0.5 * (val1 + val2) for val1, val2 in zip(final_aug, return_aug[4:])]
            adj_params = [val1 + val2 for val1, val2 in zip(avg, adj_params)]
            adj_params = [h * val for param, val in zip(adjoint_params, adj_params)]
            # adj_params = aug_state[4:]
            # print("adj_q", adj_q[:32])

        return adj_p, adj_q, None, None, None, None, None, None, None, None, *adj_params

    @staticmethod
    def euler_step(x, dynamics, f, dt):
        if len(x) > 2:
            func_out = dynamics(x, f)
            out = func_out[:4]
            result = [xi + dt * yi for xi, yi in zip(x, out)]
            return result + func_out[4:]
        out = dynamics(x, f)
        result = [xi + dt * yi for xi, yi in zip(x, out)]
        return result

    @staticmethod
    def sv_step(x, dynamics, f, dt, x_init = None, iterations = 1):
        if len(x) > 2:
            p0 = x[0]
            q0 = x[1]
            r0 = x[2]
            s0 = x[3]
            if x_init == None:
                p_half = p0
            else:
                p_half = (x_init[0] + p0)/2
            for _ in range(iterations):
                p_half = p0 + 0.5 * dt * dynamics([p_half, q0, r0, s0], f)[0]
            r1 = r0 + 0.5 * dynamics([p_half, q0, r0, s0], f)[2]
            if x_init == None:
                r2 = r1
            else:
                r2 = (x_init[2] + r1)/2
            for _ in range(iterations):
                r2 = r1 + 0.5 * dt * dynamics([p_half, q0, r2, s0], f)[2]
            p1 = p_half + 0.5 * dynamics([p_half, q0, r2, s0], f)[0]
            if x_init == None:
                s_half = s0
            else:
                s_half = (x_init[3] + s0)/2
            for _ in range(iterations):
                s_half = s0 + 0.5 * dt * dynamics([p0, q0, r0, s_half], f)[3]
            q1 = q0 + 0.5 * dynamics([p0, q0, r0, s_half], f)[1]
            if x_init == None:
                q2 = q1
            else:
                q2 = (x_init[1] + q1)/2
            for _ in range(iterations):
                q2 = q1 + 0.5 * dt * dynamics([p0, q2, r0, s_half], f)[1]
            s1 = s_half + 0.5 * dt * dynamics([p0, q2, r0, s_half], f)[3]
            return [p1, q2, r2, s1] + x[4:]

        p0 = x[0][:, 0]
        q0 = x[0][:, 1]
        if x_init == None:
            p_half = p0
        else:
            p_half = (x_init[0][:, 0] + p0)/2
        for _ in range(iterations):
            p_half = p0 + 0.5 * dt * dynamics([torch.stack([p_half, q0], dim = 1)], f)[0][:, 0]
        q1 = q0 + 0.5 * dt * dynamics([torch.stack([p_half, q0], dim = 1)], f)[0][:, 1]
        if x_init == None:
            q2 = q1
        else:
            q2 = (x_init[0][:, 1] + q1)/2
        for _ in range(iterations):
            q2 = q1 + 0.5 * dt * dynamics([torch.stack([p_half, q2], dim = 1)], f)[0][:, 1]
        p1 = p_half + 0.5 * dt * dynamics([torch.stack([p_half, q2], dim = 1)], f)[0][:, 0]
        return [torch.stack([p1, q2], dim = 1)]

        

    @staticmethod
    def rk2_step(x, dynamics, f, dt):
        if len(x) > 2:
            p0 = x[0]
            q0 = x[1]
            r0 = x[2]
            s0 = x[3]
            func_out = dynamics(x, f)
            out = func_out[:4]
            p1 = out[0]
            q1 = out[1]
            r1 = out[2]
            s1 = out[3]

            func_out = dynamics([p0 + 0.5 * dt * p1, q0 + 0.5 * dt * q1, r0 + 0.5 * dt * r1, s0 + 0.5 * dt * s1], f)
            out = func_out[:4]
            p2 = out[0]
            q2 = out[1]
            r2 = out[2]
            s2 = out[3]
            
            p = p0 + dt * (p1 + p2)/2
            q = q0 + dt * (q1 + q2)/2
            r = r0 + dt * (r1 + r2)/2
            s = s0 + dt * (s1 + s2)/2

            return [p, q, r, s] + x[4:]        
        
        p0 = x[0][:,0]
        q0 = x[0][:,1]
        out = dynamics(x, f)[0]
        p1 = out[:, 0]
        q1 = out[:, 1]

        out = dynamics([torch.stack([p0 + 0.5 * dt * p1, q0 + 0.5 * dt * q1], dim = 1)], f)[0]
        p2 = out[:, 0]
        q2 = out[:, 1]

        p = p0 + dt * (p1 + p2)/2
        q = q0 + dt * (q1 + q2)/2
        return [torch.stack([p, q], dim = 1)]

    


    @staticmethod
    def heuns_method(x, dynamics, f, dt):
        # Heun's method
        # k1 = func(y)
        # k2 = func(y + 0.5 * h * k1)
        # return y + h * k2
        k1_ = dynamics(x, f)
        k2__ = [x_ + 0.5 * dt * k1 for x_, k1 in zip(x, k1_)]
        k2_ = dynamics(k2__, f)
        result = [x_ + 0.5 * dt * k2 for x_, k2 in zip(x, k2_)]
        return result

    @staticmethod
    def predictor_corrector(x, dynamics, f, dt, iter):
        # print("inside predictor corrector")
        def heuns_method(x, dynamics, f, dt):
            k1_ = dynamics(x, f)
            k2__ = [x_ + 0.5 * dt * k1 for x_, k1 in zip(x, k1_)]
            k2_ = dynamics(k2__, f)
            result = [x_ + 0.5 * dt * k2 for x_, k2 in zip(x, k2_)]
            return result

        #corrector
        def verlet_implicit(x, dynamics, f, dt, predict):
            if len(x) == 1:
                p0 = x[0][:, 0]
                q0 = x[0][:, 1]
                p_half = predict[0][:, 0]
                q = predict[0][:, 1]
                p_half = p0 + 0.5 * dt * dynamics([torch.stack([p_half, q0], dim = 1)], f)[0][:, 0]
                q = q0 + 0.5 * dt * (dynamics([torch.stack([p_half, q], dim = 1)], f)[0][:, 1] + dynamics([torch.stack([p_half, q0], dim = 1)], f)[0][:, 1])
                p = p_half + 0.5 * dt * dynamics([torch.stack([p_half, q], dim = 1)], f)[0][:, 0]
                return [torch.stack([p, q], dim = 1)]
            else:
                # print("should be here")
                p0 = x[0]
                q0 = x[2]
                p_half = predict[0]
                q = predict[2]
                p_half = p0 + 0.5 * dt * dynamics([p_half, x[1], q0, x[3]] + x[4:], f)[0]
                q = q0 + 0.5 * dt * (dynamics([p_half, x[1], q, x[3]] + x[4:], f)[2] + dynamics([p_half, x[1], q0, x[3]] + x[4:], f)[2])
                p = p_half + 0.5 * dt * dynamics([p_half, x[1], q, x[3]] + x[4:], f)[0]
                p0_ = x[3]
                q0_ = x[1]
                p_half_ = predict[3]
                q_ = predict[1]
                p_half_ = p0_ + 0.5 * dt * dynamics([x[0], q0_, x[2], p_half_] + x[4:], f)[3]
                q_ = q0_ + 0.5 * dt * (dynamics([x[0], q_, x[2], p_half_] + x[4:], f)[1] + dynamics([x[0], q0_, x[2], p_half_] + x[4:], f)[1])
                p_ = p_half_ + 0.5 * dt * dynamics([x[0], q_, x[2], p_half_] + x[4:], f)[3]
                state = [p, q_, q, p_] + x[4:]
                # res = euler_step(state, adjoint_param_dynamics, f, dt)
                return [p, q_, q, p_]
        x_ = x
        for _ in range(iter):
            predict = heuns_method(x_[:4], dynamics, f, dt)
            result = verlet_implicit(x_[:4], dynamics, f, dt, predict)
            x_ = result + x_[4:]
        return x_

    @staticmethod
    def dynamics_fn(x, f):
        x_ = x[0]
        p = x_[:, 0]
        q = x_[:, 1]
        with torch.enable_grad():
            p.requires_grad_(True)
            q.requires_grad_(True)
            h = f(p.unsqueeze(-1), q.unsqueeze(-1))
            grad_p, = grad(h.sum(), p, create_graph=True, allow_unused=True)
            grad_q, = grad(h.sum(), q, create_graph=True, allow_unused=True)
            return [torch.stack([-grad_q, grad_p], dim = 1)]

In [None]:
def hnn_adjoint(func, x, T, dt, back_dt, dim, integrator, back_integrator, iter):
    adjoint_params = tuple(list(func.parameters()))
    adjoint_params = tuple(p for p in adjoint_params if p.requires_grad)
    trajectories = Hamiltonian_Adjoint.apply(x[:, 0], x[:, 1], func, T, dt, back_dt, dim, integrator, back_integrator, iter, *adjoint_params)
    return trajectories[:, :, 0], trajectories[:, :, 1]

class Symplectic_HNN(nn.Module):
    def __init__(self, f, T = 30, dt = 0.1, back_dt = 0.1, integrator = "euler", back_integrator = "euler", dim = 1, iter = 1):
        super(Symplectic_HNN, self).__init__()
        self.func = f
        self.T = T
        self.dt = dt
        self.back_dt = back_dt
        self.dim = dim
        self.integrator = integrator
        self.back_integrator = back_integrator
        self.iter = iter

    def forward(self, p0, q0):
        p, q = hnn_adjoint(self.func, torch.stack([p0, q0], dim = 1), self.T, self.dt, self.back_dt, self.dim, self.integrator, self.back_integrator, self.iter)
        p = torch.unsqueeze(p, dim = -1)
        q = torch.unsqueeze(q, dim = -1)
        return torch.cat([p, q], dim = -1)

In [None]:
for integrator in ["euler"]:
    for back_integrator in ["euler", "rk2", "sv"]:

        f1 = MLP_General_Hamilt(n_input = 1, n_hidden = 64)
        # f1 = MLP2H_Separable_Hamilt(n_hidden=256, input_size=1).to(device).double()
        T = 2
        dt = 0.1
        back_dt = 0.1
        dim = 1
        num_epochs = 100
        iter = 1
        # integrator = "euler"
        # back_integrator = "euler"
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Create the model, criterion, optimizer, and data loaders
        model = Symplectic_HNN(f1, T, dt, back_dt, integrator, back_integrator, dim, iter).to(device).double()
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-1)
        scheduler = StepLR(optimizer, step_size=10, gamma=0.8)
        train_mass_spring = HarmonicOscillatorDataset("./data/mass_spring/train.npy")
        data_loader = DataLoader(train_mass_spring, batch_size=32, shuffle=False)
        val_mass_spring = HarmonicOscillatorDataset("./data/mass_spring/val.npy")
        val_loader = DataLoader(val_mass_spring, batch_size=32, shuffle=False)

        train_loss = []
        val_loss = []

        best_val_loss = float('inf')
        best_model_path = integrator+"_"+back_integrator+'_best_model.pt'

        with open(integrator+"_"+back_integrator+"_loss_adjoint_"+str(dt)+".txt", "w") as loss_file:
            for epoch in tqdm(range(num_epochs)):
                model.train()
                loss_epoch = []
                for batch in data_loader:
                    optimizer.zero_grad()
                    p_batch, q_batch, _ = batch
                    p0_batch = p_batch[:, 0].to(device).double()
                    q0_batch = q_batch[:, 1].to(device).double()
                    simulated_trajectory = model(p0_batch, q0_batch)
                    simulated_trajectory = simulated_trajectory.view(simulated_trajectory.size(0), -1)
                    trajectory = torch.stack([p_batch, q_batch], axis=2).to(device)
                    trajectory = trajectory.view(trajectory.size(0), -1).to(device)
                    loss = criterion(trajectory, simulated_trajectory)
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    loss_epoch.append(loss.item())
                avg_train_loss = sum(loss_epoch)/len(loss_epoch)
                train_loss.append(avg_train_loss)

                model.eval()  # Set model to evaluation mode
                val_loss_epoch = []
                for batch in val_loader:
                    optimizer.zero_grad()
                    p_batch, q_batch, _ = batch
                    p0_batch = p_batch[:, 0].to(device).double()
                    q0_batch = q_batch[:, 1].to(device).double()
                    simulated_trajectory = model(p0_batch, q0_batch)
                    simulated_trajectory = simulated_trajectory.view(simulated_trajectory.size(0), -1)
                    trajectory = torch.stack([p_batch, q_batch], axis=2).to(device)
                    trajectory = trajectory.view(trajectory.size(0), -1).to(device)
                    loss = criterion(trajectory, simulated_trajectory)
                    val_loss_epoch.append(loss.item())
                avg_val_loss = sum(val_loss_epoch)/len(val_loss_epoch)
                val_loss.append(avg_val_loss)

                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    torch.save(model.state_dict(), best_model_path)

                loss_file.write(f"{avg_train_loss}, {avg_val_loss}\n")

In [None]:
# Function to read loss from a file
integrator+"_loss_"+str(dt)+".txt"
import os
def read_loss(file_path):
    val_losses = []
    with open(file_path, "r") as file:
        for line in file:
            train_loss, val_loss = map(float, line.strip().split(','))
            val_losses.append(val_loss)
    return val_losses

# List of file paths for different variants
file_paths = ["euler_loss_0.1.txt", "euler_euler_loss_adjoint_0.1.txt", "euler_rk2_loss_adjoint_0.1.txt"]

# Initialize a list to store validation loss trajectories for each variant
all_val_losses = []

# Read validation losses for each variant
for file_path in file_paths:
    val_losses = read_loss(file_path)
    all_val_losses.append(val_losses)

# Plot the validation loss trajectories for each variant
epochs = range(1, len(all_val_losses[0]) + 1)
for file_path, val_losses in zip(file_paths, all_val_losses):
    variant_name = os.path.splitext(os.path.basename(file_path))[0]
    plt.plot(epochs, val_losses, label=variant_name)

plt.xlabel('Epochs')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Trajectories for Different Variants')
plt.legend()
plt.grid(True)
plt.show()

In [358]:
# f1 = MLP2H_Separable_Hamilt(n_hidden=256, input_size=1).to(device).double()
# # Clone f1 to create f2
# f2 = MLP2H_Separable_Hamilt(n_hidden=256, input_size=1).to(device).double()

f1 = MLP_General_Hamilt(n_input = 1, n_hidden = 64)
f2 = MLP_General_Hamilt(n_input = 1, n_hidden = 64)
# Copy parameters from f1 to f2
f2.load_state_dict(f1.state_dict())

params_equal = all(torch.allclose(p1, p2) for p1, p2 in zip(f1.parameters(), f2.parameters()))
print("Are the parameters equal between f1 and f2?", params_equal)

Are the parameters equal between f1 and f2? True


In [359]:
T = 2
dt = 0.1
dim = 1
num_epochs = 10
iter = 1
integrator = "rk2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# f = MLP2H_Separable_Hamilt(n_hidden = 256, input_size = 1).to(device)
model = BackProp_HNN(f1, T, dt, dim, integrator, iter).to(device).double()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_mass_spring = HarmonicOscillatorDataset("./data/mass_spring/train.npy")
data_loader = DataLoader(train_mass_spring, batch_size=32, shuffle=False)
val_mass_spring = HarmonicOscillatorDataset("./data/mass_spring/val.npy")
val_loader = DataLoader(val_mass_spring, batch_size=32, shuffle=False)

train_loss = []
val_loss = []

best_val_loss = float('inf')
best_model_path = integrator+'_best_model.pt'

for epoch in tqdm(range(num_epochs)):
    model.train()
    loss_epoch = []
    for batch in data_loader:
        optimizer.zero_grad()
        p_batch, q_batch, _ = batch
        p0_batch = p_batch[:, 0].to(device).double()
        q0_batch = q_batch[:, 1].to(device).double()
        simulated_trajectory = model(p0_batch, q0_batch)
        # print("simulated trajectory", simulated_trajectory[0, :, 0])
        simulated_trajectory = simulated_trajectory.view(simulated_trajectory.size(0), -1)
        trajectory = torch.stack([p_batch, q_batch], axis=2).to(device)
        trajectory = trajectory.view(trajectory.size(0), -1).to(device)
        loss = criterion(trajectory, simulated_trajectory)
        # print(loss.item())
        loss.backward()
        optimizer.step()
        # for name, param in f1.named_parameters():
        #     # if param.grad
        #     print(name)
        #     print(param.grad)
        loss_epoch.append(loss.item())
    avg_train_loss = sum(loss_epoch)/len(loss_epoch)
    # print("avg train loss", avg_train_loss)
    train_loss.append(avg_train_loss)
    
    # Validation phase
    model.eval()  # Set model to evaluation mode
    val_loss_epoch = []
    # with torch.no_grad():
    for batch in val_loader:
        optimizer.zero_grad()
        p_batch, q_batch, _ = batch
        p0_batch = p_batch[:, 0].to(device).double()
        q0_batch = q_batch[:, 1].to(device).double()
        simulated_trajectory = model(p0_batch, q0_batch)
        simulated_trajectory = simulated_trajectory.view(simulated_trajectory.size(0), -1)
        trajectory = torch.stack([p_batch, q_batch], axis=2).to(device)
        trajectory = trajectory.view(trajectory.size(0), -1).to(device)
        loss = criterion(trajectory, simulated_trajectory)
        val_loss_epoch.append(loss.item())
    avg_val_loss = sum(val_loss_epoch)/len(val_loss_epoch)
    val_loss.append(avg_val_loss)
    # Check if the validation loss is the lowest seen so far
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        # Save the model
        torch.save(model.state_dict(), best_model_path)
    print(f"Epoch {epoch+1}, Average Train Loss: {avg_train_loss}, Average Val Loss: {avg_val_loss}")
    

 10%|████▍                                       | 1/10 [00:07<01:04,  7.22s/it]

Epoch 1, Average Train Loss: 7.482466161251068, Average Val Loss: 6.898046697889056


 20%|████████▊                                   | 2/10 [00:13<00:54,  6.84s/it]

Epoch 2, Average Train Loss: 7.462766423821449, Average Val Loss: 6.818754468645368


 30%|█████████████▏                              | 3/10 [00:20<00:46,  6.71s/it]

Epoch 3, Average Train Loss: 6.979049831628799, Average Val Loss: 5.834603990827288


 40%|█████████████████▌                          | 4/10 [00:26<00:39,  6.53s/it]

Epoch 4, Average Train Loss: 6.016539961099625, Average Val Loss: 5.577141012464251


 50%|██████████████████████                      | 5/10 [00:32<00:32,  6.44s/it]

Epoch 5, Average Train Loss: 5.93646177649498, Average Val Loss: 5.548423971448626


 60%|██████████████████████████▍                 | 6/10 [00:39<00:26,  6.57s/it]

Epoch 6, Average Train Loss: 5.892290145158768, Average Val Loss: 5.542686189923968


 70%|██████████████████████████████▊             | 7/10 [00:46<00:20,  6.81s/it]

Epoch 7, Average Train Loss: 5.883467882871628, Average Val Loss: 5.540743282863072


 80%|███████████████████████████████████▏        | 8/10 [00:53<00:13,  6.79s/it]

Epoch 8, Average Train Loss: 5.873905703425407, Average Val Loss: 5.537671497889927


 90%|███████████████████████████████████████▌    | 9/10 [01:01<00:07,  7.08s/it]

Epoch 9, Average Train Loss: 5.858141779899597, Average Val Loss: 5.517237731388637


100%|███████████████████████████████████████████| 10/10 [01:10<00:00,  7.04s/it]

Epoch 10, Average Train Loss: 5.837831005454063, Average Val Loss: 5.494555677686419





In [197]:
'''
Add solvers
'''
def euler(func, p0, q0, dt):
    dp_dt, dq_dt = func(p0, q0)
    p = p0 + dt * dp_dt
    q = q0 + dt * dq_dt
    return p, q

def rk2(func, p0, q0, dt):
    pass

def sv(func, p0, q0, dt, iterations = 1, p_init = None, q_init = None):
    pass

def pc(func, p0, q0, dt, iterations = 1):
    p_init, q_init = rk2(func, p0, q0, dt)
    p, q = sv(func, p0, q0, T, dt, iterations, p_init, q_init)
    return p, q
    

In [213]:
class LinearBlock(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(LinearBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Linear(inchannel, outchannel),
            #nn.Tanh(),
            nn.Sigmoid(),
            #nn.ReLU(inplace=True),
        )

    def forward(self, x):
        out = self.left(x)
        return out

class NN_Backprop(nn.Module):
    def __init__(self, N, hidden_dim):
        super(NN_Backprop, self).__init__()
        self.N = N
        self.f = nn.Sequential(LinearBlock(2 * self.N, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    nn.Linear(hidden_dim, 2*self.N))
        self.b = nn.Parameter(torch.zeros(1,1,2*self.N) , requires_grad=True)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.uniform_(-math.sqrt(6. / m.in_features), math.sqrt(6. / m.in_features))
                
    def hamiltonian(self, p, q):
        # with torch.enable_grad():
        x = torch.cat((p.unsqueeze(1), q.unsqueeze(1)), dim=1)
        x = x.requires_grad_(True)
        K = self.f(x)+self.b
        return K[:, :, :self.N] + K[:, :, self.N:self.N * 2]
    
    def dynamics(self, p, q):
        p.requires_grad_(True)
        q.requires_grad_(True)
        with torch.enable_grad():
            h = self.hamiltonian(p, q)
            print(h.requires_grad_)
            grad_p, = grad(h.sum(), p, create_graph=True, allow_unused=True)
            grad_q, = grad(h.sum(), q, create_graph=True, allow_unused=True)
        return -grad_q, grad_p
        
    def forward(self, p0, q0, solver, T, dt):
        trajectory = torch.zeros((p0.shape[0], T, self.N * 2)).to(device)
        n_steps = int(np.ceil(T/dt))
        trajectory[:, 0, 0] = p0
        trajectory[:, 0, 1] = q0
        p = p0
        q = q0
        for i in range(1, T):
            for _ in range(n_steps):
                p, q = solver(self.dynamics, p, q, dt)
            trajectory[:, i, 0] = p
            trajectory[:, i, 1] = q
        return trajectory[:, :, 0], trajectory[:, :, 1]

In [214]:
class Hamiltonian_Data(Dataset):
    def __init__(self, file_path):
        # Load the dataset
        dataset = np.load(file_path)[:, ::1000, :]
        
        # Convert the dataset to PyTorch tensors
        self.p_values = torch.tensor(dataset[:, :, 0], dtype=torch.float32)
        self.q_values = torch.tensor(dataset[:, :, 1], dtype=torch.float32)
        self.h_values = torch.tensor(dataset[:, :, 2], dtype=torch.float32)
        
        # Ensure consistent length for all tensors
        assert len(self.p_values) == len(self.q_values) == len(self.h_values)
        self.length = len(self.p_values)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        return self.p_values[idx], self.q_values[idx], self.h_values[idx]

In [215]:
train_mass_spring = Hamiltonian_Data("./data/mass_spring/train.npy")
train_loader = DataLoader(train_mass_spring, batch_size=32, shuffle=True)
val_mass_spring = Hamiltonian_Data("./data/mass_spring/val.npy")
val_loader = DataLoader(val_mass_spring, batch_size=32, shuffle=False)

In [216]:
for batch_idx, (p_values, q_values, h_values) in enumerate(train_loader):
    # Perform operations on the batch
    print(f"Batch {batch_idx}:")
    print("p_values:", p_values.shape)
    print("q_values:", q_values.shape)
    print("h_values:", h_values.shape)
    break

Batch 0:
p_values: torch.Size([32, 10])
q_values: torch.Size([32, 10])
h_values: torch.Size([32, 10])


In [218]:
# Define the Hamiltonian model
# (same as before)

# Define the training loop with validation and loss logging
def train_and_validate_with_logging(model, train_loader, val_loader, optimizer, criterion, epochs=10, checkpoint_path='checkpoint.pt', log_file='loss.log', T = 10, dt = 0.1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    best_val_loss = float('inf')
    losses = []

    for epoch in tqdm(range(epochs)):
        # Training phase
        model.train()
        epoch_train_loss = 0.0
        for batch_idx, (p_values, q_values, h_values) in enumerate(train_loader):
            p_values, q_values, h_values = p_values.to(device), q_values.to(device), h_values.to(device)
            optimizer.zero_grad()
            # Forward pass
            p_pred, q_pred = model(p_values[:, 0], q_values[:, 0], euler, T, dt)

            # Compute loss
            loss = criterion(p_pred, p_values) + criterion(q_pred, q_values)
            epoch_train_loss += loss.item()

            # Backward pass
            loss.backward()
            optimizer.step()

#         epoch_train_loss /= len(train_loader)

#         # Validation phase
#         model.eval()
#         epoch_val_loss = 0.0

#         with torch.no_grad():
#             for p_values, q_values, h_values in val_loader:
#                 p_values, q_values, h_values = p_values.to(device), q_values.to(device), h_values.to(device)
#                 p_pred, q_pred = model(p_values[:, 0], q_values[:, 0], euler, T, dt)
#                 val_loss = criterion(p_pred, p_values) + criterion(q_pred, q_values)
#                 epoch_val_loss += val_loss.item()

#         epoch_val_loss /= len(val_loader)

#         print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")

#         # Save losses to list
#         losses.append((epoch_train_loss, epoch_val_loss))

#         # Save checkpoint if validation loss has decreased
#         if epoch_val_loss < best_val_loss:
#             best_val_loss = epoch_val_loss
#             torch.save(model.state_dict(), checkpoint_path)
#             print("Checkpoint saved.")

#     # Write losses to file
#     with open(log_file, 'w') as f:
#         f.write("Epoch\tTrain Loss\tVal Loss\n")
#         for i, (train_loss, val_loss) in enumerate(losses):
#             f.write(f"{i+1}\t{train_loss:.6f}\t{val_loss:.6f}\n")

# Define hyperparameters and instantiate the model and datasets
# (same as before)

# Train, validate, and log the model
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()
epochs = 1
model = NN_Backprop(1, 64)
train_and_validate_with_logging(model, train_loader, val_loader, optimizer, criterion, epochs, 'checkpoint.pt', 'loss.log', 1, 1)


  0%|                                                     | 0/1 [00:00<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn