# Training Loop Example

This notebook demonstrates how to implement a custom trainer class to simplify model training setup.

In [1]:
%load_ext tensorboard
%tensorboard --logdir logs --port 6006 --bind_all

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import time
import numpy as np
import datetime

# --- Setup (Hyperparameters, Model, Data, etc.) ---

# Hyperparameters
epochs = 10
learning_rate = 0.001
batch_size = 64
num_features = 10
num_classes = 5

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.layer1 = nn.Linear(num_features, 128)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(128, num_classes)

    def forward(self, x):
        return self.layer2(self.relu(self.layer1(x)))

# Create dummy data and DataLoaders
# In a real scenario, you would load your actual dataset here
X_train = torch.randn(1000, num_features)
y_train = torch.randint(0, num_classes, (1000,))
X_val = torch.randn(200, num_features)
y_val = torch.randint(0, num_classes, (200,))

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size)

# Instantiate model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

2025-10-13 03:39:34.710479: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760326774.722870   14534 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760326774.726738   14534 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-13 03:39:34.738700: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class Trainer:

    def __init__(
        self, model, optimizer, criterion, train_dataloader, device, val_dataloader=None, log_dir=None
    ):
        self._model = model
        self._optimizer = optimizer
        self._criterion = criterion
        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader
        self._device = device
        self._log_dir = log_dir

    def _train_step(self, epoch, epochs, writer):
        """Training phase in an epoch."""
        self._model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
    
        train_pbar = tqdm(self._train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for i, (inputs, labels) in enumerate(train_pbar):
            inputs, labels = inputs.to(self._device), labels.to(self._device)
            
            # Forward pass
            outputs = self._model(inputs)
            loss = self._criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()  # clear previous gradients
            loss.backward()        # compute the gradient of the loss
            optimizer.step()       # update the model parameters
            
            # Logging batch-level metrics
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
    
            # Update tqdm progress bar description
            train_pbar.set_postfix({'loss': loss.item()})
            
            # Log batch loss to TensorBoard
            global_step = epoch * len(train_loader) + i
            writer.add_scalar('Loss/train_batch', loss.item(), global_step)
        
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = 100 * train_correct / train_total
        print(f"Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_accuracy:.2f}%")
        
        writer.add_scalar('Loss/train_epoch', avg_train_loss, epoch)
        writer.add_scalar('Accuracy/train_epoch', train_accuracy, epoch)

    def _validation_step(self, epoch, epochs, writer):
        """Validation phase in an epoch."""
        self._model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            val_pbar = tqdm(self._val_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
            for inputs, labels in val_pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = self._criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                val_pbar.set_postfix({'loss': loss.item()})
        
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Validation Loss: {avg_val_loss:.4f} | Val Accuracy: {val_accuracy:.2f}%")
        
        writer.add_scalar('Loss/validation_epoch', avg_val_loss, epoch)
        writer.add_scalar('Accuracy/validation_epoch', val_accuracy, epoch)

    def train(self, epochs):
        """Main method to start the model training."""
        print(f"Starting training on {self._device}...")
        log_dir = self._log_dir
        if log_dir is None:
            log_dir = f"logs/fit/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}"
        writer = SummaryWriter(log_dir)
        print(f"Saving logs to {log_dir}")

        train_start_t = time.time()
        for epoch in range(epochs):
            # --- Training Phase ---
            epoch_start_t = time.time()
            self._train_step(epoch, epochs, writer)
            train_epoch_stop_t = time.time()
            
            # --- Validation Phase ---
            if self._val_dataloader:
                self._validation_step(epoch, epochs, writer)
            epoch_end_t = time.time()
        
            # --- End-of-Epoch Logging ---
            train_duration = train_epoch_stop_t - epoch_start_t
            val_duration = epoch_end_t - train_epoch_stop_t
            epoch_duration = epoch_end_t - epoch_start_t
            val_log = f"Validation Time: {val_duration:.2f}s | " if self._val_dataloader else ""
            print(
                f"Epoch {epoch+1}/{epochs} | "
                f"Train Time: {train_duration:.2f}s | {val_log}"
                f"Epoch Time: {epoch_duration:.2f}s")
            writer.add_scalar('Learning_Rate', self._optimizer.param_groups[0]['lr'], epoch)

        total_train_duration = time.time() - train_start_t
        print(f"Total Training Time: {str(datetime.timedelta(seconds=total_train_duration))} | {total_train_duration:.2f}s")
        print("Training complete.")
              
        writer.close()

In [4]:
trainer = Trainer(model, optimizer, criterion, train_loader, device, val_dataloader=val_loader)
trainer.train(30)

Starting training on cuda...
Saving logs to logs/fit/20251013-033937


Epoch 1/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 147.56it/s, loss=1.61]


Train Loss: 1.6274 | Train Accuracy: 21.20%


Epoch 1/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 796.45it/s, loss=1.65]


