In [37]:
%reset

In [38]:
import os
os.system("ml tqdm/4.66.2-GCCcore-13.2.0")
from tqdm import tqdm
os.system("ml PyTorch/2.2.1-foss-2023b-CUDA-12.4.0")
import torch
import torch.nn as nn   
import torch.optim as optim
import numpy as np
from torch.optim.lr_scheduler import StepLR
from torch.autograd.functional import jacobian
from torch.utils.data import DataLoader, TensorDataset
import math
import optuna
import csv
import logging
import sys
import pandas as pd
import time
import argparse
import distutils.util
import tracemalloc

# Set random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])

In [39]:
class HamiltonianNN(nn.Module):

    def __init__(self, model_specs):
        super(HamiltonianNN, self).__init__()

        # Create a list of linear layers based on layer_sizes
        layer_sizes = model_specs[0]
        self.layers = nn.ModuleList()
        self.dropout_layers = nn.ModuleList()
        self.RANDOM_SEED = 0
        for i in range(len(layer_sizes) - 2):  # All layers except the last one
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=True))

            self.dropout_layers.append(nn.Dropout(p=0.2))
        
        # Last layer without bias
        self.layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1], bias=False))

    def forward(self, x):
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = torch.tanh(x)
            if i < len(self.dropout_layers):
                x = self.dropout_layers[i](x)
        x = self.layers[-1](x)
        return x
    
    def _apply_xavier_init(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

In [40]:
def forward_ode(y_tensor, args, kargs):

    model = args[0]
    i = kargs[0]
        
    with torch.enable_grad():

        #y = y_tensor.clone().detach().requires_grad_(True)

        y = y_tensor.requires_grad_(True)

        h = model(y)

    
        grad_h = torch.autograd.grad(outputs=h.sum(), inputs=y, create_graph=True, allow_unused=True)[0]
        #print("grad h: ", grad_h)
        dq_dt = grad_h[:, 1]
        dp_dt = -grad_h[:, 0]

    return torch.stack((dq_dt, dp_dt), dim=-1)

In [41]:
def adjoint_ode(lam, args, kargs):

    (model, y_values) = args
    batch_size = y_values.shape[0]
    (i,) = kargs
    br_i = y_values.shape[2]//2

    y_tensor = y_values[:,i,:].clone().detach().requires_grad_(True)
    
    h = model(y_tensor)

    # Compute first-order derivatives ∇H = [∂H/∂q, ∂H/∂p]
    grad_h = torch.autograd.grad(outputs=h.sum(), inputs=y_tensor, create_graph=True, retain_graph=True)[0]  # Shape: [batch_size, 2n]

    # Compute second-order derivatives
    J_H = torch.zeros(batch_size, 2*br_i, 2*br_i)  # Initialize Jacobian matrix

    for i in range(2 * br_i):
        if i<br_i:
            grad_i = torch.autograd.grad(outputs=grad_h[:, br_i+i], inputs=y_tensor, grad_outputs=torch.ones_like(grad_h[:, br_i+i]), create_graph=True, retain_graph=True)[0]
        else:
            grad_i = torch.autograd.grad(outputs=-grad_h[:, i-br_i], inputs=y_tensor, grad_outputs=torch.ones_like(-grad_h[:, i-br_i]), create_graph=True, retain_graph=True)[0]
        J_H[:, i, :] = grad_i  # Assign row-wise
    
    lam_tensor = lam.clone().detach().unsqueeze(2)


    lam_dot = - (torch.bmm(J_H, lam_tensor)).squeeze()

    return lam_dot

In [42]:
def reshape_gradients(flattened_gradients, original_shapes):
    reshaped_gradients = []
    start = 0
    for shape in original_shapes:
        size = torch.prod(torch.tensor(shape)).item()  # Calculate the number of elements in this shape
        end = start + size
        reshaped_gradients.append(flattened_gradients[start:end].reshape(shape))
        start = end
    return reshaped_gradients

In [43]:
def rk2_step(dyn, y, dt, dynamics, args, kargs):
    h = dt
    i = kargs[0]
    q, p = y[:, 0], y[:, 1]

    y = torch.stack((q, p), dim=-1)  # Shape: (batch_size, 2)

    #print(q.shape)

    dy1 = dynamics(y, args, kargs)
    q1 = q + 0.5 * dy1[:, 0] * h
    p1 = p + 0.5 * dy1[:, 1] * h

    y1 = torch.stack((q1, p1), dim=-1)  # Shape: (batch_size, 2)
    dy2 = dynamics(y1, args, kargs)

    q_new = q + dy2[:, 0] * h
    p_new = p + dy2[:, 1] * h
    return torch.stack((q_new, p_new), dim=-1)

In [23]:
# def im_step(y, dt, dynamics, iterations, y_init, args, kargs):
#     h = dt
#     br_i = y.shape[1] // 2
#     q, p = y[:, 0:br_i], y[:, br_i:2 * br_i]
    
#     y_init_concat = torch.cat((y_init[:, 0:br_i], y_init[:, br_i:2*br_i]), dim=-1)  # Shape [batch, 2]
#     f_init = dynamics(y_init_concat, args, kargs)  # Compute dynamics at initial point
    
#     q_new = q + 0.5 * h * f_init[:, 0:br_i]  # Shape [batch, 1]
#     p_new = p + 0.5 * h * f_init[:, br_i:2*br_i]  # Shape [batch, 1]

#     for _ in range(iterations):
#         mid_q = 0.5 * (q + q_new)  # Shape [batch, q_shape]
#         mid_p = 0.5 * (p + p_new)  # Shape [batch, p_shape]
        
#         mid_concat = torch.cat((mid_q, mid_p), dim=-1)  # Ensure [batch, 2*q_shape]
#         f_mid = dynamics(mid_concat, args, kargs)  # Compute dynamics at midpoint
        
#         q_new = q + h * f_mid[:, 0:br_i]  # Shape [batch, q_shape]
#         p_new = p + h * f_mid[:, br_i:2*br_i]  # Shape [batch, p_shape]

#     return torch.cat((q_new, p_new), dim=-1)  # Final shape [batch, 2*q_shape]


In [44]:
def sv_step(dyn, y, dt, dynamics, iterations, y_init, args, kargs):
    h = dt
    q, p = y[:, 0], y[:, 1]
    i = kargs[0]

    p_half = p + 0.5 * h * dynamics(torch.stack((q, y_init[:, 1]), dim=-1), args, kargs)[:, 1]
    for _ in range(iterations):
        p_half = p + 0.5 * h * dynamics(torch.stack((q, p_half), dim=-1), args, kargs)[:, 1]

    q_half = q + 0.5 * h * dynamics(torch.stack((y_init[:, 0], p_half), dim=-1), args, kargs)[:, 0]
    for _ in range(iterations):
        q_half = q + 0.5 * h * dynamics(torch.stack((q_half, p), dim=-1), args, kargs)[:, 0]

    q_new = q + h * dynamics(torch.stack((q_half, p_half), dim=-1), args, kargs)[:, 0]
    p_new = p_half + 0.5 * h * dynamics(torch.stack((q_new, p_half), dim=-1), args, kargs)[:, 1]

    return torch.stack((q_new, p_new), dim=-1)

In [45]:
def solve_ivp_custom(dynamics, dyn, y0_batch, t_span, dt, args, iters):
    #t = torch.arange(0, T, dt)
    batch_size = y0_batch.shape[0]
    t0, t1 = t_span
    if t0 > t1:
        dt = -dt
    num_steps = int((t1 - t0) / dt) + 1
    #y0_batch = noisy_obs[:, 0, :]
    #ys_batch = torch.zeros(batch_size, num_steps, 2)
    ys_batch = [y0_batch]
    #print(y0_batch.shape)
    #ys_batch[:, 0, :] = y0_batch.clone()


    for i in range(1, num_steps):
        #y = noisy_obs[:, i-1, :]  # Use the noisy observation at the current step
        y = ys_batch[-1]
        y_ = rk2_step(dyn, y, dt, dynamics, args, kargs=(i,))
        y_next = sv_step(dyn, y, dt, dynamics, iters, y_, args, kargs=(i,))
        #print(y_next.requires_grad)
        ys_batch.append(y_next)
        #print(y_next.shape)
    ys_batch = torch.stack(ys_batch, dim=1)
    #print(ys_batch.requires_grad)
    return ys_batch

In [46]:
def calculate_grad(model, y_t, lambda_t, batch_size):

    y_tensor = y_t.clone().detach().requires_grad_(True).to(y_t.device)
    
    # Perform forward pass
    h = model(y_tensor)
    
    # Compute gradients of model output w.r.t y_tensor
    grad_h = torch.autograd.grad(outputs=h.sum(), inputs=y_tensor,
                                 create_graph=True, retain_graph=True, allow_unused=True)[0]
    
    grad_w_p = torch.autograd.grad(outputs=grad_h[:, 1], inputs=model.parameters(), 
                                   grad_outputs=lambda_t[:, 0],
                                   create_graph=True, retain_graph=True, allow_unused=True)
    
    grad_w_q = torch.autograd.grad(outputs=grad_h[:, 0], inputs=model.parameters(), 
                                   grad_outputs=lambda_t[:, 1],
                                   create_graph=True, retain_graph=True, allow_unused=True)
    
    if grad_w_p is not None:
        grad_w_p = torch.cat([p_grad.flatten() for p_grad in grad_w_p]).unsqueeze(0)
        grad_w_p = grad_w_p.expand(batch_size, -1) / batch_size
    
    if grad_w_q is not None:
        grad_w_q = torch.cat([p_grad.flatten() for p_grad in grad_w_q]).unsqueeze(0)
        grad_w_q = grad_w_q.expand(batch_size, -1) / batch_size
    
    grad_w_combined = grad_w_p - grad_w_q

    model.zero_grad()
    
    return grad_w_combined.mean(dim=0)  # avg over batch

In [47]:
def backward(dynamics, dyn, lambdaT_batch, t_span, dt, args, iters):
    #t = torch.arange(0, T, dt)
    batch_size = lambdaT_batch.shape[0]
    y_batch = args[1]
    t0, t1 = t_span
    model = args[0]
    if t0 > t1:
        dt = -dt
    num_steps = int((t1 - t0) / dt) + 1
    #y0_batch = noisy_obs[:, 0, :]
    #ys_batch = torch.zeros(batch_size, num_steps, 2)
    lambda_batch = lambdaT_batch
    #print(y0_batch.shape)
    #ys_batch[:, 0, :] = y0_batch.clone()
    num_params = sum(p.numel() for p in model.parameters())
    grad_result = torch.zeros(num_params, device=lambdaT_batch.device)

    grad_result += (-dt / 2) * calculate_grad(model, y_batch[:,-1,:], lambdaT_batch, batch_size)
    #print("Current grad: ", grad_result)


    for i in range(1, num_steps):
        #y = noisy_obs[:, i-1, :]  # Use the noisy observation at the current step
        #y = ys_batch[-1]
        lambda_ = rk2_step(dyn, lambda_batch, dt, dynamics, args, kargs=(i,))
        lambda_next = sv_step(dyn, lambda_batch, dt, dynamics, iters, lambda_, args, kargs=(i,))
        #print(y_next.requires_grad)
        # Compute new gradient
        current_grad = calculate_grad(model, y_batch[:,num_steps-i-1,:], lambda_next, batch_size)
        #print("Current grad: ", current_grad)
        
        if i==num_steps-1:
            grad_result += (-dt / 2) * (current_grad)
        else:
            grad_result += (-dt) * (current_grad)

        lambda_batch = lambda_next
        #print(y_next.shape)
    #ys_batch = torch.stack(ys_batch, dim=1)
    #print(ys_batch.requires_grad)
    return grad_result

In [48]:
def downsample_gt(gt_data, dt_solve, dt_gt):
    downsample_factor = int(dt_solve / dt_gt)
    return gt_data[:, ::downsample_factor, :]

In [49]:
def downsample_gt(gt_data, dt_solve, dt_gt):
    downsample_factor = int(dt_solve / dt_gt)
    return gt_data[:, ::downsample_factor, :]


def load_data(datafolder, dynamics, dt_solve, dt_gt):

    noisy_train_path = "../data/"+str(datafolder)+"/noisy_"+str(dynamics)+"_train.pt"
    noisy_val_path = "../data/"+str(datafolder)+"/noisy_"+str(dynamics)+"_val.pt"
    noisy_test_path = "../data/"+str(datafolder)+"/noisy_"+str(dynamics)+"_test.pt"

    train_path = "../data/"+str(datafolder)+"/"+str(dynamics)+"_train.pt"
    val_path = "../data/"+str(datafolder)+"/"+str(dynamics)+"_val.pt"
    test_path = "../data/"+str(datafolder)+"/"+str(dynamics)+"_test.pt"

    noisy_train_trajectories = torch.load(noisy_train_path).to(device)
    noisy_val_trajectories = torch.load(noisy_val_path).to(device)

    true_train_trajectories = torch.load(train_path).to(device)
    true_val_trajectories = torch.load(val_path).to(device)


    # Downsample ground truth data according to dt_solve
    noisy_train_trajectories = downsample_gt(noisy_train_trajectories, dt_solve, dt_gt)
    true_train_trajectories = downsample_gt(true_train_trajectories, dt_solve, dt_gt)


    noisy_val_trajectories = downsample_gt(noisy_val_trajectories, dt_solve, dt_gt)
    true_val_trajectories = downsample_gt(true_val_trajectories, dt_solve, dt_gt)


    return noisy_train_trajectories, noisy_val_trajectories, true_train_trajectories, true_val_trajectories

In [None]:
def objective(model, noisy_train_traj, noisy_val_traj, true_train_traj, true_val_traj, dt_gt, dt_solve, param_vals):

    start_time = time.time()
    
    num_epochs = 1

    learning_rate = param_vals["lr"]
    
    train_batch_size = param_vals["train_batch_size"]
    val_batch_size = param_vals["val_batch_size"]
    sim_len = param_vals["sim_len"]
    sims = param_vals["sims"]
    learning_rate = param_vals["lr"]
    train_batch_size = param_vals["train_batch_size"]
    val_batch_size = param_vals["val_batch_size"]
    T = param_vals["t_final"]
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')

    train_dataset = TensorDataset(noisy_train_traj, true_train_traj)
    val_dataset = TensorDataset(noisy_val_traj, true_val_traj)

    train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    val_data_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True)

    
    train_losses = []
    val_losses = []
    alpha = 1e-4
    # Training loop
    #print("Julian Time: ", )
    print(f"Params: Train size {noisy_train_traj.shape}, Val size {noisy_val_traj.shape}, Sim length {T} sec")

    #memory_usage_list = []

    for epoch in range(num_epochs):
        total_loss = 0.0
        
        logging.info(f"Progress: Step {epoch+1}")

        for batch in train_data_loader:
            y_noisy_batch, y_true_batch = batch
            y_noisy_batch = y_noisy_batch.to(device)
            y_true_batch = y_true_batch.to(device)

            ## Load the batch for the initial values of q and p
            pq0_batch = torch.tensor(y_true_batch[:, 0, :], dtype=torch.float32)
            
            tracemalloc.start()
            
            ## Solve the forward ODE
            y_pred_batch = solve_ivp_custom(forward_ode, "forward", pq0_batch, (0, T), dt_solve, args=(model,), iters=5)

            y_pred_batch = y_pred_batch.requires_grad_(True)

            loss = criterion(y_pred_batch[:, :, 0], y_noisy_batch[:,:,0]) + criterion(y_pred_batch[:, :, 1], y_noisy_batch[:,:,1])

            lamb = torch.autograd.grad(loss, y_pred_batch, retain_graph=True)[0]
            lamb_0 = lamb[:,-1,:]

            y_pred_batch.detach()
            # Track memory usage during backward pass

            grads = backward(adjoint_ode, "adjoint", lamb_0, (T, 0), dt_solve, args=(model, y_pred_batch), iters=5)

            _, peak_memory = tracemalloc.get_traced_memory()
            tracemalloc.stop()
 
            #memory_usage_list.append({"epoch": epoch, "peak_memory_MB": peak_memory / (1024**2)})
            #lambda_pred_batch = lambda_pred_batch.flip(1)

            #grads = calculate_integral(model, y_pred_batch, T, lambda_pred_batch)

            #Reshape the gradients to match the model parameters
            start_idx = 0
            for param in model.parameters():
                param_shape = param.shape
                param_size = param.numel()
                param_grad = grads[start_idx:start_idx + param_size].reshape(param_shape)
                param.grad = param_grad.clone().detach()
                start_idx += param_size

            # # Update the model parameters using the optimizer
            optimizer.step()
            optimizer.zero_grad()

            # Compute loss
            
            total_loss += loss.item()
        
        average_train_loss = total_loss / (train_batch_size)
        train_losses.append(average_train_loss)

        print(f'Epoch {epoch}/{num_epochs}, Train Loss: {total_loss/len(train_data_loader)}')
        scheduler.step(total_loss/len(train_data_loader))


    end_time = time.time()
    
    # Log the time taken
    elapsed_time = end_time - start_time
    print(f"Objective function took {elapsed_time:.2f} seconds to complete")
    
    return elapsed_time, peak_memory, model

