## Import and set constants

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import glob
import csv
from dataclasses import dataclass
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from scipy.interpolate import InterpolatedUnivariateSpline
from sklearn.model_selection import train_test_split

import math

import torch

from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

INPUT_PATH = '/content/drive/MyDrive/ML/sdc2023'

WGS84_SEMI_MAJOR_AXIS = 6378137.0
WGS84_SEMI_MINOR_AXIS = 6356752.314245
WGS84_SQUARED_FIRST_ECCENTRICITY  = 6.69437999013e-3
WGS84_SQUARED_SECOND_ECCENTRICITY = 6.73949674226e-3

HAVERSINE_RADIUS = 6_371_000
SAVE_PATH = "/content/train_models_results.csv"

## Load Device

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

## Class to manage converting between lat/lng BLH and GNSS

In [5]:
@dataclass
class ECEF:
    x: np.array
    y: np.array
    z: np.array

    def to_numpy(self):
        return np.stack([self.x, self.y, self.z], axis=0)

    @staticmethod
    def from_numpy(pos):
        x, y, z = [np.squeeze(w) for w in np.split(pos, 3, axis=-1)]
        return ECEF(x=x, y=y, z=z)

@dataclass
class BLH:
    lat : np.array
    lng : np.array
    hgt : np.array = 0


def ECEF_to_BLH(ecef):
    a = WGS84_SEMI_MAJOR_AXIS
    b = WGS84_SEMI_MINOR_AXIS
    e2  = WGS84_SQUARED_FIRST_ECCENTRICITY
    e2_ = WGS84_SQUARED_SECOND_ECCENTRICITY
    x = ecef.x
    y = ecef.y
    z = ecef.z
    r = np.sqrt(x**2 + y**2)
    t = np.arctan2(z * (a/b), r)
    B = np.arctan2(z + (e2_*b)*np.sin(t)**3, r - (e2*a)*np.cos(t)**3)
    L = np.arctan2(y, x)
    n = a / np.sqrt(1 - e2*np.sin(B)**2)
    H = (r / np.cos(B)) - n
    return BLH(lat=B, lng=L, hgt=H)

def haversine_distance(blh_1, blh_2):
    dlat = blh_2.lat - blh_1.lat
    dlng = blh_2.lng - blh_1.lng
    a = np.sin(dlat/2)**2 + np.cos(blh_1.lat) * np.cos(blh_2.lat) * np.sin(dlng/2)**2
    dist = 2 * HAVERSINE_RADIUS * np.arcsin(np.sqrt(a))
    return dist

In [6]:
def ecef_to_lat_lng(tripID, gnss_df, UnixTimeMillis):
    ecef_columns = ['WlsPositionXEcefMeters', 'WlsPositionYEcefMeters', 'WlsPositionZEcefMeters']
    columns = ['utcTimeMillis'] + ecef_columns
    ecef_df = (gnss_df.drop_duplicates(subset='utcTimeMillis')[columns]
               .dropna().reset_index(drop=True))
    ecef = ECEF.from_numpy(ecef_df[ecef_columns].to_numpy())
    blh  = ECEF_to_BLH(ecef)

    TIME = ecef_df['utcTimeMillis'].to_numpy()
    lat = InterpolatedUnivariateSpline(TIME, blh.lat, ext=3)(UnixTimeMillis)
    lng = InterpolatedUnivariateSpline(TIME, blh.lng, ext=3)(UnixTimeMillis)
    return pd.DataFrame({
#         'tripId' : tripID,
        'utcTimeMillis'   : UnixTimeMillis,
        'LatitudeDegrees'  : np.degrees(lat),
        'LongitudeDegrees' : np.degrees(lng),
    })

def calc_score(pred_df, gt_df):
    d = haversine_distance(pred_df, gt_df)
    mean_d= d.mean()
    score = np.mean([np.quantile(d, 0.50), np.quantile(d, 0.95)])
    return mean_d, score

In [7]:
def print_comparison(lat, lng, gt_lat, gt_lng):
    for lat_val, lng_val, gt_lat_val, gt_lng_val in zip(lat, lng, gt_lat, gt_lng):
        print(f'Pred: ({lat_val:<12.7f}, {lng_val:<12.7f}) Ground Truth: ({gt_lat_val:<12.7f}, {gt_lng_val:<12.7f})')

def print_batch(amnt, lat_arr, lng_arr, gt_lat_arr, gt_lng_arr):
    for batch in range(amnt):
        print(f'Val data {batch}')
        print_comparison(lat_arr[batch], lng_arr[batch], gt_lat_arr[batch], gt_lng_arr[batch])

## Loading Data

In [8]:
%%capture --no-stdout

pred_dfs  = []
gt_dfs = []



