In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from datetime import datetime

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

In [None]:
from vad.architectures import STAD
from vad.datasets import TrajectoryDataset, ExactBatchSampler

In [4]:
torch.set_num_threads(8)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cuda


# Experiment Parameters

In [5]:
exp_type = 'baseweather'
include_weather = True
n_weather_vars = 5
embed_dim = 32
weight_decay = 0.1
dropout = 0.1
epochs = 250
learning_rate = 1e-5
hidden_dim_gmm = 32
latent_dim_ae = 32
patience = epochs
n_head_te = 8
n_layers_te = 4
n_components = 30
eps_gmm = 1e-7
eps_loss = 1

experiment_name = f'{exp_type}_epochs_{epochs}_pat_{patience}_embed_{embed_dim}_wd_{weight_decay}_lr_{learning_rate}_hgmm_{hidden_dim_gmm}_lae_{latent_dim_ae}_comp_{n_components}'
print(experiment_name)

baseweather_epochs_150_pat_150_embed_32_wd_0.1_lr_3e-05_hgmm_32_lae_32_comp_30


# STAD Instantiation

In [6]:
stad = STAD(n_lat_bins=400,
            n_lon_bins=400,
            n_sog_bins=30,
            n_cog_bins=72,
            max_seq_len=10,
            embed_dim=embed_dim,
            dropout=dropout,
            nhead_te=n_head_te,
            n_layers_te=n_layers_te,
            latent_dim_ae=latent_dim_ae,
            n_weather_vars=n_weather_vars,
            hidden_dim_gmm=hidden_dim_gmm,
            eps_gmm=eps_gmm,
            n_components_gmm=n_components).to(device)
print(stad)