In [None]:
# if __name__ == "__main__":

    # # Parse command line arguments
    # parser = argparse.ArgumentParser(description="Enter the simulation parameters")
    # parser.add_argument("--dynamics_name", type=str, required=True, choices=["mass_spring", "double_well", "coupled_ho", "henon_heiles"], help="The name of the dynamics function.")
    # parser.add_argument("--data_folder", type=str, required=True, help="the ground truth data folder")
    # parser.add_argument("--gt_res", type=float, required=True, help="the ground truth resolution/stepsize")
    # parser.add_argument("--hid_layers", type=parse_hidden_layers, required=True,
    #                     help="Hidden layers as a list of integers, e.g., [16,32,16]")
    # parser.add_argument("--solver_res", type=float, required=True, help="The time step length for our solver(= k*gt_res where k is an integer)")
    # parser.add_argument("--noise_level", type=float, required=False, default=0.0,
    #                 help="The noise level (a float number from data_gen). Default is 0.0.")
    # parser.add_argument("--pred", type=lambda x: bool(distutils.util.strtobool(x)), required=False, default=False, 
    #                 help="Boolean flag: True if you need a predictor step, False if you use GT (default: False)")
    # parser.add_argument("--num_sims", type=int, required=False, default=1, help="The number of multi-shooting trajectories (default= 1 single shooting)")
    # parser.add_argument("--sim_len", type=int, required=True, help="The forward simulation length of each trajectory for training.")
    # parser.add_argument("--solver", type=str, required=False, default="im", choices=["im","sv"])

    # args = parser.parse_args()
    
