# Training Trajectory Correction LSTM

In [1]:
import os
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import matplotlib.pyplot as plt
from tqdm import tqdm
from constants.filepath import PROJECT_PATH

plt.close('all')

In [2]:
mps_device = torch.device("mps")

training_folderpath = os.path.join(PROJECT_PATH, 'dataset/LSTM_training')
validation_folderpath = os.path.join(PROJECT_PATH, 'dataset/LSTM_validation')
testing_folderpath = os.path.join(PROJECT_PATH, 'dataset/LSTM_testing')

plt.rcParams['figure.dpi'] = 600  # For inline display in the notebook
plt.rcParams['savefig.dpi'] = 600 # For saving figures to files

In [3]:
def load_data(data_path):
    data_files = sorted([os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith('.npz')])
    sequences = []
    for file in data_files:
        
        command = torch.tensor(data['command_data'], dtype=torch.float32)
        analytical = torch.tensor(data['analytical_data'], dtype=torch.float32)
        residuals = torch.tensor(data['residuals'], dtype=torch.float32)

        # Stack command and analytical data as input features
        input_seq = torch.stack((command, analytical), dim=1)
        sequences.append((input_seq, residuals))
    return sequences

class FlowDataset(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences
        self.lengths = [len(seq) for seq in sequences]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.lengths[idx]

def collate_fn(batch):
    sequences, lengths = zip(*batch)
    padded_sequences = pad_sequence(sequences, batch_first=True)
    lengths = torch.tensor(lengths)
    return padded_sequences, lengths

class ResidualLSTM(torch.nn.Module):
    def __init__(self, input_dim=2, hidden_dim=128, output_dim=1, num_layers=2):
        super(ResidualLSTM, self).__init__()
        self.lstm = torch.nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x, lengths):
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        output = self.fc(output)
        return output

In [4]:
# Load the data
train_sequences = load_flows(training_folderpath)
val_sequences = load_flows(validation_folderpath)
train_dataset = FlowDataset(train_sequences)
val_dataset = FlowDataset(val_sequences)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

KeyError: 'residuals is not a file in the archive'

In [None]:
# Training setup
model = ResidualLSTM().to(mps_device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 50

In [None]:
model.train()
for epoch in range(epochs):
    epoch_loss = 0
    for batch, lengths in train_loader:
        batch = batch.to(mps_device)
        lengths = lengths.to(mps_device)
        optimizer.zero_grad()
        outputs = model(batch, lengths)
        loss = criterion(outputs.squeeze(-1), batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader)}")

In [None]:
# Save the model
torch.save(model.state_dict(), 'lstm_residual_model.pth')

print("Training complete!")

# Test

In [None]:
# model = ResidualModel([2, 512, 512, 512, 128, 1]).to('cuda:1')
model = ResidualModel([2, 512, 512, 512, 128, 1]).to('mps')

state_dict = torch.load('trajectory_correction_DNN_v0_blastoise_ultimate.pth')
model.load_state_dict(state_dict)

In [None]:
# Example: Run model inference (optional, based on your needs)
model.eval()

##%%
test_residuals = np.load(os.path.join(PROJECT_PATH, 'model_data', 'test_residuals.npz'))['residuals']
test_command = np.load(os.path.join(PROJECT_PATH, 'model_data', 'test_command_data.npz'))['command_data']
test_analytical = np.load(os.path.join(PROJECT_PATH, 'model_data', 'test_analytical_data.npz'))['analytical_data']

test_combined = np.stack((test_command, test_analytical), axis=1)
test_input = torch.tensor(test_combined, dtype=torch.float32, device=mps_device)

with torch.no_grad():
    output = model(test_input) * 1e-9
    test_result = (output).cpu().numpy().reshape(-1)

test_sim = test_residuals + test_analytical

test_t = np.arange(test_command.shape[0])   
plt.figure()
plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
# plt.plot(test_accel.ts, test_accel.sim_Q_out, label = 'Simulated Data', color = 'red')
plt.plot(test_t, test_analytical / 1e9, label = 'Analytical Data', color = 'blue')
plt.plot(test_t, test_result, label='Residual', color='magenta')
plt.plot(test_t, (test_analytical / 1e9) + test_result, label='Total', color='green')
plt.plot(test_t, test_sim / 1e9, label='GT Total', color='r')
plt.legend()
plt.subplots_adjust(right=0.8)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.85))
# plt.savefig('slides/flowrate_altogether', bbox_inches='tight')

In [None]:
plt.figure()
plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
plt.plot(test_t, test_sim / 1e9, label='Sim', color='r')
plt.ylabel('Flow rate [m^3/s]')
plt.legend()
plt.legend(loc='center left', bbox_to_anchor=(1, 0.90))
# plt.savefig('slides/flowrate_sim', bbox_inches='tight')

In [None]:
plt.figure()
plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
plt.plot(test_t, (test_analytical / 1e9) + test_result, label='DNN + Prior', color='green')
plt.ylabel('Flow rate [m^3/s]')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.90))
# plt.savefig('slides/flowrate_dnn_prior', bbox_inches='tight')

In [None]:
plt.figure()
plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
plt.plot(test_t, (test_analytical / 1e9), label='Prior', color='blue')
plt.ylabel('Flow rate [m^3/s]')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.90))
# plt.savefig('slides/flowrate_prior', bbox_inches='tight')

In [None]:
plt.figure()
# plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
prior_error = np.sqrt(np.mean(((test_sim / 1e9) - (test_analytical / 1e9))**2))
dnn_prior_error = np.sqrt(np.mean(((test_sim / 1e9) - ((test_analytical / 1e9) + test_result))**2))

print('Prior Error:', prior_error)
print("DNN + Prior Error:", dnn_prior_error)

plt.plot(test_t, (test_sim / 1e9) - (test_analytical / 1e9), label='Prior', color='blue')

plt.ylabel('Flow rate [m^3/s]')
plt.legend()
# plt.savefig('slides/error_prior')

In [None]:
plt.figure()
# plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
plt.plot(test_t, (test_sim / 1e9) - ((test_analytical / 1e9) + test_result), label='DNN + Prior', color='green')
plt.ylabel('Flow rate a[m^3/s]')
plt.legend()
# plt.savefig('slides/error_dnn_prior')

In [None]:
plt.plot(test_t, (test_sim / 1e9) - (test_analytical / 1e9), label='Prior', color='blue')
plt.plot(test_t, (test_sim / 1e9) - ((test_analytical / 1e9) + test_result), label='DNN + Prior', color='green')
plt.ylabel('Flow rate [m^3/s]')
plt.legend()
# plt.savefig('slides/error_both')

In [None]:
plt.figure()
# plt.plot(test_t, test_command / 1e9, label='Input', color='black', linestyle='--')
plt.plot(test_t, (test_sim / 1e9) - (test_sim / 1e9), label='Zero', color='black')
plt.ylabel('Flow rate a[m^3/s]')
plt.legend()

In [None]:
# torch.save(model.state_dict(), 'trajectory_correction_DNN_v0_blastoise_ultimate.pth')