for dirname in sorted(glob.glob(f'/content/drive/MyDrive/ML/sdc2023/train/*/*')):
    drive, phone = dirname.split('/')[-2:]
    tripID  = f'{drive}/{phone}'
    gnss_df = pd.read_csv(f'{dirname}/device_gnss.csv')
    gt_df   = pd.read_csv(f'{dirname}/ground_truth.csv')

    info_cols = ['IonosphericDelayMeters', 'TroposphericDelayMeters']
    columns = ['utcTimeMillis'] + info_cols
    info_df = (gnss_df.drop_duplicates(subset='utcTimeMillis')[columns].fillna(0).reset_index(drop=True))

    for col in info_cols:
        info_df[col] = info_df[col].fillna((info_df[col].bfill() + info_df[col].ffill()) / 2)

    pred_df = ecef_to_lat_lng(tripID, gnss_df, gt_df['UnixTimeMillis'])
    pred_df = pd.merge(pred_df, info_df, on='utcTimeMillis', how='left')
    gt_df   = gt_df[['LatitudeDegrees', 'LongitudeDegrees']]
    print(tripID)
#     print(pred_df.shape)
#     print(gt_df.shape)
    pred_dfs.append(pred_df)
    gt_dfs.append(gt_df)

2020-06-25-00-34-us-ca-mtv-sb-101/pixel4
2020-06-25-00-34-us-ca-mtv-sb-101/pixel4xl
2020-07-08-22-28-us-ca/pixel4
2020-07-08-22-28-us-ca/pixel4xl
2020-07-17-22-27-us-ca-mtv-sf-280/pixel4
2020-07-17-23-13-us-ca-sf-mtv-280/pixel4
2020-07-17-23-13-us-ca-sf-mtv-280/pixel4xl
2020-08-04-00-19-us-ca-sb-mtv-101/pixel4
2020-08-04-00-20-us-ca-sb-mtv-101/pixel4xl
2020-08-04-00-20-us-ca-sb-mtv-101/pixel5
2020-08-13-21-41-us-ca-mtv-sf-280/pixel4
2020-08-13-21-41-us-ca-mtv-sf-280/pixel4xl
2020-08-13-21-42-us-ca-mtv-sf-280/pixel5
2020-12-10-22-17-us-ca-sjc-c/mi8
2020-12-10-22-52-us-ca-sjc-c/mi8
2020-12-10-22-52-us-ca-sjc-c/pixel4
2020-12-10-22-52-us-ca-sjc-c/pixel4xl
2020-12-10-22-52-us-ca-sjc-c/pixel5
2021-01-04-21-50-us-ca-e1highway280driveroutea/mi8
2021-01-04-21-50-us-ca-e1highway280driveroutea/pixel4
2021-01-04-21-50-us-ca-e1highway280driveroutea/pixel5
2021-01-04-22-40-us-ca-mtv-a/mi8
2021-01-04-22-40-us-ca-mtv-a/pixel4
2021-01-04-22-40-us-ca-mtv-a/pixel5
2021-01-05-21-12-us-ca-mtv-d/mi8
2021-0

