## Import and set constants

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import glob
import csv
from dataclasses import dataclass
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from scipy.interpolate import InterpolatedUnivariateSpline
from sklearn.model_selection import train_test_split

import math

import torch

from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

INPUT_PATH = '/content/drive/MyDrive/ML/sdc2023'

WGS84_SEMI_MAJOR_AXIS = 6378137.0
WGS84_SEMI_MINOR_AXIS = 6356752.314245
WGS84_SQUARED_FIRST_ECCENTRICITY  = 6.69437999013e-3
WGS84_SQUARED_SECOND_ECCENTRICITY = 6.73949674226e-3

HAVERSINE_RADIUS = 6_371_000
SAVE_PATH = "/content/train_models_results.csv"

## Load Device

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

## Class to manage converting between lat/lng BLH and GNSS

In [5]:
@dataclass
class ECEF:
    x: np.array
    y: np.array
    z: np.array

    def to_numpy(self):
        return np.stack([self.x, self.y, self.z], axis=0)

    @staticmethod
    def from_numpy(pos):
        x, y, z = [np.squeeze(w) for w in np.split(pos, 3, axis=-1)]
        return ECEF(x=x, y=y, z=z)

@dataclass
class BLH:
    lat : np.array
    lng : np.array
    hgt : np.array = 0


def ECEF_to_BLH(ecef):
    a = WGS84_SEMI_MAJOR_AXIS
    b = WGS84_SEMI_MINOR_AXIS
    e2  = WGS84_SQUARED_FIRST_ECCENTRICITY
    e2_ = WGS84_SQUARED_SECOND_ECCENTRICITY
    x = ecef.x
    y = ecef.y
    z = ecef.z
    r = np.sqrt(x**2 + y**2)
    t = np.arctan2(z * (a/b), r)
    B = np.arctan2(z + (e2_*b)*np.sin(t)**3, r - (e2*a)*np.cos(t)**3)
    L = np.arctan2(y, x)
    n = a / np.sqrt(1 - e2*np.sin(B)**2)
    H = (r / np.cos(B)) - n
    return BLH(lat=B, lng=L, hgt=H)

def haversine_distance(blh_1, blh_2):
    dlat = blh_2.lat - blh_1.lat
    dlng = blh_2.lng - blh_1.lng
    a = np.sin(dlat/2)**2 + np.cos(blh_1.lat) * np.cos(blh_2.lat) * np.sin(dlng/2)**2
    dist = 2 * HAVERSINE_RADIUS * np.arcsin(np.sqrt(a))
    return dist

In [74]:
def ecef_to_lat_lng(tripID, gnss_df, UnixTimeMillis):
    ecef_columns = ['WlsPositionXEcefMeters', 'WlsPositionYEcefMeters', 'WlsPositionZEcefMeters']
    columns = ['utcTimeMillis'] + ecef_columns
    ecef_df = (gnss_df.drop_duplicates(subset='utcTimeMillis')[columns]
               .dropna().reset_index(drop=True))
    ecef = ECEF.from_numpy(ecef_df[ecef_columns].to_numpy())
    blh  = ECEF_to_BLH(ecef)

    TIME = ecef_df['utcTimeMillis'].to_numpy()
    lat = InterpolatedUnivariateSpline(TIME, blh.lat, ext=3)(UnixTimeMillis)
    lng = InterpolatedUnivariateSpline(TIME, blh.lng, ext=3)(UnixTimeMillis)
    return pd.DataFrame({
#         'tripId' : tripID,
        'utcTimeMillis'   : UnixTimeMillis,
        'LatitudeDegrees'  : np.degrees(lat),
        'LongitudeDegrees' : np.degrees(lng),
    })

def calc_score(pred_df, gt_df):
    d = haversine_distance(pred_df, gt_df)
    mean_d = d.mean()
    score = np.mean([np.quantile(d, 0.50), np.quantile(d, 0.95)])
    return mean_d, score

In [7]:
def print_comparison(lat, lng, gt_lat, gt_lng):
    for lat_val, lng_val, gt_lat_val, gt_lng_val in zip(lat, lng, gt_lat, gt_lng):
        print(f'Pred: ({lat_val:<12.7f}, {lng_val:<12.7f}) Ground Truth: ({gt_lat_val:<12.7f}, {gt_lng_val:<12.7f})')

def print_batch(amnt, lat_arr, lng_arr, gt_lat_arr, gt_lng_arr):
    for batch in range(amnt):
        print(f'Val data {batch}')
        print_comparison(lat_arr[batch], lng_arr[batch], gt_lat_arr[batch], gt_lng_arr[batch])

## Loading Data

In [8]:
%%capture --no-stdout

pred_dfs  = []
gt_dfs = []



for dirname in sorted(glob.glob(f'/content/drive/MyDrive/ML/sdc2023/train/*/*')):
    drive, phone = dirname.split('/')[-2:]
    tripID  = f'{drive}/{phone}'
    gnss_df = pd.read_csv(f'{dirname}/device_gnss.csv')
    gt_df   = pd.read_csv(f'{dirname}/ground_truth.csv')

    info_cols = ['IonosphericDelayMeters', 'TroposphericDelayMeters']
    columns = ['utcTimeMillis'] + info_cols
    info_df = (gnss_df.drop_duplicates(subset='utcTimeMillis')[columns].fillna(0).reset_index(drop=True))

    for col in info_cols:
        info_df[col] = info_df[col].fillna((info_df[col].bfill() + info_df[col].ffill()) / 2)

    pred_df = ecef_to_lat_lng(tripID, gnss_df, gt_df['UnixTimeMillis'])
    pred_df = pd.merge(pred_df, info_df, on='utcTimeMillis', how='left')
    gt_df   = gt_df[['LatitudeDegrees', 'LongitudeDegrees']]
    print(tripID)
#     print(pred_df.shape)
#     print(gt_df.shape)
    pred_dfs.append(pred_df)
    gt_dfs.append(gt_df)

2020-06-25-00-34-us-ca-mtv-sb-101/pixel4
2020-06-25-00-34-us-ca-mtv-sb-101/pixel4xl
2020-07-08-22-28-us-ca/pixel4
2020-07-08-22-28-us-ca/pixel4xl
2020-07-17-22-27-us-ca-mtv-sf-280/pixel4
2020-07-17-23-13-us-ca-sf-mtv-280/pixel4
2020-07-17-23-13-us-ca-sf-mtv-280/pixel4xl
2020-08-04-00-19-us-ca-sb-mtv-101/pixel4
2020-08-04-00-20-us-ca-sb-mtv-101/pixel4xl
2020-08-04-00-20-us-ca-sb-mtv-101/pixel5
2020-08-13-21-41-us-ca-mtv-sf-280/pixel4
2020-08-13-21-41-us-ca-mtv-sf-280/pixel4xl
2020-08-13-21-42-us-ca-mtv-sf-280/pixel5
2020-12-10-22-17-us-ca-sjc-c/mi8
2020-12-10-22-52-us-ca-sjc-c/mi8
2020-12-10-22-52-us-ca-sjc-c/pixel4
2020-12-10-22-52-us-ca-sjc-c/pixel4xl
2020-12-10-22-52-us-ca-sjc-c/pixel5
2021-01-04-21-50-us-ca-e1highway280driveroutea/mi8
2021-01-04-21-50-us-ca-e1highway280driveroutea/pixel4
2021-01-04-21-50-us-ca-e1highway280driveroutea/pixel5
2021-01-04-22-40-us-ca-mtv-a/mi8
2021-01-04-22-40-us-ca-mtv-a/pixel4
2021-01-04-22-40-us-ca-mtv-a/pixel5
2021-01-05-21-12-us-ca-mtv-d/mi8
2021-0

