<a href="https://colab.research.google.com/github/marilynbraojos/transformer_demo/blob/main/gps2309.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# calculate_correction

**Author:** Marilyn Braojos Gutierrez\
**Purpose:** This program aims to calculate the correction values between broadcast clock bias and final clock bias.\
**PhD Milestone:** #1: *Leverage deep learning models to GPS satellite clock bias corrections.*\
**Project:** This program is Step (1) in this PhD milestone. Obtaining the data is the first critical step.\
**References:**\
N/A

# Import Libraries

In [1]:
from datetime import datetime, timedelta
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# Calculate Correction Data

In [2]:
# load broadcast clock bias polynomial data npz without relativistic effects considered in the polynomial for G21 satellite only

def load_broadcast_file(day):
    filename = f'/Volumes/MARI/ssdl_gps/rnx_polynomial/gps_poly_2018{day:03d}_G21.npz'
    data = np.load(filename)
    time_strings = data['time_strings']
    poly_values = data['poly_values']
    time_datetimes = [datetime.strptime(ts, '%Y:%m:%d:%H:%M:%S') for ts in time_strings]
    return time_datetimes, poly_values

In [None]:
# # load broadcast clock bias polynomial data npz with relativistic effects considered in the polynomial for G21 satellite only 

# def load_broadcast_file_w_rel(day):
#     filename = f'/Volumes/MARI/ssdl_gps/rnx_polynomial/gps_poly_rel_2019{day:03d}_G21.npz'
#     data = np.load(filename)
#     time_strings = data['time_strings']
#     poly_values = data['poly_values']
#     time_datetimes = [datetime.strptime(ts, '%Y:%m:%d:%H:%M:%S') for ts in time_strings]
#     return time_datetimes, poly_values

In [3]:
# load final clock bias data npz for G21 satellite only and GRG station only

def load_clk_file(week):
    filename = f'/Volumes/MARI/ssdl_gps/clk_npz_sat/G21/grg_gps_clk_{week}_G21.npz'
    data = np.load(filename)
    epochs = [datetime(year, month, day, hour, minute, second)
              for year, month, day, hour, minute, second 
              in zip(data['yyyy'], data['mm'], data['dd'], data['hh'], data['mi'], data['ss'])]
    clock_bias = data['clock_bias_vals']
    return epochs, clock_bias

In [4]:
# initialize arrays

all_matching_epochs = []
all_matching_clock_bias = []
all_matching_poly_values = []
all_correction_vals = []

all_time_datetimes = []
all_poly_values = []

In [5]:
%%time

# store all the polynomial values and corresponding times in the time_datetimes and poly_values array - Wall time: 25.2 s 

for day in range(1, 366):
    time_datetimes, poly_values = load_broadcast_file(day)
    all_time_datetimes.extend(time_datetimes)
    all_poly_values.extend(poly_values)
    print(f'Finished storing all epochs and corresponding polynomial values for day {day}')

Finished storing all epochs and corresponding polynomial values for day 1
Finished storing all epochs and corresponding polynomial values for day 2
Finished storing all epochs and corresponding polynomial values for day 3
Finished storing all epochs and corresponding polynomial values for day 4
Finished storing all epochs and corresponding polynomial values for day 5
Finished storing all epochs and corresponding polynomial values for day 6
Finished storing all epochs and corresponding polynomial values for day 7
Finished storing all epochs and corresponding polynomial values for day 8
Finished storing all epochs and corresponding polynomial values for day 9
Finished storing all epochs and corresponding polynomial values for day 10
Finished storing all epochs and corresponding polynomial values for day 11
Finished storing all epochs and corresponding polynomial values for day 12
Finished storing all epochs and corresponding polynomial values for day 13
Finished storing all epochs and co

In [6]:
%%time 

# previously completed weeks 2034 until 2060 - stopped after 2060 was completed :: 7/29/2024
# previously completed weeks 2034 until 2086 :: 8/2/2024 (Wall Time: 15 hr 26 minutes 25 seconds)
# previously completed weeks 1982 until 2034 :: 8/29/2024 (Wall Time:  10 hr  14 minutes 56 seconds)


for week in range(1982, 2035):
    epochs, clock_bias = load_clk_file(week)
    matching_indices = [i for i, epoch in enumerate(epochs) if epoch in all_time_datetimes]
    matching_epochs = [epochs[i] for i in matching_indices]
    matching_clock_bias = clock_bias[matching_indices]
    matching_poly_values = [all_poly_values[all_time_datetimes.index(epoch)] for epoch in matching_epochs]

    correction_vals = matching_clock_bias - matching_poly_values
    all_matching_epochs.extend(matching_epochs)
    all_matching_clock_bias.extend(matching_clock_bias)
    all_matching_poly_values.extend(matching_poly_values)
    all_correction_vals.extend(correction_vals)
    print(f'Completed corrections for week: {week}')

Completed corrections for week: 1982
Completed corrections for week: 1983
Completed corrections for week: 1984
Completed corrections for week: 1985
Completed corrections for week: 1986
Completed corrections for week: 1987
Completed corrections for week: 1988
Completed corrections for week: 1989
Completed corrections for week: 1990
Completed corrections for week: 1991
Completed corrections for week: 1992
Completed corrections for week: 1993
Completed corrections for week: 1994
Completed corrections for week: 1995
Completed corrections for week: 1996
Completed corrections for week: 1997
Completed corrections for week: 1998
Completed corrections for week: 1999
Completed corrections for week: 2000
Completed corrections for week: 2001
Completed corrections for week: 2002
Completed corrections for week: 2003
Completed corrections for week: 2004
Completed corrections for week: 2005
Completed corrections for week: 2006
Completed corrections for week: 2007
Completed corrections for week: 2008
C

In [7]:
np.savez('/Volumes/MARI/ssdl_gps/correction_data/2018/correction_data_2018.npz',
         matching_epochs=all_matching_epochs,
         matching_clock_bias=all_matching_clock_bias,
         matching_poly_values=all_matching_poly_values,
         correction_vals=all_correction_vals)

In [8]:
matching_epoch_strings = [epoch.strftime('%Y:%m:%d:%H:%M:%S') for epoch in all_matching_epochs]

In [9]:
np.savez('/Volumes/MARI/ssdl_gps/correction_data/2018/correction_data_2018_str_update.npz',
         matching_epochs=matching_epoch_strings,
         matching_clock_bias=all_matching_clock_bias,
         matching_poly_values=all_matching_poly_values,
         correction_vals=all_correction_vals)

