In [1]:
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:


class TrajectoryLSTM(nn.Module):
    def __init__(self, input_dim=2+2+32, hidden_dim=128, output_dim=2, num_layers=1):
        super().__init__()
        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.decoder_lstm = nn.LSTM(output_dim, hidden_dim, num_layers, batch_first=True)
        self.output_fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input_seq, target_len=80):
        batch_size = input_seq.size(0)

        # Encode
        _, (h, c) = self.encoder_lstm(input_seq)

        # Decode
        decoder_input = input_seq[:, -1:, :2]  # just x, y of last input
        outputs = []

        for _ in range(target_len):
            out, (h, c) = self.decoder_lstm(decoder_input, (h, c))
            pred = self.output_fc(out)  # predict (x, y)
            outputs.append(pred)
            decoder_input = pred  # feed predicted position

        return torch.cat(outputs, dim=1)  # shape: (B, 80, 2)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import math


class TrajectoryLSTM(nn.Module):
    """
    LSTM model for trajectory prediction using MTR dataset format.
    
    Input: obj_trajs from MTR dataset with shape (num_center_objects, num_objects, num_timestamps, num_features)
    Output: Future trajectory predictions
    """
    
    def __init__(self, 
                 input_dim=29,  # Based on MTR dataset obj_trajs feature dimension
                 hidden_dim=256,
                 num_layers=2,
                 num_modes=6,  # Number of prediction modes
                 future_steps=80,  # Number of future timesteps to predict
                 dropout=0.1):
        super(TrajectoryLSTM, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_modes = num_modes
        self.future_steps = future_steps
        self.dropout = dropout
        
        # Feature encoder for input trajectories
        self.feature_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=False
        )
        
        # Multi-modal prediction heads
        self.mode_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_modes)
        )
        
        # Trajectory decoder for each mode
        self.traj_decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, future_steps * 4)  # x, y, vx, vy for each timestep
            ) for _ in range(num_modes)
        ])
        
        # Attention mechanism for object interactions
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=dropout,
            batch_first=True
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize model weights"""
        for name, param in self.named_parameters():
            if 'weight' in name:
                if len(param.shape) >= 2:
                    nn.init.xavier_uniform_(param)
                else:
                    nn.init.uniform_(param, -0.1, 0.1)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
    
    def forward(self, batch_dict):
        """
        Forward pass of the model
        
        Args:
            batch_dict: Dictionary containing:
                - obj_trajs: (batch_size, num_objects, num_timestamps, input_dim)
                - obj_trajs_mask: (batch_size, num_objects, num_timestamps)
                - track_index_to_predict: (batch_size,) indices of center objects
        
        Returns:
            pred_scores: (batch_size, num_modes) - confidence scores for each mode
            pred_trajs: (batch_size, num_modes, future_steps, 4) - predicted trajectories
        """
        input_dict=batch_dict["input_dict"]
        obj_trajs = input_dict['obj_trajs'].to("cuda")  # (batch_size, num_objects, num_timestamps, input_dim)
        obj_trajs_mask = input_dict['obj_trajs_mask'].to("cuda")  # (batch_size, num_objects, num_timestamps)
        track_indices = input_dict['track_index_to_predict'].to("cuda")  # (batch_size,)
        # map_polylines, map_polylines_mask = input_dict['map_polylines'].to("cuda"), input_dict['map_polylines_mask'].to("cuda") # (num_center_objects, num_topk_polylines, num_points_each_polyline, 9): [x, y, z, dir_x, dir_y, dir_z, global_type, pre_x, pre_y]

        

        
        batch_size, num_objects, num_timestamps, input_dim = obj_trajs.shape
        
        # Encode input features
        obj_features = self.feature_encoder(obj_trajs.view(-1, input_dim))
        obj_features = obj_features.view(batch_size, num_objects, num_timestamps, self.hidden_dim)
        
        # Apply mask to features
        mask_expanded = obj_trajs_mask.unsqueeze(-1).expand(-1, -1, -1, self.hidden_dim)
        obj_features = obj_features * mask_expanded.float()
        
        # Process each object's trajectory through LSTM
        all_lstm_outputs = []
        
        for obj_idx in range(num_objects):
            obj_seq = obj_features[:, obj_idx, :, :]  # (batch_size, num_timestamps, hidden_dim)
            lstm_out, _ = self.lstm(obj_seq)  # (batch_size, num_timestamps, hidden_dim)
            
            # Take the last valid output for each sequence
            seq_lengths = obj_trajs_mask[:, obj_idx, :].sum(dim=1)  # (batch_size,)
            last_outputs = []
            for b in range(batch_size):
                if seq_lengths[b] > 0:
                    last_idx = int(seq_lengths[b] - 1)
                    last_outputs.append(lstm_out[b, last_idx, :])
                else:
                    last_outputs.append(torch.zeros(self.hidden_dim, device=obj_seq.device))
            
            last_output = torch.stack(last_outputs, dim=0)  # (batch_size, hidden_dim)
            all_lstm_outputs.append(last_output)
        
        all_lstm_outputs = torch.stack(all_lstm_outputs, dim=1)  # (batch_size, num_objects, hidden_dim)
        
        # Apply attention mechanism for object interactions
        attn_output, _ = self.attention(
            all_lstm_outputs, all_lstm_outputs, all_lstm_outputs,
            key_padding_mask=~(obj_trajs_mask.sum(dim=2) > 0)  # (batch_size, num_objects)
        )
        
        # Extract center object features
        center_features = []
        for b in range(batch_size):
            center_idx = track_indices[b]
            center_features.append(attn_output[b, center_idx, :])
        center_features = torch.stack(center_features, dim=0)  # (batch_size, hidden_dim)
        
        # Predict mode probabilities
        mode_logits = self.mode_predictor(center_features)  # (batch_size, num_modes)
        pred_scores = F.softmax(mode_logits, dim=-1)
        
        # Predict trajectories for each mode
        pred_trajs_list = []
        for mode_idx in range(self.num_modes):
            traj_flat = self.traj_decoders[mode_idx](center_features)  # (batch_size, future_steps * 4)
            traj = traj_flat.view(batch_size, self.future_steps, 4)  # (batch_size, future_steps, 4)
            pred_trajs_list.append(traj)
        
        pred_trajs = torch.stack(pred_trajs_list, dim=1)  # (batch_size, num_modes, future_steps, 4)
        
        return pred_scores, pred_trajs


class TrajectoryLoss(nn.Module):
    """
    Loss function for trajectory prediction
    """
    
    def __init__(self, 
                 regression_loss_weight=1.0,
                 classification_loss_weight=1.0,
                 future_loss_weight=1.0):
        super(TrajectoryLoss, self).__init__()
        self.reg_weight = regression_loss_weight
        self.cls_weight = classification_loss_weight
        self.future_weight = future_loss_weight
    
    def forward(self, pred_scores, pred_trajs, batch_dict):
        """
        Compute loss
        
        Args:
            pred_scores: (batch_size, num_modes)
            pred_trajs: (batch_size, num_modes, future_steps, 4)
            batch_dict: Contains ground truth data
        
        Returns:
            loss_dict: Dictionary containing different loss components
        """
        center_gt_trajs = batch_dict['input_dict']['center_gt_trajs'].to('cuda')  # (batch_size, future_steps, 4)
        center_gt_trajs_mask = batch_dict['input_dict']['center_gt_trajs_mask'].to('cuda')  # (batch_size, future_steps)
        
        batch_size, num_modes, future_steps, _ = pred_trajs.shape
        
        # Compute trajectory regression loss for each mode
        gt_trajs_expanded = center_gt_trajs.unsqueeze(1).expand(-1, num_modes, -1, -1)
        gt_mask_expanded = center_gt_trajs_mask.unsqueeze(1).expand(-1, num_modes, -1)
        
        # L2 loss for position (x, y)
        pos_loss = F.mse_loss(
            pred_trajs[:, :, :, :2] * gt_mask_expanded.unsqueeze(-1),
            gt_trajs_expanded[:, :, :, :2] * gt_mask_expanded.unsqueeze(-1),
            reduction='none'
        ).sum(dim=-1)  # (batch_size, num_modes, future_steps)
        
        # L2 loss for velocity (vx, vy)
        vel_loss = F.mse_loss(
            pred_trajs[:, :, :, 2:4] * gt_mask_expanded.unsqueeze(-1),
            gt_trajs_expanded[:, :, :, 2:4] * gt_mask_expanded.unsqueeze(-1),
            reduction='none'
        ).sum(dim=-1)  # (batch_size, num_modes, future_steps)
        
        # Weighted loss over time (give more weight to near future)
        time_weights = torch.exp(-0.1 * torch.arange(future_steps, device=pred_trajs.device))
        time_weights = time_weights.view(1, 1, -1)
        
        pos_loss = (pos_loss * time_weights * gt_mask_expanded).sum(dim=-1)  # (batch_size, num_modes)
        vel_loss = (vel_loss * time_weights * gt_mask_expanded).sum(dim=-1)  # (batch_size, num_modes)
        
        # Find best mode for each sample
        total_traj_loss = pos_loss + vel_loss  # (batch_size, num_modes)
        best_mode_indices = torch.argmin(total_traj_loss, dim=1)  # (batch_size,)
        
        # Regression loss (best mode)
        best_pos_loss = pos_loss[torch.arange(batch_size), best_mode_indices].mean()
        best_vel_loss = vel_loss[torch.arange(batch_size), best_mode_indices].mean()
        regression_loss = best_pos_loss + best_vel_loss
        
        # Classification loss (encourage higher confidence for best mode)
        target_scores = torch.zeros_like(pred_scores)
        target_scores[torch.arange(batch_size), best_mode_indices] = 1.0
        classification_loss = F.cross_entropy(pred_scores, target_scores)
        
        # Total loss
        total_loss = (self.reg_weight * regression_loss + 
                     self.cls_weight * classification_loss)
        
        loss_dict = {
            'total_loss': total_loss,
            'regression_loss': regression_loss,
            'classification_loss': classification_loss,
            'pos_loss': best_pos_loss,
            'vel_loss': best_vel_loss
        }
        
        return loss_dict


def create_dataloader(dataset, batch_size=32, shuffle=True, num_workers=4):
    """
    Create DataLoader for the MTR dataset
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_batch
    )