STAD(
  (embedding): TrajectoryEmbedding(
    (lat_embed): Embedding(400, 32)
    (lon_embed): Embedding(400, 32)
    (sog_embed): Embedding(30, 32)
    (cog_embed): Embedding(72, 32)
  )
  (transenc): TrajectoryTransformerEncoder(
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
         

# STAD Loss Function

In [7]:
def calculate_gmm_penalty(sigma, epsilon=eps_loss):
    """
    Vectorized computation of GMM penalty (sum of reciprocals of diagonal elements)

    sigma: Component covariances. Shape: [num_components, input_dim, input_dim]
    epsilon: Small value for numerical stability
    """
    # Extract diagonal elements from all covariance matrices at once
    # Shape: [num_components, input_dim]
    diag_elements = torch.diagonal(sigma, dim1=-2, dim2=-1)

    # Add epsilon for numerical stability before taking reciprocal
    # This prevents division by very small numbers
    penalty = torch.sum(1.0 / (diag_elements + epsilon))

    return penalty

def compute_full_loss(penalty, transformer_loss, energy, d,
                      lambda_1=1, lambda_2=1, lambda_3=5e-3):
                      # λ₁=1, λ₂=1, λ₃=0.005 as in the STAD publication
    return transformer_loss + (lambda_1 * energy) + (lambda_2 * d) + (lambda_3 * penalty)

# STAD Unbiased Loss Function

In [8]:
def evaluate_training_set(model, train_dataloader, device):
    """
    Evaluate the training set with model in eval mode to get unbiased loss.
    Returns the average training loss without gradients or dropout effects.
    """
    model.eval()
    total_train_loss = 0

    with torch.no_grad():
        for batch in tqdm(train_dataloader, desc="Evaluating Training Set"):
            # Move data to device
            inputs = {k: v.to(device) for k, v in batch.get('src_window').items()}
            targets = {k: v.to(device) for k, v in batch.get('tgt_window').items()}
            weather_stats = batch.get('weather_stats', None).to(device)

            # Forward pass (will use testing=False path due to eval mode)
            l, energy, d_h, sigma = model(inputs, targets, weather_stats)
            l, energy, d_h = l.mean(), energy.mean(), d_h.mean()

            # Calculate loss components
            penalty = calculate_gmm_penalty(sigma)
            penalty = penalty.mean()

            # Compute final loss
            train_eval_loss = compute_full_loss(penalty, l, energy, d_h).mean()

            # Accumulate loss
            total_train_loss += train_eval_loss.item()

    # Return average loss
    return total_train_loss / len(train_dataloader)

# STAD Validation Loop

In [9]:
def validate(model, dataloader, device):

    total_val_loss = 0
    total_energy = 0
    total_te_loss = 0

    model.eval()

    for batchidx, batch in enumerate(tqdm(dataloader, desc="Validation")):

        # Move data to device
        inputs = {k: v.to(device) for k, v in batch.get('src_window').items()}
        targets = {k: v.to(device) for k, v in batch.get('tgt_window').items()}
        weather_stats = batch.get('weather_stats', None).to(device)

        # Pass data to model
        l, energy, d_h, sigma = model(inputs, targets, weather_stats)
        l, energy, d_h = l.mean(), energy.mean(), d_h.mean()

        # Calculate loss components
        penalty = calculate_gmm_penalty(sigma)
        penalty = penalty.mean()

        # Compute the final loss
        stad_loss = compute_full_loss(penalty, l, energy, d_h)
        stad_loss = stad_loss.mean()

        # Update total validation loss and total energy
        total_val_loss += stad_loss.item()
        total_energy += energy.item()
        total_te_loss += l.item()

    # Calculate average validation loss and energy
    avg_val_loss = total_val_loss / len(dataloader)
    avg_energy = total_energy / len(dataloader)
    avg_te_loss = total_te_loss / len(dataloader)
    return avg_val_loss, avg_energy, avg_te_loss

# STAD Training Loop

In [None]:
def train(model,
          train_dataloader,
          valid_dataloader,
          optimizer,
          scheduler,
          num_epochs,
          device,
          patience,
          save_dir='./models'):

    # Create directory
    os.makedirs(save_dir, exist_ok=True)

    # Initialize TensorBoard writer
    timestamp = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir=f'./runs/{timestamp}_{experiment_name}')

    # Initialize variables for early stopping
    best_val_loss = float('inf')
    patience_counter = 0

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        # Training phase
        for batchidx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")):

            # Move data to device
            inputs = {k: v.to(device) for k, v in batch.get('src_window').items()}
            targets = {k: v.to(device) for k, v in batch.get('tgt_window').items()}
            weather_stats = batch.get('weather_stats', None).to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Pass data to model
            l, energy, d_h, sigma = model(inputs, targets, weather_stats)
            l, energy, d_h = l.mean(), energy.mean(), d_h.mean()

            # Calculate loss components
            penalty = calculate_gmm_penalty(sigma)
            penalty = penalty.mean()

            # Compute the final loss
            stad_loss = compute_full_loss(penalty, l, energy, d_h).mean()

            # Update total loss for epoch
            train_loss += stad_loss.mean()

            # Print progress
            if batchidx % 200 == 0:
                writer.add_scalar('Batch/te_loss', l, epoch * len(train_dataloader) + batchidx)
                writer.add_scalar('Batch/Energy', energy, epoch * len(train_dataloader) + batchidx)
                writer.add_scalar('Batch/train_loss', stad_loss, epoch * len(train_dataloader) + batchidx)
                writer.add_scalar('Batch/Penalty', penalty*0.005, epoch * len(train_dataloader) + batchidx)
                print(f'Batch {batchidx}/{len(train_dataloader)} | Loss: {stad_loss:.6f}')

            # Backward pass and optimize
            stad_loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss = train_loss.detach()

        # Calculate average training loss for this epoch
        avg_train_loss = train_loss / len(train_dataloader)
        true_loss = evaluate_training_set(model, train_dataloader, device) # already averaged

        # Validation phase
        val_loss, avg_energy, avg_te_loss = validate(model, valid_dataloader, device)

        # Log metrics to TensorBoard
        writer.add_scalar('Epoch/train_loss', avg_train_loss, epoch)
        writer.add_scalar('Epoch/validation_loss', val_loss, epoch)
        writer.add_scalar('Epoch/avg_energy', avg_energy, epoch)
        writer.add_scalar('Epoch/avg_te_loss', avg_te_loss, epoch)
        writer.add_scalar('Epoch/learning_rate', scheduler.get_last_lr()[0], epoch)
        writer.add_scalar('Epoch/true_loss', true_loss, epoch)

        # Print epoch summary
        print(f'Epoch {epoch+1}/{num_epochs} | Average Train Loss: {avg_train_loss:.6f} | Average Validation Loss: {val_loss:.6f}')

        # Save latest model
        latest_model_path = os.path.join(save_dir, 'STAD_latest.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': val_loss
        }, latest_model_path)

        # Check if this is the best model so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0

            # Save best model
            best_model_path = os.path.join(save_dir, 'STAD_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': val_loss
            }, best_model_path)
            print(f"Saved new best model with validation loss: {val_loss:.6f}")
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{patience}")

        # Early stopping check
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break

        # Free cached memory
        torch.cuda.empty_cache()

    # Close TensorBoard writer
    writer.close()

    # Load the best model
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with validation loss: {checkpoint['val_loss']:.6f}")

    return model

# Dataset Dataloader

In [None]:
traj_dataset_train = TrajectoryDataset(ds_type='train',
                                       lat_bins=400,
                                       lon_bins=400,
                                       sog_bins=30,
                                       cog_bins=72,
                                       file_directory='../../data',
                                       filename='joined-train-stad-weather.pkl',
                                       include_weather=include_weather)

