In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm
from torchvision import datasets, transforms
import pandas as pd
from sklearn import preprocessing
from sklearn.metrics import *
import numpy as np
from torchvision import datasets, transforms
from torch.utils import *
import matplotlib.pyplot as plt

In [24]:
class Predictor(nn.Module):
    def __init__(self, num_features):
        super(Predictor, self).__init__()
        self.linear = torch.nn.Linear(num_features, 1)
        
    def forward(self, x):
        y_logits = self.linear(x)
        y_pred = F.sigmoid(y_logits)
        return y_logits, y_pred 


class Adversary(nn.Module):
    def __init__(self, num_features):
        super(Adversary, self).__init__()
        self.c = nn.Parameter(torch.ones(1), requires_grad=True)
        self.linear = nn.Linear(num_features, 1)

    def forward(self, y_logits, y):
        s = F.sigmoid((1+abs(self.c))*y_logits) 
        y = y.view_as(s) 
        x = torch.cat([s, s*y, s*(1-y)], 1)
        z_logits = self.linear(s)
        z_pred = F.sigmoid(z_logits)
        return z_logits, z_pred 


def train(predictor, adversary, optimizer_P, optimizer_A, device, train_loader, criterion, epoch, alpha=0.1, verbose=True):
    predictor.train()
    adversary.train()
    
    sum_num_pred_correct, sum_num_adv_correct = 0, 0
    sum_pred_loss, sum_adv_loss = 0, 0
    
    num_batches_since_log = 0

    if verbose:
        batches = tqdm(enumerate(train_loader), total=len(train_loader))
        batches.set_description("Epoch NA: Loss (NA) Accuracy (NA %)")
    else:
        batches = enumerate(train_loader)
        
    for batch_idx, (X, y, z) in batches:
        X, y, z = X.to(device, dtype=torch.float), y.to(device, dtype=torch.float), z.to(device)
        
        # Resetting gradients to zero
        optimizer_P.zero_grad()
        optimizer_A.zero_grad()
        
        
        ### Predictor 
        y_logits, y_pred = predictor(X)
        predictor_loss = criterion(y_pred, y.view_as(y_pred))
        
        ### Adversary
        z_logits, z_pred = adversary(y_logits, y)
        adversary_loss = criterion(z_pred, z.view_as(z_pred))
        

        adversary_grads = torch.autograd.grad(adversary_loss, (predictor.parameters()), retain_graph=True)
        normalize = lambda x: x / (torch.norm(x,1) + np.finfo(np.float32).tiny)

        predictor_loss.backward(retain_graph=True)
        adversary_loss.backward()

        for W_grad_predictor, b_grad_predictor in zip(predictor.linear.weight.grad, predictor.linear.bias.grad):
            W_grad_adversary = adversary_grads[0].view_as(W_grad_predictor)
            unit_W_grad_adversary = normalize(W_grad_adversary)
            W_grad_predictor -= torch.sum(unit_W_grad_adversary*W_grad_predictor)*unit_W_grad_adversary
            W_grad_predictor -= alpha*W_grad_adversary
            
            b_grad_adversary = adversary_grads[1].view_as(b_grad_predictor)
            unit_b_grad_adversary = normalize(b_grad_adversary)
            b_grad_predictor -= torch.sum(unit_b_grad_adversary*b_grad_predictor)*unit_b_grad_adversary
            b_grad_predictor -= alpha*b_grad_adversary
        
#         alpha += np.sqrt(num_batches_since_log)
                
        optimizer_P.step()
        optimizer_A.step()
        
        pred = (y_pred > 0.5)*1
        pred_correct = pred.eq(y.view_as(pred)).sum().item()
        sum_num_pred_correct += pred_correct
            
        adv_pred = (z_pred > 0.5)*1
        adv_correct = adv_pred.eq(z.view_as(adv_pred)).sum().item()
        sum_num_adv_correct += adv_correct
        
        sum_pred_loss += predictor_loss.item() * train_loader.batch_size
        sum_adv_loss += adversary_loss.item() * train_loader.batch_size
        
        num_batches_since_log += 1

        if verbose:
            batches.set_description(
              "Epoch {:d}: Predictor Loss ({:.2e}), Adversary Loss ({:.2e}), Accuracy ({:02.0f}%)".format(
                epoch, predictor_loss.item(), adversary_loss.item(), 100. * sum_num_pred_correct / (num_batches_since_log * train_loader.batch_size))
            )
              
    sum_pred_loss /= len(train_loader.dataset)
    sum_adv_loss /= len(train_loader.dataset)
    predictor_accuracy = sum_num_pred_correct / len(train_loader.dataset)
    adversary_accuracy = sum_num_adv_correct/len(train_loader.dataset)
    
    
    
    return sum_pred_loss, sum_adv_loss, predictor_accuracy, adversary_accuracy




In [None]:
def test(predictor, adversary, device, test_loader, criterion):
    predictor.eval()
    adversary.eval()
    test_pred_loss, test_adv_loss = 0, 0
    correct = 0
    adv_correct = 0
    test_pred = torch.zeros(0, 1, dtype=torch.torch.int64)
    test_adv_pred = torch.zeros(0, 1, dtype=torch.torch.int64)
    with torch.no_grad():
        for X, y, z in test_loader:
            X, y, z = X.to(device, dtype=torch.float), y.to(device, dtype=torch.float), z.to(device, dtype=torch.float) 
            
            y_logit, y_pred = predictor(X)
            z_logit, z_pred = adversary(y_logit, y)
           
            pred_loss = criterion(y_pred, y.view_as(y_logit))
            adv_loss = criterion(z_pred, z.view_as(z_logit))
            
            test_pred_loss += pred_loss.item()*test_loader.batch_size 
            test_adv_loss += adv_loss.item()*test_loader.batch_size
            
            pred = (y_pred > 0.5)*1
            test_pred = torch.cat([test_pred, pred], 0)
            correct += pred.eq(y.view_as(pred)).sum().item()
            
            adv_pred = (z_pred > 0.5)*1
            test_adv_pred = torch.cat([test_adv_pred, adv_pred], 0)
            adv_correct += adv_pred.eq(z.view_as(adv_pred)).sum().item()
            
    test_pred_loss /= len(test_loader.dataset)
    test_adv_loss  /= len(test_loader.dataset)
    test_pred_acc = correct / len(test_loader.dataset)
    test_adv_acc = adv_correct/len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.2e}, Predictor Accuracy: ({:.0f}%), Adversary Accuracy: ({:.0f}%)\n'.format(
        test_pred_loss, 100. * test_pred_acc, 100* test_adv_acc))
    return test_pred, test_adv_pred, test_pred_loss, test_adv_loss, test_pred_acc, test_adv_acc

