# Entrenamiento

In [None]:
import torch

## Entrenamiento sin validación

### Descenso del gradiente manual

In [None]:
def GradientDescent(model, params, data, loss_fn, lr, epochs=50):
    
    for epoch in range(epochs):
        
        # optimizer.zero_grad():
        if params.grad is not None:
            params.grad.zero_()
        
        x, y = data
        output = model(x, *params) 
        loss = loss_fn(output, y)
        loss.backward()
        
        # optimizer.step():
        with torch.no_grad():
            params -= lr * params.grad

        if epoch % 5 == 0:
            print(f'Epoch: {epoch+1} - Loss: {loss.item()}')

## Entrenamiento con validación

### Entrenamiento con data directa

In [None]:
def train_model(model, data, optimizer, loss_fn, epochs=50):
    
    x_train, x_val, y_train, y_val = data  
    
    for epoch in range(1, epochs + 1):
        train_output = model(x_train)
        train_loss = loss_fn(train_output, y_train)
        
        with torch.no_grad(): 
            val_output = model(x_val)
            val_loss = loss_fn(val_output, y_val)
        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        if epoch <= 3 or epoch % 500 == 0:
            print(f'Epoch {epoch}',
                  f'Training loss: {train_loss.item():.4f}',
                  f'Validation loss: {val_loss.item():.4f}')

### Entrenamiento simple con `DataLoader`

