In [1]:
import torch
from torch_geometric.data import InMemoryDataset
import numpy as np
from torch_geometric.data import Data
import scipy.io

def f1_calculation(y,patient_results):
    y_pred = np.zeros(len(y))
    for i,label in enumerate(patient_results) :
        if label == 1 :
            y_pred[i]=y[i]
        else :
            y_pred[i] = 1-y[i]
    return f1_score(y,y_pred)

class PairData(Data):
    def __init__(self, edge_index_s=None, x_s=None, edge_index_t=None, x_t=None, y=None):
        super().__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t
        self.y = y
        
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)
        

class YooChooseBinaryDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(YooChooseBinaryDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['C:\\Users\\User\\yoochoose-data\\discrepancy-capa-snet-nw-capa-averaging-smooth-norm1_3-adsci.dataset']

    def download(self):
        pass
    
    def process(self):
        pass
dataset = YooChooseBinaryDataset('C:\\Users\\User\\yoochoose-data\\discrepancy-capa-snet-nw-capa-averaging-smooth-norm1_3-adsci.dataset')

In [2]:
y = []
for i in range(len(dataset)):
    y.append(dataset[i].y)

In [5]:
from torch.nn import Sequential, Linear, ReLU, Parameter
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import  global_mean_pool, GCNConv
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC
from sklearn.metrics import f1_score
import warnings

warnings.filterwarnings("ignore")


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        num_features = 148
        dim = 32
        self.conv1 = GCNConv(num_features,dim)
        self.conv2 = GCNConv(dim,dim)
        
    def forward_(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return x

           
    def forward(self, x_s, edge_index_s, x_t, edge_index_t, x_s_batch, x_t_batch):
        x_s = self.forward_(x_s,edge_index_s)
        x_t = self.forward_(x_t,edge_index_t)
        x_s = global_mean_pool(x_s, x_s_batch)
        x_t = global_mean_pool(x_t, x_t_batch)
        scale_factor = 32**-0.5
        ##Mean Attention
        dot = x_s*x_t*scale_factor
        return torch.sum(dot,1)

def train(epoch):
    model.train()

    loss_all = 0
    for data in data_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x_s, data.edge_index_s, data.x_t, data.edge_index_t, data.x_s_batch, data.x_t_batch)
        var_1 = torch.var(output[torch.where(data.y==1)], unbiased=False)
        var_2 = torch.var(output[torch.where(data.y==0)], unbiased=False)
        mean_1 = torch.mean(output[torch.where(data.y==1)])
        mean_2 = torch.mean(output[torch.where(data.y==0)])
        loss = 0.1*var_1 + 0.1*var_2 + F.relu(0.5-(mean_1-mean_2))
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(dataset)


def test(loader1,loader2):
    model.eval()
    correct = 0
    for data in loader1 :        
        data = data.to(device)
        output = model(data.x_s, data.edge_index_s, data.x_t, data.edge_index_t, data.x_s_batch, data.x_t_batch)
        model_1 = KNeighborsClassifier(n_neighbors=7)
        model_1.fit(output.detach().cpu().numpy().reshape(-1, 1),data.y.detach().cpu().numpy())
    for data in loader2 :
        data = data.to(device)
        output = model(data.x_s, data.edge_index_s, data.x_t, data.edge_index_t, data.x_s_batch, data.x_t_batch)
        correct = model_1.score(output.detach().cpu().numpy().reshape(-1, 1),data.y.detach().cpu().numpy())
    del model_1
    return correct

batch_size= 21

test_num = 20
results = np.zeros([test_num])
f1_arr = np.zeros([test_num])
patient_results = np.zeros([len(dataset),test_num])

for t in range(0,test_num):
    
    acc = 0
    for i in range(0,len(dataset)) :
    #for i in range(2,3) :
        train_indices = np.delete(np.arange(len(dataset)),i).tolist() 
        test_indices = [i]
        dataset_train =  torch.utils.data.Subset(dataset, train_indices)
        dataset_test = torch.utils.data.Subset(dataset, test_indices)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = Net().to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)


        for epoch in range(1, 101):
            data_loader = DataLoader(dataset_train, batch_size=batch_size,shuffle=True,follow_batch=['x_s', 'x_t'])
            data_loader_test = DataLoader(dataset_test, batch_size=1,follow_batch=['x_s', 'x_t'])
            train_loss = train(epoch)
        data_all = DataLoader(dataset_train, batch_size=len(dataset)-1 ,shuffle=True,follow_batch=['x_s', 'x_t'])
        test_acc = test(data_all,data_loader_test)
        patient_results[i,t] = test_acc
        f1_arr[t] = f1_calculation(y,patient_results[:,t])
        acc = acc + test_acc
        del model
    acc = acc/len(dataset) 
    print(acc)
    results[t] = acc
