In [1]:
import pandas as pd
content = pd.read_csv("cora.content", sep="\t")
cites = pd.read_csv("cora.cites", sep="\t")
cites

Unnamed: 0,35,1033
0,35,103482
1,35,103515
2,35,1050679
3,35,1103960
4,35,1103985
...,...,...
5423,853116,19621
5424,853116,853155
5425,853118,1140289
5426,853155,853118


In [2]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
graph = dataset[0]

### Boost

In [3]:
import torch.nn.functional as F
from torch_geometric.nn import GATConv, to_hetero, SAGEConv, GCNConv, Linear
from torchmetrics import AUROC
from sklearn import metrics
import matplotlib.pyplot as plt
import random

sepsis_cases = torch.count_nonzero(training_graph.y)
control_cases = training_graph.y.size(dim=0) - sepsis_cases
sepsis_cases_test = torch.count_nonzero(testing_graph.y)
control_cases_test = testing_graph.y.size(dim=0) - sepsis_cases

# boost_lr = 0.01
# number_of_iter = 1000
# epochs = range(100)
# gnn_lr = 3e-4
# hidden_dim = 64


# boost_lr = 0.01
# number_of_iter = 2000
# epochs = range(90)
# gnn_lr = 8e-4
# num_used_features = 4
# hidden_dim = 32

boost_lr = 0.1
number_of_iter = 600
epochs = range(30)
gnn_lr = 8e-4
num_used_features = 7
hidden_dim = 16

weight = 1 * control_cases / sepsis_cases
weight_tensor = torch.ones(training_graph.y.shape[0], device = device)
weight_tensor[(training_graph.y == 1).nonzero(as_tuple=True)[0]] = weight
print(weight_tensor.unique(return_counts=True))

models = []
feature_indices_list = []
residuals = []

init_logit = torch.log(sepsis_cases / control_cases)
init_probs = torch.ones((training_graph.y.shape[0]), device = device) * torch.sigmoid(init_logit)
pseudo_residuals = (training_graph.y - init_probs) * weight_tensor

test_init_logit = torch.log(sepsis_cases_test / control_cases_test)
test_init_probs = torch.ones((testing_graph.y.shape[0]), device = device) * torch.sigmoid(test_init_logit)

eps = 1e-10

def recurse(last_probs, last_pseudo_residuals, last_log_odds):
    global number_of_iter, models, test_init_probs, residuals, feature_indices_list
    if number_of_iter == 0:
        return
    
    class WeakGNN(torch.nn.Module):

        def __init__(self):
            super().__init__()
            self.conv1 = SAGEConv(num_used_features, hidden_dim, normalize=True, project= True)
            self.conv_end = SAGEConv((-1, -1), 1)

        def forward(self, data):
            x, edge_index = data.x, data.edge_index
            x = self.conv1(x, edge_index)
            x = torch.relu(x)
            x = self.conv_end(x, edge_index)
            return x

    model = WeakGNN()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=gnn_lr )
    
    
    def frag_operation(tensor_1, tensor_2, lambda_fn):
        tensor_1 = tensor_1.view((56393, 18))
        tensor_2 = tensor_2.view((56393, 18))
        return lambda_fn(tensor_1, tensor_2).view(-1)
    
    def frag_test_operation(tensor_1, tensor_2, lambda_fn):
        tensor_1 = tensor_1.view((91571, 4))
        tensor_2 = tensor_2.view((91571, 4))
        return lambda_fn(tensor_1, tensor_2).view(-1)
            

    def train(data, last_pseudo_residuals):
        weight = control_cases / sepsis_cases
        for epoch in epochs:
            model.train()
            logits = model(data)
            loss_fn = torch.nn.MSELoss()
            loss = loss_fn(torch.squeeze(logits), last_pseudo_residuals)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    def make_pred(data, last_probs):
        model.eval()
        with torch.inference_mode():
            logits = model(data)
            new_probs = frag_operation(logits, last_probs, lambda logits, last_probs: logits / ((last_probs * (1 - last_probs))+eps))
            return new_probs
        
    def test():
        auroc_metric = AUROC(task="binary")
        
        test_log_odds = torch.clone(test_init_probs)
        last_probs = test_init_probs
        for i, model in enumerate(models):
            model.eval()
            with torch.inference_mode():
                feature_indices = feature_indices_list[i]
                data = Data(x = torch.clone(testing_graph.x), edge_index = testing_graph.edge_index, y = testing_graph.y)
                data.x = data.x[: ,feature_indices]
                
                logits = model(data)
                new_probs = frag_test_operation(logits, last_probs, lambda logits, last_probs: logits / ((last_probs * (1 - last_probs))+eps))
                test_log_odds += boost_lr * new_probs
                last_probs = torch.sigmoid(test_log_odds)
        test_probs = torch.sigmoid(test_log_odds)
        auroc = auroc_metric(test_probs, testing_graph.y)
        print(f"AUROC {auroc.item()}")