traj_dataset_valid = TrajectoryDataset(ds_type='valid',
                                       lat_bins=400,
                                       lon_bins=400,
                                       sog_bins=30,
                                       cog_bins=72,
                                       file_directory='../../data',
                                       filename='joined-valid-stad-weather.pkl',
                                       include_weather=include_weather)

In [12]:
train_batch_sampler = ExactBatchSampler(traj_dataset_train.batch_boundaries, shuffle_batches=True)
valid_batch_sampler = ExactBatchSampler(traj_dataset_valid.batch_boundaries, shuffle_batches=True)

In [13]:
data_loader_train = data.DataLoader(traj_dataset_train, batch_sampler=train_batch_sampler, num_workers=4, pin_memory=True, persistent_workers=True)
data_loader_valid = data.DataLoader(traj_dataset_valid, batch_sampler=valid_batch_sampler, num_workers=4, pin_memory=True, persistent_workers=True)

# Training call

In [14]:
optimizer = AdamW(stad.parameters(),
#                 betas=(0.5, 0.999), # Lower b1 because of variation in batch (trajectory) length
                 weight_decay=weight_decay)

scheduler = OneCycleLR(optimizer,
                    max_lr=learning_rate,            # Peak learning rate
                    epochs=epochs,
                    steps_per_epoch=len(data_loader_train),
                    anneal_strategy='cos'
)

In [None]:
final_model = train(stad,
                    data_loader_train,
                    data_loader_valid,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    num_epochs=epochs,
                    device=device,
                    patience=patience,
                    save_dir=f'./models/{experiment_name}')

Epoch 1/150 [Train]:   0%|          | 5/4538 [00:00<06:28, 11.66it/s]

Batch 0/4538 | Loss: 262.495758


Epoch 1/150 [Train]:   5%|▍         | 205/4538 [00:05<01:53, 38.28it/s]

Batch 200/4538 | Loss: 255.066711


Epoch 1/150 [Train]:   9%|▉         | 405/4538 [00:11<01:48, 38.13it/s]

Batch 400/4538 | Loss: 249.060272


Epoch 1/150 [Train]:  13%|█▎        | 605/4538 [00:16<01:43, 38.16it/s]

Batch 600/4538 | Loss: 249.066589


Epoch 1/150 [Train]:  18%|█▊        | 805/4538 [00:21<01:37, 38.23it/s]

Batch 800/4538 | Loss: 252.475296


Epoch 1/150 [Train]:  22%|██▏       | 1005/4538 [00:26<01:32, 38.30it/s]

Batch 1000/4538 | Loss: 259.411865


Epoch 1/150 [Train]:  27%|██▋       | 1205/4538 [00:31<01:27, 38.27it/s]

Batch 1200/4538 | Loss: 243.867645


Epoch 1/150 [Train]:  31%|███       | 1405/4538 [00:37<01:21, 38.33it/s]

Batch 1400/4538 | Loss: 244.791168


Epoch 1/150 [Train]:  35%|███▌      | 1605/4538 [00:42<01:15, 39.06it/s]

Batch 1600/4538 | Loss: 251.855713


Epoch 1/150 [Train]:  40%|███▉      | 1805/4538 [00:47<01:09, 39.16it/s]

Batch 1800/4538 | Loss: 248.831116


Epoch 1/150 [Train]:  44%|████▍     | 2005/4538 [00:52<01:05, 38.93it/s]

Batch 2000/4538 | Loss: 247.924301


Epoch 1/150 [Train]:  49%|████▊     | 2205/4538 [00:57<01:00, 38.57it/s]

Batch 2200/4538 | Loss: 248.797348


Epoch 1/150 [Train]:  53%|█████▎    | 2405/4538 [01:02<00:54, 39.08it/s]

Batch 2400/4538 | Loss: 235.794586


Epoch 1/150 [Train]:  57%|█████▋    | 2605/4538 [01:08<00:49, 38.98it/s]

Batch 2600/4538 | Loss: 246.027588


Epoch 1/150 [Train]:  62%|██████▏   | 2804/4538 [01:12<00:42, 41.20it/s]

Batch 2800/4538 | Loss: 239.652573


Epoch 1/150 [Train]:  66%|██████▌   | 3004/4538 [01:17<00:37, 41.38it/s]

Batch 3000/4538 | Loss: 247.641144


Epoch 1/150 [Train]:  71%|███████   | 3204/4538 [01:22<00:32, 41.35it/s]

Batch 3200/4538 | Loss: 238.311249


Epoch 1/150 [Train]:  75%|███████▌  | 3404/4538 [01:27<00:28, 39.39it/s]

Batch 3400/4538 | Loss: 238.166336


Epoch 1/150 [Train]:  79%|███████▉  | 3604/4538 [01:32<00:22, 41.23it/s]

