In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import torch
import joblib
from model import *

import matplotlib.pyplot as plt
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import warnings
warnings.filterwarnings("ignore")


In [2]:
from sklearn import datasets
breast_cancer = datasets.load_breast_cancer()
data, labels = breast_cancer['data'], breast_cancer['target']
mu = np.mean(data, axis = 0)
std = np.std(data, axis = 0)
data = (data - mu) / std
print(data.shape)

(569, 30)


In [3]:
np.random.seed(2)
num_train = 469
num_noisy = 100
num_val = 50
num_test = 50
train_data, train_labels = data[:num_train,:], labels[:num_train]
noisy_index = np.random.randint(0, num_train, num_noisy)
train_labels[noisy_index] = 1 - train_labels[noisy_index]

val_data, val_labels = data[num_train:num_train+num_val,:], labels[num_train:num_train+num_val]
test_data, test_labels = data[num_train+num_val:, :], labels[num_train+num_val:]

In [8]:
def train_approx(weights, x_train, y_train, x_val, y_val, x_test, y_test):
    
    learning_rate = 0.01
    weight_list = []
    train_loss = []
    val_loss = []
    val_acc = [] 
    test_acc = []
    
    x_var = Variable(torch.FloatTensor(x_train))
    y_var = Variable(torch.FloatTensor(y_train))
    
    x_val_var = Variable(torch.FloatTensor(x_val))
    y_val_var = Variable(torch.FloatTensor(y_val))
    
    test_var = Variable(torch.FloatTensor(x_test), requires_grad=True)
    
    LR = net(x_train.shape[1], 1)
    optimizer = torch.optim.SGD(LR.params(), lr=learning_rate)
    LR.train()
    
    weight_list.append(weights.clone().detach().numpy())
    for i in range(100):
        #得到当前的损失函数值，对应于算法第4-5行
        y_f_hat = LR(x_var)
        cost = F.binary_cross_entropy_with_logits(y_f_hat, y_var, reduce = False)
        l_f = torch.sum(cost * weights)
        train_loss.append(l_f / torch.sum(weights))

            
        LR.zero_grad()
        grads = torch.autograd.grad(l_f, (LR.params()), create_graph=True)
        LR.update_params(learning_rate, source_params=grads)
            
        if(i % 1 == 0):
            LR.eval()
            result = LR(test_var)
            pred_proba = torch.sigmoid(result)
            predicted = (pred_proba > 0.5).int()
            acc = accuracy_score(y_test, predicted.int())
            auc = roc_auc_score(y_test, pred_proba.data.numpy())
            print("Outer_epochs: %d, Accuracy: %.5f, AUC: %.5f" % (i, acc, auc))
            
            test_acc.append(acc)
            result = LR(x_val_var)
            pred_proba = torch.sigmoid(result)
            predicted = (pred_proba > 0.5).int()
            acc = accuracy_score(y_val, predicted.int())
            auc = roc_auc_score(y_val, pred_proba.data.numpy())
            print("Outer_epochs: %d, Accuracy: %.5f, AUC: %.5f" % (i, acc, auc))
            val_acc.append(acc)
            
        
        #计算验证集上的损失函数
        y_g_hat = LR(x_val_var)
        l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat, y_val_var)
        val_loss.append(l_g_meta)
                                                         
        #print("l_g_meta: ", l_g_meta)
        grad_eps = torch.autograd.grad(l_g_meta, weights, only_inputs = True)[0]
        
        #print(grad_eps)
        weights = weights - 500 * grad_eps
        weights = torch.clamp(weights, 0.0, 1.0)
        
        weight_list.append(weights.clone().detach().numpy())
        
    return weight_list, train_loss, val_loss, val_acc, test_acc

In [9]:
weights = Variable(0.5*torch.ones(train_data.shape[0]), requires_grad = True)
weight_list, train_loss, val_loss, val_acc, test_acc = train_approx(weights, train_data, train_labels, val_data, val_labels, test_data, test_labels)

Outer_epochs: 0, Accuracy: 0.96000, AUC: 0.99750
Outer_epochs: 0, Accuracy: 0.88000, AUC: 0.97089
Outer_epochs: 1, Accuracy: 0.66000, AUC: 0.81500
Outer_epochs: 1, Accuracy: 0.54000, AUC: 0.68815
Outer_epochs: 2, Accuracy: 0.88000, AUC: 0.98750
Outer_epochs: 2, Accuracy: 0.74000, AUC: 0.84823
Outer_epochs: 3, Accuracy: 0.96000, AUC: 0.99000
Outer_epochs: 3, Accuracy: 0.84000, AUC: 0.95634
Outer_epochs: 4, Accuracy: 0.96000, AUC: 1.00000
Outer_epochs: 4, Accuracy: 1.00000, AUC: 1.00000
Outer_epochs: 5, Accuracy: 0.98000, AUC: 1.00000
Outer_epochs: 5, Accuracy: 0.98000, AUC: 1.00000
Outer_epochs: 6, Accuracy: 0.98000, AUC: 0.99750
Outer_epochs: 6, Accuracy: 0.98000, AUC: 1.00000
Outer_epochs: 7, Accuracy: 0.98000, AUC: 0.99750
Outer_epochs: 7, Accuracy: 0.98000, AUC: 1.00000
Outer_epochs: 8, Accuracy: 0.98000, AUC: 0.99750
Outer_epochs: 8, Accuracy: 0.98000, AUC: 1.00000
Outer_epochs: 9, Accuracy: 0.98000, AUC: 0.99750
Outer_epochs: 9, Accuracy: 0.98000, AUC: 1.00000
Outer_epochs: 10, Ac

Outer_epochs: 96, Accuracy: 0.96000, AUC: 1.00000
Outer_epochs: 96, Accuracy: 1.00000, AUC: 1.00000
Outer_epochs: 97, Accuracy: 0.96000, AUC: 1.00000
Outer_epochs: 97, Accuracy: 1.00000, AUC: 1.00000
Outer_epochs: 98, Accuracy: 0.96000, AUC: 1.00000
Outer_epochs: 98, Accuracy: 1.00000, AUC: 1.00000
Outer_epochs: 99, Accuracy: 0.96000, AUC: 1.00000
Outer_epochs: 99, Accuracy: 1.00000, AUC: 1.00000
