# CNN Network Training, Testing and Evaluation

In [None]:
# Function to train CNN Network.
def train_network(model):
    
    # Use GPU to train model.
    if torch.cuda.is_available():
        model = model.cuda()
    
    # Choose Loss Function and Optimizer.
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)

    # Training with Validation
    epochs = 30
    min_valid_loss = np.inf

    # Training and Validation loss list.
    loss_list = list()
    loss_list1 = list()
    
    # Train the model.
    for e in range(epochs):
        train_loss = 0.0
        for data, labels in train_loader:
            # Transfer Data to GPU if available
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()

            # Clear the gradients
            optimizer.zero_grad()
            # Forward Pass
            target = model(data)
            # Find the Loss
            loss = criterion(target,labels)
            # Calculate gradients
            loss.backward()
            # Update Weights
            optimizer.step()
            # Calculate Loss
            train_loss += loss.item()

        # Add total epoch loss to list.
        loss_list.append(train_loss/(len(train_loader) * 64))
        
        # Apply validation to model.
        valid_loss = 0.0
        model.eval()     # Optional when not using Model Specific layer
        for data, labels in validation_loader:
            # Transfer Data to GPU if available
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()

            # Forward Pass
            target = model(data)
            # Find the Loss
            loss = criterion(target,labels)
            # Calculate Loss
            valid_loss += loss.item()
        
        # Add validation loss to list.
        loss_list1.append(valid_loss/(len(validation_loader) * 64))

    # Plot the loss graphs of training and validation.
    fig = plt.figure(figsize=(12, 4))
    plt.plot(range(1,epochs + 1), loss_list, label = "train loss")
    plt.plot(range(1,epochs + 1), loss_list1, label = "validation loss")
    plt.title('Model Loss Graph')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    
    return model

In [None]:
# Function to test and evaluate the CNN Model.
def test_network(model):
    model.eval()
    correct = 0
    test_predictions = list()
    test_truths = list()
    with torch.no_grad():
        for batch in (test_loader):
            images = batch[0].cuda()
            labels = batch[1].cuda()
            output = model(images)
            _, preds = torch.max(output.data, 1)
            correct += (preds == labels).sum().item()
            test_predictions.extend(preds.cpu().numpy())
            test_truths.extend(labels.cpu().numpy())
        
        # Show Accuracy, Confusion Matrix and Classification Report of CNN Model.
        print(f'Accuracy: {correct / len(test_loader.dataset)}')
        model_df = pd.DataFrame(confusion_matrix(test_truths, test_predictions),
                                index = ["MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented"],
                                columns = ["MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented"])
        plt.figure(figsize = (10,7))
        sns.heatmap(model_df, annot=True, fmt="d")
        plt.show()
        print(classification_report(test_truths, test_predictions))
        