dynamics_name = "mass_spring"
data_folder = "mass_spring_10"
dt_gt = 0.01
hidden_layer_sizes = [16,32,16]
dt_solve = 0.01
noise_level = 0
pred = True
num_sims = 1
sim_len = 7
solver = "im"

noisy_train, noisy_val, true_train, true_val = load_data(data_folder, dynamics_name, dt_gt, dt_solve)

input_size = noisy_train.shape[2]
output_size = 1

train_set_len = int(noisy_train.shape[0])
val_set_len = int(noisy_val.shape[0])

layer_sizes = [input_size] + hidden_layer_sizes + [output_size]

model_specs = (layer_sizes,)

model = HamiltonianNN(model_specs).to(device) 

params_list = [{"sim_len":4, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.03, "sims": num_sims}
                ,{"sim_len":8, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.07, "sims": num_sims}
                  ,{"sim_len":12, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.11, "sims": num_sims},
                    {"sim_len":16, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.15, "sims": num_sims},
                    {"sim_len":20, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.19, "sims": num_sims},
                    {"sim_len":24, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.23, "sims": num_sims},
                    {"sim_len":28, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.27, "sims": num_sims},
                    {"sim_len":32, "lr":0.01, "pred":pred, "solver":solver,
                    "train_batch_size":512, "val_batch_size":512, "t_final":0.31, "sims": num_sims}
                    ]