In [9]:
class PositionalEncoding(torch.nn.Module):
    """
    This module injects some information about the relative or absolute position of the tokens in the sequence.
    The positional encodings have the same dimension as the embeddings, so that the two can be summed. Here, we use
    sine and cosine functions of different frequencies.

    Args:
        d_model (int): The dimension of the model (i.e., the size of the input embeddings).
        max_len (int): The maximum length of the input sequences for which to precompute positional encodings.

    Attributes:

        pe (torch.Tensor): Precomputed positional encodings.
    """

    def __init__(self, d_model: int, max_len: int = 5000):
        """
        Initializes the PositionalEncoding module.

        Args:
            d_model (int): The dimension of the model (i.e., the size of the input embeddings).
            max_len (int): The maximum length of the input sequences for which to precompute positional encodings.
        """
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Adds positional encodings to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (seq_len, batch_size, embedding_dim).

        Returns:
            torch.Tensor: Tensor with positional encodings added.
        """
        return x + self.pe[: x.size(0)]

## Define Baseline Model

In [10]:
class TransformerEncoder(torch.nn.Module):
    """
    Transformer Encoder, which consists of an input linear layer to upscale the input
    dimension to the model dimension, a positional encoding layer, a stack of Transformer encoder layers, and
    a final fully connected (fc) layer for output transformation.

    Args:
        config (object): A configuration object containing the hyperparameters.

    Attributes:
        config (object): The configuration object.
        upscale (torch.nn.Linear): Linear layer to upscale input dimension to model dimension.
        pos_encoder (PositionalEncoding): Positional encoding layer.
        transformer (torch.nn.TransformerEncoder): Stack of Transformer encoder layers.
        fc (torch.nn.Sequential): Fully connected layers for output transformation.
    """

    def __init__(self, config, num_trainable_params):

        super().__init__()
        self.name = "Transformer"
        self.config = config
        self.upscale = torch.nn.Linear(config.input_dim, config.d_model)
        self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_len)
        transformer_layer = torch.nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            activation=config.activation,
            batch_first=True,
        )
        layer_trainable_params = sum(
            p.numel() for p in transformer_layer.parameters() if p.requires_grad
        )

        self.fc = torch.nn.Sequential()
        for i, num_neurons in enumerate(config.fc_layers[:-1]):
            self.fc.add_module(
                f"fc_{i}", torch.nn.Linear(num_neurons, config.fc_layers[i + 1])
            )
            if i < len(config.fc_layers) - 1:
                self.fc.add_module(f"relu_{i}", torch.nn.ReLU())
        fc_trainable_params = sum(
            p.numel() for p in self.fc.parameters() if p.requires_grad
        )

        num_layers = int(
            (num_trainable_params - fc_trainable_params) / layer_trainable_params
        )
        self.transformer = torch.nn.TransformerEncoder(
            transformer_layer, num_layers=num_layers
        )
    """
        assert (
            abs(
                num_trainable_params
                - sum(p.numel() for p in self.parameters() if p.requires_grad)
            )
            < num_trainable_params / 10
        ), f"Number of trainable parameters of transformer is not equal to {num_trainable_params}"
    """
    def forward(self, x):
        x = self.upscale(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.fc(x)
        return x


In [11]:
class LSTMEncoder(torch.nn.Module):
    """
    LSTM Encoder, which consists of an input linear layer to upscale the input
    dimension to the model dimension, a stack of LSTM layers, and a final fully connected (fc) layer for output transformation.

    Args:
        config (object): A configuration object containing the hyperparameters.

    Attributes:
        config (object): The configuration object.
        upscale (torch.nn.Linear): Linear layer to upscale input dimension to model dimension.
        lstm (torch.nn.LSTM): Stack of LSTM layers.
        fc (torch.nn.Sequential): Fully connected layers for output transformation.
    """

    def __init__(self, config, num_trainable_params):
        """
        Initializes the LSTMEncoder module.

        Args:
            config (object): A configuration object containing the hyperparameters for the LSTM Encoder.
        """
        super().__init__()
        self.name = "LSTM"
        self.config = config

        # Linear layer to upscale the input dimension to the model dimension
        self.upscale = torch.nn.Linear(config.input_dim, config.d_model)

        # Fully connected layers for output transformation
        self.fc = torch.nn.Sequential()
        for i, num_neurons in enumerate(config.fc_layers[:-1]):
            self.fc.add_module(
                f"fc_{i}", torch.nn.Linear(num_neurons, config.fc_layers[i + 1])
            )
            if i < len(config.fc_layers) - 1:
                self.fc.add_module(f"relu_{i}", torch.nn.ReLU())
        fc_trainable_params = sum(
            p.numel() for p in self.fc.parameters() if p.requires_grad
        )

        lstm_layer = torch.nn.LSTM(
            input_size=config.d_model,
            hidden_size=config.d_model,
            num_layers=1,
            batch_first=True,
        )
        layer_trainable_params = sum(
            p.numel() for p in lstm_layer.parameters() if p.requires_grad
        )
        num_layers = int(
            (num_trainable_params - fc_trainable_params) / layer_trainable_params
        )
        # Stack of LSTM layers
        self.lstm = torch.nn.LSTM(
            input_size=config.d_model,
            hidden_size=config.d_model,
            num_layers=num_layers,
            batch_first=True,
        )
        assert (
            abs(
                num_trainable_params
                - sum(p.numel() for p in self.parameters() if p.requires_grad)
            )
            < num_trainable_params / 40
        ), f"Number of trainable parameters of LSTM is not equal to {num_trainable_params}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the LSTMEncoder.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim).

        Returns:
            torch.Tensor: Output tensor after passing through the LSTM Encoder.
        """
        # Upscale the input dimension to the model dimension
        x = self.upscale(x)

        # Pass through the stack of LSTM layers
        x, _ = self.lstm(x)

        # Transform the output through fully connected layers
        x = self.fc(x)

        return x

In [94]:
class MLP(nn.Module):
    def __init__(self, config, num_trainable_params):
        super(MLP, self).__init__()
        self.name = "MLP"

        max_hidden_layers = 10
        input_dim = config.input_dim
        d_model = config.d_model
        output_dim = config.output_dim

        seq_model = nn.Sequential()

        # Initial input layer
        seq_model.add_module("fc_0", nn.Linear(input_dim, d_model))
        seq_model.add_module("relu_0", nn.ReLU())

        # Add hidden layers
        for i in range(1, max_hidden_layers + 1):
            seq_model.add_module(f"fc_{i}", nn.Linear(d_model, d_model))
            seq_model.add_module(f"relu_{i}", nn.ReLU())

        # Output layer
        seq_model.add_module(f"fc_output", nn.Linear(d_model, output_dim))
        self.layers = seq_model

    def forward(self, x):
        return self.layers(x)


## Declare Model

In [13]:
class Config:
    """
    Configuration class to hold the hyperparameters and other settings.

    Args:
        config_dict (dict): A dictionary containing the hyperparameters and other settings.

    Attributes:
        input_dim (int): The dimension of the input features.
        d_model (int): The dimension of the model (i.e., the size of the input embeddings).
        nhead (int): The number of heads in the multiheadattention models.
        dim_feedforward (int): The dimension of the feedforward network model.
        activation (str): The activation function of intermediate layer, relu or gelu.
        num_layers (int): The number of sub-encoder-layers in the encoder.
        fc_layers (list[int]): The number of neurons in the fully connected layers.
        output_dim (int): The dimension of the output features.
        max_seq_len (int): The maximum sequence length.
        val_split (float): The validation split ratio.
    """

    def __init__(self, config_dict):
        self.input_dim = config_dict.get("input_dim")
        self.d_model = config_dict.get("d_model")
        self.nhead = config_dict.get("nhead")
        self.dim_feedforward = config_dict.get("dim_feedforward")
        self.activation = config_dict.get("activation")
        self.num_layers = config_dict.get("num_layers")
        self.fc_layers = config_dict.get("fc_layers")
        self.output_dim = config_dict.get("output_dim")
        self.max_seq_len = config_dict.get("max_seq_len")
        self.val_split = config_dict.get("val_split")

config_dict = {
        "input_dim": 4,
        "d_model": 8,
        "nhead": 4,
        "dim_feedforward": 128,
        "activation": "relu",
        "num_layers": 6,
        "output_dim": 2,
        "max_seq_len": 3500,
        "val_split": 0.2,
    }
    # For fc layers divide the d_model by 2 until it reaches the output_dim