In [9]:
class PositionalEncoding(torch.nn.Module):
    """
    This module injects some information about the relative or absolute position of the tokens in the sequence.
    The positional encodings have the same dimension as the embeddings, so that the two can be summed. Here, we use
    sine and cosine functions of different frequencies.

    Args:
        d_model (int): The dimension of the model (i.e., the size of the input embeddings).
        max_len (int): The maximum length of the input sequences for which to precompute positional encodings.

    Attributes:

        pe (torch.Tensor): Precomputed positional encodings.
    """

    def __init__(self, d_model: int, max_len: int = 5000):
        """
        Initializes the PositionalEncoding module.

        Args:
            d_model (int): The dimension of the model (i.e., the size of the input embeddings).
            max_len (int): The maximum length of the input sequences for which to precompute positional encodings.
        """
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Adds positional encodings to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (seq_len, batch_size, embedding_dim).

        Returns:
            torch.Tensor: Tensor with positional encodings added.
        """
        return x + self.pe[: x.size(0)]

## Define Baseline Model

In [10]:
class TransformerEncoder(torch.nn.Module):
    """
    Transformer Encoder, which consists of an input linear layer to upscale the input
    dimension to the model dimension, a positional encoding layer, a stack of Transformer encoder layers, and
    a final fully connected (fc) layer for output transformation.

    Args:
        config (object): A configuration object containing the hyperparameters.

    Attributes:
        config (object): The configuration object.
        upscale (torch.nn.Linear): Linear layer to upscale input dimension to model dimension.
        pos_encoder (PositionalEncoding): Positional encoding layer.
        transformer (torch.nn.TransformerEncoder): Stack of Transformer encoder layers.
        fc (torch.nn.Sequential): Fully connected layers for output transformation.
    """

    def __init__(self, config, num_trainable_params):

        super().__init__()
        self.name = "Transformer"
        self.config = config
        self.upscale = torch.nn.Linear(config.input_dim, config.d_model)
        self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_len)
        transformer_layer = torch.nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            activation=config.activation,
            batch_first=True,
        )
        layer_trainable_params = sum(
            p.numel() for p in transformer_layer.parameters() if p.requires_grad
        )

        self.fc = torch.nn.Sequential()
        for i, num_neurons in enumerate(config.fc_layers[:-1]):
            self.fc.add_module(
                f"fc_{i}", torch.nn.Linear(num_neurons, config.fc_layers[i + 1])
            )
            if i < len(config.fc_layers) - 1:
                self.fc.add_module(f"relu_{i}", torch.nn.ReLU())
        fc_trainable_params = sum(
            p.numel() for p in self.fc.parameters() if p.requires_grad
        )

        num_layers = int(
            (num_trainable_params - fc_trainable_params) / layer_trainable_params
        )
        self.transformer = torch.nn.TransformerEncoder(
            transformer_layer, num_layers=num_layers
        )
    """
        assert (
            abs(
                num_trainable_params
                - sum(p.numel() for p in self.parameters() if p.requires_grad)
            )
            < num_trainable_params / 10
        ), f"Number of trainable parameters of transformer is not equal to {num_trainable_params}"
    """
    def forward(self, x):
        x = self.upscale(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.fc(x)
        return x


In [11]:
class LSTMEncoder(torch.nn.Module):
    """
    LSTM Encoder, which consists of an input linear layer to upscale the input
    dimension to the model dimension, a stack of LSTM layers, and a final fully connected (fc) layer for output transformation.

    Args:
        config (object): A configuration object containing the hyperparameters.

    Attributes:
        config (object): The configuration object.
        upscale (torch.nn.Linear): Linear layer to upscale input dimension to model dimension.
        lstm (torch.nn.LSTM): Stack of LSTM layers.
        fc (torch.nn.Sequential): Fully connected layers for output transformation.
    """

    def __init__(self, config, num_trainable_params):
        """
        Initializes the LSTMEncoder module.

        Args:
            config (object): A configuration object containing the hyperparameters for the LSTM Encoder.
        """
        super().__init__()
        self.name = "LSTM"
        self.config = config

        # Linear layer to upscale the input dimension to the model dimension
        self.upscale = torch.nn.Linear(config.input_dim, config.d_model)

        # Fully connected layers for output transformation
        self.fc = torch.nn.Sequential()
        for i, num_neurons in enumerate(config.fc_layers[:-1]):
            self.fc.add_module(
                f"fc_{i}", torch.nn.Linear(num_neurons, config.fc_layers[i + 1])
            )
            if i < len(config.fc_layers) - 1:
                self.fc.add_module(f"relu_{i}", torch.nn.ReLU())
        fc_trainable_params = sum(
            p.numel() for p in self.fc.parameters() if p.requires_grad
        )

        lstm_layer = torch.nn.LSTM(
            input_size=config.d_model,
            hidden_size=config.d_model,
            num_layers=1,
            batch_first=True,
        )
        layer_trainable_params = sum(
            p.numel() for p in lstm_layer.parameters() if p.requires_grad
        )
        num_layers = int(
            (num_trainable_params - fc_trainable_params) / layer_trainable_params
        )
        # Stack of LSTM layers
        self.lstm = torch.nn.LSTM(
            input_size=config.d_model,
            hidden_size=config.d_model,
            num_layers=num_layers,
            batch_first=True,
        )
        assert (
            abs(
                num_trainable_params
                - sum(p.numel() for p in self.parameters() if p.requires_grad)
            )
            < num_trainable_params / 40
        ), f"Number of trainable parameters of LSTM is not equal to {num_trainable_params}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the LSTMEncoder.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim).

        Returns:
            torch.Tensor: Output tensor after passing through the LSTM Encoder.
        """
        # Upscale the input dimension to the model dimension
        x = self.upscale(x)

        # Pass through the stack of LSTM layers
        x, _ = self.lstm(x)

        # Transform the output through fully connected layers
        x = self.fc(x)

        return x

