In [1]:
import torch

In [2]:
import torch.nn.functional as F
import torch_geometric.nn as nn

In [3]:
train_data = []
test_data = []

for i in range(1,35):
    train_data.append(torch.load('separed_timestamps/without_unknowns/train/' + str(i) + '.pt'))

for i in range(35,50):
    test_data.append(torch.load('separed_timestamps/without_unknowns/test/' + str(i) + '.pt'))

In [4]:
hidden_size = 100
n_classes = 2


class Skip_GCN(torch.nn.Module):
    def __init__(self):
        super(Skip_GCN, self).__init__()
        self.conv1 = nn.GCNConv(166,100, bias=False)
        self.act1 = torch.nn.ReLU()
        self.conv2 = nn.GCNConv(100,2, bias=False)
        self.W_skip = torch.nn.Parameter(torch.Tensor(166, 2))
        torch.nn.init.xavier_uniform_(self.W_skip)
        self.act2 = torch.nn.Softmax(dim=1)
    
    def forward(self, x, edge_index, batch_index):
        hidden1 = self.conv1(x, edge_index)
        hidden1 = self.act1(hidden1)
        hidden2 = self.conv2(hidden1, edge_index)
        H_skip = torch.mm(x, self.W_skip)
        skip_output = torch.add(hidden2, H_skip)
        output = self.act2(skip_output)
        
        return hidden2, output

In [5]:
model = Skip_GCN()

In [6]:
model

Skip_GCN(
  (conv1): GCNConv(166, 100)
  (act1): ReLU()
  (conv2): GCNConv(100, 2)
  (act2): Softmax(dim=1)
)

In [7]:
## Use a GPU para treinar
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [8]:
loss = torch.nn.CrossEntropyLoss(weight=torch.Tensor([0.7, 0.3]))

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [10]:
def train():
    model.train()
    # Enumerar sobre os dados.
    for ts, data in enumerate(train_data):
        for epoch in range(1000):
            # Usar GPU
            data.to(device)
            # Resetar Gradientes
            optimizer.zero_grad()
            # Passando as informações do batch e de conexão dos grafos
            hidden, logits = model(data.x.float(), data.edge_index, None)
            #label_pred = pred.max(1)[1]
            # Calculando a perda e os gradientes.
            l = loss(logits, data.y)
            l.backward()
            # Atualizar usando os gradientes.
            optimizer.step()
        if (epoch + 1) % 100 == 0:
            print('ts',ts+1,'epoch =', epoch + 1, 'loss =', l.item())

In [11]:
train()

ts 1 epoch = 1000 loss = 0.32516300678253174
ts 2 epoch = 1000 loss = 0.31331831216812134
ts 3 epoch = 1000 loss = 0.32409191131591797
ts 4 epoch = 1000 loss = 0.32431960105895996
ts 5 epoch = 1000 loss = 0.31696781516075134
ts 6 epoch = 1000 loss = 0.3274991810321808
ts 7 epoch = 1000 loss = 0.3394516408443451
ts 8 epoch = 1000 loss = 0.33933112025260925
ts 9 epoch = 1000 loss = 0.34163084626197815
ts 10 epoch = 1000 loss = 0.3213188052177429
ts 11 epoch = 1000 loss = 0.3236527144908905
ts 12 epoch = 1000 loss = 0.32653698325157166
ts 13 epoch = 1000 loss = 0.341155081987381
ts 14 epoch = 1000 loss = 0.342074990272522
ts 15 epoch = 1000 loss = 0.33416494727134705
ts 16 epoch = 1000 loss = 0.3413386940956116
ts 17 epoch = 1000 loss = 0.3556835353374481
ts 18 epoch = 1000 loss = 0.34744542837142944
ts 19 epoch = 1000 loss = 0.3410555422306061
ts 20 epoch = 1000 loss = 0.4012643098831177
ts 21 epoch = 1000 loss = 0.3903186321258545
ts 22 epoch = 1000 loss = 0.3484022617340088
ts 23 epoch

In [21]:
label_pred_list = []
y_true_list = []

def test():
    model.eval()
    with torch.no_grad():
        global label_pred_list
        global y_true_list
        for data in test_data:
            data.to(device)
            _, logits = model(data.x.float(), data.edge_index, None)
            label_pred = logits.max(1)[1].tolist()
            label_pred_list += label_pred
            y_true_list += data.y.tolist()
    model.train()

In [22]:
test()

In [23]:
from sklearn.metrics import classification_report

In [25]:
print(classification_report(y_true_list,label_pred_list))

              precision    recall  f1-score   support

           0       0.80      0.50      0.62      1083
           1       0.97      0.99      0.98     15587

    accuracy                           0.96     16670
   macro avg       0.88      0.75      0.80     16670
weighted avg       0.96      0.96      0.96     16670