# Plots

In [10]:
data = np.load('/Volumes/MARI/ssdl_gps/correction_data/2018/correction_data_2018_str_update.npz')
epochs = data['matching_epochs']
final_clock_bias = data['matching_clock_bias']
broadcast_clock_bias = data['matching_poly_values']
correction_value = data['correction_vals']

epoch_datetime = [datetime.strptime(epoch, '%Y:%m:%d:%H:%M:%S') for epoch in epochs]

In [24]:
for month in range(1, 13):
    start_date = datetime(2018, month, 1)
    if month == 12:
        end_date = datetime(2019, 1, 1)
    else:
        end_date = datetime(2018, month + 1, 1)

    monthly_indices = [i for i, epoch in enumerate(epoch_datetime) if start_date <= epoch < end_date]
    monthly_epochs = [epoch_datetime[i] for i in monthly_indices]
    monthly_clock_bias = [final_clock_bias[i] for i in monthly_indices]
    monthly_poly_values = [broadcast_clock_bias[i] for i in monthly_indices]
    monthly_correction_vals = [correction_value[i] for i in monthly_indices]

    # Plot clock_bias and poly_values
    plt.figure(figsize=(50, 10))
    plt.scatter(monthly_epochs, monthly_clock_bias, label='Clock Bias (Station: GRG)')
    plt.scatter(monthly_epochs, monthly_poly_values, label='Broadcast Polynomial Values', s=5)
    plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
    plt.ylabel('Bias Values (s)')
    # plt.ylim(-0.00002734, -0.00002722)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
    plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=5))  # Set major ticks every hour
    # plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)
    plt.grid(color='gray', linestyle='--', linewidth=0.25)
    plt.xticks(rotation=90)
    plt.legend()
    plt.title(f'Clock Bias and Poly Values for G21 {start_date.strftime("%B %Y")}')
    plt.savefig(f'/Volumes/MARI/ssdl_gps/plots_raw/clock_bias_poly_values_{month:02d}_2018.png')
    plt.close()
    
    # Plot correction values
    plt.figure(figsize=(50, 10))
    plt.scatter(monthly_epochs, monthly_correction_vals, label='Clock Bias Correction (s)')
    plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
    plt.ylabel('Bias Correction Values (s)')
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
    plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=5))  # Set major ticks every hour
    # plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)
    plt.grid(color='gray', linestyle='--', linewidth=0.25)
    plt.xticks(rotation=90)
    plt.legend()
    plt.title(f'Clock Bias Correction Values for G21 {start_date.strftime("%B %Y")}')
    plt.savefig(f'/Volumes/MARI/ssdl_gps/plots_raw/correction_values_{month:02d}_2018.png')
    plt.close() 

In [25]:
for month in range(1, 13):
    start_date = datetime(2018, month, 1)
    if month == 12:
        end_date = datetime(2019, 1, 1)
    else:
        end_date = datetime(2018, month + 1, 1)

    monthly_indices = [i for i, epoch in enumerate(epoch_datetime) if start_date <= epoch < end_date]
    monthly_epochs = [epoch_datetime[i] for i in monthly_indices]
    monthly_clock_bias = [epoch_datetime[i] for i in monthly_indices]
    monthly_poly_values = [epoch_datetime[i] for i in monthly_indices]
    monthly_correction_vals = [epoch_datetime[i] for i in monthly_indices]

    # Plot clock_bias and poly_values
    plt.figure(figsize=(50, 10))
    plt.plot(monthly_epochs, monthly_clock_bias, label='Clock Bias (Station: GRG)', linewidth=4)
    plt.plot(monthly_epochs, monthly_poly_values, label='Broadcast Polynomial Values', linewidth=2)
    plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
    plt.ylabel('Bias Values (s)')
    # plt.ylim(-0.0002734, -0.0002722)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
    plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=5))  # Set major ticks every hour
    # plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)
    plt.grid(color='gray', linestyle='--', linewidth=0.25)
    plt.xticks(rotation=90)
    plt.legend()
    plt.title(f'Clock Bias and Poly Values for G21 {start_date.strftime("%B %Y")}')
    plt.savefig(f'/Volumes/MARI/ssdl_gps/plots_raw/plt_clock_bias_poly_values_{month:02d}_2018.png')
    plt.close()
    
    # Plot correction values
    plt.figure(figsize=(50, 10))
    plt.plot(monthly_epochs, monthly_correction_vals, label='Clock Bias Correction (s)')
    plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
    plt.ylabel('Bias Correction Values (s)')
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
    plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=5))  # Set major ticks every hour
    # plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)
    plt.grid(color='gray', linestyle='--', linewidth=0.25)
    plt.xticks(rotation=90)
    plt.legend()
    plt.title(f'Clock Bias Correction Values for G21 {start_date.strftime("%B %Y")}')
    plt.savefig(f'/Volumes/MARI/ssdl_gps/plots_raw/plt_correction_values_{month:02d}_2018.png')
    plt.close() 

## Archive

In [None]:
# broadcast_01082019 = np.load('/Volumes/MARI/ssdl_gps/rnx_polynomial/gps_poly_2019008_G21.npz')
# broadcast_01082019.files

In [None]:
# time_strings = broadcast_01082019['time_strings']
# poly_values = broadcast_01082019['poly_values']

In [None]:
# time_datetimes = [datetime.strptime(ts, '%Y:%m:%d:%H:%M:%S') for ts in time_strings]
# # time_datetimes

In [None]:
# clk = np.load('/Volumes/MARI/ssdl_gps/clk_npz_sat/G21/grg_gps_clk_2035_G21.npz')
# # clk = np.load('gps_2035_G21_test.npz')

# clk.files

In [None]:
# # Extract the arrays
# satellite = clk['satellite']
# yyyy = clk['yyyy']
# mm = clk['mm']
# dd = clk['dd']
# hh = clk['hh']
# mi = clk['mi']
# ss = clk['ss']
# clock_bias = clk['clock_bias_vals']
# ver = clk['vers']
# filepath = clk['filename']
 
# # Combine date and time into datetime objects
# epochs = [datetime(year, month, day, hour, minute, second)
#           for year, month, day, hour, minute, second 
#           in zip(yyyy, mm, dd, hh, mi, ss)]