train_losses = []
val_losses = []
models = []
memory_usage_list = []
time_list = []

for i in range(len(params_list)):
    
    #start_ind = int(params_list[i]["t_start"]/dt_solve)
    end_ind = params_list[i]["sim_len"]
    train_batch_size = params_list[i]["train_batch_size"]
    val_batch_size = params_list[i]["val_batch_size"]


    print("Trial: ", str(i))
    
    elapsed_time, peak_memory, model = objective(model, noisy_train[0:train_batch_size, 0:end_ind, :],
                                                noisy_val[0:val_batch_size, 0:end_ind, :], 
                                                true_train[0:train_batch_size, 0:end_ind, :], 
                                                true_val[0:val_batch_size, 0:end_ind, :], 
                                                dt_gt, dt_solve, params_list[i])
    
    memory_usage_list.append({"peak_memory_MB": peak_memory / (1024**2)})
    time_list.append({"run_time":elapsed_time})


# Save memory tracking results
df = pd.DataFrame(zip(memory_usage_list, time_list))
df.to_csv("adjoint_memory_usage_results_2.csv", index=False)
print("Memory usage data saved to memory_usage_results.csv")


    #torch.save(model, f'../models/model_{i}_{dynamics_name}_{noise_level}_adjoint_{num_sims}_{sim_len}_{solver}.pt')
    # df = pd.DataFrame({
    # "train_loss": train_loss,
    # "val_loss": val_loss
    # })
    # df.to_csv(f'output_{dynamics_name}_{noise_level}_adjoint_{num_sims}_{sim_len}_{solver}.csv', index=False)

