In [None]:
"""
Functions used for the training, validation, and testing
of the hybrid model on the TUH Abnormal dataset.

Each method has its own description in it's header section.'

The methods defined in this file are:
    - train_model
    - validate_model
    - test_model
"""

In [None]:

# import packages
import torch
from datetime import datetime

In [None]:

# Training function

def train_model(model, device, criterion, optimizer, train_loader, valid_loader, n_epochs=5, early_stop_patience=2):
    """
    Function for the training of the hybrid model. Also includes a call
    on the validation of the model performance after every epoch.
    
    Parameters
    ----------
    model : model to be trained
    device : device the training should take place (cpu/gpu)
    criterion : loss function to be used during training
    optimizer : optimizer to be used during training
    train_loader : DataLoader training set
    valid_loader : DataLoader validation set
    n_epochs : number of epochs 
    early_stop_patience : patience for Early Stopping
    
    Returns
    ----------
    model : trained model
    train_losses_per_epoch : list containing the training losses computed after every epoch
    train_accs_per_epoch : list containing the training accuracies computed after every epoch
    valid_losses_per_epoch : list containing the validation losses computed after every epoch
    valid_accs_per_epoch : list containing the validation accuracies computed after every epoch
    """
    
    for i in range(1, 1000000):
        if len(train_loader) / i <= 20:
            print_every = i
            break
    
    model.train()
    the_last_loss = 100
    trigger_times = 0
    train_losses_per_epoch, valid_losses_per_epoch = [], []
    train_accs_per_epoch, valid_accs_per_epoch = [], []
        
    
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        
        running_loss = 0.0
        running_acc = 0.0
        total = 0
        correct = 0
        
        for i, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels, inds = data
            inputs = inputs.float()
            labels = labels.float()
            
            # sending input to GPU/CPU
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = model(inputs)
            
            outputs = torch.squeeze(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            predicted = torch.round(outputs)
            total += labels.numel()
            
            correct += predicted.eq(labels).sum().item()
          
            # Statistics
            running_loss += loss.item()
            train_acc = correct/total
            running_acc += train_acc
            
            
            if i % print_every == print_every-1 or i+1 == len(train_loader):
                now = datetime.now()
                current_time = now.strftime("%m/%d/%Y, %H:%M:%S")
                print('Time: {}. . [{}/{}, {}/{}] train_loss: {:.6}, train_acc: {:.3}'.format(current_time, epoch+1, n_epochs, i+1, len(train_loader), running_loss/(i+1), running_acc/(i+1)))

    
        train_losses_per_epoch.append(running_loss/len(train_loader))
        train_accs_per_epoch.append(running_acc/len(train_loader))
        current_loss, current_acc = validate_model(model, device, valid_loader, criterion)
        model.train()
        print('Current Validation Loss: {:.6}, Accuracy: {:.3}'.format(current_loss, current_acc))
        
            
        if current_loss > the_last_loss:
            trigger_times += 1
            print('Trigger Times:', trigger_times)
    
            if trigger_times >= early_stop_patience:
                print('Early Stopping!\nStart the test process.')
                return model, train_losses_per_epoch, valid_losses_per_epoch
                # this needs to be modified to start the testing process -> return model if turn into a function
    
        else:
            print('Trigger Times: 0')
            trigger_times = 0
    
        the_last_loss = current_loss
        valid_losses_per_epoch.append(current_loss)
        valid_accs_per_epoch.append(current_acc)
    
    print('Finished Training')        
    return model, train_losses_per_epoch, train_accs_per_epoch, valid_losses_per_epoch, valid_accs_per_epoch

In [None]:

def validate_model(model, device, valid_loader, loss_function):
    """
    Function for the validation of the hybrid model. 
    
    Parameters
    ----------
    model : model to be trained
    device : device the training should take place (cpu/gpu)
    valid_loader : DataLoader validation set
    loss_function : loss function
    
    Returns
    ----------
    model : trained model
    loss_total / len(valid_loader) : loss computed over the validation set
    accuracy : accuracy computed over the validation set
    """
    
    model.eval()
    loss_total = 0
    correct = 0
    total = 0
    accuracy = 0

    # Test validation data
    with torch.no_grad():
        for data in valid_loader:
            inputs, labels, inds = data
            inputs = inputs.float()
            labels = labels.float()
            #labels = torch.clone(labels.long())
             
            #Sending input to GPU/CPU
            inputs, labels = inputs.to(device),  labels.to(device)
            

            outputs = model(inputs)
            outputs = torch.squeeze(outputs, 1)
            loss = loss_function(outputs, labels)
            loss_total += loss.item()
             
            # Prediction for generating validation accuracy
            predicted = torch.round(outputs)
            total += labels.numel()
            correct += predicted.eq(labels).sum().item()
     
    accuracy = correct/total
    return loss_total / len(valid_loader), accuracy

    

def test_model(device, model, test_loader):
    """
    Function for the testing of the hybrid model. 
    
    Parameters
    ----------
    model : model to be trained
    device : device the training should take place (cpu/gpu)
    test_loader : DataLoader test set
    
    Returns
    ----------
    inputs_all : input variables in the test set
    labels_all : targets in the test set
    predicted_all : predictions 
    accuracy : accuracy calculated over the test set
    """
    model.eval()
    total = 0
    correct = 0
    accuracy = 0
    
    inputs_all = torch.tensor([])
    labels_all = torch.tensor([])
    predicted_all = torch.tensor([])
    inputs_all = inputs_all.to(device)
    labels_all = labels_all.to(device)
    predicted_all = predicted_all.to(device)
    
    with torch.no_grad():
        for data in test_loader:
            inputs, labels, inds = data
            inputs = inputs.float()
            labels = torch.clone(labels.long())
            
            # sending input to GPU/CPU
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            outputs = torch.squeeze(outputs, 1)
            predicted = torch.round(outputs)
            
            total += labels.numel()
            correct += (predicted == labels).sum().item()
            
            inputs_all = torch.cat((inputs_all, inputs) ,dim=0)
            predicted_all = torch.cat((predicted_all, predicted) ,dim=0)
            labels_all = torch.cat((labels_all, labels) ,dim=0)
            
    accuracy = correct/total
    
    print('Accuracy: {:.3}'.format(accuracy))
    return inputs_all, labels_all, predicted_all, accuracy

