# Implementação da Skip-GCN

Implementa a variante "skip" da arquitetura apresentada no artigo ["Anti-Money Laundering in Bitcoin: Experimenting with Graph Convolutional Networks for Financial Forensics"][1]. Essa variante insere um conjunto de parâmetros novo e realiza uma multiplicação normal entre a primeira e a segunda camada escondida e soma seus resultados com os embeddings da primeira camada escondida. Foi implementado em [PyTorch Geometric][2].

[1]: https://arxiv.org/pdf/1908.02591.pdf
[2]: https://pytorch-geometric.readthedocs.io/en/latest/

## Importando as Bibliotecas

In [1]:
import torch
import torch.nn.functional as F
import torch_geometric.nn as nn
from sklearn.metrics import classification_report

## Carregando os dados

In [2]:
train_data = []
test_data = []

for i in range(1,35):
    train_data.append(torch.load('elliptic_pt/train/' + str(i) + '.pt'))

for i in range(35,50):
    test_data.append(torch.load('elliptic_pt/test/' + str(i) + '.pt'))

## Definindo o modelo

In [3]:
hidden_size = 100
n_classes = 2


class Skip_GCN(torch.nn.Module):
    def __init__(self):
        super(Skip_GCN, self).__init__()
        self.conv1 = nn.GCNConv(166,100, bias=False)
        self.act1 = torch.nn.ReLU()
        self.conv2 = nn.GCNConv(100,2, bias=False)
        # Conexão "skip".
        self.W_skip = torch.nn.Linear(166, 2, bias=False)
        self.act2 = torch.nn.Softmax(dim=1)
    
    def forward(self, x, edge_index, batch_index):
        hidden1 = self.conv1(x, edge_index)
        hidden1 = self.act1(hidden1)
        hidden2 = self.conv2(hidden1, edge_index)
        H_skip = self.W_skip(x)
        skip_output = torch.add(hidden2, H_skip)
        output = self.act2(skip_output)
        
        return hidden2, output

In [4]:
model = Skip_GCN()

In [5]:
model

Skip_GCN(
  (conv1): GCNConv(166, 100)
  (act1): ReLU()
  (conv2): GCNConv(100, 2)
  (W_skip): Linear(in_features=166, out_features=2, bias=False)
  (act2): Softmax(dim=1)
)

In [6]:
## Use a GPU para treinar
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [7]:
loss = torch.nn.CrossEntropyLoss(weight=torch.Tensor([0.7, 0.3]))

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Treinando

In [9]:
def train():
    model.train()
    # Enumerar sobre os dados.
    for epoch in range(1000):
        for ts, data in enumerate(train_data):
            # Usar GPU
            data.to(device)
            # Resetar Gradientes
            optimizer.zero_grad()
            # Passando as informações do batch e de conexão dos grafos
            hidden, logits = model(data.x.float(), data.edge_index, None)
            #label_pred = pred.max(1)[1]
            # Calculando a perda e os gradientes.
            l = loss(logits, data.y)
            l.backward()
            # Atualizar usando os gradientes.
            optimizer.step()
        if (epoch + 1) % 100 == 0:
            print('ts',ts+1,'epoch =', epoch + 1, 'loss =', l.item())

In [10]:
train()

ts 34 epoch = 100 loss = 0.3446897566318512
ts 34 epoch = 200 loss = 0.3351903259754181
ts 34 epoch = 300 loss = 0.3345642685890198
ts 34 epoch = 400 loss = 0.3300868272781372
ts 34 epoch = 500 loss = 0.33000448346138
ts 34 epoch = 600 loss = 0.329947829246521
ts 34 epoch = 700 loss = 0.32987141609191895
ts 34 epoch = 800 loss = 0.33159339427948
ts 34 epoch = 900 loss = 0.32982337474823
ts 34 epoch = 1000 loss = 0.3298146426677704


## Testando

In [11]:
label_pred_list = []
y_true_list = []

def test():
    model.eval()
    with torch.no_grad():
        global label_pred_list
        global y_true_list
        for data in test_data:
            data.to(device)
            _, logits = model(data.x.float(), data.edge_index, None)
            label_pred = logits.max(1)[1].tolist()
            label_pred_list += label_pred
            y_true_list += data.y.tolist()
    model.train()

In [12]:
test()

## Resultados

In [13]:
print(classification_report(y_true_list,label_pred_list))

              precision    recall  f1-score   support

           0       0.81      0.56      0.66      1083
           1       0.97      0.99      0.98     15587

    accuracy                           0.96     16670
   macro avg       0.89      0.78      0.82     16670
weighted avg       0.96      0.96      0.96     16670