Batch 3600/4538 | Loss: 229.683350


Epoch 1/150 [Train]:  84%|████████▍ | 3804/4538 [01:37<00:17, 41.04it/s]

Batch 3800/4538 | Loss: 245.392303


Epoch 1/150 [Train]:  88%|████████▊ | 4004/4538 [01:42<00:12, 41.16it/s]

Batch 4000/4538 | Loss: 242.506012


Epoch 1/150 [Train]:  93%|█████████▎| 4204/4538 [01:46<00:08, 41.24it/s]

Batch 4200/4538 | Loss: 239.732239


Epoch 1/150 [Train]:  97%|█████████▋| 4404/4538 [01:51<00:03, 41.23it/s]

Batch 4400/4538 | Loss: 244.842697


Epoch 1/150 [Train]: 100%|██████████| 4538/4538 [01:55<00:00, 39.46it/s]
Evaluating Training Set: 100%|██████████| 4538/4538 [00:34<00:00, 129.89it/s]
Validation: 100%|██████████| 787/787 [00:07<00:00, 99.24it/s] 


Epoch 1/150 | Average Train Loss: 245.467941 | Average Validation Loss: 238.208952
Saved new best model with validation loss: 238.208952


Epoch 2/150 [Train]:   0%|          | 4/4538 [00:00<02:18, 32.63it/s]

Batch 0/4538 | Loss: 247.537109


Epoch 2/150 [Train]:   4%|▍         | 204/4538 [00:04<01:45, 40.91it/s]

Batch 200/4538 | Loss: 241.086487


Epoch 2/150 [Train]:   9%|▉         | 404/4538 [00:09<01:40, 41.13it/s]

Batch 400/4538 | Loss: 244.331940


Epoch 2/150 [Train]:  13%|█▎        | 604/4538 [00:14<01:35, 41.15it/s]

Batch 600/4538 | Loss: 240.171906


Epoch 2/150 [Train]:  18%|█▊        | 804/4538 [00:19<01:33, 39.99it/s]

Batch 800/4538 | Loss: 223.790283


Epoch 2/150 [Train]:  22%|██▏       | 1004/4538 [00:24<01:26, 40.76it/s]

Batch 1000/4538 | Loss: 236.038574


Epoch 2/150 [Train]:  27%|██▋       | 1204/4538 [00:29<01:21, 41.08it/s]

Batch 1200/4538 | Loss: 241.797272


Epoch 2/150 [Train]:  31%|███       | 1404/4538 [00:34<01:15, 41.24it/s]

Batch 1400/4538 | Loss: 252.090958


Epoch 2/150 [Train]:  35%|███▌      | 1604/4538 [00:39<01:11, 40.90it/s]

Batch 1600/4538 | Loss: 247.216934


Epoch 2/150 [Train]:  40%|███▉      | 1804/4538 [00:43<01:06, 41.11it/s]

Batch 1800/4538 | Loss: 223.977783


Epoch 2/150 [Train]:  44%|████▍     | 2004/4538 [00:48<01:01, 41.33it/s]

Batch 2000/4538 | Loss: 240.647217


Epoch 2/150 [Train]:  48%|████▊     | 2199/4538 [00:53<00:56, 41.32it/s]

Batch 2200/4538 | Loss: 228.856964


Epoch 2/150 [Train]:  53%|█████▎    | 2404/4538 [00:58<00:52, 40.94it/s]

Batch 2400/4538 | Loss: 225.495117


Epoch 2/150 [Train]:  57%|█████▋    | 2604/4538 [01:03<00:47, 40.95it/s]

Batch 2600/4538 | Loss: 236.173203


Epoch 2/150 [Train]:  62%|██████▏   | 2804/4538 [01:08<00:42, 41.06it/s]

Batch 2800/4538 | Loss: 228.678635


Epoch 2/150 [Train]:  66%|██████▌   | 3004/4538 [01:13<00:37, 41.16it/s]

Batch 3000/4538 | Loss: 200.891861


Epoch 2/150 [Train]:  71%|███████   | 3208/4538 [01:18<00:34, 38.23it/s]

Batch 3200/4538 | Loss: 211.581421


Epoch 2/150 [Train]:  75%|███████▌  | 3408/4538 [01:23<00:29, 38.36it/s]

Batch 3400/4538 | Loss: 224.133652


Epoch 2/150 [Train]:  80%|███████▉  | 3608/4538 [01:28<00:24, 38.32it/s]

Batch 3600/4538 | Loss: 238.113724


Epoch 2/150 [Train]:  84%|████████▍ | 3808/4538 [01:34<00:19, 38.25it/s]

Batch 3800/4538 | Loss: 206.277634


Epoch 2/150 [Train]:  88%|████████▊ | 4008/4538 [01:39<00:13, 38.27it/s]