fc_layers = [config_dict["d_model"]]
while fc_layers[-1] > config_dict["output_dim"]:
  fc_layers.append(fc_layers[-1] // 2)

fc_layers[-1] = config_dict["output_dim"]
config_dict["fc_layers"] = fc_layers

config = Config(config_dict)

## Define Dataset Class

In [14]:
class GNSSDataset(torch.utils.data.Dataset):
    """
    This class represents a custom dataset for Global Navigation Satellite System (GNSS) data.
    It processes prediction and ground truth dataframes, pads sequences to a maximum length,
    and computes the residuals between predictions and ground truth positions.
    Args:
        pred_dfs (list of pandas.DataFrame): List of dataframes containing prediction data.
        gt_dfs (list of pandas.DataFrame): List of dataframes containing ground truth data.

    Attributes:
        pred_dfs (list of pandas.DataFrame): List of dataframes containing prediction data.
        gt_dfs (list of pandas.DataFrame): List of dataframes containing ground truth data.
        sequences (numpy.ndarray): Numpy array of padded prediction sequences.
        labels (numpy.ndarray): Numpy array of residuals (ground truth - prediction).
    """

    def __init__(self, pred_dfs, gt_dfs):
        """
        Initializes the GNSSDataset.
        Args:
            pred_dfs (list of pandas.DataFrame): List of dataframes containing prediction data.
            gt_dfs (list of pandas.DataFrame): List of dataframes containing ground truth data.
        """
        self.pred_dfs = pred_dfs
        self.gt_dfs = gt_dfs
        self.sequences = []
        self.labels = []

        for pred_df in self.pred_dfs:
            x_np = pred_df[
                [
                    "LatitudeDegrees",
                    "LongitudeDegrees",
                    "IonosphericDelayMeters",
                    "TroposphericDelayMeters",
                ]
            ].to_numpy()
            ## pad to max sequence length
            pad = np.zeros((config.max_seq_len - x_np.shape[0], x_np.shape[1]))
            x_np = np.vstack([x_np, pad])
            # x_np = x_np/180
            self.sequences.append(x_np)

        for gt_df in self.gt_dfs:
            y_np = gt_df[["LatitudeDegrees", "LongitudeDegrees"]].to_numpy()
            ## pad to max sequence length
            pad = np.zeros((config.max_seq_len - y_np.shape[0], y_np.shape[1]))
            y_np = np.vstack([y_np, pad])
            # y_np = y_np/180
            self.labels.append(y_np)

        self.sequences = np.array(self.sequences, dtype=np.float32)
        self.labels = np.array(self.labels, dtype=np.float32)

        self.labels = self.labels - self.sequences[:, :, :2]  # just the residuals

        print("seq and label shapes")
        print(self.sequences.shape)
        print(self.labels.shape)

    def __getitem__(self, i):
        """
        Retrieves the sequence and label at index i.
        Args:
            i (int): Index of the data to retrieve.
        Returns:
            tuple: (sequence, label) at index i.
        """
        return self.sequences[i], self.labels[i]

    def __len__(self):
        """
        Returns the number of sequences in the dataset.
        Returns:
            int: Number of sequences in the dataset.
        """
        return self.sequences.shape[0]

## Initialize Dataset

In [15]:
import sklearn
from sklearn.model_selection import train_test_split

In [16]:
print(len(pred_dfs)*config.val_split)

31.200000000000003


In [17]:
def is_converged(val_losses):
    """
    Check if the last 10 val losses have a standard deviation of less than.
    Args:
        val_losses (list): List of validation losses.
    Returns:
        bool: True if the validation loss has converged, False otherwise.
    """
    if len(val_losses) < 10:
        return False
    return np.std(val_losses[-10:]) < 1.0  # TODO: maybe needs readjusting

In [18]:
def save_results(
    save_path,
    model_type,
    num_params,
    training_loss,
    best_loss,
    val_loss,
    test_loss,
    inf_time,
    kaggle_score,
    kaggle_test_score,
    epochs,
):
    """
    Add the training results to a CSV file.
    Args:
        save_path (str): Path to save the CSV file.
        val_losses (list): List of validation losses.

    """
    row = [
        model_type,
        num_params,
        training_loss,
        best_loss,
        val_loss,
        test_loss,
        inf_time,
        kaggle_score,
        kaggle_test_score,
        epochs,
    ]
    with open(save_path, "a") as file:
        writer = csv.writer(file)
        writer.writerow(row)


In [56]:
def val_model(model, loader, loss_fn):
    mean_dist = 0
    mean_score = 0
    count = 0
    losses = []
    inf_time = 0.0
    starter = torch.cuda.Event(enable_timing=True)
    ender = torch.cuda.Event(enable_timing=True)

    for features, labels in loader:
        #print(features.shape)
        #print(features[0, :20])
        features = features.to(device)
        labels = labels.to(device)

        starter.record()
        pred = model(features)
        ender.record()
        torch.cuda.synchronize()
        inf_time += starter.elapsed_time(ender) / 1000  # seconds

        loss = loss_fn(pred, labels)
        losses.append(float(loss.cpu()))

        features = features.detach().cpu()  # * 180
        pred = pred.detach().cpu()  # * 180
        labels = labels.detach().cpu()  # * 180
        #print(pred.shape)
        pred_lats = pred[:, :, 0] + features[:, :, 0]
        pred_lngs = pred[:, :, 1] + features[:, :, 1]
        gt_lats = labels[:, :, 0] + features[:, :, 0]
        gt_lns = labels[:, :, 1] + features[:, :, 1]

        # Calculate score according to kaggle, height not necessary for distance
        blh1 = BLH(np.deg2rad(pred_lats), np.deg2rad(pred_lngs), hgt = 0)
        blh2 = BLH(np.deg2rad(gt_lats), np.deg2rad(gt_lns), hgt = 0)

        mean_dist, mean_score=calc_score(blh1, blh2)
        features = features.cpu()
        labels = labels.cpu()
        count += 1
    return (
        np.array(losses).mean(),
        mean_dist / count,
        mean_score / count,
        inf_time / count,
    )


In [80]:
def train_model(
    model_type,
    train_loader,
    val_loader,
    test_loader,
    config,
    device,
    epochs,
    lr,
    es=False,
):
    n_eval = 10  # evaluate every n_eval epochs

    if es:
      early_stopping_threshold = 1e-6
      early_stopping_patience = 20
      early_stopping_counter = 0

    PATH = "model.pt"

    best_loss = 99999999  # high number

    if model_type == "LSTM":
        model = LSTMEncoder(config, num_trainable_params)
    elif model_type == "MLP":
        model = MLP(config, num_trainable_params)
    elif model_type == "Transformer":
        model = TransformerEncoder(config, num_trainable_params)

    model.to(device)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    loss_fn = torch.nn.MSELoss()

    val_losses = []
    converged = False
    for epoch in range(epochs):
      print(f"Epoch {epoch + 1} of {epochs}")
      if es and early_stopping_counter >= early_stopping_patience:
        print("Early stopping triggered due to consecutive low val_loss")
        print(f"Stopped at epoch {epoch + 1}")
        break
      # Loop over each batch in the dataset
      for batch in tqdm(train_loader):
        optimizer.zero_grad()  # If not, the gradients would sum up over each iteration

        # Unpack the data and labels
        features, labels = batch
        features = features.to(device)
        labels = labels.to(device)

        # Forward propagate
        outputs = model(features)

        # Backpropagation and gradient descent
        loss = loss_fn(outputs, labels)

        loss.backward()
        optimizer.step()

        print(f"Loss/train {loss}")

        # Periodically evaluate our model + log to Tensorboard
        if epoch % n_eval == 0:
          model.eval()
          val_loss, mean_dist, mean_score, _ = val_model(
                        model, val_loader, loss_fn
                    )
          val_losses.append(val_loss)

          if es:
                if val_loss < early_stopping_threshold:
                    early_stopping_counter += 1
                else:
                    early_stopping_counter = 0

          if val_loss < best_loss:
            best_loss = val_loss

            torch.save(
                {
                  "epoch": epoch,
                  "model_state_dict": model.state_dict(),
                  "optimizer_state_dict": optimizer.state_dict(),
                  "loss": val_loss,
                },
                PATH,
                  )

          print(f"Val mean dist {mean_dist}")
          print(f"Val mean score {mean_score}")
          print(f"Loss/val {val_loss}")

          converged = is_converged(val_losses)
          print(converged)
          # turn on training, evaluate turns off training
          model.train()

        if converged:
          break

        # Get test loss and inference time
    model.eval()
    test_loss, mean_dist, mean_test_score, inf_time = val_model(
            model, test_loader, loss_fn
        )

    save_results(
            SAVE_PATH,
            model.name,
            num_trainable_params,
            float(loss.cpu()),
            best_loss,
            val_loss,
            test_loss,
            inf_time,
            mean_score,
            mean_test_score,
            epoch,
        )

In [21]:
# train test split into training and validation
X_train, X_val, y_train, y_val = train_test_split(
        pred_dfs, gt_dfs, test_size=config.val_split, random_state=2
    )
    # split into validation and test
X_val, X_test, y_val, y_test = train_test_split(
        X_val, y_val, test_size=0.5, random_state=2
    )

train_dataset = GNSSDataset(X_train, y_train)
val_dataset = GNSSDataset(X_val, y_val)
test_dataset = GNSSDataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=32, shuffle=True
    )
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