In [12]:
class MLP(torch.nn.Module):
    def __init__(self, config, num_trainable_params, max_layers=50):
        super(MLP, self).__init__()
        self.name = "MLP"

        input_dim = config.input_dim
        d_model = config.d_model
        output_dim = config.output_dim

        seq_model = nn.Sequential()
        i = 0
        total_params = 0

        # Initial input layer
        seq_model.add_module(f"fc_{i}", nn.Linear(input_dim, d_model))
        seq_model.add_module(f"relu_{i}", nn.ReLU())
        i += 1

        while i < max_layers:
            # Add hidden layers
            seq_model.add_module(f"fc_{i}", nn.Linear(d_model, d_model))
            seq_model.add_module(f"relu_{i}", nn.ReLU())

            # Calculate total parameters only when necessary
            total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

            # Check if adding an output layer would reach num_trainable_params
            output_layer = nn.Linear(d_model, output_dim)
            output_params = output_layer.weight.numel() + output_layer.bias.numel()

            if abs(num_trainable_params - (total_params + output_params)) < num_trainable_params / 40:
                seq_model.add_module(f"fc_{i}", nn.Linear(d_model, output_dim))
                self.layers = seq_model
                break

            i += 1

        assert (
            abs(
                num_trainable_params
                - sum(p.numel() for p in self.parameters() if p.requires_grad)
            )
            < num_trainable_params / 40
        ), f"Number of trainable parameters of MLP is not equal to {num_trainable_params}"

    def forward(self, x):
        return self.layers(x)


## Declare Model

In [13]:
class Config:
    """
    Configuration class to hold the hyperparameters and other settings.

    Args:
        config_dict (dict): A dictionary containing the hyperparameters and other settings.

    Attributes:
        input_dim (int): The dimension of the input features.
        d_model (int): The dimension of the model (i.e., the size of the input embeddings).
        nhead (int): The number of heads in the multiheadattention models.
        dim_feedforward (int): The dimension of the feedforward network model.
        activation (str): The activation function of intermediate layer, relu or gelu.
        num_layers (int): The number of sub-encoder-layers in the encoder.
        fc_layers (list[int]): The number of neurons in the fully connected layers.
        output_dim (int): The dimension of the output features.
        max_seq_len (int): The maximum sequence length.
        val_split (float): The validation split ratio.
    """

    def __init__(self, config_dict):
        self.input_dim = config_dict.get("input_dim")
        self.d_model = config_dict.get("d_model")
        self.nhead = config_dict.get("nhead")
        self.dim_feedforward = config_dict.get("dim_feedforward")
        self.activation = config_dict.get("activation")
        self.num_layers = config_dict.get("num_layers")
        self.fc_layers = config_dict.get("fc_layers")
        self.output_dim = config_dict.get("output_dim")
        self.max_seq_len = config_dict.get("max_seq_len")
        self.val_split = config_dict.get("val_split")

config_dict = {
        "input_dim": 4,
        "d_model": 8,
        "nhead": 4,
        "dim_feedforward": 128,
        "activation": "relu",
        "num_layers": 6,
        "output_dim": 2,
        "max_seq_len": 3500,
        "val_split": 0.2,
    }
    # For fc layers divide the d_model by 2 until it reaches the output_dim