Trial:  0
Params: Train size torch.Size([512, 4, 2]), Val size torch.Size([512, 4, 2]), Sim length 0.03 sec
2025-03-18 23:00:40,592 - Progress: Step 1


  pq0_batch = torch.tensor(y_true_batch[:, 0, :], dtype=torch.float32)


Epoch 0/1, Train Loss: 0.0038560437969863415
Objective function took 0.56 seconds to complete
Trial:  1
Params: Train size torch.Size([512, 8, 2]), Val size torch.Size([512, 8, 2]), Sim length 0.07 sec
2025-03-18 23:00:41,154 - Progress: Step 1
Epoch 0/1, Train Loss: 0.018968800082802773
Objective function took 1.16 seconds to complete
Trial:  2
Params: Train size torch.Size([512, 12, 2]), Val size torch.Size([512, 12, 2]), Sim length 0.11 sec
2025-03-18 23:00:42,330 - Progress: Step 1
Epoch 0/1, Train Loss: 0.045060932636260986
Objective function took 1.80 seconds to complete
Trial:  3
Params: Train size torch.Size([512, 16, 2]), Val size torch.Size([512, 16, 2]), Sim length 0.15 sec
2025-03-18 23:00:44,160 - Progress: Step 1
Epoch 0/1, Train Loss: 0.08154118806123734
Objective function took 2.51 seconds to complete
Trial:  4
Params: Train size torch.Size([512, 20, 2]), Val size torch.Size([512, 20, 2]), Sim length 0.19 sec
2025-03-18 23:00:46,718 - Progress: Step 1
Epoch 0/1, Train L