<a href="https://colab.research.google.com/github/barakmam/ECG-Arrhythmia-Classification/blob/master/main2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
import time
import pandas as pd

def checkDevice():
    if torch.cuda.is_available():
        torch.cuda.current_device()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    print("running calculations on: ", device)
    return device

In [None]:
!wget -r -N -c -np https://physionet.org/files/mitdb/1.0.0/

In [None]:
def trainNet(net, batch_size, learning_rate, step, patience, valCalcFreq, train_loader, val_loader, device):
    # Print all of the hyperparameters of the training iteration:
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", batch_size)
    print("learning_rate=", learning_rate)
    print("step=", step)
    print("patience=", patience)
    print("=" * 30)

    # Time for printing
    training_start_time = time.time()

    # Statistics:
    numDeleted = np.array([])
    train_acc_total = np.array([])
    val_acc_total = np.array([0])
    train_loss_total = np.array([])
    val_loss_total = np.array([])


    delJumps = np.arange(12, 62, step)
    n_epochJumps = 500 * np.ones(delJumps.shape)
    for n_delete, n_epochs in zip(np.append([1, 2, 3, 4, 6, 8, 10], delJumps), np.append([5, 5, 5, 5, 5, 5,  5], n_epochJumps)):
        print(f'Delete: {n_delete}.')

        # for epoch in range(n_epochs):
        epoch = 0
        p = 0
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
        while p < patience:
            epoch += 1
            train_running_loss = 0.0
            train_running_hits = 0.0
            train_samples_checked = 0
            runningDeletedNumber = 0
            start_time = time.time()

            net.train()
            for batch_idx, data in enumerate(train_loader, 0):

                # Get inputs
                _, solutions = data
                solutions = solutions.float().to(device)
                quizzes = delete_cells_improved_complexity(solutions, n_delete, device)
                mask_of_deleted_cells = (quizzes.argmax(1) == 0).float()
                # Set the parameter gradients to zero
                optimizer.zero_grad()

                # Forward pass, backward pass, optimize
                outputs = net(quizzes)
                loss_matrix, quizz_example, sol_example = loss_func(outputs, solutions)
                loss_size = (mask_of_deleted_cells * loss_matrix).sum()/mask_of_deleted_cells.sum()
                loss_size.backward()
                optimizer.step()

                # Update statstics:
                train_running_loss += loss_size.data
                train_running_hits += (
                            (outputs.argmax(1) == solutions.argmax(1)).float() * mask_of_deleted_cells).sum().double()
                # running_hits += ((solved_boards == solutions.argmax(1) + 1).float() * mask_of_deleted_cells).sum().double()
                runningDeletedNumber += mask_of_deleted_cells.sum()
                train_samples_checked += len(outputs)
            train_acc_total = np.append(train_acc_total,
                                        (train_running_hits / runningDeletedNumber).cpu().numpy())
            train_loss_total = np.append(train_loss_total, (train_running_loss / len(train_loader)).cpu().numpy())
            print("Delete {}  Epoch {}:\tTook {:.2f}s. \t Train: loss = {:.3f} Acc = {:.3f}"
                  .format(n_delete, epoch, time.time() - start_time, train_loss_total[-1], train_acc_total[-1]))

            if epoch % valCalcFreq == 0:
                # At the end of the epoch, do a pass on the validation set
                val_start_time = time.time()
                total_val_loss = 0
                val_hits = 0
                val_runningDeletedNumber = 0
                val_samples_checked = 0
                net.eval()
                with torch.no_grad():
                    for data in val_loader:
                        # Wrap tensors in Variables
                        _, val_solutions = data
                        val_solutions = val_solutions.float().to(device)
                        val_quizzes = delete_cells_improved_complexity(val_solutions, n_delete, device=device)
                        val_mask_of_deleted_cells = (val_quizzes.argmax(1) == 0).float()
                        # Forward pass
                        # val_outputs = net(val_quizzes)
                        val_solved_boards = fillBlank_imporved_complexity(net, val_quizzes, n_delete, device)
                        val_iterative_outputs = torch_categorical(val_solved_boards - 1, 9, device)
                        val_loss_matix, _, _ = loss_func(val_iterative_outputs, val_solutions)
                        total_val_loss += (val_mask_of_deleted_cells * val_loss_matix).sum()/val_mask_of_deleted_cells.sum()
                        # val_hits += ((val_outputs.argmax(1) == val_solutions.argmax(1)) * val_mask_of_deleted_cells).sum().double()
                        val_hits += ((val_solved_boards == val_solutions.argmax(
                            1) + 1).float() * val_mask_of_deleted_cells).sum().double()
                        val_runningDeletedNumber += val_mask_of_deleted_cells.sum()
                        val_samples_checked += len(val_solutions)

                val_acc_total = np.append(val_acc_total, (val_hits / val_runningDeletedNumber).cpu().numpy())
                val_loss_total = np.append(val_loss_total, (total_val_loss / len(val_loader)).cpu().numpy())
                numDeleted = np.append(numDeleted, (val_runningDeletedNumber / val_samples_checked).cpu().numpy())
                print("Delete {}  Epoch {}:\tTook {:.2f}s. \t Validation: loss = {:.3f} Acc = {:.3f}"
                      .format(n_delete, epoch, time.time() - val_start_time, val_loss_total[-1], val_acc_total[-1]))
                if val_acc_total[-1] <= val_acc_total[-2]:
                    p += 1
                else:
                    p = 0
                if epoch == n_epochs:
                    break

    print("Training finished, took {:.3f}s".format(time.time() - training_start_time))
    return net, numDeleted, train_acc_total, train_loss_total, val_acc_total[1:], val_loss_total