In [None]:
# epochs

In [None]:
# matching_indices = [i for i, epoch in enumerate(epochs) if epoch in time_datetimes]
# matching_epochs = [epochs[i] for i in matching_indices]
# matching_clock_bias = clock_bias[matching_indices]
# # matching_epochs

In [None]:
# matching_poly_values = [poly_values[time_datetimes.index(epoch)] for epoch in epochs if epoch in time_datetimes]
# # matching_poly_values

In [None]:
# # Plot clock_bias and poly_values on the same plot
# plt.figure(figsize=(10, 5))
# plt.scatter(matching_epochs, matching_clock_bias, label='Clock Bias (Station: GRG)')
# plt.scatter(matching_epochs, matching_poly_values, label='Broadcast Polynomial Values', s = 5)
# plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
# plt.ylabel('Bias Values (s)')
# plt.ylim(-0.0002734,-0.0002722)
# # plt.xlim(datetime.strptime('2019:01:08:18:00:00', '%Y:%m:%d:%H:%M:%S'),datetime.strptime('2019:01:08:21:00:00', '%Y:%m:%d:%H:%M:%S'))
# plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
# plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=1))  # Set major ticks every hour
# plt.tight_layout()
# plt.grid(color='gray', linestyle='--', linewidth=0.25)
# plt.xticks(rotation=90)

# plt.legend()
# plt.title('Clock Bias and Poly Values for G21 Jan 8, 2019')
# plt.show()

In [None]:
# correction_val = matching_clock_bias - matching_poly_values
# correction_val

# e1 = datetime.strptime('2019:01:08:00:00:00', '%Y:%m:%d:%H:%M:%S')
# e2 = datetime.strptime('2019:01:09:00:30:00', '%Y:%m:%d:%H:%M:%S')

In [None]:
# # Plot clock_bias and poly_values on the same plot
# plt.figure(figsize=(10, 5))
# plt.scatter(matching_epochs, correction_val, label='Clock Bias Correction (s)')
# plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
# plt.ylabel('Bias Correction Values (s)')
# # plt.xlim(e1,e2)
# plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
# plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=1))  # Set major ticks every hour
# plt.tight_layout()
# plt.grid(color='gray', linestyle='--', linewidth=0.25)
# plt.xticks(rotation=90)

# plt.legend()
# plt.title('Clock Bias Correction Value for G21 Jan 8, 2019')
# plt.show()

In [None]:
# import numpy as np
# from tabulate import tabulate

# # Load the data
# data_g21 = np.load('gps_rnx_2019008_G21.npz')

# # Get the file keys
# keys = data_g21.files

# # Prepare the data for display
# table_data = [keys]  # Column headers

# # Get the length of the data (assuming all arrays have the same length)
# num_rows = len(data_g21[keys[0]])

# # Collect rows of data
# for i in range(num_rows):
#     row = [data_g21[key][i] for key in keys]
#     table_data.append(row)

# # Print the table
# print(tabulate(table_data, headers="firstrow", tablefmt="grid"))

# (Archive) Transformer

In [None]:
data = np.load('/Volumes/MARI/ssdl_gps//correction_data/correction_data_2019_str_update.npz')
data.files

In [None]:
# Extract the arrays from the .npz file
upd_matching_epochs = data['matching_epochs']
upd_matching_clock_bias = data['matching_clock_bias']
upd_matching_poly_values = data['matching_poly_values']
upd_correction_vals = data['correction_vals']

In [None]:
print(max(upd_matching_epochs))

In [None]:
# Plot correction values
match_epochs_obj = [datetime.strptime(ts, '%Y:%m:%d:%H:%M:%S') for ts in upd_matching_epochs]

plt.figure(figsize=(50, 10))
plt.scatter(match_epochs_obj, upd_correction_vals, label='Clock Bias Correction (s)')
plt.xlabel('Time (YYYY:MM:DD:HH:MI:SS)')
plt.ylabel('Bias Correction Values (s)')
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S'))
plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=5))  # Set major ticks every hour
# plt.tight_layout()
plt.grid(color='gray', linestyle='--', linewidth=0.25)
plt.xticks(rotation=90)
plt.legend()
plt.title(f'<test>')

In [None]:
plt.plot(upd_matching_clock_bias)

In [None]:
upd_epoch_dates = [datetime.strptime(t, '%Y:%m:%d:%H:%M:%S') for t in upd_matching_epochs]
end_date = datetime(2019, 1, 31, 23, 59, 59)

In [None]:
mask = [(date <= end_date) and (date.second == 0 or date.second == 30) for date in upd_epoch_dates]

In [None]:
mask = np.array(mask)
mask

In [None]:
# Apply the mask to filter the arrays
filtered_epochs = upd_matching_epochs[mask]
filtered_clock_bias = upd_matching_clock_bias[mask]
filtered_poly_values = upd_matching_poly_values[mask]
filtered_correction_vals = upd_correction_vals[mask]

In [None]:
print(filtered_epochs)

In [None]:
# Save the filtered data into a new .npz file
np.savez('/Volumes/MARI/ssdl_gps/correction_data_2019_mo1_upd.npz',
         matching_epochs = filtered_epochs,
         matching_clock_bias=filtered_clock_bias,
         matching_poly_values=filtered_poly_values,
         correction_vals=filtered_correction_vals)

In [None]:
data = np.load('/Volumes/MARI/ssdl_gps/correction_data_2019_mo1.npz')

In [None]:
matching_poly_values = data['matching_poly_values']
correction_vals = data['correction_vals']

In [None]:
# Ensure data is in the correct shape [num_samples, num_features]
# Assuming each sample has one feature (i.e., 1D data)
matching_poly_values = matching_poly_values[:, np.newaxis]  # Shape: [num_samples, 1]
correction_vals = correction_vals[:, np.newaxis]  # Shape: [num_samples, 1]

In [None]:
train_size = int(0.7 * len(matching_poly_values))
test_size = len(matching_poly_values) - train_size

In [None]:
train_poly_values = matching_poly_values[:train_size]
train_correction_vals = correction_vals[:train_size]

In [None]:
test_poly_values = matching_poly_values[train_size:]
test_correction_vals = correction_vals[train_size:]