seq and label shapes
(124, 3500, 4)
(124, 3500, 2)
seq and label shapes
(16, 3500, 4)
(16, 3500, 2)
seq and label shapes
(16, 3500, 4)
(16, 3500, 2)


In [78]:
num_trainable_params=3000
epochs = 1000  # number of epochs
lr = 0.01  # learning rate

In [81]:
train_model(
        "Transformer",
        train_loader,
        val_loader,
        test_loader,
        config,
        device,
        epochs,
        lr,
        es=True,
    )

Epoch 1 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.04971472918987274
Val mean dist 24853.00390625
Val mean score 33125.8525390625
Loss/val 0.028049560263752937
False
Loss/train 0.022323420271277428
Val mean dist 18552.40625
Val mean score 25652.015625
Loss/val 0.020032117143273354
False
Loss/train 0.015209498815238476
Val mean dist 13008.294921875
Val mean score 18796.255859375
Loss/val 0.01406748965382576
False
Loss/train 0.008780363947153091
Val mean dist 6887.3193359375
Val mean score 10789.692220330238
Loss/val 0.0049915979616343975
False
Epoch 2 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0025373948737978935
Loss/train 0.0007453588768839836
Loss/train 0.00020048351143486798
Loss/train 6.95360722602345e-05
Epoch 3 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9186065148678608e-05
Loss/train 8.627092938695569e-06
Loss/train 5.259465979179367e-06
Loss/train 3.821306108875433e-06
Epoch 4 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.971256986318622e-06
Loss/train 1.6454538354082615e-06
Loss/train 1.2925281680509215e-06
Loss/train 3.9745881963426655e-07
Epoch 5 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.8940896729400265e-07
Loss/train 3.9338459600912756e-07
Loss/train 9.09877780941315e-06
Loss/train 1.6226499610638712e-07
Epoch 6 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.6037148498071474e-07
Loss/train 8.90519550011959e-06
Loss/train 4.510519957534598e-08
Loss/train 3.083340800458245e-07
Epoch 7 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1316462433796914e-08
Loss/train 8.38086907606339e-06
Loss/train 4.022051598440157e-07
Loss/train 3.587634296309261e-07
Epoch 8 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.677888346624968e-07
Loss/train 8.657942089485005e-06
Loss/train 1.3662138087511266e-07
Loss/train 4.3298256646728817e-10
Epoch 9 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.635218869315395e-08
Loss/train 3.3346839245496085e-07
Loss/train 8.673270713188685e-06
Loss/train 5.4563848550515104e-08
Epoch 10 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.3796210153223e-08
Loss/train 5.445929787128989e-07
Loss/train 8.47095998324221e-06
Loss/train 5.471808961488023e-08
Epoch 11 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1841182729076536e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
False
Loss/train 8.691384209669195e-06
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.67046094804391e-07
False
Loss/train 9.60625357038225e-08
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Loss/train 4.413366028188648e-08
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Epoch 12 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.357541446457617e-06
Loss/train 1.2708571262010082e-07
Loss/train 3.1546056789011345e-07
Loss/train 5.791417265754717e-07
Epoch 13 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.630824595456943e-06
Loss/train 1.3550702249176538e-07
Loss/train 2.1481248779764428e-07
Loss/train 9.388925548137195e-08
Epoch 14 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0973521114010509e-07
Loss/train 8.606098162999842e-06
Loss/train 3.188525283803756e-07
Loss/train 4.550968935035371e-09
Epoch 15 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.121943619761481e-10
Loss/train 8.36082108435221e-06
Loss/train 1.5268614106389578e-07
Loss/train 6.132047474238789e-07
Epoch 16 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2407962180514005e-07
Loss/train 9.677203394176104e-08
Loss/train 4.519411789249972e-10
Loss/train 9.888693966786377e-06
Epoch 17 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.893774251333525e-07
Loss/train 5.8467474417511767e-08
Loss/train 8.594343853474129e-06
Loss/train 1.3092980566398182e-07
Epoch 18 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.354731107829139e-06
Loss/train 4.2032461067265103e-08
Loss/train 6.253395099520276e-07
Loss/train 3.027095729635221e-08
Epoch 19 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.0022835062482045e-07
Loss/train 3.582827616810391e-07
Loss/train 8.435649760940578e-06
Loss/train 1.0327988242408992e-08
Epoch 20 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.088839454155277e-07
Loss/train 3.1026598890093737e-07
Loss/train 3.346815447002882e-07
Loss/train 9.513589247944765e-06
Epoch 21 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.3258300946909e-06
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Loss/train 3.9878719348962477e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 22 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.442352808884607e-07
Epoch 23 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.612869351054542e-06
Epoch 24 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.033962941463187e-09
Epoch 25 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.831533206266613e-07
Epoch 26 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1172623948805267e-07
Epoch 27 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.128430930450122e-10
Epoch 28 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.37290735944407e-06
Epoch 29 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.322391295223497e-06
Epoch 30 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.955285414278478e-08
Epoch 31 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.579491804994177e-06
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 32 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2883174349553883e-07
Epoch 33 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.893510497870011e-07
Epoch 34 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.313195406917657e-07
Epoch 35 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.611210432718508e-06
Epoch 36 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2716335479344707e-07
Epoch 37 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.223188909691089e-08
Epoch 38 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5598258563140917e-08
Epoch 39 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.162305179354007e-07
Epoch 40 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.467663974442985e-06
Epoch 41 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.103057224507211e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 42 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4957398636615835e-07
Epoch 43 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.644378965272921e-10
Epoch 44 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.434027222392615e-06
Epoch 45 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.771952695683865e-09
Epoch 46 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5469088882346114e-07
Epoch 47 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.55646521813469e-06
Epoch 48 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.5801521158209653e-07
Epoch 49 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0817621642900122e-07
Epoch 50 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.572081242164131e-06
Epoch 51 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.105989740021414e-09
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 52 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.296126286997605e-08
Epoch 53 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.38133393191265e-09
Epoch 54 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.519752014990445e-08
Epoch 55 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.537088433513418e-06
Epoch 56 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2297364782607474e-07
Epoch 57 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.9032006899096814e-08
Epoch 58 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.5526744442177e-06
Epoch 59 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.653304576000664e-06
Epoch 60 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.387346497329418e-06
Epoch 61 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.148591420336743e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 62 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.150636788655902e-07
Epoch 63 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.895390593948832e-07
Epoch 64 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3685850553411e-07
Epoch 65 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.7429161447424235e-09
Epoch 66 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.419207695922523e-07
Epoch 67 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.7455713530754053e-10
Epoch 68 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1645291781169362e-08
Epoch 69 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.3589036857174506e-07
Epoch 70 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.68450752022909e-06
Epoch 71 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.69624454935547e-06
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
True
Epoch 72 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.479768140302042e-10
Epoch 73 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.345896963466657e-08
Epoch 74 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.947223490285978e-07
Epoch 75 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.24272244003987e-08
Epoch 76 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.36987146865431e-08
Epoch 77 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.53315285098688e-08
Epoch 78 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1055688509695756e-07
Epoch 79 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.159456175514606e-07
Epoch 80 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.53550409374293e-06
Epoch 81 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.159246918573501e-10
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 82 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.376587175007444e-06
Epoch 83 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.609938049630728e-06
Epoch 84 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1492320456673042e-07
Epoch 85 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.6575110584635695e-07
Epoch 86 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.286135345590083e-08
Epoch 87 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1103767150189015e-07
Epoch 88 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.694149073562585e-06
Epoch 89 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.396105840802193e-06
Epoch 90 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.094464228250217e-08
Epoch 91 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.24386356559603e-09
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 92 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.8919750977584044e-07
Epoch 93 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.250404283036914e-08
Epoch 94 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.461295878219971e-07
Epoch 95 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.589625394961331e-06
Epoch 96 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.50070052945739e-07
Epoch 97 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.219612381144543e-07
Epoch 98 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1446869763176437e-07
Epoch 99 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.941252716937015e-07
Epoch 100 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1623945656301657e-07
Epoch 101 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.383548447454814e-06
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670461516478099e-07
True
Epoch 102 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.335011443705298e-06
Epoch 103 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.32807228839738e-08
Epoch 104 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.576813343097456e-06
Epoch 105 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.644980660690635e-07
Epoch 106 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.758758778843912e-08
Epoch 107 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.546162284986167e-08
Epoch 108 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.9248355798140437e-08
Epoch 109 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.586354852013756e-06
Epoch 110 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9849317684238486e-08
Epoch 111 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.446458367115326e-10
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 112 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.47010323923314e-06
Epoch 113 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.442420039500576e-06
Epoch 114 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2235929602866236e-07
Epoch 115 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.386427907680627e-06
Epoch 116 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.0519776689507125e-07
Epoch 117 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.2931534949275374e-07
Epoch 118 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1532062532969576e-07
Epoch 119 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.908648752963927e-07
Epoch 120 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.351948135195016e-08
Epoch 121 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.410909409420128e-08
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
True
Epoch 122 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2713051944501785e-07
Epoch 123 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.228820709746174e-10
Epoch 124 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.64969570102403e-06
Epoch 125 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.2532205252475705e-09
Epoch 126 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.17880232312018e-07
Epoch 127 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.542213965421979e-07
Epoch 128 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.858947993900074e-07
Epoch 129 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.911282024342654e-07
Epoch 130 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.158138526364155e-09
Epoch 131 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.13327923606721e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670461516478099e-07
True
Epoch 132 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5188026004107087e-07
Epoch 133 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.601884474046528e-06
Epoch 134 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.891569863550103e-07
Epoch 135 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.102651646713639e-08
Epoch 136 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.33477133710403e-06
Epoch 137 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0626698809801383e-07
Epoch 138 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4550695937032287e-07
Epoch 139 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.8130901341683057e-07
Epoch 140 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5471135245425103e-07
Epoch 141 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1678965822502505e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670461516478099e-07
True
Epoch 142 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1485696777290286e-07
Epoch 143 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.839706318511162e-07
Epoch 144 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.686094588483684e-06
Epoch 145 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1947469974747946e-07
Epoch 146 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.621901604830782e-08
Epoch 147 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.321045243064873e-06
Epoch 148 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.686028195370454e-06
Epoch 149 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.815036961896112e-07
Epoch 150 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.4556444695208484e-08
Epoch 151 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0717005949345548e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 152 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1497036211949307e-07
Epoch 153 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.421670661893586e-08
Epoch 154 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5552867555234116e-07
Epoch 155 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0586536802748014e-07
Epoch 156 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4447211899314425e-07
Epoch 157 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.725680800125701e-08
Epoch 158 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5241279761066835e-07
Epoch 159 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.484142199449707e-06
Epoch 160 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.448626438919746e-07
Epoch 161 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.631943800679437e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 162 of 1000
Early stopping triggered due to consecutive low val_loss
Stopped at epoch 162