#         print(torch.round(test_probs).unique(return_counts=True))
#         print(testing_graph.y.unique(return_counts=True))
#         confusion_matrix = metrics.confusion_matrix(testing_graph.y.cpu(),
#                                                     torch.round(test_probs).type(torch.int).cpu())
#         cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=[False, True])
#         cm_display.plot()
        

    data = Data(x= torch.clone(training_graph.x),  edge_index = training_graph.edge_index, y = training_graph.y)
    feature_indices = random.sample(range(7), num_used_features)
    feature_indices_list.append(feature_indices)
    data.x = data.x[:, feature_indices]
    train(data, last_pseudo_residuals)
    models.append(model)
    
    last_log_odds = last_log_odds + boost_lr* make_pred(data, last_probs)
    new_probs = torch.sigmoid(last_log_odds)
    pseudo_residuals = (training_graph.y - new_probs)*weight_tensor
    residuals.append(torch.abs(pseudo_residuals).sum().cpu().item())
    
    number_of_iter -= 1
    print(f"{number_of_iter} Iterations coming")
    
    test()
    recurse(new_probs, pseudo_residuals, last_log_odds)

    
recurse(init_probs, pseudo_residuals, init_probs)
plt.plot(range(0, len(residuals)), residuals, 'g', label='Pseudo residuals')
plt.xlabel('Models')
plt.ylabel('Residuals')
plt.legend()
plt.show()

In [7]:
import torch
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
graph = graph.to(device)

In [98]:
from torch_geometric.loader import NeighborLoader, ImbalancedSampler, DataLoader
from hgnn.GNN import GNN
from hgnn.HGraph import PATIENT_NAME
import torch
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, to_hetero, SAGEConv, GCNConv, Linear
from torchmetrics import AUROC, Accuracy
from torchmetrics.classification import MulticlassAccuracy
import matplotlib.pyplot as plt
from sklearn import metrics

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training, p = .6)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
            

VALIDATION = "val"
TRAINING = "train"
TEST ="test"
EPOCHS = 20001
LEARNING_RATE = 3e-4

dataset = Planetoid(root='data/CORA', name='CORA')
data = dataset[0]
graph = data.to(device)
class HomoGNN:
    def __init__(self, graph):
        self.graph = graph
        self.model = GCN()
        self.model = self.model.to(device)
        

        self.loss_dict = {
            TRAINING:[],
            VALIDATION:[],
            TEST:[],
        }  
        self.aurocs_dict = {
            VALIDATION:[],
            TEST:[],
        }
        self.increased_loss = 0
        self.last_loss = 0
        self.epochs = range(EPOCHS)