def collate_batch(batch_list):
    """
    Custom collate function to handle variable number of objects per scene
    """
    batch_dict = {}
    
    # Stack all the tensors
    for key in batch_list[0].keys():
        if isinstance(batch_list[0][key], np.ndarray):
            batch_dict[key] = torch.from_numpy(np.stack([item[key] for item in batch_list], axis=0))
        elif isinstance(batch_list[0][key], torch.Tensor):
            batch_dict[key] = torch.stack([item[key] for item in batch_list], axis=0)
        else:
            batch_dict[key] = [item[key] for item in batch_list]
    
    # Add batch size info
    batch_dict['batch_size'] = len(batch_list)
    batch_dict['batch_sample_count'] = [1] * len(batch_list)  # Each sample is one center object
    
    return batch_dict


def train_model(model, train_dataloader, val_dataloader, num_epochs=100, lr=1e-3):
    """
    Training loop for the LSTM model
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = TrajectoryLoss()
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_losses = []
        
        for batch_idx, batch_dict in enumerate(train_dataloader):
            # Move data to device
            for key, value in batch_dict.items():
                if isinstance(value, torch.Tensor):
                    batch_dict[key] = value.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            pred_scores, pred_trajs = model(batch_dict)
            
            # Compute loss
            loss_dict = criterion(pred_scores, pred_trajs, batch_dict)
            loss = loss_dict['total_loss']
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_losses.append(loss.item())
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch_dict in val_dataloader:
                # Move data to device
                for key, value in batch_dict.items():
                    if isinstance(value, torch.Tensor):
                        batch_dict[key] = value.to(device)
                
                pred_scores, pred_trajs = model(batch_dict)
                loss_dict = criterion(pred_scores, pred_trajs, batch_dict)
                val_losses.append(loss_dict['total_loss'].item())
        
        scheduler.step()
        
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), '/code/jjiang23/csc587/KimchiVision/best_trajectory_lstm.pth')
    
    return model



In [4]:
from mtr.datasets import build_dataloader
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils

In [5]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
logger = common_utils.create_logger("/files/waymo/log.txt", rank=0)
from easydict import EasyDict as edict
args = edict({
    "batch_size": 32,
    "workers": 4,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 10,
    "add_worker_init_fn": False,
})
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)
test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=False, workers=args.workers, logger=logger, training=False
    )

2025-06-06 11:04:34,145   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl


2025-06-06 11:04:39,030   INFO  Total scenes before filters: 243401


2025-06-06 11:04:44,976   INFO  Total scenes after filter_info_by_object_type: 243401


2025-06-06 11:04:44,991   INFO  Total scenes after filters: 243401


2025-06-06 11:04:44,993   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_val_infos.pkl


2025-06-06 11:04:46,866   INFO  Total scenes before filters: 22089


2025-06-06 11:04:47,492   INFO  Total scenes after filter_info_by_object_type: 22089


2025-06-06 11:04:47,494   INFO  Total scenes after filters: 22089


In [6]:

# Initialize model
model = TrajectoryLSTM(input_dim=29, hidden_dim=256, num_modes=6)

# Train the model
trained_model = train_model(model, train_loader, test_loader)

Epoch 0, Batch 0, Loss: 1799.0417


Epoch 0, Batch 100, Loss: 164.5122


Epoch 0, Batch 200, Loss: 138.6299


Epoch 0, Batch 300, Loss: 92.2056


Epoch 0, Batch 400, Loss: 63.5781


Epoch 0, Batch 500, Loss: 39.8303


Epoch 0, Batch 600, Loss: 47.2432


Epoch 0, Batch 700, Loss: 40.4891


Epoch 0, Batch 800, Loss: 45.0766


Epoch 0, Batch 900, Loss: 40.1546


Epoch 0, Batch 1000, Loss: 38.8022


Epoch 0, Batch 1100, Loss: 50.8402


Epoch 0, Batch 1200, Loss: 33.0714


Epoch 0, Batch 1300, Loss: 32.9971


Epoch 0, Batch 1400, Loss: 42.0355


Epoch 0, Batch 1500, Loss: 33.4903


Epoch 0, Batch 1600, Loss: 44.3411


Epoch 0, Batch 1700, Loss: 39.1717


Epoch 0, Batch 1800, Loss: 45.1885


Epoch 0, Batch 1900, Loss: 43.2191


Epoch 0, Batch 2000, Loss: 35.3747


Epoch 0, Batch 2100, Loss: 30.8936


Epoch 0, Batch 2200, Loss: 33.0085


Epoch 0, Batch 2300, Loss: 35.9082


Epoch 0, Batch 2400, Loss: 31.8212


Epoch 0, Batch 2500, Loss: 39.3983


Epoch 0, Batch 2600, Loss: 31.1740


Epoch 0, Batch 2700, Loss: 33.5954


Epoch 0, Batch 2800, Loss: 29.2644


Epoch 0, Batch 2900, Loss: 55.5808


Epoch 0, Batch 3000, Loss: 34.4237


Epoch 0, Batch 3100, Loss: 36.8041


Epoch 0, Batch 3200, Loss: 24.9021


Epoch 0, Batch 3300, Loss: 34.4186


Epoch 0, Batch 3400, Loss: 30.0354


Epoch 0, Batch 3500, Loss: 28.4190


Epoch 0, Batch 3600, Loss: 26.5481


Epoch 0, Batch 3700, Loss: 18.9490


Epoch 0, Batch 3800, Loss: 30.5802


Epoch 0, Batch 3900, Loss: 40.5250


Epoch 0, Batch 4000, Loss: 27.9064


Epoch 0, Batch 4100, Loss: 52.1555


Epoch 0, Batch 4200, Loss: 56.8571


Epoch 0, Batch 4300, Loss: 23.4741


Epoch 0, Batch 4400, Loss: 24.4867


Epoch 0, Batch 4500, Loss: 26.2242


Epoch 0, Batch 4600, Loss: 34.5945


Epoch 0, Batch 4700, Loss: 26.7738


Epoch 0, Batch 4800, Loss: 21.4997


Epoch 0, Batch 4900, Loss: 27.7283


Epoch 0, Batch 5000, Loss: 21.0283


Epoch 0, Batch 5100, Loss: 32.2981


Epoch 0, Batch 5200, Loss: 26.1633


Epoch 0, Batch 5300, Loss: 25.8782


Epoch 0, Batch 5400, Loss: 20.6735


Epoch 0, Batch 5500, Loss: 24.5748


Epoch 0, Batch 5600, Loss: 27.9224


Epoch 0, Batch 5700, Loss: 26.7416


Epoch 0, Batch 5800, Loss: 22.2228


Epoch 0, Batch 5900, Loss: 32.3010


Epoch 0, Batch 6000, Loss: 26.7145


Epoch 0, Batch 6100, Loss: 19.5875


Epoch 0, Batch 6200, Loss: 26.6679


Epoch 0, Batch 6300, Loss: 20.6131


Epoch 0, Batch 6400, Loss: 24.1470


Epoch 0, Batch 6500, Loss: 43.0464


Epoch 0, Batch 6600, Loss: 23.6138


Epoch 0, Batch 6700, Loss: 33.2187


Epoch 0, Batch 6800, Loss: 31.4542


Epoch 0, Batch 6900, Loss: 48.9673


Epoch 0, Batch 7000, Loss: 18.6418


Epoch 0, Batch 7100, Loss: 19.9253


Epoch 0, Batch 7200, Loss: 20.0101


Epoch 0, Batch 7300, Loss: 24.2168


Epoch 0, Batch 7400, Loss: 24.3222


Epoch 0, Batch 7500, Loss: 23.3536


Epoch 0, Batch 7600, Loss: 15.1390


  return torch._native_multi_head_attention(


Epoch 0: Train Loss: 40.2791, Val Loss: 21.1765


Epoch 1, Batch 0, Loss: 17.1732


Epoch 1, Batch 100, Loss: 21.8002


Epoch 1, Batch 200, Loss: 24.5861


Epoch 1, Batch 300, Loss: 23.8503


Epoch 1, Batch 400, Loss: 20.5135


Epoch 1, Batch 500, Loss: 26.7370


Epoch 1, Batch 600, Loss: 22.5683


Epoch 1, Batch 700, Loss: 18.0010


Epoch 1, Batch 800, Loss: 23.5256


Epoch 1, Batch 900, Loss: 16.9863


Epoch 1, Batch 1000, Loss: 19.5289


Epoch 1, Batch 1100, Loss: 26.8444


Epoch 1, Batch 1200, Loss: 19.9961


Epoch 1, Batch 1300, Loss: 22.4784


Epoch 1, Batch 1400, Loss: 16.8311


Epoch 1, Batch 1500, Loss: 19.0872


Epoch 1, Batch 1600, Loss: 14.7055


Epoch 1, Batch 1700, Loss: 21.1102


Epoch 1, Batch 1800, Loss: 21.1353


Epoch 1, Batch 1900, Loss: 23.1818


Epoch 1, Batch 2000, Loss: 19.8985


Epoch 1, Batch 2100, Loss: 21.7862


Epoch 1, Batch 2200, Loss: 18.3147


Epoch 1, Batch 2300, Loss: 15.7734


Epoch 1, Batch 2400, Loss: 22.0062


Epoch 1, Batch 2500, Loss: 17.3161


Epoch 1, Batch 2600, Loss: 16.4670


Epoch 1, Batch 2700, Loss: 16.7988


Epoch 1, Batch 2800, Loss: 16.7807


Epoch 1, Batch 2900, Loss: 17.7512


Epoch 1, Batch 3000, Loss: 22.3167


Epoch 1, Batch 3100, Loss: 20.8188


Epoch 1, Batch 3200, Loss: 20.8859


Epoch 1, Batch 3300, Loss: 27.3661


Epoch 1, Batch 3400, Loss: 21.3691


Epoch 1, Batch 3500, Loss: 18.6868


Epoch 1, Batch 3600, Loss: 23.1170


Epoch 1, Batch 3700, Loss: 16.9425


Epoch 1, Batch 3800, Loss: 14.3871


Epoch 1, Batch 3900, Loss: 17.1178


Epoch 1, Batch 4000, Loss: 21.5461


Epoch 1, Batch 4100, Loss: 23.1548


Epoch 1, Batch 4200, Loss: 41.5235


Epoch 1, Batch 4300, Loss: 18.6489


Epoch 1, Batch 4400, Loss: 16.7808


Epoch 1, Batch 4500, Loss: 18.7303


Epoch 1, Batch 4600, Loss: 16.4518


Epoch 1, Batch 4700, Loss: 15.8878


Epoch 1, Batch 4800, Loss: 24.2032


Epoch 1, Batch 4900, Loss: 14.1714


Epoch 1, Batch 5000, Loss: 16.1331


Epoch 1, Batch 5100, Loss: 22.5586


Epoch 1, Batch 5200, Loss: 13.4151


Epoch 1, Batch 5300, Loss: 17.3466


Epoch 1, Batch 5400, Loss: 20.1315


Epoch 1, Batch 5500, Loss: 19.0750


Epoch 1, Batch 5600, Loss: 15.6237


Epoch 1, Batch 5700, Loss: 21.8624


Epoch 1, Batch 5800, Loss: 14.1650


Epoch 1, Batch 5900, Loss: 30.3432


Epoch 1, Batch 6000, Loss: 13.9581


Epoch 1, Batch 6100, Loss: 17.5412


Epoch 1, Batch 6200, Loss: 22.5188


Epoch 1, Batch 6300, Loss: 26.3464


Epoch 1, Batch 6400, Loss: 17.1010


Epoch 1, Batch 6500, Loss: 28.7525


Epoch 1, Batch 6600, Loss: 17.4764


Epoch 1, Batch 6700, Loss: 13.4077


Epoch 1, Batch 6800, Loss: 15.5014


Epoch 1, Batch 6900, Loss: 16.6863


Epoch 1, Batch 7000, Loss: 16.9892


Epoch 1, Batch 7100, Loss: 13.5321


Epoch 1, Batch 7200, Loss: 15.5099


Epoch 1, Batch 7300, Loss: 19.5060


Epoch 1, Batch 7400, Loss: 15.5160


Epoch 1, Batch 7500, Loss: 34.2877


Epoch 1, Batch 7600, Loss: 15.0499


Epoch 1: Train Loss: 20.4784, Val Loss: 17.5051


Epoch 2, Batch 0, Loss: 15.4753


Epoch 2, Batch 100, Loss: 14.1039


Epoch 2, Batch 200, Loss: 16.3207


Epoch 2, Batch 300, Loss: 17.6846


Epoch 2, Batch 400, Loss: 18.3656


Epoch 2, Batch 500, Loss: 16.2046


Epoch 2, Batch 600, Loss: 34.2905


Epoch 2, Batch 700, Loss: 19.7415


Epoch 2, Batch 800, Loss: 15.0691


Epoch 2, Batch 900, Loss: 12.8788


Epoch 2, Batch 1000, Loss: 14.7682


Epoch 2, Batch 1100, Loss: 16.4245


Epoch 2, Batch 1200, Loss: 14.8838


Epoch 2, Batch 1300, Loss: 13.3493


Epoch 2, Batch 1400, Loss: 14.5781


Epoch 2, Batch 1500, Loss: 13.1205


Epoch 2, Batch 1600, Loss: 22.3030


Epoch 2, Batch 1700, Loss: 17.7059


Epoch 2, Batch 1800, Loss: 15.0538


Epoch 2, Batch 1900, Loss: 17.1693


Epoch 2, Batch 2000, Loss: 13.9185


Epoch 2, Batch 2100, Loss: 20.2833


Epoch 2, Batch 2200, Loss: 14.8597


Epoch 2, Batch 2300, Loss: 16.7934


Epoch 2, Batch 2400, Loss: 18.2269


Epoch 2, Batch 2500, Loss: 27.1190


Epoch 2, Batch 2600, Loss: 15.4038


Epoch 2, Batch 2700, Loss: 14.5219


Epoch 2, Batch 2800, Loss: 23.4234


Epoch 2, Batch 2900, Loss: 15.0773


Epoch 2, Batch 3000, Loss: 17.5386


Epoch 2, Batch 3100, Loss: 16.3736


Epoch 2, Batch 3200, Loss: 13.4703


Epoch 2, Batch 3300, Loss: 20.9156


Epoch 2, Batch 3400, Loss: 14.7607


Epoch 2, Batch 3500, Loss: 12.5646


Epoch 2, Batch 3600, Loss: 27.0691


Epoch 2, Batch 3700, Loss: 13.4031


Epoch 2, Batch 3800, Loss: 25.1575


Epoch 2, Batch 3900, Loss: 25.4373


Epoch 2, Batch 4000, Loss: 17.7567


Epoch 2, Batch 4100, Loss: 12.9538


Epoch 2, Batch 4200, Loss: 16.0089


Epoch 2, Batch 4300, Loss: 16.5772


Epoch 2, Batch 4400, Loss: 14.0587


Epoch 2, Batch 4500, Loss: 32.7517


Epoch 2, Batch 4600, Loss: 13.2709


Epoch 2, Batch 4700, Loss: 14.3588


Epoch 2, Batch 4800, Loss: 16.5189


Epoch 2, Batch 4900, Loss: 14.4474


Epoch 2, Batch 5000, Loss: 19.2570


Epoch 2, Batch 5100, Loss: 18.0766


Epoch 2, Batch 5200, Loss: 13.6793


Epoch 2, Batch 5300, Loss: 12.4590


Epoch 2, Batch 5400, Loss: 16.4694


Epoch 2, Batch 5500, Loss: 10.8353


Epoch 2, Batch 5600, Loss: 14.8191


Epoch 2, Batch 5700, Loss: 17.5382


Epoch 2, Batch 5800, Loss: 15.1715


Epoch 2, Batch 5900, Loss: 13.0190


Epoch 2, Batch 6000, Loss: 14.3031


Epoch 2, Batch 6100, Loss: 15.9221


Epoch 2, Batch 6200, Loss: 11.3108


Epoch 2, Batch 6300, Loss: 13.6919


Epoch 2, Batch 6400, Loss: 45.6488


Epoch 2, Batch 6500, Loss: 12.7503


Epoch 2, Batch 6600, Loss: 14.9779


Epoch 2, Batch 6700, Loss: 12.8438


Epoch 2, Batch 6800, Loss: 13.5827


Epoch 2, Batch 6900, Loss: 16.2012


Epoch 2, Batch 7000, Loss: 18.0965


Epoch 2, Batch 7100, Loss: 12.6521


Epoch 2, Batch 7200, Loss: 16.7924


Epoch 2, Batch 7300, Loss: 21.0279


Epoch 2, Batch 7400, Loss: 17.5599


Epoch 2, Batch 7500, Loss: 16.4047


Epoch 2, Batch 7600, Loss: 13.1378


Epoch 2: Train Loss: 17.8612, Val Loss: 16.8254


Epoch 3, Batch 0, Loss: 17.2613


Epoch 3, Batch 100, Loss: 13.8637


Epoch 3, Batch 200, Loss: 41.9596


Epoch 3, Batch 300, Loss: 17.1019


Epoch 3, Batch 400, Loss: 22.9104


Epoch 3, Batch 500, Loss: 15.4420


Epoch 3, Batch 600, Loss: 18.0288


Epoch 3, Batch 700, Loss: 13.2786


Epoch 3, Batch 800, Loss: 17.4502


Epoch 3, Batch 900, Loss: 14.8604


Epoch 3, Batch 1000, Loss: 14.8848


Epoch 3, Batch 1100, Loss: 12.6338


Epoch 3, Batch 1200, Loss: 14.3903


Epoch 3, Batch 1300, Loss: 13.0701


Epoch 3, Batch 1400, Loss: 11.0658


Epoch 3, Batch 1500, Loss: 11.8913


Epoch 3, Batch 1600, Loss: 15.6665


Epoch 3, Batch 1700, Loss: 14.9468


Epoch 3, Batch 1800, Loss: 14.5236


Epoch 3, Batch 1900, Loss: 14.2184


Epoch 3, Batch 2000, Loss: 13.3806


Epoch 3, Batch 2100, Loss: 13.1135


Epoch 3, Batch 2200, Loss: 19.3192


Epoch 3, Batch 2300, Loss: 16.7609


Epoch 3, Batch 2400, Loss: 13.7104


Epoch 3, Batch 2500, Loss: 18.4701


Epoch 3, Batch 2600, Loss: 15.7973


Epoch 3, Batch 2700, Loss: 12.8249


Epoch 3, Batch 2800, Loss: 16.7564


Epoch 3, Batch 2900, Loss: 14.9889


Epoch 3, Batch 3000, Loss: 16.6362


Epoch 3, Batch 3100, Loss: 14.2015


Epoch 3, Batch 3200, Loss: 18.5524


Epoch 3, Batch 3300, Loss: 37.0325


Epoch 3, Batch 3400, Loss: 13.4351


Epoch 3, Batch 3500, Loss: 15.2981


Epoch 3, Batch 3600, Loss: 17.6919


Epoch 3, Batch 3700, Loss: 14.2128


Epoch 3, Batch 3800, Loss: 11.3127


Epoch 3, Batch 3900, Loss: 13.3254


Epoch 3, Batch 4000, Loss: 17.0329


Epoch 3, Batch 4100, Loss: 14.5835


Epoch 3, Batch 4200, Loss: 13.4548


Epoch 3, Batch 4300, Loss: 11.6700


Epoch 3, Batch 4400, Loss: 14.6481


Epoch 3, Batch 4500, Loss: 14.5698


Epoch 3, Batch 4600, Loss: 14.6056


Epoch 3, Batch 4700, Loss: 19.4614


Epoch 3, Batch 4800, Loss: 12.0908


Epoch 3, Batch 4900, Loss: 17.2701


Epoch 3, Batch 5000, Loss: 12.0242


Epoch 3, Batch 5100, Loss: 13.3648


Epoch 3, Batch 5200, Loss: 13.2970


Epoch 3, Batch 5300, Loss: 12.7031


Epoch 3, Batch 5400, Loss: 11.6245


Epoch 3, Batch 5500, Loss: 14.0384


Epoch 3, Batch 5600, Loss: 12.8646


Epoch 3, Batch 5700, Loss: 11.9054


Epoch 3, Batch 5800, Loss: 14.5173


Epoch 3, Batch 5900, Loss: 15.7527


Epoch 3, Batch 6000, Loss: 17.7814


Epoch 3, Batch 6100, Loss: 12.3482


Epoch 3, Batch 6200, Loss: 12.8048


Epoch 3, Batch 6300, Loss: 13.2478


Epoch 3, Batch 6400, Loss: 13.9211


Epoch 3, Batch 6500, Loss: 13.1343


Epoch 3, Batch 6600, Loss: 11.1831


Epoch 3, Batch 6700, Loss: 15.8975


Epoch 3, Batch 6800, Loss: 12.7651


Epoch 3, Batch 6900, Loss: 12.3752


Epoch 3, Batch 7000, Loss: 12.1223


Epoch 3, Batch 7100, Loss: 14.9092


Epoch 3, Batch 7200, Loss: 13.3436


Epoch 3, Batch 7300, Loss: 15.8404


Epoch 3, Batch 7400, Loss: 11.8205


Epoch 3, Batch 7500, Loss: 13.3713


Epoch 3, Batch 7600, Loss: 15.6929


Epoch 3: Train Loss: 16.8377, Val Loss: 15.2892


Epoch 4, Batch 0, Loss: 15.7331


Epoch 4, Batch 100, Loss: 12.0217


Epoch 4, Batch 200, Loss: 17.6510


Epoch 4, Batch 300, Loss: 31.3657


Epoch 4, Batch 400, Loss: 12.8446


Epoch 4, Batch 500, Loss: 11.6847


Epoch 4, Batch 600, Loss: 14.6237


Epoch 4, Batch 700, Loss: 15.3939


Epoch 4, Batch 800, Loss: 12.8523


Epoch 4, Batch 900, Loss: 13.2384


Epoch 4, Batch 1000, Loss: 24.2960


Epoch 4, Batch 1100, Loss: 13.0594


Epoch 4, Batch 1200, Loss: 11.8682


Epoch 4, Batch 1300, Loss: 16.5752


Epoch 4, Batch 1400, Loss: 14.2798


Epoch 4, Batch 1500, Loss: 11.3614


Epoch 4, Batch 1600, Loss: 12.5872


Epoch 4, Batch 1700, Loss: 13.4081


Epoch 4, Batch 1800, Loss: 14.6948


Epoch 4, Batch 1900, Loss: 13.3989


Epoch 4, Batch 2000, Loss: 14.9995


Epoch 4, Batch 2100, Loss: 16.8604


Epoch 4, Batch 2200, Loss: 17.8864


Epoch 4, Batch 2300, Loss: 14.9007


Epoch 4, Batch 2400, Loss: 17.6734


Epoch 4, Batch 2500, Loss: 13.9022


Epoch 4, Batch 2600, Loss: 14.3293


Epoch 4, Batch 2700, Loss: 14.8316


Epoch 4, Batch 2800, Loss: 15.4855


Epoch 4, Batch 2900, Loss: 13.2258


Epoch 4, Batch 3000, Loss: 13.8927


Epoch 4, Batch 3100, Loss: 14.7005


Epoch 4, Batch 3200, Loss: 11.8047


Epoch 4, Batch 3300, Loss: 12.3235


Epoch 4, Batch 3400, Loss: 24.5513


Epoch 4, Batch 3500, Loss: 12.0089


Epoch 4, Batch 3600, Loss: 16.6032


Epoch 4, Batch 3700, Loss: 15.2886


Epoch 4, Batch 3800, Loss: 11.1569


Epoch 4, Batch 3900, Loss: 14.8612


Epoch 4, Batch 4000, Loss: 18.1743


Epoch 4, Batch 4100, Loss: 35.2616


Epoch 4, Batch 4200, Loss: 19.0604


Epoch 4, Batch 4300, Loss: 14.6015


Epoch 4, Batch 4400, Loss: 18.6071


Epoch 4, Batch 4500, Loss: 16.7397


Epoch 4, Batch 4600, Loss: 13.6383


Epoch 4, Batch 4700, Loss: 13.4227


Epoch 4, Batch 4800, Loss: 12.9350


Epoch 4, Batch 4900, Loss: 14.5204


Epoch 4, Batch 5000, Loss: 11.9586


Epoch 4, Batch 5100, Loss: 14.3951


Epoch 4, Batch 5200, Loss: 16.5009


Epoch 4, Batch 5300, Loss: 12.5217


Epoch 4, Batch 5400, Loss: 21.7938


Epoch 4, Batch 5500, Loss: 37.3140


Epoch 4, Batch 5600, Loss: 15.0349


Epoch 4, Batch 5700, Loss: 13.5817


Epoch 4, Batch 5800, Loss: 12.3202


Epoch 4, Batch 5900, Loss: 11.1569


Epoch 4, Batch 6000, Loss: 12.5821


Epoch 4, Batch 6100, Loss: 12.4626


Epoch 4, Batch 6200, Loss: 14.2471


Epoch 4, Batch 6300, Loss: 14.9636


Epoch 4, Batch 6400, Loss: 11.5480


Epoch 4, Batch 6500, Loss: 11.6109