In [82]:
train_model(
        "LSTM",
        train_loader,
        val_loader,
        test_loader,
        config,
        device,
        epochs,
        lr,
        es=True,
    )

Epoch 1 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.010983339510858059
Val mean dist 12622.3681640625
Val mean score 13839.2705078125
Loss/val 0.007744121830910444
False
Loss/train 0.007743177004158497
Val mean dist 10826.5810546875
Val mean score 11867.6953125
Loss/val 0.005697975866496563
False
Loss/train 0.005696800071746111
Val mean dist 9832.3544921875
Val mean score 10778.166015625
Loss/val 0.004699450917541981
False
Loss/train 0.004708063788712025
Val mean dist 8849.61328125
Val mean score 9701.2041015625
Loss/val 0.0038071826566010714
False
Epoch 2 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0038065044209361076
Loss/train 0.003026395570486784
Loss/train 0.0023323085624724627
Loss/train 0.001746037625707686
Epoch 3 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0012561115436255932
Loss/train 0.0008583212620578706
Loss/train 0.0005544504383578897
Loss/train 0.0003140019252896309
Epoch 4 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0001524914987385273
Loss/train 5.370936924009584e-05
Loss/train 1.559174961585086e-05
Loss/train 4.1941535755540826e-07
Epoch 5 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.996585793534905e-08
Loss/train 6.290579790402262e-08
Loss/train 8.833166248223279e-06
Loss/train 7.149080971657895e-08
Epoch 6 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.13918854746953e-08
Loss/train 9.609973261603955e-08
Loss/train 8.686257388035301e-06
Loss/train 2.455347498653282e-07
Epoch 7 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3353224182519625e-08
Loss/train 1.0758090240869933e-07
Loss/train 8.39398035168415e-06
Loss/train 5.756317591476545e-07
Epoch 8 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5769698314425113e-10
Loss/train 8.684522072144318e-06
Loss/train 2.756447088358982e-07
Loss/train 8.922097549657337e-08
Epoch 9 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.655249543044192e-07
Loss/train 4.223505811751238e-08
Loss/train 8.557864020986017e-06
Loss/train 8.33928410770568e-08
Epoch 10 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.547464858565945e-06
Loss/train 2.15132622827241e-08
Loss/train 3.2934582350208075e-07
Loss/train 1.6030691085688886e-07
Epoch 11 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.522214673372218e-08
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Loss/train 8.397128112846985e-06
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
False
Loss/train 3.104005372733809e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
False
Loss/train 2.923895294770773e-07
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Epoch 12 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.854732089081381e-10
Loss/train 8.36207800603006e-06
Loss/train 2.740289914981986e-07
Loss/train 4.5954320171404106e-07
Epoch 13 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.166463985986411e-08
Loss/train 8.864215487847105e-06
Loss/train 1.3781906282872569e-08
Loss/train 1.587772402444898e-07
Epoch 14 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.87992609932553e-06
Loss/train 3.5939279330321483e-10
Loss/train 9.193649397332138e-09
Loss/train 1.7041564603914594e-07
Epoch 15 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.268821930329068e-08
Loss/train 8.333320693054702e-06
Loss/train 5.662135436068638e-07
Loss/train 1.3299455758897238e-07
Epoch 16 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5686733212969557e-07
Loss/train 5.860558149883843e-10
Loss/train 8.650944437249564e-06
Loss/train 3.45077708630015e-08
Epoch 17 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.5304846101571457e-07
Loss/train 2.982591240652255e-07
Loss/train 8.322246685565915e-06
Loss/train 3.028995081422181e-07
Epoch 18 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1619257495331112e-07
Loss/train 8.329941010742914e-06
Loss/train 2.4791580344185604e-08
Loss/train 5.344762712411466e-07
Epoch 19 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.031265252839148e-08
Loss/train 6.123718776507303e-08
Loss/train 8.329634511028416e-06
Loss/train 6.484676191576e-07
Epoch 20 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.75042678671889e-06
Loss/train 2.4818814381433185e-07
Loss/train 2.956155853439668e-08
Loss/train 1.1903398267065768e-08
Epoch 21 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2364709195699106e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
False
Loss/train 2.1446078690701142e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 22 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.357377309498588e-08
Epoch 23 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.894719712351559e-10
Epoch 24 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.275936463382095e-09
Epoch 25 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1871087458057445e-07
Epoch 26 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.290942345302028e-07
Epoch 27 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.905664757690829e-07
Epoch 28 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2369178509507037e-07
Epoch 29 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.665417954654913e-07
Epoch 30 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.650898962514475e-06
Epoch 31 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.028639086887779e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
True
Epoch 32 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0584852105921527e-07
Epoch 33 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4417101940343855e-07
Epoch 34 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.756650019364315e-07
Epoch 35 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.5959194611432395e-08
Epoch 36 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.576444997743238e-06
Epoch 37 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.322357643919531e-06
Epoch 38 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.632239769212902e-06
Epoch 39 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.92745542146622e-08
Epoch 40 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.5077969806043257e-07
Epoch 41 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.495666535651253e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 42 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5466070496804605e-07
Epoch 43 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5578555096217315e-07
Epoch 44 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.332073556019168e-07
Epoch 45 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1864263405623205e-07
Epoch 46 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.249840918011614e-07
Epoch 47 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.874538082181971e-08
Epoch 48 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.631064702058211e-06
Epoch 49 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.32831537675338e-08
Epoch 50 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3793477882682055e-07
Epoch 51 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.8964058174096863e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 52 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.088484786279878e-07
Epoch 53 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.5421039734064834e-07
Epoch 54 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.946341812863466e-08
Epoch 55 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.036558878960932e-08
Epoch 56 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.990631445409235e-08
Epoch 57 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4453558467030234e-07
Epoch 58 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.236398254420237e-08
Epoch 59 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.568656994611956e-06
Epoch 60 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.735081133271478e-09
Epoch 61 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.728476940523251e-07
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 62 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.249200398793619e-07
Epoch 63 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.762590735817867e-08
Epoch 64 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.646560672787018e-06
Epoch 65 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.838800980953238e-07
Epoch 66 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.645614630291675e-08
Epoch 67 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4072554677113658e-07
Epoch 68 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.016924368803302e-08
Epoch 69 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.471838093740189e-08
Epoch 70 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.365540107187371e-08
Epoch 71 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.164213898618982e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 72 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.125858137944306e-07
Epoch 73 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.287125131620996e-08
Epoch 74 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.35490026309526e-07
Epoch 75 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.71386600920232e-06
Epoch 76 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.6871931629580672e-09
Epoch 77 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.320991582877468e-06
Epoch 78 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.339266640978167e-07
Epoch 79 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.009971684761695e-07
Epoch 80 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.431829836448742e-07
Epoch 81 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.136241050720855e-08
Val mean dist 1.8092882633209229
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 82 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.651823918626178e-06
Epoch 83 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.010505444464798e-07
Epoch 84 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.388192327402066e-06
Epoch 85 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9756206387787643e-08
Epoch 86 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.360363608517218e-06
Epoch 87 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.9027561748430344e-09
Epoch 88 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.198239503405148e-09
Epoch 89 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.430261914327275e-06
Epoch 90 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.40592292661313e-06
Epoch 91 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.371286988520296e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 92 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.374492768491109e-07
Epoch 93 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.0421013264003705e-08
Epoch 94 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.696349141246174e-06
Epoch 95 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4949576154776878e-07
Epoch 96 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.536740097042639e-06
Epoch 97 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1458677562359298e-07
Epoch 98 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.355014870176092e-06
Epoch 99 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1577027925777656e-07
Epoch 100 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.53548408485949e-06
Epoch 101 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.939148657787882e-07
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 102 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.369448551093228e-06
Epoch 103 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.8937341767232283e-07
Epoch 104 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.023416266818458e-08
Epoch 105 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.321065251948312e-06
Epoch 106 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.345602509507444e-06
Epoch 107 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.0433923825512466e-07
Epoch 108 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.670824172440916e-06
Epoch 109 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.689676178619266e-06
Epoch 110 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.564736162952613e-06
Epoch 111 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.1605421330359604e-08
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670463221780665e-07
True
Epoch 112 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.9484334024564305e-07
Epoch 113 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.747069640297923e-08
Epoch 114 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.14396615408441e-08
Epoch 115 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.7515994815985323e-07
Epoch 116 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.414947044206201e-09
Epoch 117 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.102853085579227e-08
Epoch 118 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.445670573564712e-06
Epoch 119 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.226757598895347e-07
Epoch 120 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4812027277221205e-07
Epoch 121 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.765892627910944e-08
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462084912288e-07
True
Epoch 122 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.1809491152223757e-10
Epoch 123 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.589086974097881e-06
Epoch 124 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.613020327175036e-06
Epoch 125 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9975086590638966e-07
Epoch 126 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.354164492629934e-06
Epoch 127 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.67082053446211e-06
Epoch 128 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.72605960466899e-06
Epoch 129 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.231309785358462e-08
Epoch 130 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2391545801147004e-07
Epoch 131 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.596336556365713e-06
Val mean dist 1.8092879056930542
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 132 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.320963388541713e-06
Epoch 133 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.169368470589689e-08
Epoch 134 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.34609457708757e-08
Epoch 135 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.091715683467555e-08
Epoch 136 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.940228114312049e-07
Epoch 137 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0733317168387657e-07
Epoch 138 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9281650881785026e-07
Epoch 139 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.967852630320067e-08
Epoch 140 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.3072586568796396e-07
Epoch 141 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.8776520366591285e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670464358649042e-07
True
Epoch 142 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.122153450860424e-08
Epoch 143 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.060498881575114e-10
Epoch 144 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.770264183662221e-10
Epoch 145 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.321051609527785e-06
Epoch 146 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.36210892884992e-06
Epoch 147 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.6495539378880153e-09
Epoch 148 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.814223159537278e-09
Epoch 149 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.250541903478734e-07
Epoch 150 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.649150371818564e-09
Epoch 151 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4420506861133617e-07
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 152 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.893360715461313e-07
Epoch 153 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.373612263932955e-07
Epoch 154 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1823520057751011e-07
Epoch 155 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.536608220310882e-06
Epoch 156 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.6228456679054943e-07
Epoch 157 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2258728310807783e-07
Epoch 158 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.8654297895845957e-07
Epoch 159 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.3414395129181e-10
Epoch 160 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.248693923931569e-07
Epoch 161 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.354407327715307e-06
Val mean dist 1.8092881441116333
Val mean score 2.4904342293739314
Loss/val 6.670462653346476e-07
True
Epoch 162 of 1000
Early stopping triggered due to consecutive low val_loss
Stopped at epoch 162