#         self.plot_confusion_matrix()
    
    
    def train(self, mask):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        data = self.graph
        for epoch in self.epochs:
            print(epoch)
            self.model.train()
            optimizer.zero_grad()
            out = self.model(data.x.to(device), data.edge_index.to(device))
            loss = F.nll_loss(out[mask], data.y[mask])
            loss.backward()
            optimizer.step()

            self.loss_dict[TRAINING].append(loss.item())
            if epoch % 1 == 0:
                self.test(TEST, graph.test_mask)
                self.test(VALIDATION, graph.val_mask)
            if self.increased_loss == 10:
                print("Breaked")
                print(self.increased_loss)
                self.epochs = range(epoch+1)
                break
        
        
    def plot_loss(self):
        plt.plot(self.epochs, self.loss_dict[TRAINING], 'g', label='Training loss')
        plt.plot(self.epochs, self.loss_dict[TEST_LEIPZIG], 'b', label='Testing loss')
        plt.plot(self.epochs, self.loss_dict[TEST_GW], 'r', label='Testing GW loss')
        plt.plot(self.epochs, self.loss_dict[VALIDATION], 'y', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()
        
        
    def update_stopping_criteria(self, loss_item):
        if loss_item > self.last_loss:
            self.increased_loss += 1
        else:
            self.increased_loss = 0
        self.last_loss = loss_item
        
        
    def test(self, set_type, mask):
        data = self.graph
        with torch.inference_mode():
            self.model.eval()
            out = self.model(data.x.to(device), data.edge_index.to(device))
            loss = F.nll_loss(out[mask], data.y[mask])
            self.loss_dict[set_type].append(loss.item())
            pred_patient =  out[mask].max(dim=1)
            metric = MulticlassAccuracy(num_classes=dataset.num_classes)
            metric = metric.to(device)
            acc = metric(out[mask], self.graph.y[mask])
            print(set_type + " " + str(acc.item()))
            if set_type != VALIDATION:
                return
            print(loss.item())
            self.update_stopping_criteria(loss.cpu().item())
    
    def get_model(self):
        return self.model
		

In [99]:
hgnn = HomoGNN(graph)
hgnn.train(graph.train_mask)

0
test 0.1406092792749405
val 0.13419727981090546
1.9315857887268066
1
test 0.14749112725257874
val 0.13906969130039215
1.9298434257507324
2
test 0.15506511926651
val 0.15097445249557495
1.9280602931976318
3
test 0.16879230737686157
val 0.16162344813346863
1.926284670829773
4
test 0.17661775648593903
val 0.16343176364898682
1.9244422912597656
5
test 0.18027541041374207
val 0.17680227756500244
1.9225893020629883
6
test 0.1889784187078476
val 0.18127840757369995
1.9207252264022827
7
test 0.1981804519891739
val 0.19408732652664185
1.9188425540924072
8
test 0.2010306715965271
val 0.19765931367874146
1.9169464111328125
9
test 0.20787924528121948
val 0.20037180185317993
1.915057897567749
10
test 0.21994657814502716
val 0.20394378900527954
1.9131522178649902
11
test 0.2269917130470276
val 0.20661160349845886
1.9112390279769897
12
test 0.2323586642742157
val 0.21635642647743225
1.909297227859497
13
test 0.2411913424730301
val 0.22475619614124298
1.9073166847229004
14
test 0.251955509185791
val

test 0.7138475775718689
val 0.6820381879806519
1.5740751028060913
120
test 0.7173255085945129
val 0.6820381879806519
1.5709996223449707
121
test 0.7173255085945129
val 0.6829423308372498
1.5679343938827515
122
test 0.7184244394302368
val 0.6847738027572632
1.5649049282073975
123
test 0.720803439617157
val 0.6838696599006653
1.5618826150894165
124
test 0.7234722375869751
val 0.6838696599006653
1.5588529109954834
125
test 0.7241232991218567
val 0.6829655170440674
1.5558282136917114
126
test 0.7241232991218567
val 0.6918598413467407
1.552819013595581
127
test 0.7236754894256592
val 0.6918598413467407
1.5498296022415161
128
test 0.7241232991218567
val 0.6918598413467407
1.5468194484710693
129
test 0.7236754894256592
val 0.6918598413467407
1.5437911748886108
130
test 0.7268996834754944
val 0.6918598413467407
1.5407826900482178
131
test 0.7268996834754944
val 0.6900961995124817
1.537756085395813
132
test 0.7290974855422974
val 0.6900961995124817
1.5347572565078735
133
test 0.7300562858581543

val 0.7648534774780273
1.2495421171188354
244
test 0.7938064336776733
val 0.7648534774780273
1.247362494468689
245
test 0.7938064336776733
val 0.7648534774780273
1.2451941967010498
246
test 0.7960042953491211
val 0.7648534774780273
1.2429940700531006
247
test 0.7960042953491211
val 0.7648534774780273
1.2407909631729126
248
test 0.7960042953491211
val 0.7648534774780273
1.238620638847351
249
test 0.7960042953491211
val 0.7648534774780273
1.2364779710769653
250
test 0.7976473569869995
val 0.767195463180542
1.2343568801879883
251
test 0.7976473569869995
val 0.767195463180542
1.2322285175323486
252
test 0.7966886162757874
val 0.767195463180542
1.2301548719406128
253
test 0.7966886162757874
val 0.7770476341247559
1.2281148433685303
254
test 0.7977875471115112
val 0.7770476341247559
1.2261302471160889
255
test 0.7977875471115112
val 0.7770476341247559
1.2241706848144531
256
test 0.7987266778945923
val 0.7819737195968628
1.2221604585647583
257
test 0.7987266778945923
val 0.7802100777626038
1.

test 0.8148733973503113
val 0.7817955017089844
1.039492130279541
364
test 0.8148733973503113
val 0.7817955017089844
1.0380985736846924
365
test 0.8148733973503113
val 0.7817955017089844
1.0367463827133179
366
test 0.8148733973503113
val 0.7817955017089844
1.0354399681091309
367
test 0.8148733973503113
val 0.7792892456054688
1.034140706062317
368
test 0.8148733973503113
val 0.7792892456054688
1.0328551530838013
369
test 0.8148733973503113
val 0.7792892456054688
1.0315505266189575
370
test 0.8148733973503113
val 0.7792892456054688
1.030257225036621
371
test 0.8159723281860352
val 0.7792892456054688
1.0289170742034912
372
test 0.8159723281860352
val 0.781120777130127
1.0275585651397705
373
test 0.8144024610519409
val 0.781120777130127
1.0261852741241455
374
test 0.8144024610519409
val 0.781120777130127
1.0247935056686401
375
test 0.8144024610519409
val 0.781120777130127
1.02342689037323
376
test 0.8144024610519409
val 0.781120777130127
1.022120714187622
377
test 0.8144024610519409
val 0.7

test 0.8128188848495483
val 0.7829754948616028
0.9075433015823364
485
test 0.8128188848495483
val 0.7829754948616028
0.9067938327789307
486
test 0.8128188848495483
val 0.7829754948616028
0.906105101108551
487
test 0.8128188848495483
val 0.7829754948616028
0.9054109454154968
488
test 0.8128188848495483
val 0.7829754948616028
0.9047120809555054
489
test 0.8128188848495483
val 0.7829754948616028
0.9040293097496033
490
test 0.8128188848495483
val 0.7829754948616028
0.9033569097518921
491
test 0.8128188848495483
val 0.7829754948616028
0.9026412963867188
492
test 0.8128188848495483
val 0.7829754948616028
0.9018443822860718
493
test 0.8128188848495483
val 0.7829754948616028
0.9010311961174011
494
test 0.8128188848495483
val 0.7829754948616028
0.9002460241317749
495
test 0.8128188848495483
val 0.7829754948616028
0.899425208568573
496
test 0.812371015548706
val 0.7829754948616028
0.8986268639564514
497
test 0.812371015548706
val 0.7820713520050049
0.8978186845779419
498
test 0.812371015548706
v

test 0.8096609711647034
val 0.7919235229492188
0.8156272768974304
629
test 0.8096609711647034
val 0.7919235229492188
0.8152053952217102
630
test 0.8096609711647034
val 0.7919235229492188
0.8148390054702759
631
test 0.8096609711647034
val 0.7919235229492188
0.8144190907478333
632
test 0.8096609711647034
val 0.7919235229492188
0.8139779567718506
633
test 0.8096609711647034
val 0.7919235229492188
0.8134591579437256
634
test 0.8096609711647034
val 0.7919235229492188
0.8129290342330933
635
test 0.8096609711647034
val 0.7919235229492188
0.8124026656150818
636
test 0.8096609711647034
val 0.7919235229492188
0.8118522763252258
637
test 0.8096609711647034
val 0.7919235229492188
0.8112488985061646
638
test 0.8096609711647034
val 0.7919235229492188
0.810630738735199
639
test 0.8096609711647034
val 0.7919235229492188
0.810005784034729
640
test 0.8096609711647034
val 0.7919235229492188
0.8093938827514648
641
test 0.8096609711647034
val 0.7919235229492188
0.8088406920433044
642
test 0.809660971164703

val 0.7936871647834778
0.7628977298736572
774
test 0.8120880126953125
val 0.7936871647834778
0.7626107335090637
775
test 0.8120880126953125
val 0.7936871647834778
0.7623611688613892
776
test 0.8120880126953125
val 0.7936871647834778
0.7620670199394226
777
test 0.8120880126953125
val 0.7936871647834778
0.7617642879486084
778
test 0.8120880126953125
val 0.7936871647834778
0.7614912986755371
779
test 0.8120880126953125
val 0.7936871647834778
0.7611978650093079
780
test 0.8120880126953125
val 0.7936871647834778
0.7609235644340515
781
test 0.8120880126953125
val 0.7936871647834778
0.7606974244117737
782
test 0.8120880126953125
val 0.7936871647834778
0.760497510433197
783
test 0.8120880126953125
val 0.7936871647834778
0.7602744102478027
784
test 0.8120880126953125
val 0.7936871647834778
0.7600811719894409
785
test 0.8120880126953125
val 0.7936871647834778
0.7598759531974792
786
test 0.8120880126953125
val 0.7936871647834778
0.7596555352210999
787
test 0.811640202999115
val 0.7936871647834778

test 0.8101038932800293
val 0.7936871647834778
0.7364775538444519
895
test 0.8101038932800293
val 0.7936871647834778
0.7363483905792236
896
test 0.8101038932800293
val 0.7936871647834778
0.7362872362136841
897
test 0.8101038932800293
val 0.7927830219268799
0.7362309098243713
898
test 0.8101038932800293
val 0.7927830219268799
0.7361881732940674
899
test 0.8101038932800293
val 0.7927830219268799
0.7361658215522766
900
test 0.8101038932800293
val 0.7927830219268799
0.7361193299293518
901
test 0.8101038932800293
val 0.7927830219268799
0.7360885739326477
902
test 0.8120880126953125
val 0.7927830219268799
0.7360109090805054
903
test 0.8120880126953125
val 0.7927830219268799
0.7359392046928406
904
test 0.8120880126953125
val 0.7927830219268799
0.7358591556549072
905
test 0.8120880126953125
val 0.7927830219268799
0.7356938719749451
906
test 0.8120880126953125
val 0.7927830219268799
0.735515832901001
907
test 0.8101038932800293
val 0.7927830219268799
0.7352954149246216
908
test 0.81010389328002

val 0.7945467233657837
0.7228823304176331
1025
test 0.8065913915634155
val 0.7945467233657837
0.7228574752807617
1026
test 0.8065913915634155
val 0.7945467233657837
0.7228680849075317
1027
test 0.8065913915634155
val 0.7945467233657837
0.7228659391403198
1028
test 0.8065913915634155
val 0.7945467233657837
0.722896933555603
1029
test 0.8065913915634155
val 0.7945467233657837
0.7229475378990173
1030
test 0.806143581867218
val 0.7945467233657837
0.7229828834533691
1031
test 0.806143581867218
val 0.7945467233657837
0.7230337858200073
1032
test 0.806143581867218
val 0.7945467233657837
0.7230455279350281
1033
test 0.806143581867218
val 0.7945467233657837
0.723002552986145
1034
test 0.806143581867218
val 0.7945467233657837
0.7229763269424438
1035
test 0.8065913915634155
val 0.7945467233657837
0.7229480743408203
1036
test 0.8065913915634155
val 0.7945467233657837
0.7229191064834595
1037
test 0.8065913915634155
val 0.7945467233657837
0.722775936126709
1038
test 0.8065913915634155
val 0.79454672

test 0.8067728281021118
val 0.7945467233657837
0.7118070125579834
1167
test 0.8067728281021118
val 0.7945467233657837
0.7117397785186768
1168
test 0.8067728281021118
val 0.7945467233657837
0.711625874042511
1169
test 0.8067728281021118
val 0.7945467233657837
0.7115375995635986
1170
test 0.8067728281021118
val 0.7945467233657837
0.7114216089248657
1171
test 0.8067728281021118
val 0.7945467233657837
0.7113375663757324
1172
test 0.8067728281021118
val 0.7945467233657837
0.7112518548965454
1173
test 0.8067728281021118
val 0.7945467233657837
0.7111399173736572
1174
test 0.8067728281021118
val 0.7945467233657837
0.711039125919342
1175
test 0.8067728281021118
val 0.7945467233657837
0.7108790874481201
1176
test 0.8067728281021118
val 0.7945467233657837
0.710719645023346
1177
test 0.8067728281021118
val 0.7945467233657837
0.7105882167816162
1178
test 0.8067728281021118
val 0.7945467233657837
0.7104648351669312
1179
test 0.8067728281021118
val 0.7945467233657837
0.7103336453437805
1180
test 0.80

In [39]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

# Laden des CORA-Datensatzes
dataset = Planetoid(root='data/CORA', name='CORA')
data = dataset[0]
data = data.to(device)

# Definieren des Modells
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Initialisieren des Modells und definieren des Optimierers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Trainieren des Modells
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x.to(device), data.edge_index.to(device))
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# Evaluieren des Modells
model.eval()
_, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('Accuracy: {:.4f}'.format(acc))


Accuracy: 0.8100
