In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm
from torchvision import datasets, transforms
from sklearn import preprocessing
from sklearn.metrics import *
from sklearn.utils import resample
import numpy as np
from torch.utils import *
import matplotlib.pyplot as plt


In [4]:
class Predictor(nn.Module):
    def __init__(self, num_predictor_features):
        super(Predictor, self).__init__()
        self.linear = torch.nn.Linear(num_predictor_features, 1)

    def forward(self, x):
        y_logits = self.linear(x)
        y_pred = F.sigmoid(y_logits)
        return y_logits, y_pred 

def train(model, device, train_loader, optimizer, epoch, verbose=True):
    model.train()
    sum_num_correct = 0
    sum_loss = 0
    num_batches_since_log = 0

    if verbose:
        batches = tqdm(enumerate(train_loader), total=len(train_loader))
        batches.set_description("Epoch NA: Loss (NA) Accuracy (NA %)")
    else:
        batches = enumerate(train_loader)
    for batch_idx, (data, target, protect) in batches:
        data, target, protect = data.to(device, dtype=torch.float), target.to(device, dtype=torch.float), protect.to(device)
        optimizer.zero_grad()
        logits, output = model(data)
        criterion = torch.nn.BCELoss()
        loss = criterion(output, target.view_as(output))
        pred = (output > 0.5)*1
        correct = pred.eq(target.view_as(pred)).sum().item()
        sum_num_correct += correct
        sum_loss += loss.item() * train_loader.batch_size
        num_batches_since_log += 1
        loss.backward()
        optimizer.step()
        

        if verbose:
            batches.set_description(
              "Epoch {:d}: Loss ({:.2e}), Accuracy ({:02.0f}%)".format(
                epoch, loss.item(), 100. * sum_num_correct / (num_batches_since_log * train_loader.batch_size))
            )
        
    sum_loss /= len(train_loader.dataset)
    train_accuracy = sum_num_correct / len(train_loader.dataset)
    
    return sum_loss, train_accuracy

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    test_pred = torch.zeros(0, 1, dtype=torch.int64)
    with torch.no_grad():
        for data, target, protect in test_loader:
            data, target, protect = data.to(device, dtype=torch.float), target.to(device, dtype=torch.float), protect.to(device, dtype=torch.float) 
            logit, output = model(data)
            criterion = torch.nn.BCELoss()
            loss = criterion(output, target.view_as(output))
            test_loss += loss.item()*test_loader.batch_size # sum up loss for each test sample
            pred = (output > 0.5)*1
            test_pred = torch.cat([test_pred, pred], 0)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy =  correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.2e}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_pred, test_loss, test_accuracy