fc_layers = [config_dict["d_model"]]
while fc_layers[-1] > config_dict["output_dim"]:
  fc_layers.append(fc_layers[-1] // 2)

fc_layers[-1] = config_dict["output_dim"]
config_dict["fc_layers"] = fc_layers

config = Config(config_dict)

## Define Dataset Class

In [14]:
class GNSSDataset(torch.utils.data.Dataset):
    """
    This class represents a custom dataset for Global Navigation Satellite System (GNSS) data.
    It processes prediction and ground truth dataframes, pads sequences to a maximum length,
    and computes the residuals between predictions and ground truth positions.
    Args:
        pred_dfs (list of pandas.DataFrame): List of dataframes containing prediction data.
        gt_dfs (list of pandas.DataFrame): List of dataframes containing ground truth data.

    Attributes:
        pred_dfs (list of pandas.DataFrame): List of dataframes containing prediction data.
        gt_dfs (list of pandas.DataFrame): List of dataframes containing ground truth data.
        sequences (numpy.ndarray): Numpy array of padded prediction sequences.
        labels (numpy.ndarray): Numpy array of residuals (ground truth - prediction).
    """

    def __init__(self, pred_dfs, gt_dfs):
        """
        Initializes the GNSSDataset.
        Args:
            pred_dfs (list of pandas.DataFrame): List of dataframes containing prediction data.
            gt_dfs (list of pandas.DataFrame): List of dataframes containing ground truth data.
        """
        self.pred_dfs = pred_dfs
        self.gt_dfs = gt_dfs
        self.sequences = []
        self.labels = []

        for pred_df in self.pred_dfs:
            x_np = pred_df[
                [
                    "LatitudeDegrees",
                    "LongitudeDegrees",
                    "IonosphericDelayMeters",
                    "TroposphericDelayMeters",
                ]
            ].to_numpy()
            ## pad to max sequence length
            pad = np.zeros((config.max_seq_len - x_np.shape[0], x_np.shape[1]))
            x_np = np.vstack([x_np, pad])
            # x_np = x_np/180
            self.sequences.append(x_np)

        for gt_df in self.gt_dfs:
            y_np = gt_df[["LatitudeDegrees", "LongitudeDegrees"]].to_numpy()
            ## pad to max sequence length
            pad = np.zeros((config.max_seq_len - y_np.shape[0], y_np.shape[1]))
            y_np = np.vstack([y_np, pad])
            # y_np = y_np/180
            self.labels.append(y_np)

        self.sequences = np.array(self.sequences, dtype=np.float32)
        self.labels = np.array(self.labels, dtype=np.float32)

        self.labels = self.labels - self.sequences[:, :, :2]  # just the residuals

        print("seq and label shapes")
        print(self.sequences.shape)
        print(self.labels.shape)

    def __getitem__(self, i):
        """
        Retrieves the sequence and label at index i.
        Args:
            i (int): Index of the data to retrieve.
        Returns:
            tuple: (sequence, label) at index i.
        """
        return self.sequences[i], self.labels[i]

    def __len__(self):
        """
        Returns the number of sequences in the dataset.
        Returns:
            int: Number of sequences in the dataset.
        """
        return self.sequences.shape[0]

## Initialize Dataset

In [15]:
import sklearn
from sklearn.model_selection import train_test_split

In [16]:
print(len(pred_dfs)*config.val_split)

31.200000000000003


In [17]:
def is_converged(val_losses):
    """
    Check if the last 10 val losses have a standard deviation of less than.
    Args:
        val_losses (list): List of validation losses.
    Returns:
        bool: True if the validation loss has converged, False otherwise.
    """
    if len(val_losses) < 10:
        return False
    return np.std(val_losses[-10:]) < 1.0  # TODO: maybe needs readjusting

In [31]:
def save_results(
    save_path,
    model_type,
    num_params,
    training_loss,
    best_loss,
    val_loss,
    test_loss,
    inf_time,
    kaggle_score,
    kaggle_test_score,
    epochs,
):
    """
    Add the training results to a CSV file.
    Args:
        save_path (str): Path to save the CSV file.
        val_losses (list): List of validation losses.

    """
    row = [
        model_type,
        num_params,
        training_loss,
        best_loss,
        val_loss,
        test_loss,
        inf_time,
        kaggle_score,
        kaggle_test_score,
        epochs,
    ]
    with open(save_path, "a") as file:
        writer = csv.writer(file)
        writer.writerow(row)


In [19]:
def val_model(model, loader, loss_fn):
    mean_dist = 0
    mean_score = 0
    count = 0
    losses = []
    inf_time = 0.0
    starter = torch.cuda.Event(enable_timing=True)
    ender = torch.cuda.Event(enable_timing=True)

    for features, labels in loader:
        #print(features.shape)
        #print(features[0, :20])
        features = features.to(device)
        labels = labels.to(device)

        starter.record()
        pred = model(features)
        ender.record()
        torch.cuda.synchronize()
        inf_time += starter.elapsed_time(ender) / 1000  # seconds

        loss = loss_fn(pred, labels)
        losses.append(float(loss.cpu()))

        features = features.detach().cpu()  # * 180
        pred = pred.detach().cpu()  # * 180
        labels = labels.detach().cpu()  # * 180
        #print(pred.shape)
        pred_lats = pred[:, :, 0] + features[:, :, 0]
        pred_lngs = pred[:, :, 1] + features[:, :, 1]
        gt_lats = labels[:, :, 0] + features[:, :, 0]
        gt_lns = labels[:, :, 1] + features[:, :, 1]

        # Calculate score according to kaggle, height not necessary for distance
        blh1 = BLH(np.deg2rad(pred_lats), np.deg2rad(pred_lngs), hgt = 0)
        blh2 = BLH(np.deg2rad(gt_lats), np.deg2rad(gt_lns), hgt = 0)

        mean_dist, mean_score=calc_score(blh1, blh2)

        features = features.cpu()
        labels = labels.cpu()
        count += 1

    return (
        np.array(losses).mean(),
        mean_dist / count,
        mean_score / count,
        inf_time / count,
    )


In [32]:
def train_model(
    model_type,
    train_loader,
    val_loader,
    test_loader,
    config,
    device,
    epochs,
    lr,
    es=False,
):
    n_eval = 10  # evaluate every n_eval epochs

    if es:
      early_stopping_threshold = 1e-6
      early_stopping_patience = 20
      early_stopping_counter = 0

    PATH = "model.pt"

    best_loss = 99999999  # high number

    if model_type == "LSTM":
        model = LSTMEncoder(config, num_trainable_params)
    elif model_type == "MLP":
        model = MLP(config, num_trainable_params)
    elif model_type == "Transformer":
        model = TransformerEncoder(config, num_trainable_params)

    model.to(device)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    loss_fn = torch.nn.MSELoss()

    val_losses = []
    converged = False
    for epoch in range(epochs):
      print(f"Epoch {epoch + 1} of {epochs}")
      if es and early_stopping_counter >= early_stopping_patience:
        print("Early stopping triggered due to consecutive low val_loss")
        print(f"Stopped at epoch {epoch + 1}")
        break
      # Loop over each batch in the dataset
      for batch in tqdm(train_loader):
        optimizer.zero_grad()  # If not, the gradients would sum up over each iteration

        # Unpack the data and labels
        features, labels = batch
        features = features.to(device)
        labels = labels.to(device)

        # Forward propagate
        outputs = model(features)

        # Backpropagation and gradient descent
        loss = loss_fn(outputs, labels)

        loss.backward()
        optimizer.step()

        print(f"Loss/train {loss}")

        # Periodically evaluate our model + log to Tensorboard
        if epoch % n_eval == 0:
          model.eval()
          val_loss, mean_dist, mean_score, _ = val_model(
                        model, val_loader, loss_fn
                    )
          val_losses.append(val_loss)

          if es:
                if val_loss < early_stopping_threshold:
                    early_stopping_counter += 1
                else:
                    early_stopping_counter = 0

          if val_loss < best_loss:
            best_loss = val_loss

            torch.save(
                {
                  "epoch": epoch,
                  "model_state_dict": model.state_dict(),
                  "optimizer_state_dict": optimizer.state_dict(),
                  "loss": val_loss,
                },
                PATH,
                  )

          print(f"Val mean dist {mean_dist}")
          print(f"Val mean score {mean_score}")
          print(f"Loss/val {val_loss}")

          converged = is_converged(val_losses)
          print(converged)
          # turn on training, evaluate turns off training
          model.train()

        if converged:
          break

        # Get test loss and inference time
    model.eval()
    test_loss, mean_dist, mean_test_score, inf_time = val_model(
            model, test_loader, loss_fn
        )

    save_results(
            SAVE_PATH,
            model.name,
            num_trainable_params,
            float(loss.cpu()),
            best_loss,
            val_loss,
            test_loss,
            inf_time,
            mean_score,
            mean_test_score,
            epoch,
        )

In [21]:
# train test split into training and validation
X_train, X_val, y_train, y_val = train_test_split(
        pred_dfs, gt_dfs, test_size=config.val_split, random_state=2
    )
    # split into validation and test
X_val, X_test, y_val, y_test = train_test_split(
        X_val, y_val, test_size=0.5, random_state=2
    )

train_dataset = GNSSDataset(X_train, y_train)
val_dataset = GNSSDataset(X_val, y_val)
test_dataset = GNSSDataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=32, shuffle=True
    )
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