In [None]:
def train_model(net, dataloaders, optimizer, loss_fn, epochs=10, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        
    net.to(device)

    for epoch in range(1, epochs+1):
        print(f'Epoch {epoch}/{epochs}.')

        for mode in ('train', 'val'):
            net.train(mode == 'train')

            accumulated_loss, corrects = 0.0, 0.0
            with torch.set_grad_enabled(mode == 'train'):
                for x, y in dataloaders[mode]:
                    x, y = x.to(device), y.to(device)
                    output = net(x)
                    loss = loss_fn(output, y)
                    if mode == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                    accumulated_loss += loss.item() * x.size(0)
                    label_pred = torch.argmax(output, dim=1)
                    corrects += torch.sum(label_pred == y).item()
    
            epoch_loss = accumulated_loss / len(dataloaders[mode].dataset)
            epoch_accuracy = corrects / len(dataloaders[mode].dataset) * 100
            print(f'- {mode:5} | loss: {epoch_loss:.4f} -  accuracy: {epoch_accuracy:.2f}%')

### Clase de entrenamiento

In [None]:
import torch
import torch.nn as nn
from torch.optim import Optimizer, lr_scheduler
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from sklearn import metrics

import time
import os
import copy

class Trainer:

    def __init__(self,
                 net: nn.Module,
                 optimizer: Optimizer,
                 scheduler: lr_scheduler._LRScheduler = None):

        # PyTorch objects:
        self.net = net
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        # Default configuration:
        self.loss_fn = nn.CrossEntropyLoss()
        self.metric = metrics.accuracy_score
        self.metric_criterion = 'max'  # better models have higher metric value.
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.least_improvement = 0.01  # the new metric value must be at least 1% better to update parameters.
        self.verbose_step = 1

        # Training stats:
        self.losses = {'train': [], 'val': []}
        self.metric_values = {'train': [], 'val': []}
        self.best_states = {
            'net': None,
            'optimizer': None,
            'scheduler': None,
            'epoch': 0,
            'metric_value': float('-inf')
            }
        self.training_time = 0

    def save_training(self, file: str) -> None:
        states = dict()
        for attr_key in vars(self).keys():
            attr_value = getattr(self, attr_key)
            if attr_key in ('net', 'optimizer', 'scheduler'):  # for PyTorch objects, only the states will be saved.
                states[attr_key] = attr_value.state_dict() if attr_value else attr_value  # object may not be defined.
            else:
                states[attr_key] = attr_value
        torch.save(states, file)

    def load_training(self, file: str) -> None:
        states = torch.load(file)
        for attr_key in vars(self).keys():
            if attr_key in ('net', 'optimizer', 'scheduler'):
                attr_value = getattr(self, attr_key)
                if attr_value and states[attr_key]:  # the attribute is defined and there is an associated saved state.
                    attr_value.load_state_dict(states[attr_key])
            else:
                attr_value = states[attr_key]
            setattr(self, attr_key, attr_value)

    def last_epoch(self) -> int:
        return len(self.losses['train'])

    def plot_training(self) -> None:
        
        if self.best_states['epoch'] == 0:
            print('No training record.')
            return
        
        metric_name = self.metric.__name__.removesuffix('_score')
        total_epochs = self.last_epoch()
        stats = {'Loss': self.losses, metric_name.capitalize(): self.metric_values}
        epochs = list(range(1, total_epochs + 1))

        fig, ax = plt.subplots(1, len(stats), figsize=(15, 5))
        for idx, key in enumerate(stats.keys()):
            ax[idx].plot(epochs, stats[key]['train'], label='train')
            ax[idx].plot(epochs, stats[key]['val'], label='validation')
            ax[idx].axvline(x=self.best_states['epoch'], color='r', linestyle='--', 
                            label=f'best {metric_name}', alpha=0.8)
            ax[idx].set_title(f'Epochs vs. {key}')
            ax[idx].set_xlabel('epoch')
            ax[idx].set_ylabel(key)
            ax[idx].legend()
            # Grid:
            ax[idx].grid()
            plt_num_ticks = 15  # max number of epochs ticks on the plot.
            epoch_nums = list(range(total_epochs, 0, (total_epochs//-plt_num_ticks)))
            ax[idx].set_xticks(epoch_nums[::-1], minor=False)
            ax[idx].xaxis.grid(True, which='major', linestyle='--')
            if key == metric_name.capitalize():
                ax[idx].set_ylim([0, 1.1])
                ax[idx].set_yticks([i/10 for i in range(0,11)], minor=False)
                ax[idx].yaxis.grid(True, which='major', linestyle='--')
        fig.tight_layout()

    def train(self, dataloaders: dict, epochs: int, early_stopping: int = None) -> None:

        # Progress bar (only for training mode):
        initial_epoch = self.last_epoch()
        def progress_bar(dataloader, epoch, mode):
            if mode == 'train':
                bar_format = f'Epoch {initial_epoch + epoch}/{initial_epoch + epochs} ' \
                             '{l_bar}{bar}| batch {n_fmt} of {total_fmt} ({rate_fmt}) ' \
                             '- training time: {elapsed}'
                return tqdm(dataloader, unit=' batches', bar_format=bar_format)
            else:
                return dataloader
        
        # Training:
        self.net.to(self.device)
        metric_name = self.metric.__name__.removesuffix('_score')
        initial_training_time = time.time()
        early_stopping_patience = 0
        try:
            for epoch in range(1, epochs + 1):

                for mode in ('train', 'val'):
                    self.net.train(mode == 'train')
                    dataloader = progress_bar(dataloaders[mode], epoch, mode)

                    # Forward and backward:
                    epoch_loss = 0.0
                    dataloader_labels = {'true': [], 'pred': []}
                    with torch.set_grad_enabled(mode == 'train'):
                        for x, y in dataloader:
                            x, y = x.to(self.device), y.to(self.device)
                            output = self.net(x)
                            loss = self.loss_fn(output, y)

                            if mode == 'train':
                                self.optimizer.zero_grad()
                                loss.backward()
                                self.optimizer.step()
                            
                            epoch_loss += loss * x.size(0)
                            y_pred = torch.argmax(output, dim=1)
                            
                            # Dataloader labels for metric evaluation:
                            dataloader_labels['true'] += y.tolist()
                            dataloader_labels['pred'] += y_pred.tolist()

                    if mode == 'train' and self.scheduler:
                        self.scheduler.step()
                    
                    # Metrics backup:
                    epoch_loss /= len(dataloaders[mode].dataset)
                    epoch_metric = self.metric(dataloader_labels['true'],
                                               dataloader_labels['pred'])
                    self.losses[mode].append(epoch_loss.item())
                    self.metric_values[mode].append(epoch_metric)

                    # Verbose:
                    if self.verbose_step and epoch % self.verbose_step == 0:
                        print(f'- {mode:5} | loss: {epoch_loss:.4f}',
                              f'- {metric_name}: {epoch_metric:.4f}')

                # Best model update:
                sgn = 1 if self.metric_criterion == 'max' else -1  # maximize (+1) or minimize (-1) the metric.
                improvement_threshold = self.best_states['metric_value'] \
                                        * (1 + sgn * self.least_improvement)
                if sgn * epoch_metric >= sgn * improvement_threshold:
                    early_stopping_patience = 0 
                    for key in ('net', 'optimizer', 'scheduler'):
                        value = getattr(self, key)
                        state = value.state_dict() if value else value
                        self.best_states[key] = copy.deepcopy(state)
                    self.best_states['epoch'] = epoch
                    self.best_states['metric_value'] = epoch_metric
                else:
                    early_stopping_patience += 1
                    if early_stopping_patience == early_stopping:
                        print(f'\nEarly stopping criterion reached.',
                              f'The model improved less than {self.least_improvement * 100}%',
                              f'in the last {early_stopping} epochs.')
                        break

        except KeyboardInterrupt:
            print('Interrupted training. Last epoch stats will not be saved.')
            # Clip last stats to the shortest one:
            last_valid_epoch = len(self.metric_values['val'])
            for mode in ('train', 'val'):
                self.losses[mode] = self.losses[mode][:last_valid_epoch]
                self.metric_values[mode] = self.metric_values[mode][:last_valid_epoch]

        # Post-training:
        
        if self.last_epoch() == 0:
            return  # nothing will print.
        
        current_training_time = time.time() - initial_training_time
        self.training_time += current_training_time

        if self.verbose_step:
            print(f'\nBest {metric_name} value reached at epoch {self.best_states["epoch"]}',
                    f'with a value of {self.best_states["metric_value"]}.')
            formatted_time = lambda t: time.strftime('%-H h, %-M min, %-S sec', # add %-j for days.
                                                     time.gmtime(t))
            print(f'Training time: {formatted_time(current_training_time)}.')
            if self.training_time > current_training_time:
                print(f'Accumulated training time: {formatted_time(self.training_time)}.')
        
        # Notification:
        title = f'Finished training'
        text = f'Final {metric_name}: {self.best_states["metric_value"]:.2f}' + \
               f'\nLast epoch: {self.best_states["epoch"]}'
        os.system(f"""osascript -e 'display notification "{text}" with title "{title}"'""")