In [None]:
import torch
import sys
sys.path.append('..')
import random
import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from datasets import ExperimentDataset, getSplittedDataset
from model import bank_net
import copy
from vfl import Client, Server, VFLNN
import torch.optim as optim
import torch.nn.functional as F

In [2]:
manualseed = 47
random.seed(manualseed)
torch.manual_seed(manualseed)
np.random.seed(manualseed)

In [3]:
input_dim = 10
output_dim = 2
data_path = 'data/bank_cleaned.csv'
train_portion = 0.7
test_portion = 0.3
dataset = 'bank'

In [4]:
def split_data(data):
    x_a = data[:, 0: 10]
    x_b = data[:, 10: 20]   
    return x_a, x_b

In [5]:
expset = ExperimentDataset(datafilepath=data_path)
print(expset)
print(len(expset))

CRITICAL:root:Dataset column permutation is: 
 range(0, 20)
CRITICAL:root:Creating dataset, len(samples): 30488; positive labels sum: 3859


<datasets.ExperimentDataset object at 0x000002476F6E42E0>
30488


In [6]:
x, _ = expset[0: 10]
print(x.shape)
x_a, x_b = split_data(x)
print(x_a.shape, x_b.shape)

torch.Size([10, 20])
torch.Size([10, 10]) torch.Size([10, 10])


In [7]:
xx = torch.cat((x_a, x_b), 1)
print(xx.shape)

torch.Size([10, 20])


In [8]:
trainset, testset = getSplittedDataset(train_portion, test_portion, expset)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=True)
print("len(trainloader): %d", len(trainloader))
print("len(testloader): %d", len(testloader))


CRITICAL:root:
[FUNCTION]: Splitting dataset by getSplittedDataset()......
CRITICAL:root:len(trainset): 21341
CRITICAL:root:len(testset): 9147


len(trainloader): %d 334
len(testloader): %d 143


In [9]:
bottom1, top_model = bank_net(input_dim, output_dim)
bottom2 = copy.deepcopy(bottom1)

bottom1, bottom2, top_model = bottom1.cuda(), bottom2.cuda(), top_model.cuda()
client1 = Client(bottom1)
client2 = Client(bottom2)
server = Server(top_model, dataset)

In [12]:
client1_optimizer = optim.Adam(bottom1.parameters(), lr=0.01)
client2_optimizer = optim.Adam(bottom2.parameters(), lr=0.01)
server_optimizer = optim.Adam(server.parameters(), lr=0.01)
client_optimizer = [client1_optimizer, client2_optimizer]
target_vflnn = VFLNN(client1, client2, server, client_optimizer, server_optimizer)

for i in range(30):
    print("Epoch: ", i) 
    target_vflnn.train()
    train_loss = 0
    for batchidx, (data, target) in enumerate(trainloader):
        
        data, target_label = data.cuda(), target.cuda()
        target_vflnn.zero_grads()
        x_a, x_b = split_data(data)
        # print(x_a.shape, x_b.shape)
        target_vflNN_output = target_vflnn(x_a, x_b)
        # 计算loss
        target_vflNN_loss = F.cross_entropy(target_vflNN_output, target_label.long())
        
        # 反向传播
        target_vflNN_loss.backward()
        # 整体vflNN的反向传播
        target_vflnn.backward()

        train_loss += target_vflNN_loss.item() * data.size(0)
        # 更新模型
        target_vflnn.step()
    train_loss = train_loss / len(trainloader.dataset)
    print("======loss=======")
    print(train_loss)
        
    target_vflnn.eval()
    
    print("---------------------------testtesttest---------------------------")
    correct = 0
    total = 0
    with torch.no_grad():
        for test_data, test_target in testloader:
            test_data, test_target = test_data.cuda(), test_target.cuda()
            x_a, x_b = split_data(test_data)
            outputs = target_vflnn(x_a, x_b)
            _, predicted = torch.max(outputs.data, 1)
            total += test_target.size(0)
            correct += (predicted == test_target).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')

Epoch:  0
0.21613871915361316
---------------------------testtesttest---------------------------
Accuracy of the network on the test images: 89.56%
Epoch:  1
0.21235084777447338
---------------------------testtesttest---------------------------
Accuracy of the network on the test images: 89.50%
Epoch:  2
0.21016390812351757
---------------------------testtesttest---------------------------
Accuracy of the network on the test images: 90.25%
Epoch:  3
0.2078761103858061
---------------------------testtesttest---------------------------
Accuracy of the network on the test images: 90.12%
Epoch:  4
0.2082147679164804
---------------------------testtesttest---------------------------
Accuracy of the network on the test images: 90.10%
Epoch:  5
0.21042747287681313
---------------------------testtesttest---------------------------
Accuracy of the network on the test images: 90.16%
Epoch:  6
0.20737225825675654
---------------------------testtesttest---------------------------
Accuracy of the n