In [None]:
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

def train_0(model, train_data_loader, val_data_loader, loss_fn, optimizer, num_epochs, scheduler=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    iters = len(train_data_loader)

    # Lists to store metrics for plotting
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        model.train()
        batch_idx = 0
        for inputs, targets in tqdm(train_data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Calculate the loss
            loss = loss_fn(outputs, targets.long())

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            if scheduler != None:
                scheduler.step(epoch + batch_idx / iters)

            batch_idx += 1

        # Validate on train set
        train_loss, train_accuracy, train_iou, train_f1, train_precision, train_recall = validate(model, train_data_loader, loss_fn)
        print(f"Epoch {epoch + 1} | Train Loss:   {train_loss:.4f} | Train Accuracy:   {train_accuracy:.4f}% | Train mIOU:   {train_iou:.4f} | Train mF1:   {train_f1:.4f} | Train Precision:   {train_precision:.4f} | Train Recall:   {train_recall:.4f}")

        # Validate on validation set
        if val_data_loader is not None:
            val_loss, val_accuracy, val_iou, val_f1, val_precision, val_recall = validate(model, val_data_loader, loss_fn)
            print(f"Epoch {epoch + 1} | Val Loss:   {val_loss:.4f} | Val Accuracy:   {val_accuracy:.4f}% | Val mIOU:   {val_iou:.4f} | Val mF1:   {val_f1:.4f} | Val Precision:   {val_precision:.4f} | Val Recall:   {val_recall:.4f}")
            
        else:
            val_loss, val_accuracy = None, None

        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)

    # Plotting
    plt.figure(figsize=(12, 5))

    # Plot train and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
    if val_data_loader is not None:
        plt.plot(range(1, num_epochs + 1), val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss vs. Epochs')
    plt.legend()

    # Plot train and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy')
    if val_data_loader is not None:
        plt.plot(range(1, num_epochs + 1), val_accuracies, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy vs. Epochs')
    plt.legend()

    plt.tight_layout()
    plt.show()