Batch 4000/4538 | Loss: 200.961319


Epoch 2/150 [Train]:  93%|█████████▎| 4208/4538 [01:44<00:08, 38.13it/s]

Batch 4200/4538 | Loss: 242.298019


Epoch 2/150 [Train]:  97%|█████████▋| 4408/4538 [01:49<00:03, 38.01it/s]

Batch 4400/4538 | Loss: 225.157837


Epoch 2/150 [Train]: 100%|██████████| 4538/4538 [01:53<00:00, 40.11it/s]
Evaluating Training Set: 100%|██████████| 4538/4538 [00:34<00:00, 131.80it/s]
Validation: 100%|██████████| 787/787 [00:08<00:00, 95.16it/s] 


Epoch 2/150 | Average Train Loss: 233.738708 | Average Validation Loss: 229.052732
Saved new best model with validation loss: 229.052732


Epoch 3/150 [Train]:   0%|          | 0/4538 [00:00<?, ?it/s]

Batch 0/4538 | Loss: 243.278458


Evaluating Training Set: 100%|██████████| 4538/4538 [00:35<00:00, 128.49it/s]
Validation: 100%|██████████| 787/787 [00:07<00:00, 100.98it/s]


Epoch 3/150 | Average Train Loss: 225.924652 | Average Validation Loss: 222.363145
Saved new best model with validation loss: 222.363145


Epoch 4/150 [Train]:   0%|          | 8/4538 [00:00<02:02, 37.07it/s]

Batch 0/4538 | Loss: 224.754150


Epoch 4/150 [Train]:   5%|▍         | 208/4538 [00:05<01:46, 40.69it/s]

Batch 200/4538 | Loss: 211.183182


Epoch 4/150 [Train]:   9%|▉         | 408/4538 [00:10<01:41, 40.71it/s]

Batch 400/4538 | Loss: 213.206833


Epoch 4/150 [Train]:  13%|█▎        | 608/4538 [00:14<01:36, 40.82it/s]

Batch 600/4538 | Loss: 217.351303


Epoch 4/150 [Train]:  18%|█▊        | 808/4538 [00:19<01:30, 41.14it/s]

Batch 800/4538 | Loss: 233.282349


Epoch 4/150 [Train]:  22%|██▏       | 1008/4538 [00:24<01:26, 40.88it/s]

Batch 1000/4538 | Loss: 231.230087


Epoch 4/150 [Train]:  27%|██▋       | 1208/4538 [00:29<01:21, 40.68it/s]

Batch 1200/4538 | Loss: 210.341629


Epoch 4/150 [Train]:  31%|███       | 1408/4538 [00:34<01:16, 40.88it/s]

Batch 1400/4538 | Loss: 222.254761


Epoch 4/150 [Train]:  35%|███▌      | 1608/4538 [00:39<01:11, 41.23it/s]

Batch 1600/4538 | Loss: 181.244263


Epoch 4/150 [Train]:  40%|███▉      | 1808/4538 [00:44<01:06, 40.86it/s]

Batch 1800/4538 | Loss: 190.467758


Epoch 4/150 [Train]:  44%|████▍     | 2008/4538 [00:49<01:01, 40.93it/s]

Batch 2000/4538 | Loss: 201.681763


Epoch 4/150 [Train]:  49%|████▊     | 2208/4538 [00:54<00:56, 40.95it/s]

Batch 2200/4538 | Loss: 228.024780


Epoch 4/150 [Train]:  53%|█████▎    | 2408/4538 [00:58<00:52, 40.66it/s]

Batch 2400/4538 | Loss: 208.051590


Epoch 4/150 [Train]:  57%|█████▋    | 2608/4538 [01:03<00:46, 41.14it/s]

Batch 2600/4538 | Loss: 219.684998


Epoch 4/150 [Train]:  62%|██████▏   | 2808/4538 [01:08<00:42, 40.88it/s]

Batch 2800/4538 | Loss: 243.238953


Epoch 4/150 [Train]:  66%|██████▋   | 3008/4538 [01:13<00:37, 40.94it/s]

Batch 3000/4538 | Loss: 216.299866


Epoch 4/150 [Train]:  71%|███████   | 3208/4538 [01:18<00:32, 40.81it/s]

Batch 3200/4538 | Loss: 225.681839


Epoch 4/150 [Train]:  75%|███████▌  | 3408/4538 [01:23<00:28, 40.28it/s]

Batch 3400/4538 | Loss: 225.128998


Epoch 4/150 [Train]:  80%|███████▉  | 3608/4538 [01:28<00:22, 40.60it/s]

Batch 3600/4538 | Loss: 218.221680


Epoch 4/150 [Train]:  84%|████████▍ | 3808/4538 [01:33<00:17, 40.83it/s]

