In [10]:
import torch

from copy import deepcopy
from matplotlib import pyplot as plt

In [5]:
def train(model, loader):
    running_loss = 0.
    
    model.train()
    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad() 
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    return running_loss / len(loader)

In [7]:
def evaluate(model, loader):
    running_loss = 0.
    
    model.eval()
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item()
            
    return running_loss / len(loader)

In [3]:
def train_loop_with_validation(model, num_epoch, train_loader, val_loader):
    best_model, best_loss = None, float('inf')
    
    train_losses, val_losses = [], []

    for epoch in range(num_epoch):
        print(f"Epoch {epoch + 1}/{num_epoch} started...")

        train_loss = train(model, train_loader)
        validation_loss = evaluate(model, val_loader)

        if validation_loss < best_loss:
            best_loss = validation_loss
            best_model = deepcopy(model)
            
        train_losses.append(train_loss)
        val_losses.append(validation_loss)

        print(f'Training loss: {train_loss}')
        print(f'Validation loss: {validation_loss}')
        print('...........................................')
        
    plt.plot(range(num_epoch), train_losses, label='Train loss')
    plt.plot(range(num_epoch), val_losses, label='Validation loss')
    plt.legend()
    plt.show()
        
    return best_model, best_loss

In [4]:
def train_loop(model, num_epoch, loader):
    best_model, best_loss = None, float('inf')
    
    train_losses = []

    for epoch in range(num_epoch):
        print(f"Epoch {epoch + 1}/{num_epoch} started...")

        train_loss = train(model, train_loader)

        if train_loss < best_loss:
            best_loss = train_loss
            best_model = deepcopy(model)
            
        train_losses.append(train_loss)

        print(f'Training loss: {train_loss}')
        print('...........................................')
        
    plt.plot(range(num_epoch), train_losses, label='Train loss')
    plt.legend()
    plt.show()
        
    return best_model, best_loss