In [None]:
train_poly_values = torch.from_numpy(train_poly_values).float()
train_correction_vals = torch.from_numpy(train_correction_vals).float()
test_poly_values = torch.from_numpy(test_poly_values).float()
test_correction_vals = torch.from_numpy(test_correction_vals).float()

In [None]:
train_dataset = TensorDataset(train_poly_values, train_correction_vals)
test_dataset = TensorDataset(test_poly_values, test_correction_vals)

In [None]:
numeric_covariates = data.files

In [None]:
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import TensorDataset

# Load data
data = np.load('/Volumes/MARI/ssdl_gps/correction_data_2019_mo1.npz')

matching_poly_values = data['matching_poly_values']
correction_vals = data['correction_vals']

# Ensure data is in the correct shape [num_samples, num_features]
# Assuming each sample has one feature (i.e., 1D data)
matching_poly_values = matching_poly_values[:, np.newaxis]  # Shape: [num_samples, 1]
correction_vals = correction_vals[:, np.newaxis]  # Shape: [num_samples, 1]

# Apply min-max scaling
scaler_poly = MinMaxScaler()
scaler_correction = MinMaxScaler()

matching_poly_values_scaled = scaler_poly.fit_transform(matching_poly_values)
correction_vals_scaled = scaler_correction.fit_transform(correction_vals)

# Split data into train and test sets
train_size = int(0.7 * len(matching_poly_values_scaled))
test_size = len(matching_poly_values_scaled) - train_size

train_poly_values = matching_poly_values_scaled[:train_size]
train_correction_vals = correction_vals_scaled[:train_size]

test_poly_values = matching_poly_values_scaled[train_size:]
test_correction_vals = correction_vals_scaled[train_size:]

# Convert to torch tensors
train_poly_values = torch.from_numpy(train_poly_values).float()
train_correction_vals = torch.from_numpy(train_correction_vals).float()
test_poly_values = torch.from_numpy(test_poly_values).float()
test_correction_vals = torch.from_numpy(test_correction_vals).float()

# Create datasets
train_dataset = TensorDataset(train_poly_values, train_correction_vals)
test_dataset = TensorDataset(test_poly_values, test_correction_vals)

In [None]:
class transformer_block(torch.nn.Module):

    def __init__(self,embed_size,num_heads):
        super(transformer_block, self).__init__()

        self.attention = torch.nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.fc = torch.nn.Sequential(nn.Linear(embed_size, 4 * embed_size),
                                 nn.LeakyReLU(),
                                 nn.Linear(4 * embed_size, embed_size))
        self.dropout = torch.nn.Dropout(drop_prob)
        self.ln1 = torch.nn.LayerNorm(embed_size, eps=1e-6)
        self.ln2 = torch.nn.LayerNorm(embed_size, eps=1e-6)

    def forward(self, x):

        attn_out, _ = self.attention(x, x, x, need_weights=False)
        x = x + self.dropout(attn_out)
        x = self.ln1(x)

        fc_out = self.fc(x)
        x = x + self.dropout(fc_out)
        x = self.ln2(x)

        return x

class transformer_forecaster(torch.nn.Module):

    def __init__(self,embed_size,num_heads,num_blocks):
        super(transformer_forecaster, self).__init__()

        num_len = len(numeric_covariates)

        self.blocks = torch.nn.ModuleList([transformer_block(embed_size,num_heads) for n in range(num_blocks)])

        self.forecast_head = torch.nn.Sequential(nn.Linear(embed_size, embed_size*2),
                                           nn.LeakyReLU(),
                                           nn.Dropout(drop_prob),
                                           nn.Linear(embed_size*2, embed_size*4),
                                           nn.LeakyReLU(),
                                           nn.Linear(embed_size*4, forecast_length),
                                           nn.ReLU())

    def forward(self, x_numeric, x_category, x_static):

        tmp_list = []
        for i,embed_layer in enumerate(self.embedding_static):
            tmp_list.append(embed_layer(x_static[:,i]))
        categroical_static_embeddings = torch.stack(tmp_list).mean(dim=0).unsqueeze(1)

        tmp_list = []
        for i,embed_layer in enumerate(self.embedding_cov):
            tmp_list.append(embed_layer(x_category[:,:,i]))
        categroical_covariates_embeddings = torch.stack(tmp_list).mean(dim=0)
        T = categroical_covariates_embeddings.shape[1]

        embed_out = (categroical_covariates_embeddings + categroical_static_embeddings.repeat(1,T,1))/2
        x = torch.concat((x_numeric,embed_out),dim=-1)

        for block in self.blocks:
            x = block(x)

        x = x.mean(dim=1)
        x = self.forecast_head(x)

        return x

In [None]:
class RMSLELoss(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.mse = torch.nn.MSELoss()

    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))

In [None]:
num_epoch = 1000
min_val_loss = 999

num_blocks = 1
embed_size = 500
num_heads = 50
batch_size = 128
learning_rate = 3e-4
time_shuffle = False
drop_prob = 0.1

model = transformer_forecaster(embed_size,num_heads,num_blocks).to(device)
criterion = RMSLELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

In [None]:
BS = 100  # Batch size
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)

In [None]:
print("Train Poly Values Shape:", train_poly_values.shape)
print("Train Correction Vals Shape:", train_correction_vals.shape)
print("Test Poly Values Shape:", test_poly_values.shape)
print("Test Correction Vals Shape:", test_correction_vals.shape)