Batch 3800/4538 | Loss: 222.509491


Epoch 4/150 [Train]:  88%|████████▊ | 4008/4538 [01:38<00:13, 40.12it/s]

Batch 4000/4538 | Loss: 220.146133


Epoch 4/150 [Train]:  93%|█████████▎| 4208/4538 [01:42<00:08, 40.83it/s]

Batch 4200/4538 | Loss: 225.300995


Epoch 4/150 [Train]:  97%|█████████▋| 4408/4538 [01:47<00:03, 40.71it/s]

Batch 4400/4538 | Loss: 177.244751


Epoch 4/150 [Train]: 100%|██████████| 4538/4538 [01:51<00:00, 40.88it/s]
Evaluating Training Set: 100%|██████████| 4538/4538 [00:34<00:00, 130.05it/s]
Validation: 100%|██████████| 787/787 [00:07<00:00, 99.45it/s] 


Epoch 4/150 | Average Train Loss: 219.758621 | Average Validation Loss: 216.655289
Saved new best model with validation loss: 216.655289


Epoch 5/150 [Train]:   0%|          | 0/4538 [00:00<?, ?it/s]

Batch 0/4538 | Loss: 176.053726


Epoch 5/150 [Train]:   5%|▍         | 208/4538 [00:05<01:46, 40.68it/s]

Batch 200/4538 | Loss: 212.309402


Epoch 5/150 [Train]:   9%|▉         | 408/4538 [00:09<01:40, 41.25it/s]

Batch 400/4538 | Loss: 250.735947


Epoch 5/150 [Train]:  13%|█▎        | 608/4538 [00:14<01:35, 41.20it/s]

Batch 600/4538 | Loss: 205.490814


Epoch 5/150 [Train]:  18%|█▊        | 808/4538 [00:19<01:31, 40.62it/s]

Batch 800/4538 | Loss: 218.314987


Epoch 5/150 [Train]:  22%|██▏       | 1008/4538 [00:24<01:26, 40.89it/s]

Batch 1000/4538 | Loss: 246.020584


Epoch 5/150 [Train]:  27%|██▋       | 1208/4538 [00:29<01:20, 41.17it/s]

Batch 1200/4538 | Loss: 225.082123


Epoch 5/150 [Train]:  31%|███       | 1408/4538 [00:34<01:17, 40.21it/s]

Batch 1400/4538 | Loss: 224.887070


Epoch 5/150 [Train]:  35%|███▌      | 1608/4538 [00:39<01:11, 40.92it/s]

Batch 1600/4538 | Loss: 218.640610


Epoch 5/150 [Train]:  40%|███▉      | 1808/4538 [00:44<01:06, 41.08it/s]

Batch 1800/4538 | Loss: 213.982162


Epoch 5/150 [Train]:  44%|████▍     | 2008/4538 [00:48<01:02, 40.80it/s]

Batch 2000/4538 | Loss: 195.032974


Epoch 5/150 [Train]:  49%|████▊     | 2208/4538 [00:53<00:56, 41.05it/s]

Batch 2200/4538 | Loss: 218.859787


Epoch 5/150 [Train]:  53%|█████▎    | 2408/4538 [00:58<00:51, 41.25it/s]

Batch 2400/4538 | Loss: 227.583435


Epoch 5/150 [Train]:  57%|█████▋    | 2608/4538 [01:03<00:46, 41.32it/s]

Batch 2600/4538 | Loss: 220.849731


Epoch 5/150 [Train]:  62%|██████▏   | 2808/4538 [01:08<00:42, 41.18it/s]

Batch 2800/4538 | Loss: 195.304108


Epoch 5/150 [Train]:  66%|██████▋   | 3008/4538 [01:13<00:37, 41.17it/s]

Batch 3000/4538 | Loss: 186.671280


Epoch 5/150 [Train]:  71%|███████   | 3208/4538 [01:18<00:32, 41.11it/s]

Batch 3200/4538 | Loss: 233.786514


Epoch 5/150 [Train]:  75%|███████▌  | 3408/4538 [01:23<00:27, 41.41it/s]

Batch 3400/4538 | Loss: 215.724243


Epoch 5/150 [Train]:  80%|███████▉  | 3608/4538 [01:27<00:22, 40.97it/s]

Batch 3600/4538 | Loss: 210.709641


Epoch 5/150 [Train]:  84%|████████▍ | 3808/4538 [01:32<00:17, 41.15it/s]

Batch 3800/4538 | Loss: 198.597778


Epoch 5/150 [Train]:  88%|████████▊ | 4008/4538 [01:37<00:12, 41.09it/s]

Batch 4000/4538 | Loss: 215.873718


Epoch 5/150 [Train]:  93%|█████████▎| 4208/4538 [01:42<00:08, 40.95it/s]

