# Import everyting here

In [None]:
import torch.utils.data
import matplotlib.pyplot as plt
import torch
import scipy.integrate
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger


from Euler import *
from task_1 import *

%reload_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Setup Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs --port 6006

In [None]:
# train_dataset[0] --> trajectory sets as dict
# train_dataset[0]["traj"] --> first trajectory set 
# train_dataset[0]["traj_shifted"] --> first trajectory set shifted by one (targets)
train_dataset = None # 100x10x2
validation_dataset = None
test_dataset = None

# Create dataset

In [None]:
# Define the parameters
alpha = -1.8
(range_start, range_end) = (-2, 2)
(t_start, t_end, delta_t) = (0, 1, 0.01)
num_trajectories = 1000

evaluation_times = np.arange(t_start, t_end, delta_t)
print("will solve for times:", evaluation_times)

sols = generate_dataset(alpha, delta_t, range_start, range_end, t_start, t_end, num_trajectories)

print("Result have the shape:", sols.shape)


print("first datapoint in the first trajectory: ", sols[0][0][0], sols[0][1][0])
print("last datapoint in the first trajectory: ", sols[0][0][-1], sols[0][1][-1])

print("first datapoint in the last trajectory: ", sols[-1][0][0], sols[-1][1][0])
print("last datapoint in the last trajectory: ", sols[-1][0][-1], sols[-1][1][-1])

In [None]:
# Lets plot the first and last generated trajectories with different starting points
# to get a feeling how the data looks like 
plot_trajectory(sols[0], t_start, t_end, delta_t)
plot_trajectory(sols[999], t_start, t_end, delta_t)

# Prepare the dataset

In [None]:
# for training we want to have a shape (num_trajectories * num_datapoints, 2, 2)
# where we will hold values and corresponding targets in the second axis
train_dataset = reshape_for_training(sols)

print("train_dataset shape: ", train_dataset.shape)

In [None]:
"""
train_dataset = np.stack((train_dataset_values, train_dataset_targets))
print("train_dataset shape: ", train_dataset.shape)

train_dataset = np.moveaxis(train_dataset, 0, 1)
print("train_dataset shape: ", train_dataset.shape)

# As one can see, we could successfully generate the targets!
#print(train_dataset_values[0, :, :10])
#print(train_dataset_targets[0, :, :10])


#values_x = train_dataset_values[:, 0, :]
#values_y = train_dataset_values[:, 1, :]

#A = np.reshape(train_dataset_values, (1000, 9999, 2))
#print(A[0, :10, :])

#train_dataset_targets =

#(1000, 10000, 10000)

#((10000000, 2), (10000000, 2))
"""

# Define hyperparameters

In [None]:
hparams = {
    "hidden_layer_1": 512,
    "hidden_layer_2": 512,
    "hidden_layer_3": 512,
    "delta_t": delta_t,
    "batch_size": 256,
    "learning_rate": 1e-3,
    "num_workers": 8,
}

In [None]:
model = Euler(hparams)

# Validation set

In [None]:
start_positions = np.random.uniform(-2, 2, (200, 2))

sols = []
t_eval = np.arange(0, 10, delta_t)

for start_position in start_positions:
    sol = scipy.integrate.solve_ivp(lambda t, y: get_derivatives(y[0], y[1], alpha), (0, 10) , start_position, t_eval=t_eval)
    sols.append(sol.y)
    
validation_dataset_values = np.array(sols)

print("VAL DATASET SHAPE:", validation_dataset_values.shape)

validation_dataset_targets = np.delete(validation_dataset_values, (0), axis=-1)
validation_dataset_values = np.delete(validation_dataset_values, (-1), axis=-1)

validation_dataset_values = np.moveaxis(validation_dataset_values, 1, -1).reshape((validation_dataset_values.shape[0] * validation_dataset_values.shape[2], 2))

validation_dataset_targets = np.moveaxis(validation_dataset_targets, 1, -1).reshape((validation_dataset_targets.shape[0] * validation_dataset_targets.shape[2], 2))


validation_dataset = np.stack((validation_dataset_values, validation_dataset_targets))

validation_dataset = np.moveaxis(validation_dataset , 0, 1)


print("VAL DATASET SHAPE:", validation_dataset.shape)

def evaluate_model(model, dataset):
    """
    TODO
    """
    model.eval()
    model.to(device)
    criterion = torch.nn.MSELoss()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=hparams["batch_size"], shuffle=False)
    loss = 0
    for batch in dataloader:
        pos = batch[:, 0]
        pos_target = batch[:, 1]
        
        pos = pos.to(device)
        pos_target = pos_target.to(device)
        
        #x = torch.flatten(traj).to(device)
        #y = torch.flatten(traj_shifted).to(device)
        
        pred = model.forward(pos.float()).to(device)
        
        loss += criterion(pred.float(), pos_target.float()).item()
    return 1.0 / (2 * (loss/len(dataloader)))