seq and label shapes
(124, 3500, 4)
(124, 3500, 2)
seq and label shapes
(16, 3500, 4)
(16, 3500, 2)
seq and label shapes
(16, 3500, 4)
(16, 3500, 2)


In [29]:
num_trainable_params=10000
epochs = 1000  # number of epochs
lr = 0.01  # learning rate

In [34]:
train_model(
        "Transformer",
        train_loader,
        val_loader,
        test_loader,
        config,
        device,
        epochs,
        lr,
        es=True,
    )

Epoch 1 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.06754066050052643
Val mean dist 29935.447265625
Val mean score 41499.578125
Loss/val 0.04555682837963104
False
Loss/train 0.04602718725800514
Val mean dist 24777.23828125
Val mean score 40397.71875
Loss/val 0.037983156740665436
False
Loss/train 0.03579578548669815
Val mean dist 22199.931640625
Val mean score 39185.318359375
Loss/val 0.03511188551783562
False
Loss/train 0.029974237084388733
Val mean dist 4131.44189453125
Val mean score 7694.352294921875
Loss/val 0.0013236397644504905
False
Epoch 2 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.002381169470027089
Loss/train 4.130066827201517e-06
Loss/train 3.161697748055303e-07
Loss/train 3.8235924648688524e-08
Epoch 3 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.39725908008404e-06
Loss/train 2.8216237524247845e-07
Loss/train 3.290657275556441e-07
Loss/train 3.440620233163827e-08
Epoch 4 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.612089257871048e-08
Loss/train 8.362213520740625e-06
Loss/train 7.655614808754763e-08
Loss/train 5.756583050242625e-07
Epoch 5 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.355149475391954e-06
Loss/train 7.345354902099643e-08
Loss/train 3.7417589737742674e-07
Loss/train 2.695005605346523e-07
Epoch 6 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.426966815022752e-06
Loss/train 5.369931272980466e-07
Loss/train 7.41841077456229e-08
Loss/train 5.125174729059268e-10
Epoch 7 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.756671560495306e-07
Loss/train 3.979131468767605e-10
Loss/train 2.928154003711825e-07
Loss/train 9.565384971210733e-06
Epoch 8 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.620675119251018e-10
Loss/train 6.625940329740843e-08
Loss/train 8.576539585192222e-06
Loss/train 4.5157875661061553e-07
Epoch 9 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.4352964917161444e-07
Loss/train 2.580101410298852e-10
Loss/train 8.585358045820612e-06
Loss/train 1.2508192526183848e-07
Epoch 10 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1480720135969023e-07
Loss/train 1.1970158197982528e-07
Loss/train 8.639301086077467e-06
Loss/train 7.403729540556014e-08
Epoch 11 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.362887456314638e-06
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
False
Loss/train 2.1952637041522394e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
False
Loss/train 4.556473811589967e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Loss/train 6.0580912508712e-10
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Epoch 12 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.324403097503819e-06
Loss/train 2.157196021812524e-08
Loss/train 1.115068002377484e-07
Loss/train 6.641265031248622e-07
Epoch 13 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.102562828871669e-08
Loss/train 1.1971471280958212e-07
Loss/train 8.362852895515971e-06
Loss/train 5.999989411975548e-07
Epoch 14 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.321056157001294e-06
Loss/train 6.939580288189973e-08
Loss/train 6.429208951885812e-07
Loss/train 5.9645755001724865e-09
Epoch 15 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.36992407312664e-08
Loss/train 8.865768904797733e-06
Loss/train 9.903726549964631e-08
Loss/train 4.5812438287384794e-08
Epoch 16 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.976865509651816e-08
Loss/train 7.331934170906607e-08
Loss/train 3.2416571116300474e-07
Loss/train 9.841529390541837e-06
Epoch 17 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0097869562741835e-07
Loss/train 3.1284653090324355e-08
Loss/train 2.979954558668396e-07
Loss/train 9.838095138547942e-06
Epoch 18 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2372680064108863e-07
Loss/train 8.364223504031543e-06
Loss/train 3.5194304359720263e-07
Loss/train 1.1279849587708668e-07
Epoch 19 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.048455722241215e-08
Loss/train 8.354118108400144e-06
Loss/train 3.328170521399443e-07
Loss/train 3.8991106521280017e-07
Epoch 20 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.4382946978439577e-07
Loss/train 4.149752896864811e-08
Loss/train 8.434971277893055e-06
Loss/train 2.4947956944743055e-07
Epoch 21 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.35174068924971e-06
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Loss/train 2.146756514775916e-07
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670461516478099e-07
True
Epoch 22 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.666895741436747e-07
Epoch 23 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.666938728334571e-08
Epoch 24 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.53678648127243e-06
Epoch 25 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.341996363014914e-06
Epoch 26 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.565294592699502e-06
Epoch 27 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.324548616656102e-06
Epoch 28 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.694544048471471e-08
Epoch 29 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.375149263883941e-06
Epoch 30 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.9746639313165133e-10
Epoch 31 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0118674964587626e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 32 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.676139005368782e-09
Epoch 33 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0423671703563286e-08
Epoch 34 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.350402822543401e-06
Epoch 35 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1611812428545818e-07
Epoch 36 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.580114539796341e-07
Epoch 37 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1335335276262413e-08
Epoch 38 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0274278139377202e-07
Epoch 39 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.9359852144780234e-08
Epoch 40 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.4007779150852e-06
Epoch 41 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.3878369031772309e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 42 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1830122420851694e-07
Epoch 43 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.027645297976505e-09
Epoch 44 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.400663318752777e-06
Epoch 45 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.354315468750428e-06
Epoch 46 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9217363817224395e-07
Epoch 47 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.7866790269494004e-09
Epoch 48 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.894313695378514e-07
Epoch 49 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.915590411907033e-07
Epoch 50 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.46740909107757e-08
Epoch 51 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.301094646055105e-10
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 52 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.517678237585642e-07
Epoch 53 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.426418389717583e-06
Epoch 54 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.730253178626299e-07
Epoch 55 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.6133024633168134e-08
Epoch 56 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2692117579149453e-08
Epoch 57 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.643578214810987e-08
Epoch 58 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.353674646992658e-08
Epoch 59 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2072574640551466e-07
Epoch 60 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.499056049349747e-07
Epoch 61 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.325860108016059e-06
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 62 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.942627095440912e-08
Epoch 63 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.025655897521574e-08
Epoch 64 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.909895613356639e-07
Epoch 65 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.419756341027096e-06
Epoch 66 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2792987053321667e-08
Epoch 67 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1118196141524095e-07
Epoch 68 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.161053345162145e-07
Epoch 69 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0647811166109022e-07
Epoch 70 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.845381671562791e-06
Epoch 71 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.619903383078054e-06
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 72 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1634198432707308e-08
Epoch 73 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.993397743011883e-07
Epoch 74 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9430543690978084e-07
Epoch 75 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2544039640924893e-07
Epoch 76 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.375076124077168e-07
Epoch 77 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.950623176431691e-07
Epoch 78 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.928369185501879e-09
Epoch 79 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1880810524853587e-07
Epoch 80 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.371370313398074e-06
Epoch 81 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.204189079677235e-09
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 82 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.2309375203331e-07
Epoch 83 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.544045158487279e-06
Epoch 84 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.355032150575425e-06
Epoch 85 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.52897698058041e-07
Epoch 86 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.086707861006289e-07
Epoch 87 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.726243621964386e-07
Epoch 88 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.404390428040642e-06
Epoch 89 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.324626833200455e-06
Epoch 90 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.324366717715748e-06
Epoch 91 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.955332392839409e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670464927083231e-07
True
Epoch 92 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.639316547487397e-06
Epoch 93 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.198029118247405e-09
Epoch 94 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.371369317051176e-08
Epoch 95 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.1580486254133575e-07
Epoch 96 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1995332960832457e-07
Epoch 97 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.613642421551049e-06
Epoch 98 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.160272849025205e-07
Epoch 99 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.3833990841760624e-08
Epoch 100 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.643065484648105e-06
Epoch 101 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.8901835469905564e-08
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 102 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.320872439071536e-06
Epoch 103 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.581036255833169e-07
Epoch 104 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.355672434845474e-06
Epoch 105 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.574766070523765e-06
Epoch 106 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.22155643775568e-08
Epoch 107 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.690974937053397e-06
Epoch 108 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.618638275947887e-06
Epoch 109 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.358999366464559e-06
Epoch 110 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.3901436091146024e-07
Epoch 111 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.152846244664943e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
True
Epoch 112 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.719683761005399e-08
Epoch 113 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.409224405004352e-07
Epoch 114 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.140316901488859e-08
Epoch 115 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.671243860026152e-07
Epoch 116 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2547145534245203e-10
Epoch 117 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.9007063296357956e-08
Epoch 118 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.018605058777382e-10
Epoch 119 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.322215762746055e-06
Epoch 120 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.3590885689372953e-07
Epoch 121 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.065101760237667e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 122 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2273715078190435e-07
Epoch 123 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.649686750710316e-08
Epoch 124 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2596171217182928e-08
Epoch 125 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.132982832696143e-08
Epoch 126 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4804575105008553e-07
Epoch 127 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2102346381179814e-07
Epoch 128 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.284644626930458e-08
Epoch 129 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2365445317973354e-07
Epoch 130 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.053204053088848e-08
Epoch 131 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.493654622208851e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 132 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3794002263221046e-08
Epoch 133 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.496536583545094e-07
Epoch 134 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.428625439116928e-10
Epoch 135 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.051273154142109e-07
Epoch 136 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.405795597354881e-06
Epoch 137 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.320958841068204e-06
Epoch 138 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.345479727722704e-06
Epoch 139 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.436975804215763e-06
Epoch 140 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1486764012479398e-07
Epoch 141 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.674439927995991e-08
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 142 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.611344128439669e-06
Epoch 143 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2502261742829432e-07
Epoch 144 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.792090185721463e-07
Epoch 145 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.466350664093625e-06
Epoch 146 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.440544346261959e-08
Epoch 147 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.977962959034187e-10
Epoch 148 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.84667952050222e-06
Epoch 149 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.410483133047819e-06
Epoch 150 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.350280950253364e-06
Epoch 151 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2496381884830043e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670461516478099e-07
True
Epoch 152 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9696007430857208e-08
Epoch 153 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.175162318797447e-08
Epoch 154 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.407046283929958e-07
Epoch 155 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2205559585254377e-07
Epoch 156 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.865876951891778e-09
Epoch 157 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.566919761193276e-07
Epoch 158 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4399818698839226e-07
Epoch 159 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.196568165250937e-07
Epoch 160 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2678406768884543e-08
Epoch 161 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0588367871378068e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 162 of 1000
Early stopping triggered due to consecutive low val_loss
Stopped at epoch 162


In [None]:
train_model(
        "LSTM",
        train_loader,
        val_loader,
        test_loader,
        config,
        device,
        epochs,
        lr,
    )

In [None]:
"""
train_model(
        "MLP",
        train_loader,
        val_loader,
        test_loader,
        config,
        device,
        epochs,
        lr,
    )
"""

AssertionError: Number of trainable parameters of MLP is not equal to 10000