Batch 4200/4538 | Loss: 203.593384


Epoch 5/150 [Train]:  97%|█████████▋| 4408/4538 [01:47<00:03, 41.18it/s]

Batch 4400/4538 | Loss: 206.408813


Epoch 5/150 [Train]: 100%|██████████| 4538/4538 [01:50<00:00, 41.06it/s]
Evaluating Training Set: 100%|█████████▉| 4537/4538 [00:34<00:00, 127.38it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 118/150 [Train]:   5%|▍         | 207/4538 [00:05<01:51, 38.94it/s]

Batch 200/4538 | Loss: 88.834511


Epoch 118/150 [Train]:   9%|▉         | 407/4538 [00:10<01:46, 38.95it/s]

Batch 400/4538 | Loss: 86.875618


Epoch 118/150 [Train]:  13%|█▎        | 607/4538 [00:15<01:34, 41.75it/s]

Batch 600/4538 | Loss: 75.644554


Epoch 118/150 [Train]:  18%|█▊        | 807/4538 [00:20<01:29, 41.47it/s]

Batch 800/4538 | Loss: 80.727486


Epoch 118/150 [Train]:  22%|██▏       | 1007/4538 [00:25<01:25, 41.37it/s]

Batch 1000/4538 | Loss: 74.778595


Epoch 118/150 [Train]:  27%|██▋       | 1207/4538 [00:29<01:20, 41.48it/s]

Batch 1200/4538 | Loss: 83.501846


Epoch 118/150 [Train]:  31%|███       | 1407/4538 [00:34<01:15, 41.58it/s]

Batch 1400/4538 | Loss: 76.222595


Epoch 118/150 [Train]:  35%|███▌      | 1607/4538 [00:39<01:10, 41.51it/s]

Batch 1600/4538 | Loss: 75.803482


Epoch 118/150 [Train]:  40%|███▉      | 1807/4538 [00:44<01:05, 41.71it/s]

Batch 1800/4538 | Loss: 78.951370


Epoch 118/150 [Train]:  44%|████▍     | 2007/4538 [00:49<01:01, 41.42it/s]

Batch 2000/4538 | Loss: 68.682739


Epoch 118/150 [Train]:  49%|████▊     | 2207/4538 [00:53<00:56, 41.40it/s]

Batch 2200/4538 | Loss: 75.101128


Epoch 118/150 [Train]:  53%|█████▎    | 2407/4538 [00:58<00:51, 41.47it/s]

Batch 2400/4538 | Loss: 71.408676


Epoch 118/150 [Train]:  57%|█████▋    | 2607/4538 [01:03<00:46, 41.34it/s]

Batch 2600/4538 | Loss: 84.737556


Epoch 118/150 [Train]:  62%|██████▏   | 2807/4538 [01:08<00:41, 41.46it/s]

Batch 2800/4538 | Loss: 86.887650


Epoch 118/150 [Train]:  66%|██████▋   | 3007/4538 [01:13<00:37, 41.18it/s]

Batch 3000/4538 | Loss: 76.128128


Epoch 118/150 [Train]:  71%|███████   | 3207/4538 [01:18<00:31, 41.60it/s]

Batch 3200/4538 | Loss: 93.717628


Epoch 118/150 [Train]:  75%|███████▌  | 3405/4538 [01:23<00:29, 38.84it/s]

Batch 3400/4538 | Loss: 105.352951


Epoch 118/150 [Train]:  79%|███████▉  | 3605/4538 [01:28<00:24, 38.42it/s]

Batch 3600/4538 | Loss: 80.305649


Epoch 118/150 [Train]:  84%|████████▍ | 3805/4538 [01:33<00:18, 38.73it/s]

Batch 3800/4538 | Loss: 74.345039


Epoch 118/150 [Train]:  88%|████████▊ | 4005/4538 [01:38<00:13, 38.77it/s]

Batch 4000/4538 | Loss: 98.085579


Epoch 118/150 [Train]:  93%|█████████▎| 4205/4538 [01:43<00:08, 38.70it/s]

Batch 4200/4538 | Loss: 87.673737


Epoch 118/150 [Train]:  97%|█████████▋| 4405/4538 [01:48<00:03, 38.72it/s]

Batch 4400/4538 | Loss: 95.243599


Epoch 118/150 [Train]: 100%|██████████| 4538/4538 [01:52<00:00, 40.42it/s]
Evaluating Training Set: 100%|██████████| 4538/4538 [00:33<00:00, 135.92it/s]
Validation: 100%|██████████| 787/787 [00:07<00:00, 98.57it/s] 


Epoch 118/150 | Average Train Loss: 84.049850 | Average Validation Loss: 87.973023
Validation loss did not improve. Patience: 2/150


Epoch 119/150 [Train]:   0%|          | 0/4538 [00:00<?, ?it/s]

Batch 0/4538 | Loss: 84.514412


Epoch 119/150 [Train]:   4%|▍         | 204/4538 [00:04<01:44, 41.66it/s]

Batch 200/4538 | Loss: 87.315430


Epoch 119/150 [Train]:   9%|▉         | 404/4538 [00:09<01:40, 41.29it/s]

Batch 400/4538 | Loss: 98.388863


Epoch 119/150 [Train]:  13%|█▎        | 604/4538 [00:14<01:34, 41.50it/s]

Batch 600/4538 | Loss: 67.909744


Epoch 119/150 [Train]:  18%|█▊        | 808/4538 [00:19<01:35, 38.94it/s]

Batch 800/4538 | Loss: 98.889496


Epoch 119/150 [Train]:  22%|██▏       | 1008/4538 [00:24<01:30, 38.82it/s]

Batch 1000/4538 | Loss: 67.358582


Epoch 119/150 [Train]:  27%|██▋       | 1208/4538 [00:30<01:25, 38.92it/s]

Batch 1200/4538 | Loss: 80.756119


Epoch 119/150 [Train]:  31%|███       | 1408/4538 [00:35<01:20, 39.05it/s]

Batch 1400/4538 | Loss: 91.967186


Epoch 119/150 [Train]:  35%|███▌      | 1608/4538 [00:40<01:15, 38.83it/s]

Batch 1600/4538 | Loss: 90.132370


Epoch 119/150 [Train]:  40%|███▉      | 1808/4538 [00:45<01:09, 39.05it/s]

Batch 1800/4538 | Loss: 80.174362


Epoch 119/150 [Train]:  44%|████▍     | 2008/4538 [00:50<01:04, 38.95it/s]

Batch 2000/4538 | Loss: 79.741928


Epoch 119/150 [Train]:  49%|████▊     | 2208/4538 [00:55<00:59, 38.96it/s]

Batch 2200/4538 | Loss: 84.326874


Epoch 119/150 [Train]:  53%|█████▎    | 2408/4538 [01:00<00:54, 38.90it/s]

Batch 2400/4538 | Loss: 74.632973


Epoch 119/150 [Train]:  57%|█████▋    | 2608/4538 [01:05<00:49, 38.92it/s]

Batch 2600/4538 | Loss: 89.718842


Epoch 119/150 [Train]:  62%|██████▏   | 2808/4538 [01:11<00:44, 38.81it/s]

Batch 2800/4538 | Loss: 87.861969


Epoch 119/150 [Train]:  66%|██████▋   | 3008/4538 [01:16<00:39, 38.98it/s]

Batch 3000/4538 | Loss: 78.866211


Epoch 119/150 [Train]:  71%|███████   | 3208/4538 [01:21<00:34, 38.93it/s]

Batch 3200/4538 | Loss: 95.901909


Epoch 119/150 [Train]:  75%|███████▌  | 3408/4538 [01:26<00:29, 38.82it/s]

Batch 3400/4538 | Loss: 90.447136


Epoch 119/150 [Train]:  80%|███████▉  | 3608/4538 [01:31<00:23, 39.11it/s]

Batch 3600/4538 | Loss: 70.862015


Epoch 119/150 [Train]:  84%|████████▍ | 3808/4538 [01:36<00:18, 38.90it/s]

Batch 3800/4538 | Loss: 68.060440


Epoch 119/150 [Train]:  88%|████████▊ | 4008/4538 [01:41<00:13, 38.93it/s]

Batch 4000/4538 | Loss: 88.612831


Epoch 119/150 [Train]:  93%|█████████▎| 4208/4538 [01:47<00:08, 38.82it/s]

Batch 4200/4538 | Loss: 98.464737


Epoch 119/150 [Train]:  97%|█████████▋| 4408/4538 [01:52<00:03, 38.88it/s]

Batch 4400/4538 | Loss: 76.593567


Epoch 119/150 [Train]: 100%|██████████| 4538/4538 [01:55<00:00, 39.27it/s]
Evaluating Training Set: 100%|██████████| 4538/4538 [00:33<00:00, 134.62it/s]
Validation: 100%|██████████| 787/787 [00:07<00:00, 98.95it/s] 


Epoch 119/150 | Average Train Loss: 84.028770 | Average Validation Loss: 87.966630
Validation loss did not improve. Patience: 3/150


Epoch 120/150 [Train]:   0%|          | 0/4538 [00:00<?, ?it/s]

Batch 0/4538 | Loss: 77.911392


Epoch 120/150 [Train]:   4%|▍         | 204/4538 [00:04<01:44, 41.54it/s]

Batch 200/4538 | Loss: 71.563339


Epoch 120/150 [Train]:   7%|▋         | 334/4538 [00:08<01:40, 41.94it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