In [99]:
train_model(
        "MLP",
        train_loader,
        val_loader,
        test_loader,
        config,
        device,
        epochs,
        lr,
        es=True,
    )

Epoch 1 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.051767747849226
Val mean dist 28072.21484375
Val mean score 30764.533203125
Loss/val 0.03829119727015495
False
Loss/train 0.03829053416848183
Val mean dist 23923.138671875
Val mean score 26211.12890625
Loss/val 0.02772573009133339
False
Loss/train 0.027721110731363297
Val mean dist 19763.18359375
Val mean score 21680.4296875
Loss/val 0.01884654350578785
False
Loss/train 0.018829572945833206
Val mean dist 15447.9287109375
Val mean score 17057.5234375
Loss/val 0.011438899673521519
False
Epoch 2 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.011454313062131405
Loss/train 0.005659590940922499
Loss/train 0.0018419650150462985
Loss/train 0.0003636281471699476
Epoch 3 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0010663756402209401
Loss/train 0.002594118705019355
Loss/train 0.0037820565048605204
Loss/train 0.00420510396361351
Epoch 4 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.003867689287289977
Loss/train 0.0030723935924470425
Loss/train 0.002193158259615302
Loss/train 0.0013303119922056794
Epoch 5 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0006490186788141727
Loss/train 0.00021185362129472196
Loss/train 3.103512426605448e-05
Loss/train 5.601807788480073e-05
Epoch 6 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0001799445308279246
Loss/train 0.0003686489653773606
Loss/train 0.0005313875153660774
Loss/train 0.0006525064236484468
Epoch 7 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0007041314383968711
Loss/train 0.000683439546264708
Loss/train 0.0006021700100973248
Loss/train 0.0004895385354757309
Epoch 8 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.0003414746606722474
Loss/train 0.00020891227177344263
Loss/train 0.00011036600335501134
Loss/train 3.268801447120495e-05
Epoch 9 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.135147603141377e-06
Loss/train 1.5006118701421656e-05
Loss/train 6.030965596437454e-05
Loss/train 9.829730697674677e-05
Epoch 10 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 0.00013692821084987372
Loss/train 0.0001656869426369667
Loss/train 0.000155668705701828
Loss/train 0.0001316890266025439
Epoch 11 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.468592907069251e-05
Val mean dist 1072.2684326171875
Val mean score 1169.328857421875
Loss/val 5.584332029684447e-05
False
Loss/train 6.37133271084167e-05
Val mean dist 702.4847412109375
Val mean score 763.5003051757812
Loss/val 2.3828488338040188e-05
False
Loss/train 2.3221788069349714e-05
Val mean dist 365.319580078125
Val mean score 378.1036071777344
Loss/val 6.349996056087548e-06
False
Loss/train 6.265186129894573e-06
Val mean dist 330.24163818359375
Val mean score 335.26309356689455
Loss/val 5.443122518045129e-06
False
Epoch 12 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.148591299075633e-06
Loss/train 2.3964374122442678e-05
Loss/train 2.8853277399321087e-05
Loss/train 3.754700446734205e-05
Epoch 13 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.8089928895933554e-05
Loss/train 3.86157953471411e-05
Loss/train 3.3374617487424985e-05
Loss/train 2.6726833311840892e-05
Epoch 14 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.805696047085803e-05
Loss/train 9.723436960484833e-06
Loss/train 1.2141448678448796e-05
Loss/train 1.4754610901945853e-06
Epoch 15 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2027735414885683e-06
Loss/train 7.220769475679845e-06
Loss/train 2.0650195438065566e-05
Loss/train 1.4754004041606095e-05
Epoch 16 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4428136637434363e-05
Loss/train 1.1501108019729145e-05
Loss/train 7.449486474797595e-06
Loss/train 1.3358448995859362e-05
Epoch 17 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.581990773047437e-06
Loss/train 1.6624746876914287e-06
Loss/train 1.0138130164705217e-05
Loss/train 2.785329797916347e-06
Epoch 18 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.6667670428869314e-06
Loss/train 3.822266990027856e-06
Loss/train 4.078727215528488e-06
Loss/train 1.3198105989431497e-05
Epoch 19 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3616970540606417e-06
Loss/train 1.1477638508949894e-05
Loss/train 1.8402394061922678e-06
Loss/train 9.427155305274937e-07
Epoch 20 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.603044653274992e-07
Loss/train 2.782354044938984e-07
Loss/train 8.815927685645875e-06
Loss/train 1.3460801255860133e-06
Epoch 21 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.928761755872983e-06
Val mean dist 187.1905059814453
Val mean score 206.76075744628906
Loss/val 2.3352758944383822e-06
False
Loss/train 1.6709113879187498e-06
Val mean dist 173.68789672851562
Val mean score 192.13497924804688
Loss/val 2.1046644178568386e-06
True
Epoch 22 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.7669176486379001e-06
Epoch 23 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.33314731810242e-06
Epoch 24 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.880707355274353e-06
Epoch 25 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.179881306574316e-07
Epoch 26 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.956471901910845e-06
Epoch 27 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.020155526835879e-07
Epoch 28 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.433539745856251e-07
Epoch 29 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.095107897228445e-07
Epoch 30 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.757051202541334e-07
Epoch 31 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.397865265877044e-07
Val mean dist 95.375732421875
Val mean score 100.33055877685547
Loss/val 1.0947301234409679e-06
True
Epoch 32 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0359678981330944e-06
Epoch 33 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.992631592263933e-06
Epoch 34 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.877180334820878e-06
Epoch 35 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.489796528010629e-06
Epoch 36 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.839525675175537e-07
Epoch 37 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.489869974757312e-08
Epoch 38 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5561448069311155e-07
Epoch 39 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3378995567545644e-07
Epoch 40 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.267218969289388e-07
Epoch 41 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2908236019247852e-07
Val mean dist 66.28312683105469
Val mean score 68.43365478515625
Loss/val 8.526926649210509e-07
True
Epoch 42 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.213302074480453e-07
Epoch 43 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.195799027049361e-07
Epoch 44 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.492800589010585e-06
Epoch 45 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.121995746369066e-07
Epoch 46 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1316463854882386e-07
Epoch 47 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.419054211117327e-06
Epoch 48 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.984678815209918e-07
Epoch 49 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.574076394372241e-08
Epoch 50 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.56579647586841e-07
Epoch 51 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.476418770442251e-06
Val mean dist 33.429222106933594
Val mean score 34.249240970611574
Loss/val 7.217774395940069e-07
True
Epoch 52 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.3547575861139194e-07
Epoch 53 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.3364366320020054e-07
Epoch 54 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.609305950812995e-06
Epoch 55 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9465599027389544e-07
Epoch 56 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.859733497956768e-06
Epoch 57 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.56171845953213e-06
Epoch 58 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1639951225106415e-07
Epoch 59 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.8897968579476583e-07
Epoch 60 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.573442755732685e-06
Epoch 61 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.42807912704302e-06
Val mean dist 20.676433563232422
Val mean score 24.512914657592773
Loss/val 6.85901056840521e-07
True
Epoch 62 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.7550396453079884e-07
Epoch 63 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4826369465481548e-07
Epoch 64 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.5147051019303035e-07
Epoch 65 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.41834237360672e-07
Epoch 66 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4048090690721438e-08
Epoch 67 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.597419666017231e-07
Epoch 68 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.11196631691746e-08
Epoch 69 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.692491974215955e-06
Epoch 70 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.608915216726018e-07
Epoch 71 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.129524538460828e-07
Val mean dist 12.717111587524414
Val mean score 13.251779103279109
Loss/val 6.740189064657898e-07
True
Epoch 72 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.483590032577922e-08
Epoch 73 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.643777618999593e-06
Epoch 74 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.323881957039703e-06
Epoch 75 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.604858799301837e-08
Epoch 76 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.374172466574237e-06
Epoch 77 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1848847236706206e-07
Epoch 78 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.4667132808863244e-07
Epoch 79 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.932708794873179e-08
Epoch 80 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.6032284833045196e-08
Epoch 81 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.6056428004703776e-08
Val mean dist 6.197725296020508
Val mean score 6.851119756698608
Loss/val 6.681383410978015e-07
True
Epoch 82 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.916573009770218e-07
Epoch 83 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.367121668266918e-08
Epoch 84 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.32491696201032e-06
Epoch 85 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.539662303519435e-06
Epoch 86 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.685592547408305e-06
Epoch 87 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0333877753510023e-07
Epoch 88 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.605868373479097e-08
Epoch 89 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3639389052098068e-09
Epoch 90 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0601211464233984e-08
Epoch 91 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3029942869688966e-07
Val mean dist 4.2697930335998535
Val mean score 5.370612823963166
Loss/val 6.68162215333723e-07
True
Epoch 92 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.168379940163504e-08
Epoch 93 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.662876638254602e-08
Epoch 94 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.576168753876118e-07
Epoch 95 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.6868993668595067e-07
Epoch 96 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.385060937143862e-06
Epoch 97 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.227629861157766e-07
Epoch 98 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.606049633868679e-07
Epoch 99 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.132797132820997e-08
Epoch 100 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2137955940925167e-07
Epoch 101 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.325063390657306e-06
Val mean dist 4.015377044677734
Val mean score 4.8118625521659855
Loss/val 6.676647217318532e-07
True
Epoch 102 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.130450742602989e-07
Epoch 103 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.937001266280276e-08
Epoch 104 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.8704988014615083e-07
Epoch 105 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.6972399869151786e-09
Epoch 106 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.569033755316923e-08
Epoch 107 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.966462429616513e-07
Epoch 108 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.339473154615689e-08
Epoch 109 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.363562301383354e-06
Epoch 110 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.866882422215895e-09
Epoch 111 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.586308467783965e-06
Val mean dist 4.998304843902588
Val mean score 5.5629496574401855
Loss/val 6.67818369493034e-07
True
Epoch 112 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.3587141129155498e-09
Epoch 113 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.4258953007793025e-08
Epoch 114 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9011789592914283e-07
Epoch 115 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.363484994333703e-06
Epoch 116 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.016221476556893e-08
Epoch 117 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.359822459169663e-06
Epoch 118 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.5065546849798466e-09
Epoch 119 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.185290857274595e-08
Epoch 120 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.44738315208815e-06
Epoch 121 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.326326678798068e-06
Val mean dist 4.921855926513672
Val mean score 5.397918367385865
Loss/val 6.678885142719082e-07
True
Epoch 122 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2542255351254425e-07
Epoch 123 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.75547732700943e-08
Epoch 124 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9120442945895775e-07
Epoch 125 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.3458362496458e-06
Epoch 126 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.89165748515552e-08
Epoch 127 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.399410933800027e-07
Epoch 128 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.72778855409706e-06
Epoch 129 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0710098052868489e-07
Epoch 130 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.038707394713128e-08
Epoch 131 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5899483918910846e-07
Val mean dist 4.358687400817871
Val mean score 4.88235856294632
Loss/val 6.675765575892001e-07
True
Epoch 132 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.398929576154842e-08
Epoch 133 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1907328456572941e-07
Epoch 134 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.0616192724910434e-07
Epoch 135 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.503261573698182e-08
Epoch 136 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1832102586595283e-07
Epoch 137 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.3272875182556163e-08
Epoch 138 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.188441132171647e-07
Epoch 139 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.692091796547174e-06
Epoch 140 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.4959087479837763e-07
Epoch 141 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.660667845106218e-06
Val mean dist 4.4935994148254395
Val mean score 4.969092607498169
Loss/val 6.676730777144257e-07
True
Epoch 142 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.445935236522928e-06
Epoch 143 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.358034392585978e-06
Epoch 144 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.123476233213296e-07
Epoch 145 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.598935679998249e-07
Epoch 146 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.631368473288603e-06
Epoch 147 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.879005690687336e-06
Epoch 148 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.900480353673629e-07
Epoch 149 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.723784278823587e-08
Epoch 150 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2060626153906924e-07
Epoch 151 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.1871676381124416e-07
Val mean dist 3.9509403705596924
Val mean score 4.594687056541443
Loss/val 6.674947030660405e-07
True
Epoch 152 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.261384844994609e-08
Epoch 153 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1634138274985162e-07
Epoch 154 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.519188591065813e-08
Epoch 155 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.910647864013299e-08
Epoch 156 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1012342326921498e-07
Epoch 157 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.321370842168108e-06
Epoch 158 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.666038411320187e-06
Epoch 159 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.487807896997765e-08
Epoch 160 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.351970791409258e-06
Epoch 161 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1492175389093973e-07
Val mean dist 4.294656753540039
Val mean score 4.845853590965271
Loss/val 6.676417569906334e-07
True
Epoch 162 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.351734322786797e-06
Epoch 163 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.714997032806423e-07
Epoch 164 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1906403329930981e-07
Epoch 165 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.611746125097852e-06
Epoch 166 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.270407316653291e-08
Epoch 167 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.739667738983826e-09
Epoch 168 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.5268904241547716e-08
Epoch 169 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.346140020876192e-06
Epoch 170 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.3155333767354023e-07
Epoch 171 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.15312032437032e-08
Val mean dist 4.11307430267334
Val mean score 4.606063175201416
Loss/val 6.674792984995292e-07
True
Epoch 172 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.370603162191401e-07
Epoch 173 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.4398939874336065e-07
Epoch 174 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.677491081243716e-08
Epoch 175 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.458260708721355e-06
Epoch 176 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.078133883922419e-07
Epoch 177 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.326210263476241e-06
Epoch 178 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2406085520287888e-07
Epoch 179 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.355077625310514e-06
Epoch 180 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.6947901094208646e-07
Epoch 181 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2083714412701738e-08
Val mean dist 3.9587504863739014
Val mean score 4.564544320106506
Loss/val 6.675204531347845e-07
True
Epoch 182 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.333363439305685e-06
Epoch 183 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.410516562027624e-07
Epoch 184 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.039825623498473e-07
Epoch 185 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.3948158539278666e-07
Epoch 186 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.650107702123933e-06
Epoch 187 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.324846930918284e-06
Epoch 188 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.8986485745008395e-07
Epoch 189 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.6614670761991874e-07
Epoch 190 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.967822147567858e-08
Epoch 191 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.3957665007401374e-07
Val mean dist 3.6985127925872803
Val mean score 4.363161325454708
Loss/val 6.674166570519446e-07
True
Epoch 192 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 6.013748929944995e-07
Epoch 193 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.311270641541796e-08
Epoch 194 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.0863455407749143e-09
Epoch 195 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.458467164018657e-06
Epoch 196 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 7.451251349266386e-08
Epoch 197 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.160911532518185e-08
Epoch 198 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.4624685579510697e-07
Epoch 199 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.460644274535298e-07
Epoch 200 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 9.795397204470646e-08
Epoch 201 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.0867477341398626e-08
Val mean dist 3.5345330238342285
Val mean score 4.222840368747711
Loss/val 6.673864163531107e-07
True
Epoch 202 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.650016752653755e-06
Epoch 203 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.670898751006462e-06
Epoch 204 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.925440583112504e-07
Epoch 205 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.2748176320419589e-07
Epoch 206 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1173906955264101e-07
Epoch 207 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.684228305355646e-06
Epoch 208 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.485379013705824e-07
Epoch 209 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.363016449180577e-09
Epoch 210 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.2017007356444083e-07
Epoch 211 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.549084570091509e-07
Val mean dist 3.366563081741333
Val mean score 4.135035371780395
Loss/val 6.674005135209882e-07
True
Epoch 212 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.79455343590962e-07
Epoch 213 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.409871043113526e-06
Epoch 214 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.451772373518907e-06
Epoch 215 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.3822587839677e-06
Epoch 216 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.430007255810779e-06
Epoch 217 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1200675231748392e-07
Epoch 218 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.942219409391328e-08
Epoch 219 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.98683271846312e-08
Epoch 220 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 8.773857196331392e-10
Epoch 221 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 5.041576969233574e-07
Val mean dist 3.748234987258911
Val mean score 4.735935688018799
Loss/val 6.676665407212568e-07
True
Epoch 222 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.389822443798039e-07
Epoch 223 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.495783928997298e-08
Epoch 224 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.2663235793297645e-07
Epoch 225 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.898849231769418e-07
Epoch 226 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.94799917810451e-07
Epoch 227 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.016607053041298e-08
Epoch 228 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 4.0595136141519106e-08
Epoch 229 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 1.1135258404237902e-07
Epoch 230 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 3.228586535897193e-07
Epoch 231 of 1000


  0%|          | 0/4 [00:00<?, ?it/s]

Loss/train 2.9833921644240036e-07
Val mean dist 3.2063496112823486
Val mean score 4.136018854379653
Loss/val 6.674716814814019e-07
True
Epoch 232 of 1000
Early stopping triggered due to consecutive low val_loss
Stopped at epoch 232