Validation Loss: 1.6196 | Val Accuracy: 20.00%
Epoch 1/30 | Train Time: 0.11s | Validation Time: 0.01s | Epoch Time: 0.12s


Epoch 2/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 640.08it/s, loss=1.6]


Train Loss: 1.6068 | Train Accuracy: 23.10%


Epoch 2/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1150.78it/s, loss=1.62]


Validation Loss: 1.6142 | Val Accuracy: 22.50%
Epoch 2/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 3/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 650.83it/s, loss=1.58]


Train Loss: 1.5972 | Train Accuracy: 23.40%


Epoch 3/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1158.49it/s, loss=1.64]


Validation Loss: 1.6166 | Val Accuracy: 22.00%
Epoch 3/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 4/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 570.34it/s, loss=1.58]


Train Loss: 1.5895 | Train Accuracy: 27.00%


Epoch 4/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 933.57it/s, loss=1.63]


Validation Loss: 1.6192 | Val Accuracy: 22.50%
Epoch 4/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.04s


Epoch 5/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 646.25it/s, loss=1.58]


Train Loss: 1.5842 | Train Accuracy: 27.30%


Epoch 5/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1134.75it/s, loss=1.63]


Validation Loss: 1.6215 | Val Accuracy: 22.00%
Epoch 5/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 6/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 642.20it/s, loss=1.58]


Train Loss: 1.5785 | Train Accuracy: 28.60%


Epoch 6/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1162.90it/s, loss=1.64]


Validation Loss: 1.6256 | Val Accuracy: 19.50%
Epoch 6/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 7/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 657.95it/s, loss=1.56]


Train Loss: 1.5734 | Train Accuracy: 28.00%


Epoch 7/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1164.60it/s, loss=1.63]


Validation Loss: 1.6212 | Val Accuracy: 21.00%
Epoch 7/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 8/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 630.09it/s, loss=1.58]


Train Loss: 1.5704 | Train Accuracy: 28.20%


Epoch 8/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1169.14it/s, loss=1.62]


Validation Loss: 1.6257 | Val Accuracy: 21.00%
Epoch 8/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 9/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 620.96it/s, loss=1.56]


Train Loss: 1.5660 | Train Accuracy: 29.10%


Epoch 9/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1020.14it/s, loss=1.63]


Validation Loss: 1.6279 | Val Accuracy: 21.00%
Epoch 9/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 10/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 564.11it/s, loss=1.54]


Train Loss: 1.5611 | Train Accuracy: 29.10%


Epoch 10/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1141.85it/s, loss=1.63]


Validation Loss: 1.6320 | Val Accuracy: 22.00%
Epoch 10/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.04s


Epoch 11/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 637.17it/s, loss=1.59]


Train Loss: 1.5593 | Train Accuracy: 29.20%


Epoch 11/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1161.94it/s, loss=1.63]


Validation Loss: 1.6310 | Val Accuracy: 21.00%
Epoch 11/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 12/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 660.16it/s, loss=1.57]


Train Loss: 1.5538 | Train Accuracy: 30.10%


Epoch 12/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1149.20it/s, loss=1.63]


Validation Loss: 1.6345 | Val Accuracy: 20.50%
Epoch 12/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 13/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 640.90it/s, loss=1.48]


Train Loss: 1.5488 | Train Accuracy: 31.40%


Epoch 13/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1161.13it/s, loss=1.62]


Validation Loss: 1.6314 | Val Accuracy: 19.00%
Epoch 13/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 14/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 644.62it/s, loss=1.51]


Train Loss: 1.5457 | Train Accuracy: 30.50%


Epoch 14/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1165.49it/s, loss=1.63]


Validation Loss: 1.6321 | Val Accuracy: 20.00%
Epoch 14/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 15/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 661.05it/s, loss=1.53]


Train Loss: 1.5423 | Train Accuracy: 31.00%


Epoch 15/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1147.87it/s, loss=1.62]


Validation Loss: 1.6312 | Val Accuracy: 20.50%
Epoch 15/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 16/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 638.99it/s, loss=1.56]