def recreate_trajectory(model, start_position, t_start, t_end, delta_t):
    """
    TODO
    """
    trajectory = [start_position]
    last_traj = torch.tensor(start_position)
    
    last_traj = last_traj.to(device)
    
    t0 = t_start
    while t0 < t_end - delta_t:
        last_traj = last_traj + delta_t * model(last_traj.float())
        trajectory.append(last_traj.cpu().detach())
        
        t0 += delta_t
        
    return np.array(trajectory)


print("Score of the Model before training:", evaluate_model(model, validation_dataset))

# plot the first one
recreation = recreate_trajectory(model, start_positions[0], t_eval[0], t_eval[-1], delta_t)

# unzip the list 
recreation = [[i for i, j in recreation],
                 [j for i, j in recreation]]

fig = plt.figure(figsize=(10, 10))
ax0 = plt.axes(projection="3d")
ax0.plot(t_eval, recreation[0], recreation[1], label=r"Trajectory with starting point $(-3, -3)$", color="r")

ax0.set_xlabel(r"$t$")
ax0.set_ylabel(r"$x_1$")
ax0.set_zlabel(r"$x_2$")


print("FIRST VALUE: ", recreation[0][0], recreation[0][1])
print("LAST VALUE: ", recreation[-1][0], recreation[-1][1])

positions = np.empty((40, 40, 2))
next_positions = np.empty((40, 40, 2))
for idx1, i in enumerate(np.linspace(-1, 1, 40)):
    for idx2, j in enumerate(np.linspace(-1, 1, 40)):
        positions[idx1][idx2] = np.array((i, j))
        position = torch.tensor((i, j))
        next_positions[idx1][idx2] = (model(position.float())).detach().numpy()
derivatives = (next_positions - positions) / delta_t
plt.figure(figsize=(25, 25))
plt.quiver(positions[:, :, 0], positions[:, :, 1], derivatives[:, :, 0], derivatives[:, :, 1], units="xy", scale=10);

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                         shuffle=False,
                                         batch_size=hparams["batch_size"],
                                         num_workers=8)
validation__dataloader = torch.utils.data.DataLoader(validation_dataset,
                                         shuffle=False,
                                         batch_size=hparams["batch_size"],
                                         num_workers=8)

In [None]:
trainer = pl.Trainer(
    max_epochs=5,
    log_every_n_steps=1,
    gpus=1 if torch.cuda.is_available() else None
)

trainer.fit(model, train_dataloader)

In [None]:
print("Score of the Model after training:", evaluate_model(model, validation_dataset))

# plot the first one
recreation = recreate_trajectory(model, start_positions[0], t_eval[0], t_eval[-1], delta_t)

# unzip the list 
recreation = [[i for i, j in recreation],
                 [j for i, j in recreation]]

print("FIRST VALUE: ", recreation[0][0], recreation[0][1])
print("LAST VALUE: ", recreation[-1][0], recreation[-1][1])

fig = plt.figure(figsize=(10, 10))
ax0 = plt.axes(projection="3d")
ax0.plot(t_eval, recreation[0], recreation[1], label=r"Trajectory with starting point $(-3, -3)$", color="r")

ax0.set_xlabel(r"$t$")
ax0.set_ylabel(r"$x_1$")
ax0.set_zlabel(r"$x_2$")

# plot the last one
recreation = recreate_trajectory(model, start_positions[-1], t_eval[0], t_eval[-1], delta_t)

# unzip the list 
recreation = [[i for i, j in recreation],
                 [j for i, j in recreation]]

print("FIRST VALUE: ", recreation[0][0], recreation[0][1])
print("LAST VALUE: ", recreation[-1][0], recreation[-1][1])

fig = plt.figure(figsize=(10, 10))
ax0 = plt.axes(projection="3d")
ax0.plot(t_eval, recreation[0], recreation[1], label=r"Trajectory with starting point $(-3, -3)$", color="r")

ax0.set_xlabel(r"$t$")
ax0.set_ylabel(r"$x_1$")
ax0.set_zlabel(r"$x_2$")

In [None]:
positions = np.empty((40, 40, 2))
next_positions = np.empty((40, 40, 2))
for idx1, i in enumerate(np.linspace(-1, 1, 40)):
    for idx2, j in enumerate(np.linspace(-1, 1, 40)):
        positions[idx1][idx2] = np.array((i, j))
        position = torch.tensor((i, j))
        next_positions[idx1][idx2] = (model(position.float())).detach().numpy()
derivatives = (next_positions - positions) / delta_t
plt.figure(figsize=(25, 25))
plt.quiver(positions[:, :, 0], positions[:, :, 1], derivatives[:, :, 0], derivatives[:, :, 1], units="xy", scale=10);