In [None]:
def split_sequence(
    sequence: np.ndarray, ratio: float = 0.7
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Splits a sequence into 2 (3) parts, as is required by our transformer
    model.

    Assume our sequence length is L, we then split this into src of length N
    and tgt_y of length M, with N + M = L.
    src, the first part of the input sequence, is the input to the encoder, and we
    expect the decoder to predict tgt_y, the second part of the input sequence.
    In addition we generate tgt, which is tgt_y but "shifted left" by one - i.e. it
    starts with the last token of src, and ends with the second-last token in tgt_y.
    This sequence will be the input to the decoder.

    Args:

        sequence: batched input sequences to split [bs, seq_len, num_features]
        ratio: split ratio, N = ratio * L

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: src, tgt, tgt_y
    """
    src_end = int(sequence.shape[1] * ratio)
    # [bs, src_seq_len, num_features]
    src = sequence[:, :src_end]
    # [bs, tgt_seq_len, num_features]
    tgt = sequence[:, src_end - 1 : -1]
    # [bs, tgt_seq_len, num_features]
    tgt_y = sequence[:, src_end:]

    return src, tgt, tgt_y

In [None]:
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html,
# only modified to account for "batch first"

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
      """
      Args:
        d_model (int): Dimension of the model (embedding dimension)
        dropout (float, optional): Dropout probability. Default is 0.1.
        max_len (int, optional): Maximum length of input sequences. Default is 5000.

      Attributes:
        pe (torch.Tensor): Positional encoding tensor. Shape: (1, max_len, d_model)

      Returns:
        torch.Tensor: input tensor with added positional encoding.

      """
      super().__init__()
      self.dropout = torch.nn.Dropout(p=dropout)

      position = torch.arange(max_len).unsqueeze(1)                           # 1-D tensor from 0 to max_len -1. Unsqueeze "adds" a superficial 1 dim.
      div_term = torch.exp(
          torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
      )
      pe = torch.zeros(1, max_len, d_model)
      pe[0, :, 0::2] = torch.sin(position * div_term)
      pe[0, :, 1::2] = torch.cos(position * div_term)
      self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      """Adds positional encoding to the given tensor.

      Args:
          x: tensor to add PE to [bs, seq_len, embed_dim]

      Returns:
          torch.Tensor: tensor with PE [bs, seq_len, embed_dim]
      """
      x = x + self.pe[:, : x.size(1)]
      return self.dropout(x)

In [None]:
class TransformerWithPE(torch.nn.Module):
    def __init__(
        self, in_dim: int, out_dim: int, embed_dim: int, num_heads: int, num_layers: int
    ) -> None:
        """
        Initializes a transformer model with positional encoding.

        Args:
            in_dim: number of input features
            out_dim: number of features to predict
            embed_dim: embed features to this dimension
            num_heads: number of transformer heads
            num_layers: number of encoder and decoder layers
        """
        super().__init__()

        self.positional_encoding = PositionalEncoding(embed_dim)

        # transform input features into embedded features
        self.encoder_embedding = torch.nn.Linear(
            in_features=in_dim, out_features=embed_dim
        )
        self.decoder_embedding = torch.nn.Linear(
            in_features=out_dim, out_features=embed_dim
        )

        # map output into output dimension
        self.output_layer = torch.nn.Linear(in_features=embed_dim, out_features=out_dim)


        self.transformer = torch.nn.Transformer(
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            d_model=embed_dim,
            batch_first=True,
        )

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """Forward function of the model.

        Args:
            src: input sequence to the encoder [bs, src_seq_len, num_features]
            tgt: input sequence to the decoder [bs, tgt_seq_len, num_features]

        Returns:
            torch.Tensor: predicted sequence [bs, tgt_seq_len, feat_dim]
        """
        # if self.train:
        # Add noise to decoder inputs during training
        # tgt = tgt + torch.normal(0, 0.1, size=tgt.shape).to(tgt.device)

        # Embed encoder input and add positional encoding.
        # [bs, src_seq_len, embed_dim]
        src = self.encoder_embedding(src)
        src = self.positional_encoding(src)

        # Generate mask to avoid attention to future outputs.
        # [tgt_seq_len, tgt_seq_len]
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(tgt.shape[1])
        # Embed decoder input and add positional encoding.
        # [bs, tgt_seq_len, embed_dim]
        tgt = self.decoder_embedding(tgt)
        tgt = self.positional_encoding(tgt)

        # Get prediction from transformer and map to output dimension.
        # [bs, tgt_seq_len, embed_dim]
        pred = self.transformer(src, tgt, tgt_mask=tgt_mask)
        pred = self.output_layer(pred)

        return pred                                                             # return predicted sequence


    def infer(self, src: torch.Tensor, tgt_len: int) -> torch.Tensor:
        """Runs inference with the model, meaning: predicts future values
        for an unknown sequence.
        For this, iteratively generate the next output token while
        feeding the already generated ones as input sequence to the decoder.

        Args:
            src: input to the encoder [bs, src_seq_len, num_features]
            tgt_len: desired length of the output

        Returns:
            torch.Tensor: inferred sequence
        """
        output = torch.zeros((src.shape[0], tgt_len + 1, src.shape[2])).to(src.device)
        output[:, 0] = src[:, -1]
        for i in range(tgt_len):
            output[:, i + 1] = self.forward(src, output)[:, i]

        return output[:, 1:]

In [None]:
def load_and_partition_data(
    data_path: Path, seq_length: int = 100
) -> tuple[np.ndarray, int]:
    """Loads the given data and paritions it into sequences of equal length.

    Args:
        data_path: path to the dataset
        sequence_length: length of the generated sequences

    Returns:
        tuple[np.ndarray, int]: tuple of generated sequences and number of
            features in dataset
    """
    data = np.load(data_path)
    num_features = len(data.keys())

    # Check that each feature provides the same number of data points
    data_lens = [len(data[key]) for key in data.keys()]
    assert len(set(data_lens)) == 1

    num_sequences = data_lens[0] // seq_length
    sequences = np.empty((num_sequences, seq_length, num_features))

    for i in range(0, num_sequences):
        # [sequence_length, num_features]
        sample = np.asarray(
            [data[key][i * seq_length : (i + 1) * seq_length] for key in data.keys()]
        ).swapaxes(0, 1)
        sequences[i] = sample

    return sequences, num_features


def make_datasets(sequences: np.ndarray) -> tuple[TensorDataset, TensorDataset]:
    """Create train and test dataset.

    Args:
        sequences: sequences to use [num_sequences, sequence_length, num_features]

    Returns:
        tuple[TensorDataset, TensorDataset]: train and test dataset
    """
    # Split sequences into train and test split
    train, test = train_test_split(sequences, test_size=0.2)
    return TensorDataset(torch.Tensor(train)), TensorDataset(torch.Tensor(test))


def visualize(
    src: torch.Tensor,
    tgt: torch.Tensor,
    # 529 pred: torch.Tensor,
    pred_infer: torch.Tensor,
    idx=0,
) -> None:
    """Visualizes a given sample including predictions.

    Args:
        src: source sequence [bs, src_seq_len, num_features]
        tgt: target sequence [bs, tgt_seq_len, num_features]
        pred: prediction of the model [bs, tgt_seq_len, num_features]
        pred_infer: prediction obtained by running inference
            [bs, tgt_seq_len, num_features]
        idx: batch index to visualize
    """
    
    x = np.arange(src.shape[1] + tgt.shape[1])
    src_len = src.shape[1]

    # 529 commented out
    plt.plot(x[:src_len], src[idx].cpu().detach(), "bo-", label="src")
    plt.plot(x[src_len:], tgt[idx].cpu().detach(), "go-", label="tgt")
    # 529 plt.plot(x[src_len:], pred[idx].cpu().detach(), "ro-", label="pred")
    plt.plot(x[src_len:], pred_infer[idx].cpu().detach(), "yo-", label="pred_infer")

    # plt.plot(x[:src_len], scaler.inverse_transform(src[idx].cpu().detach()), "bo-", label="src")
    # plt.plot(x[src_len:], scaler.inverse_transform(tgt[idx].cpu().detach()), "go-", label="tgt")
    # # 529 plt.plot(x[src_len:], pred[idx].cpu().detach(), "ro-", label="pred")
    # plt.plot(x[src_len:], scaler.inverse_transform(pred_infer[idx].cpu().detach()), "yo-", label="pred_infer")

    plt.legend()
    plt.show()
    plt.clf()


def split_sequence(
    sequence: np.ndarray, ratio: float = 0.7
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Splits a sequence into 2 (3) parts, as is required by our transformer
    model.

    Assume our sequence length is L, we then split this into src of length N
    and tgt_y of length M, with N + M = L.
    src, the first part of the input sequence, is the input to the encoder, and we
    expect the decoder to predict tgt_y, the second part of the input sequence.
    In addition we generate tgt, which is tgt_y but "shifted left" by one - i.e. it
    starts with the last token of src, and ends with the second-last token in tgt_y.
    This sequence will be the input to the decoder.


    Args:
        sequence: batched input sequences to split [bs, seq_len, num_features]
        ratio: split ratio, N = ratio * L

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: src, tgt, tgt_y
    """
    src_end = int(sequence.shape[1] * ratio)
    # [bs, src_seq_len, num_features]
    src = sequence[:, :src_end]
    # [bs, tgt_seq_len, num_features]
    tgt = sequence[:, src_end - 1 : -1]
    # [bs, tgt_seq_len, num_features]
    tgt_y = sequence[:, src_end:]

    return src, tgt, tgt_y


def move_to_device(device: torch.Tensor, *tensors: torch.Tensor) -> list[torch.Tensor]:
    """Move all given tensors to the given device.

    Args:
        device: device to move tensors to
        tensors: tensors to move

    Returns:
        list[torch.Tensor]: moved tensors
    """
    moved_tensors = []
    for tensor in tensors:
        if isinstance(tensor, torch.Tensor):
            moved_tensors.append(tensor.to(device))
        else:
            moved_tensors.append(tensor)
    return moved_tensors

In [None]:
BS = 100                                                                        # batch size
FEATURE_DIM = 128                                                              # dimensionality of input features
NUM_HEADS = 16                                                                  # number of attention heads in the multi-head attention mechanism
NUM_EPOCHS = 5                                                                 # number of times entire dataset is passed for training
NUM_VIS_EXAMPLES = 1
NUM_LAYERS = 2                                                                # number of encoder and decoder layers in the mdel
LR = 0.001    

In [None]:
# Load data and generate train and test datasets / dataloaders
sequences, num_features = load_and_partition_data("/Volumes/MARI/ssdl_gps/correction_data_2019_mo1_upd.npz",500) # change data file name
train_set, test_set = make_datasets(sequences)
train_loader, test_loader = DataLoader(
    train_set, batch_size=BS, shuffle=True
), DataLoader(test_set, batch_size=BS, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Initialize model, optimizer and loss criterion
model = TransformerWithPE(
    num_features, num_features, FEATURE_DIM, NUM_HEADS, NUM_LAYERS
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# criterion = torch.nn.MSELoss()
criterion = torch.nn.L1Loss()

In [None]:
%%time
losses = []

# Train loop
for epoch in range(NUM_EPOCHS):
    epoch_loss = 0.0                                                            # initialize epoch loss
    for batch in train_loader:
        optimizer.zero_grad()

        src, tgt, tgt_y = split_sequence(batch[0])
        src, tgt, tgt_y = move_to_device(device, src, tgt, tgt_y)
        
        # [bs, tgt_seq_len, num_features]
        pred = model(src, tgt)
        loss = criterion(pred, tgt_y)
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()

        avg_epoch_loss = epoch_loss / len(train_loader)
        losses.append(avg_epoch_loss)

    print(
        f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: "
        f"{(avg_epoch_loss):.4f}"
    )

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(1, NUM_EPOCHS + 1), losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim(0,1)
plt.title('Training Loss per Epoch')
plt.show()

In [None]:
%%time
# Evaluate model
model.eval()
eval_loss = 0.0
infer_loss = 0.0

with torch.no_grad():
    for idx, batch in enumerate(test_loader):
        src, tgt, tgt_y = split_sequence(batch[0])
        src, tgt, tgt_y = move_to_device(device, src, tgt, tgt_y)

        # [bs, tgt_seq_len, num_features]
        pred = model(src, tgt)
        loss = criterion(pred, tgt_y)
        eval_loss += loss.item()

        # Run inference with model
        pred_infer = model.infer(src, tgt.shape[1])
        loss_infer = criterion(pred_infer, tgt_y)
        infer_loss += loss_infer.item()

        if idx < NUM_VIS_EXAMPLES:
            visualize(src, tgt, pred_infer)
            # 529 visualize(src, tgt, pred, pred_infer)

avg_eval_loss = eval_loss / len(test_loader)
avg_infer_loss = infer_loss / len(test_loader)

print(f"Eval Loss on test set: {avg_eval_loss:.4f}")

In [None]:
def split_sequence(
    sequence: np.ndarray, ratio: float = 0.7
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Splits a sequence into 2 (3) parts, as is required by our transformer
    model.

    Assume our sequence length is L, we then split this into src of length N
    and tgt_y of length M, with N + M = L.
    src, the first part of the input sequence, is the input to the encoder, and we
    expect the decoder to predict tgt_y, the second part of the input sequence.
    In addition we generate tgt, which is tgt_y but "shifted left" by one - i.e. it
    starts with the last token of src, and ends with the second-last token in tgt_y.
    This sequence will be the input to the decoder.

    Args:

        sequence: batched input sequences to split [bs, seq_len, num_features]
        ratio: split ratio, N = ratio * L

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: src, tgt, tgt_y
    """
    src_end = int(sequence.shape[1] * ratio)
    # [bs, src_seq_len, num_features]
    src = sequence[:, :src_end]
    # [bs, tgt_seq_len, num_features]
    tgt = sequence[:, src_end - 1 : -1]
    # [bs, tgt_seq_len, num_features]
    tgt_y = sequence[:, src_end:]

    return src, tgt, tgt_y

In [None]:
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html,
# only modified to account for "batch first"

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
      """
      Args:
        d_model (int): Dimension of the model (embedding dimension)
        dropout (float, optional): Dropout probability. Default is 0.1.
        max_len (int, optional): Maximum length of input sequences. Default is 5000.

      Attributes:
        pe (torch.Tensor): Positional encoding tensor. Shape: (1, max_len, d_model)

      Returns:
        torch.Tensor: input tensor with added positional encoding.

      """
      super().__init__()
      self.dropout = torch.nn.Dropout(p=dropout)

      position = torch.arange(max_len).unsqueeze(1)                           # 1-D tensor from 0 to max_len -1. Unsqueeze "adds" a superficial 1 dim.
      div_term = torch.exp(
          torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
      )
      pe = torch.zeros(1, max_len, d_model)
      pe[0, :, 0::2] = torch.sin(position * div_term)
      pe[0, :, 1::2] = torch.cos(position * div_term)
      self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      """Adds positional encoding to the given tensor.

      Args:
          x: tensor to add PE to [bs, seq_len, embed_dim]

      Returns:
          torch.Tensor: tensor with PE [bs, seq_len, embed_dim]
      """
      x = x + self.pe[:, : x.size(1)]
      return self.dropout(x)

In [None]:
class TransformerWithPE(torch.nn.Module):
    def __init__(
        self, in_dim: int, out_dim: int, embed_dim: int, num_heads: int, num_layers: int
    ) -> None:
        """
        Initializes a transformer model with positional encoding.

        Args:
            in_dim: number of input features
            out_dim: number of features to predict
            embed_dim: embed features to this dimension
            num_heads: number of transformer heads
            num_layers: number of encoder and decoder layers
        """
        super().__init__()

        self.positional_encoding = PositionalEncoding(embed_dim)

        # transform input features into embedded features
        self.encoder_embedding = torch.nn.Linear(
            in_features=in_dim, out_features=embed_dim
        )
        self.decoder_embedding = torch.nn.Linear(
            in_features=out_dim, out_features=embed_dim
        )

        # map output into output dimension
        self.output_layer = torch.nn.Linear(in_features=embed_dim, out_features=out_dim)


        self.transformer = torch.nn.Transformer(
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            d_model=embed_dim,
            batch_first=True,
        )

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """Forward function of the model.

        Args:
            src: input sequence to the encoder [bs, src_seq_len, num_features]
            tgt: input sequence to the decoder [bs, tgt_seq_len, num_features]

        Returns:
            torch.Tensor: predicted sequence [bs, tgt_seq_len, feat_dim]
        """
        # if self.train:
        # Add noise to decoder inputs during training
        # tgt = tgt + torch.normal(0, 0.1, size=tgt.shape).to(tgt.device)

        # Embed encoder input and add positional encoding.
        # [bs, src_seq_len, embed_dim]
        src = self.encoder_embedding(src)
        src = self.positional_encoding(src)

        # Generate mask to avoid attention to future outputs.
        # [tgt_seq_len, tgt_seq_len]
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(tgt.shape[1])
        # Embed decoder input and add positional encoding.
        # [bs, tgt_seq_len, embed_dim]
        tgt = self.decoder_embedding(tgt)
        tgt = self.positional_encoding(tgt)

        # Get prediction from transformer and map to output dimension.
        # [bs, tgt_seq_len, embed_dim]
        pred = self.transformer(src, tgt, tgt_mask=tgt_mask)
        pred = self.output_layer(pred)

        return pred                                                             # return predicted sequence


    def infer(self, src: torch.Tensor, tgt_len: int) -> torch.Tensor:
        """Runs inference with the model, meaning: predicts future values
        for an unknown sequence.
        For this, iteratively generate the next output token while
        feeding the already generated ones as input sequence to the decoder.

        Args:
            src: input to the encoder [bs, src_seq_len, num_features]
            tgt_len: desired length of the output

        Returns:
            torch.Tensor: inferred sequence
        """
        output = torch.zeros((src.shape[0], tgt_len + 1, src.shape[2])).to(src.device)
        output[:, 0] = src[:, -1]
        for i in range(tgt_len):
            output[:, i + 1] = self.forward(src, output)[:, i]

        return output[:, 1:]

In [None]:
def load_and_partition_data(
    data_path: Path, seq_length: int = 100
) -> tuple[np.ndarray, int]:
    """Loads the given data and paritions it into sequences of equal length.

    Args:
        data_path: path to the dataset
        sequence_length: length of the generated sequences

    Returns:
        tuple[np.ndarray, int]: tuple of generated sequences and number of
            features in dataset
    """
    data = np.load(data_path)
    num_features = len(data.keys())

    # Check that each feature provides the same number of data points
    data_lens = [len(data[key]) for key in data.keys()]
    assert len(set(data_lens)) == 1

    num_sequences = data_lens[0] // seq_length
    sequences = np.empty((num_sequences, seq_length, num_features))

    for i in range(0, num_sequences):
        # [sequence_length, num_features]
        sample = np.asarray(
            [data[key][i * seq_length : (i + 1) * seq_length] for key in data.keys()]
        ).swapaxes(0, 1)
        sequences[i] = sample

    return sequences, num_features


def make_datasets(sequences: np.ndarray) -> tuple[TensorDataset, TensorDataset]:
    """Create train and test dataset.

    Args:
        sequences: sequences to use [num_sequences, sequence_length, num_features]

    Returns:
        tuple[TensorDataset, TensorDataset]: train and test dataset
    """
    # Split sequences into train and test split
    train, test = train_test_split(sequences, test_size=0.2)
    return TensorDataset(torch.Tensor(train)), TensorDataset(torch.Tensor(test))


def visualize(
    src: torch.Tensor,
    tgt: torch.Tensor,
    # 529 pred: torch.Tensor,
    pred_infer: torch.Tensor,
    idx=0,
) -> None:
    """Visualizes a given sample including predictions.

    Args:
        src: source sequence [bs, src_seq_len, num_features]
        tgt: target sequence [bs, tgt_seq_len, num_features]
        pred: prediction of the model [bs, tgt_seq_len, num_features]
        pred_infer: prediction obtained by running inference
            [bs, tgt_seq_len, num_features]
        idx: batch index to visualize
    """
    
    x = np.arange(src.shape[1] + tgt.shape[1])
    src_len = src.shape[1]

    # 529 commented out
    # plt.plot(x[:src_len], src[idx].cpu().detach(), "bo-", label="src")
    # plt.plot(x[src_len:], tgt[idx].cpu().detach(), "go-", label="tgt")
    # # 529 plt.plot(x[src_len:], pred[idx].cpu().detach(), "ro-", label="pred")
    # plt.plot(x[src_len:], pred_infer[idx].cpu().detach(), "yo-", label="pred_infer")

    plt.plot(x[:src_len], scaler.inverse_transform(src[idx].cpu().detach()), "bo-", label="src")
    plt.plot(x[src_len:], scaler.inverse_transform(tgt[idx].cpu().detach()), "go-", label="tgt")
    # 529 plt.plot(x[src_len:], pred[idx].cpu().detach(), "ro-", label="pred")
    plt.plot(x[src_len:], scaler.inverse_transform(pred_infer[idx].cpu().detach()), "yo-", label="pred_infer")

    plt.legend()
    plt.show()
    plt.clf()


def split_sequence(
    sequence: np.ndarray, ratio: float = 0.7
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Splits a sequence into 2 (3) parts, as is required by our transformer
    model.

    Assume our sequence length is L, we then split this into src of length N
    and tgt_y of length M, with N + M = L.
    src, the first part of the input sequence, is the input to the encoder, and we
    expect the decoder to predict tgt_y, the second part of the input sequence.
    In addition we generate tgt, which is tgt_y but "shifted left" by one - i.e. it
    starts with the last token of src, and ends with the second-last token in tgt_y.
    This sequence will be the input to the decoder.


    Args:
        sequence: batched input sequences to split [bs, seq_len, num_features]
        ratio: split ratio, N = ratio * L

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: src, tgt, tgt_y
    """
    src_end = int(sequence.shape[1] * ratio)
    # [bs, src_seq_len, num_features]
    src = sequence[:, :src_end]
    # [bs, tgt_seq_len, num_features]
    tgt = sequence[:, src_end - 1 : -1]
    # [bs, tgt_seq_len, num_features]
    tgt_y = sequence[:, src_end:]

    return src, tgt, tgt_y


def move_to_device(device: torch.Tensor, *tensors: torch.Tensor) -> list[torch.Tensor]:
    """Move all given tensors to the given device.

    Args:
        device: device to move tensors to
        tensors: tensors to move

    Returns:
        list[torch.Tensor]: moved tensors
    """
    moved_tensors = []
    for tensor in tensors:
        if isinstance(tensor, torch.Tensor):
            moved_tensors.append(tensor.to(device))
        else:
            moved_tensors.append(tensor)
    return moved_tensors

In [None]:
BS = 10                                                                        # batch size
FEATURE_DIM = 10                                                              # dimensionality of input features
NUM_HEADS = 10                                                                  # number of attention heads in the multi-head attention mechanism
NUM_EPOCHS = 50                                                                 # number of times entire dataset is passed for training
NUM_VIS_EXAMPLES = 1
NUM_LAYERS = 2                                                                # number of encoder and decoder layers in the mdel
LR = 0.001                                                                     # learning rate

In [None]:
# # Load data and generate train and test datasets / dataloaders
# sequences, num_features = load_and_partition_data("short_file.npz",500) # change data file name
# train_set, test_set = make_datasets(sequences)
# train_loader, test_loader = DataLoader(
#     train_set, batch_size=BS, shuffle=True
# ), DataLoader(test_set, batch_size=BS, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Initialize model, optimizer and loss criterion
num_features = 1
model = TransformerWithPE(
    num_features, num_features, FEATURE_DIM, NUM_HEADS, NUM_LAYERS
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# criterion = torch.nn.MSELoss()
criterion = torch.nn.L1Loss()

In [None]:
%%time
losses = []

# Train loop
for epoch in range(NUM_EPOCHS):
    epoch_loss = 0.0   # initialize epoch loss
    model.train() # added this based on gpt rec
    for batch in train_loader:
        optimizer.zero_grad()

        src, tgt = batch
        src, tgt = src.to(device), tgt.to(device)
        
        # [bs, tgt_seq_len, num_features]
        pred = model(src)  # Using src for both src and tgt here

        # assert pred.shape == tgt.shape, f"Shape mismatch: pred {pred.shape} vs tgt {tgt.shape}"

        # loss = criterion(pred, tgt)
        
    #     epoch_loss += loss.item()

    #     loss.backward()
    #     optimizer.step()

    #     avg_epoch_loss = epoch_loss / len(train_loader)
    #     losses.append(avg_epoch_loss)

    # print(
    #     f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: "
    #     f"{(avg_epoch_loss):.4f}"
    # )

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(1, NUM_EPOCHS + 1), losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim(0,1)
plt.title('Training Loss per Epoch')
plt.show()

In [None]:
%%time
# Evaluate model
model.eval()
eval_loss = 0.0
infer_loss = 0.0

with torch.no_grad():
    for idx, batch in enumerate(test_loader):
        src, tgt, tgt_y = split_sequence(batch[0])
        src, tgt, tgt_y = move_to_device(device, src, tgt, tgt_y)

        # [bs, tgt_seq_len, num_features]
        pred = model(src, tgt)
        loss = criterion(pred, tgt_y)
        eval_loss += loss.item()

        # Run inference with model
        pred_infer = model.infer(src, tgt.shape[1])
        loss_infer = criterion(pred_infer, tgt_y)
        infer_loss += loss_infer.item()

        if idx < NUM_VIS_EXAMPLES:
            visualize(src, tgt, pred_infer)
            # 529 visualize(src, tgt, pred, pred_infer)

avg_eval_loss = eval_loss / len(test_loader)
avg_infer_loss = infer_loss / len(test_loader)

print(f"Eval Loss on test set: {avg_eval_loss:.4f}")