Train Loss: 1.5386 | Train Accuracy: 31.60%


Epoch 16/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1109.09it/s, loss=1.61]


Validation Loss: 1.6305 | Val Accuracy: 21.50%
Epoch 16/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 17/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 662.30it/s, loss=1.5]


Train Loss: 1.5344 | Train Accuracy: 32.10%


Epoch 17/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1169.63it/s, loss=1.61]


Validation Loss: 1.6330 | Val Accuracy: 20.00%
Epoch 17/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 18/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 635.61it/s, loss=1.58]


Train Loss: 1.5333 | Train Accuracy: 32.30%


Epoch 18/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1164.28it/s, loss=1.61]


Validation Loss: 1.6330 | Val Accuracy: 21.50%
Epoch 18/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 19/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 662.46it/s, loss=1.56]


Train Loss: 1.5284 | Train Accuracy: 32.60%


Epoch 19/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1161.05it/s, loss=1.61]


Validation Loss: 1.6331 | Val Accuracy: 21.50%
Epoch 19/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 20/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 633.54it/s, loss=1.49]


Train Loss: 1.5231 | Train Accuracy: 33.20%


Epoch 20/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1167.76it/s, loss=1.59]


Validation Loss: 1.6295 | Val Accuracy: 18.50%
Epoch 20/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 21/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 665.15it/s, loss=1.56]


Train Loss: 1.5221 | Train Accuracy: 33.60%


Epoch 21/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1112.33it/s, loss=1.6]


Validation Loss: 1.6324 | Val Accuracy: 20.00%
Epoch 21/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 22/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 599.42it/s, loss=1.52]


Train Loss: 1.5165 | Train Accuracy: 33.90%


Epoch 22/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1161.29it/s, loss=1.59]


Validation Loss: 1.6326 | Val Accuracy: 20.00%
Epoch 22/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 23/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 643.53it/s, loss=1.45]


Train Loss: 1.5112 | Train Accuracy: 34.10%


Epoch 23/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1172.00it/s, loss=1.6]


Validation Loss: 1.6327 | Val Accuracy: 22.00%
Epoch 23/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 24/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 632.93it/s, loss=1.5]


Train Loss: 1.5089 | Train Accuracy: 35.10%


Epoch 24/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1157.93it/s, loss=1.59]


Validation Loss: 1.6355 | Val Accuracy: 19.00%
Epoch 24/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 25/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 545.15it/s, loss=1.48]


Train Loss: 1.5057 | Train Accuracy: 35.30%


Epoch 25/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1154.66it/s, loss=1.59]


Validation Loss: 1.6319 | Val Accuracy: 22.00%
Epoch 25/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.04s


Epoch 26/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 649.94it/s, loss=1.53]


Train Loss: 1.5038 | Train Accuracy: 34.70%


Epoch 26/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1158.41it/s, loss=1.58]


Validation Loss: 1.6331 | Val Accuracy: 21.50%
Epoch 26/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s


Epoch 27/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 634.04it/s, loss=1.5]


Train Loss: 1.4991 | Train Accuracy: 36.00%


Epoch 27/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1156.97it/s, loss=1.59]


Validation Loss: 1.6402 | Val Accuracy: 21.00%
Epoch 27/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 28/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 595.12it/s, loss=1.49]


Train Loss: 1.4959 | Train Accuracy: 36.00%


Epoch 28/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1056.57it/s, loss=1.59]


Validation Loss: 1.6348 | Val Accuracy: 20.00%
Epoch 28/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 29/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 653.90it/s, loss=1.47]


Train Loss: 1.4906 | Train Accuracy: 35.60%


Epoch 29/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1139.83it/s, loss=1.59]


Validation Loss: 1.6396 | Val Accuracy: 21.00%
Epoch 29/30 | Train Time: 0.03s | Validation Time: 0.01s | Epoch Time: 0.03s


Epoch 30/30 [Train]: 100%|██████████| 16/16 [00:00<00:00, 580.44it/s, loss=1.41]


Train Loss: 1.4857 | Train Accuracy: 35.30%


Epoch 30/30 [Val]: 100%|██████████| 4/4 [00:00<00:00, 1314.11it/s, loss=1.58]

Validation Loss: 1.6383 | Val Accuracy: 20.00%
Epoch 30/30 | Train Time: 0.03s | Validation Time: 0.00s | Epoch Time: 0.03s
Total Training Time: 0:00:01.045402 | 1.05s
Training complete.



