In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
class Trainer:
    def __init__(self, model, processor, data_loaders, criterion, optimizer, device=None):
        """
        Initialize the DepthEstimationTrainer.

        Args:
            model: PyTorch model to train.
            processor: Processor for data preprocessing and postprocessing.
            data_loaders: Tuple of DataLoaders for training, validation, and testing.
            criterion: Loss function for training and evaluation.
            optimizer: Optimizer for the model.
            device: Device for computation (default: auto-detect).
        """

        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.processor = processor
        self.train_dl, self.val_dl, self.test_dl = data_loaders
        self.criterion = criterion
        self.optimizer = optimizer

        self.train_losses = []
        self.val_losses = []

    def train_one_epoch(self):
        """Train the model for one epoch."""
        self.model.train()
        running_loss = 0.0

        for batch in tqdm(self.train_dl, desc="Training"):
            imgs, depths = batch
            depths = depths.to(self.device)

            self.optimizer.zero_grad()

            inputs = self.processor.preprocess(imgs).to(self.device)
            outputs = self.model(inputs)
            preds = self.processor.postprocess(inputs, outputs)

            loss = self.criterion(preds, depths)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()

        return running_loss / len(self.train_dl)

    def evaluate(self, dataloader):
        """Evaluate the model."""
        self.model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch in dataloader:
                imgs, depths = batch
                depths = depths.to(self.device)

                inputs = self.processor.preprocess(imgs).to(self.device)
                outputs = self.model(inputs)
                preds = self.processor.postprocess(inputs, outputs)

                loss = self.criterion(preds, depths)
                val_loss += loss.item()

        return val_loss / len(dataloader)

    def abs_rel_difference(self, preds, depths):
        """Calculate Absolute Relative Difference (AbsRel)."""
        abs_rel = torch.mean(torch.abs(preds - depths) / depths)
        return abs_rel.item()

    def train(self, epochs):
        """Train the model for multiple epochs."""
        plt.ion()
        self.figure, self.ax = plt.subplots(figsize=(10, 6))

        for epoch in range(epochs):
            train_loss = self.train_one_epoch()
            val_loss = self.evaluate(self.val_dl)

            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)

            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.update_learning_curve(epoch)

            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    def test(self):
        """Test the model on the test set."""
        self.model.eval()
        total = 0.0

        with torch.no_grad():
            for batch in self.test_dl:
                imgs, depths = batch
                depths = depths.to(self.device)

                inputs = self.processor.preprocess(imgs).to(self.device)
                outputs = self.model(inputs)
                preds = self.processor.postprocess(inputs, outputs)
                total += self.abs_rel_difference(preds, depths)

        print(f"Test AbsRel: {total / len(self.test_dl):.4f}")

    def plot_learning_curve(self):
        """Final plot of the learning curve."""
        plt.ioff()
        plt.figure(figsize=(10, 6))
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Learning Curve')
        plt.legend()
        plt.grid(True)
        plt.show()

    def update_learning_curve(self, epoch):
        """Update the learning curve in real-time."""
        self.ax.clear()
        self.ax.plot(self.train_losses, label='Train Loss')
        self.ax.plot(self.val_losses, label='Validation Loss')
        self.ax.set_xlabel('Epochs')
        self.ax.set_ylabel('Loss')
        self.ax.set_title(f'Learning Curve (Epoch {epoch + 1})')
        self.ax.legend()
        self.ax.grid(True)
        plt.pause(0.1)

    def visualize(self, num_samples=3):
        """Visualize predictions on the test set."""
        self.model.eval()
        return


In [None]:
criterion = nn.MSELoss()

learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs = 10

In [None]:
trainer = Trainer(model, processor, (train_dl, val_dl, test_dl), criterion, optimizer)
trainer.train(epochs)