In [1]:
import csv
import time 
import datetime
from statistics import mean, stdev
from dateutil import parser
import numpy as np

In [2]:
class Node():
    
    def __init__(self, addr, name, acc_type, contract_type, entity, label, tags = []):
        self.addr = addr
        self.name = name
        self.acc_type = acc_type
        self.contract_type = contract_type
        self.entity = entity
        self.label = label
        self.tags = tags      
        self.txs_in = []
        self.txs_out = []
        
    def add_tx_in(self, tx):
        self.txs_in.append(tx)
        
    def add_tx_out(self, tx):
        self.txs_out.append(tx)
        
    def feature_vec(self):
        vec = [0] * 8
        type_indexes = {'Wallet': 0, 'Smart Contract': 1}
        
        vec[type_indexes[self.acc_type]] = 1
        
        if self.contract_type == 'Token':
            vec[2] = 1
            
        if self.entity:
            entity_indexes = {'Exchange': 3, 'Dex': 4, 'ICO Wallets': 5, 'Mining': 6, 'DeFi': 7}
            vec[entity_indexes[self.entity]] = 1
        
        return vec
            
        
    def __str__(self):
        return f"{self.addr} - {self.name} \n{self.acc_type}, {self.entity}, {self.contract_type}"

In [3]:
class Tx():
    #A transaction is a transfer of ether, a contract invocation, or contract creation
    def __init__(self, tx_type, block_num, timestamp, sender, receiver, value, gasLim, gasUsed, gasPrice, isError, code):
        self.tx_type = tx_type
        self.block_num = block_num
        self.timestamp = timestamp
        self.sender = sender
        self.receiver = receiver
        self.value = value
        self.gasLim = gasLim
        self.gasUsed = gasUsed
        self.gasPrice = gasPrice
        self.isError = isError
        self.code = code
    
    def edge_vec(self):
        vec = [0] * 8
        
        type_indexes = {'Transfer': 0, 'Contract Invocation': 1, 'Contract Creation': 2}
        vec[type_indexes[self.tx_type]] = 1
        
        vec[3] = self.value
        vec[4] = self.gasLim
        vec[5] = self.gasUsed
        vec[6] = self.gasPrice
        vec[7] = self.isError
        
        return vec
    
    def __str__(self):
        return f"{self.value} ETH from {self.sender} to {self.receiver} at {self.timestamp}"
        

In [4]:
def parse_node_row(row):
    return row[0], row[1], row[2], row[3], row[4], row[5], [x for x in row[6:] if x]

def parse_tx_row(row):
    code = row[13]
    
    if code == '0x' or int(code, 0) == 0:
        tx_type = 'Transfer'
        
    else:
        tx_type = 'Contract Invocation'
    
    sender, receiver = row[6], row[7]
    if receiver == '':
        tx_type = 'Contract Creation'
        receiver = row[14]
    
    d = datetime.datetime.fromtimestamp(int(row[1]))
    
    return tx_type, row[0], d, sender, receiver, float(row[8]), int(row[9]), int(row[16]), int(row[10]), int(row[11]), code

In [5]:
#Fetch reference nodes
reference_nodes = {}
with open("final.csv") as csv_file:
    csv_reader = csv.reader(csv_file, delimiter = ',')
    line_count = 0
    
    for row in csv_reader:
        
        if line_count == 0:
            line_count += 1 
            continue
        else:
            addr, name, acc_type, contract_type, entity, label, tags = parse_node_row(row)
            reference_nodes[row[0]] = Node(addr, name, acc_type, contract_type, entity, label, tags)

In [6]:
#Fetch transactions
txs = []
with open('2019-2-11-to-2019-4-11-txs-byN.csv') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter = ',')
    line_count = 0
    
    for row in csv_reader:
        
        if line_count == 0:
            line_count += 1
            continue
        else:
            tx_type, block_num, timestamp, sender, receiver, value, gasSent, gasUsed, gasPrice, isError, code = parse_tx_row(row)
            txs.append(Tx(tx_type, block_num, timestamp, sender, receiver, value, gasSent, gasUsed, gasPrice, isError, code))

        
txs.sort(key = lambda x: x.timestamp)

In [7]:
#Only external transactions retrieved (from blockchain), so senders must be EOAs
senders = set()
for tx in txs:
    senders.add(tx.sender)
    
for tx in txs:
    if tx.tx_type == 'Contract Invocation' and tx.receiver in senders:
        tx.tx_type = 'Transfer'
        
#Need to use etherscan if want to fetch internal transactions

In [8]:
types= {
    'Transfer': 0,
    'Contract Creation': 0,
    'Contract Invocation': 0
}
for tx in txs:
    types[tx.tx_type] +=1
    
print(types)

{'Transfer': 404904, 'Contract Creation': 2225, 'Contract Invocation': 330602}


In [9]:
#Turn transaction list into a nodelist

In [10]:
nodes = {}
for tx in txs:
    
    sender, receiver = tx.sender, tx.receiver
    
    if sender in nodes.keys():
        nodes[sender].add_tx_out(tx)
    else:
        if sender in reference_nodes.keys():
            nodes[sender] = reference_nodes[sender]
            nodes[sender].add_tx_out(tx)
        else:
            nodes[sender] = Node(sender, '', 'Wallet', '', '', 'Unlabelled')
            nodes[sender].add_tx_out(tx)
        
        
    if receiver in nodes.keys():
        nodes[receiver].add_tx_in(tx)
    else:
        if receiver in reference_nodes.keys():
            nodes[receiver] = reference_nodes[receiver]
            nodes[receiver].add_tx_in(tx)
        else:
            nodes[receiver] = Node(receiver, '', 'Wallet' if tx.tx_type == 'Transfer' else 'Smart Contract', '', '', 'Unlabelled')
            nodes[receiver].add_tx_in(tx)

In [11]:
#Shuffle the nodes and then assign node numbers - so graphs are roughly same size (means v sparse)
import random
keys = list(nodes.keys())
random.shuffle(keys)
for i, addr in enumerate(keys):
    nodes[addr].node_id = i

In [12]:
lim = 0
# self.addr = addr
#         self.name = name
#         self.acc_type = acc_type
#         self.contract_type = contract_type
#         self.entity = entity
#         self.label = label
#         self.tags = tags 
ls = {
    'Dodgy': 0,
    'Legit': 0,
    'Unlabelled': 0
}
tos = {
    'Token': 0,
    '': 0
}
ts = {
    'Wallet': 0,
    'Smart Contract': 0
}
es = {
    'Dex': 0,
    'Exchange': 0,
    'ICO Wallets': 0,
    'Mining': 0,
    'DeFi': 0,
    '': 0
}

for node in nodes.values():
    if lim == 100:
        print(node)
        print(node.node_id)
        print(node.feature_vec())
        lim += 1
    
    ls[node.label] += 1
    tos[node.contract_type] += 1
    es[node.entity] += 1
    ts[node.acc_type] += 1
    
    
print(ls)
print(ts)
print(es)
print(tos)

{'Dodgy': 126, 'Legit': 1045, 'Unlabelled': 273002}
{'Wallet': 262305, 'Smart Contract': 11868}
{'Dex': 16, 'Exchange': 112, 'ICO Wallets': 8, 'Mining': 26, 'DeFi': 27, '': 273984}
{'Token': 709, '': 273464}


In [13]:
#Need to normalize edge features before splitting into sub graphs, for consistency across graphs
vec_length = len(txs[0].edge_vec())
maxs = txs[0].edge_vec()
mins = txs[0].edge_vec()

for tx in txs:
    vec = tx.edge_vec()
    
    for i in range(vec_length):
        maxs[i] = vec[i] if vec[i] > maxs[i] else maxs[i]
        mins[i] = vec[i] if vec[i] < mins[i] else mins[i]
        
maxs = np.array(maxs)
mins = np.array(mins)

In [14]:
def split_txs(txs, splitter):
    return splitter(txs)

In [15]:
def split_by_day(txs):
    
    txs_by_day = {}
    
    for tx in txs:
        date_str = str(tx.timestamp.month) + '/' + str(tx.timestamp.day)
        
        if date_str in txs_by_day.keys():
            txs_by_day[date_str].append(tx)
        else:
            txs_by_day[date_str] = [tx]
        
    return txs_by_day
    

In [16]:
txs_dict = split_txs(txs, split_by_day)
del txs_dict['2/10'] #first day is not a full day

In [17]:
print(mins)
print(maxs)

[    0.     0.     0.     0. 21000. 14150.     0.     0.]
[1.000000e+00 1.000000e+00 1.000000e+00 5.000000e+22 8.000000e+06
 7.614876e+06 3.999900e+13 1.000000e+00]


In [18]:
import torch
from torch_geometric.data import Data

In [19]:
demo_node = next(iter(nodes.values()))
feat_length = len(demo_node.feature_vec())

feats = torch.zeros(len(nodes), feat_length, dtype=torch.float)

In [20]:
for node in nodes.values():
    feats[node.node_id] = torch.FloatTensor(node.feature_vec())

In [21]:
def normalize(maxs, mins, vec):
    return (vec - mins) / (maxs - mins)

def denormalize(maxs, mins, vec):
    return vec * (maxs - mins) + mins

In [22]:
def make_subgraph(tx_list, feats):
    l = len(tx_list)
    u = torch.zeros([l,1], dtype=torch.long)
    v = torch.zeros([l,1], dtype=torch.long)
    e = torch.zeros([l,vec_length], dtype=torch.float32)
    
    edges = {}
    count = 0
    for i, tx in enumerate(tx_list):
        if nodes[tx.sender].node_id not in edges.keys():
            edges[nodes[tx.sender].node_id] = count
            count += 1 
            
        if nodes[tx.receiver].node_id not in edges.keys():
            edges[nodes[tx.receiver].node_id] = count
            count += 1 
            
        u[i] = edges[nodes[tx.sender].node_id]
        v[i] = edges[nodes[tx.receiver].node_id]
        e[i] = torch.FloatTensor(tx.edge_vec())
        
    e = normalize(maxs, mins, e)    
    edge_index = torch.stack((u.t(), v.t()), dim=0)
    edge_index = edge_index.squeeze()
    
    mask = list(edges.keys())
    
    g = Data(x=feats[mask], edge_index=edge_index, edge_attr=e)
    
    return g, mask

In [23]:
graphs = {}
for incr in txs_dict.keys():
    graphs[incr], _ = make_subgraph(txs_dict[incr], feats)

In [24]:
graphs

{'2/11': Data(edge_attr=[55943, 8], edge_index=[2, 55943], x=[34914, 8]),
 '2/12': Data(edge_attr=[39643, 8], edge_index=[2, 39643], x=[22388, 8]),
 '2/13': Data(edge_attr=[35093, 8], edge_index=[2, 35093], x=[20215, 8]),
 '2/14': Data(edge_attr=[31703, 8], edge_index=[2, 31703], x=[19179, 8]),
 '2/15': Data(edge_attr=[32583, 8], edge_index=[2, 32583], x=[18286, 8]),
 '2/16': Data(edge_attr=[26403, 8], edge_index=[2, 26403], x=[13754, 8]),
 '2/17': Data(edge_attr=[21063, 8], edge_index=[2, 21063], x=[10902, 8]),
 '2/18': Data(edge_attr=[21593, 8], edge_index=[2, 21593], x=[11609, 8]),
 '2/19': Data(edge_attr=[21727, 8], edge_index=[2, 21727], x=[11974, 8]),
 '2/20': Data(edge_attr=[18630, 8], edge_index=[2, 18630], x=[10285, 8]),
 '2/21': Data(edge_attr=[14732, 8], edge_index=[2, 14732], x=[8005, 8]),
 '2/22': Data(edge_attr=[13271, 8], edge_index=[2, 13271], x=[7839, 8]),
 '2/23': Data(edge_attr=[12373, 8], edge_index=[2, 12373], x=[7375, 8]),
 '2/24': Data(edge_attr=[12805, 8], edge_

In [25]:
def node_to_address(node_num):
    for node in nodes.values():
        if node.node_id == node_num:
            return node.addr
        
    return ''

In [26]:
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE
from torch_geometric.utils import train_test_split_edges, remove_isolated_nodes, to_dense_adj

In [27]:
class Encoder(torch.nn.Module):
    def __init__(self, model_type, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.model_type = model_type
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=False)
        if self.model_type in ['GAE']:
            self.conv2 = GCNConv(2 * out_channels, out_channels, cached=False)
        elif self.model_type in ['VGAE']:
            self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=False)
            self.conv_logstd = GCNConv(2 * out_channels, out_channels,
                                       cached=False)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        if self.model_type in ['GAE']:
            return self.conv2(x, edge_index)
        elif self.model_type in ['VGAE']:
            return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [28]:
def setup_ae_model(model_type, in_features, embed_dims):
    model_options = {'GAE': GAE, 'VGAE': VGAE}
    dev = torch.device('cuda:0')
    model = model_options[model_type](Encoder(model_type, in_features, embed_dims)).to(dev)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    return model, optimizer

In [29]:
def graph_split(g_list): 
    for g in g_list.values():
        g = train_test_split_edges(g)
    
def data_to_device(g, dev):
    x, train_pos_edge_index = g.x.to(dev), g.train_pos_edge_index.to(dev)
    return x, train_pos_edge_index

In [30]:
def train_gae(model, model_type, optimizer, graph_list, dev):
    model.train()
    optimizer.zero_grad()
    

    keys = list(graph_list.keys())
    random.shuffle(keys)
    
    for key in keys:
        graph = graph_list[key]
        x, train_pos_edge_index = data_to_device(graph, dev)
        z = model.encode(x, train_pos_edge_index)
        loss = model.recon_loss(z, train_pos_edge_index)
        if model_type in ['VGAE']:
            loss = loss + (1 / data.num_nodes) * model.kl_loss()
            
    loss.backward()
    optimizer.step()

In [31]:
def test_gae(model, graph_list, dev):
    dev = torch.device('cuda:0')
    AUC = 0
    AP = 0
    model.eval()
    with torch.no_grad():
        for graph in graph_list.values():
            z = model.encode(graph.x.to(dev), graph.train_pos_edge_index.to(dev))
            auc, ap = model.test(z, graph.test_pos_edge_index, graph.test_neg_edge_index)
            AUC += auc
            AP += ap
           
    N = len(graph_list)
    return AUC/N, AP/N

In [32]:
t = 'GAE'
model, optimizer = setup_ae_model(t, 8, 16)
graph_split(graphs)
for epoch in range(1001):
    print(f"Epoch {epoch}")
    train_gae(model, 'GAE', optimizer, graphs, 'cuda:0')
    auc, ap = test_gae(model, graphs, 'cuda')
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

Epoch 0
Epoch: 000, AUC: 0.8898, AP: 0.9231
Epoch 1
Epoch: 001, AUC: 0.9410, AP: 0.9539
Epoch 2
Epoch: 002, AUC: 0.9592, AP: 0.9662
Epoch 3
Epoch: 003, AUC: 0.9660, AP: 0.9701
Epoch 4
Epoch: 004, AUC: 0.9711, AP: 0.9728
Epoch 5
Epoch: 005, AUC: 0.9748, AP: 0.9751
Epoch 6
Epoch: 006, AUC: 0.9756, AP: 0.9755
Epoch 7
Epoch: 007, AUC: 0.9754, AP: 0.9760
Epoch 8
Epoch: 008, AUC: 0.9751, AP: 0.9757
Epoch 9
Epoch: 009, AUC: 0.9757, AP: 0.9750
Epoch 10
Epoch: 010, AUC: 0.9742, AP: 0.9737
Epoch 11
Epoch: 011, AUC: 0.9722, AP: 0.9727
Epoch 12
Epoch: 012, AUC: 0.9719, AP: 0.9729
Epoch 13
Epoch: 013, AUC: 0.9731, AP: 0.9736
Epoch 14
Epoch: 014, AUC: 0.9753, AP: 0.9741
Epoch 15
Epoch: 015, AUC: 0.9746, AP: 0.9749
Epoch 16
Epoch: 016, AUC: 0.9757, AP: 0.9747
Epoch 17
Epoch: 017, AUC: 0.9728, AP: 0.9741
Epoch 18
Epoch: 018, AUC: 0.9696, AP: 0.9717
Epoch 19
Epoch: 019, AUC: 0.9663, AP: 0.9709
Epoch 20
Epoch: 020, AUC: 0.9649, AP: 0.9702
Epoch 21
Epoch: 021, AUC: 0.9653, AP: 0.9696
Epoch 22
Epoch: 022,

Epoch: 181, AUC: 0.5121, AP: 0.7032
Epoch 182
Epoch: 182, AUC: 0.5130, AP: 0.7033
Epoch 183
Epoch: 183, AUC: 0.5137, AP: 0.7035
Epoch 184
Epoch: 184, AUC: 0.5151, AP: 0.7039
Epoch 185
Epoch: 185, AUC: 0.5181, AP: 0.7049
Epoch 186
Epoch: 186, AUC: 0.5149, AP: 0.7041
Epoch 187
Epoch: 187, AUC: 0.5121, AP: 0.7020
Epoch 188
Epoch: 188, AUC: 0.5121, AP: 0.7021
Epoch 189
Epoch: 189, AUC: 0.5133, AP: 0.7023
Epoch 190
Epoch: 190, AUC: 0.5120, AP: 0.7017
Epoch 191
Epoch: 191, AUC: 0.5112, AP: 0.7011
Epoch 192
Epoch: 192, AUC: 0.5115, AP: 0.7016
Epoch 193
Epoch: 193, AUC: 0.5138, AP: 0.7026
Epoch 194
Epoch: 194, AUC: 0.5162, AP: 0.7037
Epoch 195
Epoch: 195, AUC: 0.5177, AP: 0.7040
Epoch 196
Epoch: 196, AUC: 0.5176, AP: 0.7043
Epoch 197
Epoch: 197, AUC: 0.5156, AP: 0.7040
Epoch 198
Epoch: 198, AUC: 0.5142, AP: 0.7037
Epoch 199
Epoch: 199, AUC: 0.5130, AP: 0.7040
Epoch 200
Epoch: 200, AUC: 0.5136, AP: 0.7034
Epoch 201
Epoch: 201, AUC: 0.5128, AP: 0.7023
Epoch 202
Epoch: 202, AUC: 0.5105, AP: 0.701

Epoch: 360, AUC: 0.5106, AP: 0.7009
Epoch 361
Epoch: 361, AUC: 0.5102, AP: 0.7002
Epoch 362
Epoch: 362, AUC: 0.5100, AP: 0.7011
Epoch 363
Epoch: 363, AUC: 0.5125, AP: 0.7023
Epoch 364
Epoch: 364, AUC: 0.5126, AP: 0.7030
Epoch 365
Epoch: 365, AUC: 0.5124, AP: 0.7038
Epoch 366
Epoch: 366, AUC: 0.5130, AP: 0.7044
Epoch 367
Epoch: 367, AUC: 0.5149, AP: 0.7046
Epoch 368
Epoch: 368, AUC: 0.5147, AP: 0.7047
Epoch 369
Epoch: 369, AUC: 0.5147, AP: 0.7043
Epoch 370
Epoch: 370, AUC: 0.5140, AP: 0.7037
Epoch 371
Epoch: 371, AUC: 0.5138, AP: 0.7034
Epoch 372
Epoch: 372, AUC: 0.5136, AP: 0.7032
Epoch 373
Epoch: 373, AUC: 0.5127, AP: 0.7030
Epoch 374
Epoch: 374, AUC: 0.5130, AP: 0.7031
Epoch 375
Epoch: 375, AUC: 0.5131, AP: 0.7034
Epoch 376
Epoch: 376, AUC: 0.5124, AP: 0.7039
Epoch 377
Epoch: 377, AUC: 0.5159, AP: 0.7044
Epoch 378
Epoch: 378, AUC: 0.5169, AP: 0.7042
Epoch 379
Epoch: 379, AUC: 0.5165, AP: 0.7037
Epoch 380
Epoch: 380, AUC: 0.5146, AP: 0.7027
Epoch 381
Epoch: 381, AUC: 0.5115, AP: 0.701

Epoch: 539, AUC: 0.5159, AP: 0.7044
Epoch 540
Epoch: 540, AUC: 0.5151, AP: 0.7046
Epoch 541
Epoch: 541, AUC: 0.5149, AP: 0.7043
Epoch 542
Epoch: 542, AUC: 0.5149, AP: 0.7043
Epoch 543
Epoch: 543, AUC: 0.5145, AP: 0.7041
Epoch 544
Epoch: 544, AUC: 0.5135, AP: 0.7038
Epoch 545
Epoch: 545, AUC: 0.5130, AP: 0.7039
Epoch 546
Epoch: 546, AUC: 0.5159, AP: 0.7051
Epoch 547
Epoch: 547, AUC: 0.5180, AP: 0.7058
Epoch 548
Epoch: 548, AUC: 0.5182, AP: 0.7058
Epoch 549
Epoch: 549, AUC: 0.5178, AP: 0.7053
Epoch 550
Epoch: 550, AUC: 0.5161, AP: 0.7041
Epoch 551
Epoch: 551, AUC: 0.5140, AP: 0.7028
Epoch 552
Epoch: 552, AUC: 0.5126, AP: 0.7015
Epoch 553
Epoch: 553, AUC: 0.5126, AP: 0.7016
Epoch 554
Epoch: 554, AUC: 0.5124, AP: 0.7019
Epoch 555
Epoch: 555, AUC: 0.5120, AP: 0.7023
Epoch 556
Epoch: 556, AUC: 0.5130, AP: 0.7032
Epoch 557
Epoch: 557, AUC: 0.5147, AP: 0.7041
Epoch 558
Epoch: 558, AUC: 0.5156, AP: 0.7046
Epoch 559
Epoch: 559, AUC: 0.5171, AP: 0.7071
Epoch 560
Epoch: 560, AUC: 0.5154, AP: 0.704

Epoch: 718, AUC: 0.5101, AP: 0.7010
Epoch 719
Epoch: 719, AUC: 0.5081, AP: 0.7000
Epoch 720
Epoch: 720, AUC: 0.5098, AP: 0.7011
Epoch 721
Epoch: 721, AUC: 0.5101, AP: 0.7015
Epoch 722
Epoch: 722, AUC: 0.5126, AP: 0.7028
Epoch 723
Epoch: 723, AUC: 0.5146, AP: 0.7039
Epoch 724
Epoch: 724, AUC: 0.5159, AP: 0.7050
Epoch 725
Epoch: 725, AUC: 0.5165, AP: 0.7064
Epoch 726
Epoch: 726, AUC: 0.5161, AP: 0.7049
Epoch 727
Epoch: 727, AUC: 0.5154, AP: 0.7036
Epoch 728
Epoch: 728, AUC: 0.5131, AP: 0.7034
Epoch 729
Epoch: 729, AUC: 0.5115, AP: 0.7030
Epoch 730
Epoch: 730, AUC: 0.5111, AP: 0.7036
Epoch 731
Epoch: 731, AUC: 0.5124, AP: 0.7035
Epoch 732
Epoch: 732, AUC: 0.5128, AP: 0.7037
Epoch 733
Epoch: 733, AUC: 0.5164, AP: 0.7060
Epoch 734
Epoch: 734, AUC: 0.5165, AP: 0.7053
Epoch 735
Epoch: 735, AUC: 0.5162, AP: 0.7051
Epoch 736
Epoch: 736, AUC: 0.5133, AP: 0.7040
Epoch 737
Epoch: 737, AUC: 0.5122, AP: 0.7034
Epoch 738
Epoch: 738, AUC: 0.5112, AP: 0.7027
Epoch 739
Epoch: 739, AUC: 0.5114, AP: 0.702

Epoch: 897, AUC: 0.5124, AP: 0.7029
Epoch 898
Epoch: 898, AUC: 0.5109, AP: 0.7026
Epoch 899
Epoch: 899, AUC: 0.5119, AP: 0.7031
Epoch 900
Epoch: 900, AUC: 0.5121, AP: 0.7032
Epoch 901
Epoch: 901, AUC: 0.5117, AP: 0.7029
Epoch 902
Epoch: 902, AUC: 0.5115, AP: 0.7031
Epoch 903
Epoch: 903, AUC: 0.5128, AP: 0.7024
Epoch 904
Epoch: 904, AUC: 0.5158, AP: 0.7049
Epoch 905
Epoch: 905, AUC: 0.5156, AP: 0.7041
Epoch 906
Epoch: 906, AUC: 0.5144, AP: 0.7055
Epoch 907
Epoch: 907, AUC: 0.5108, AP: 0.7023
Epoch 908
Epoch: 908, AUC: 0.5106, AP: 0.7021
Epoch 909
Epoch: 909, AUC: 0.5106, AP: 0.7017
Epoch 910
Epoch: 910, AUC: 0.5097, AP: 0.7006
Epoch 911
Epoch: 911, AUC: 0.5068, AP: 0.6985
Epoch 912
Epoch: 912, AUC: 0.5101, AP: 0.7029
Epoch 913
Epoch: 913, AUC: 0.5113, AP: 0.7022
Epoch 914
Epoch: 914, AUC: 0.5133, AP: 0.7039
Epoch 915
Epoch: 915, AUC: 0.5167, AP: 0.7065
Epoch 916
Epoch: 916, AUC: 0.5169, AP: 0.7060
Epoch 917
Epoch: 917, AUC: 0.5153, AP: 0.7054
Epoch 918
Epoch: 918, AUC: 0.5134, AP: 0.704

In [33]:
#RESET GRAPHS
graphs = {}
masks = {}
for incr in txs_dict.keys():
    graphs[incr], masks[incr] = make_subgraph(txs_dict[incr], feats)
     

In [34]:
#Stitch embeddings and feats into main tensor

In [35]:
# # model.eval()
# dev = torch.device('cuda:0')
# f = model.encode(graphs['2/14'].x.to(dev), graphs['2/13'].edge_index.to(dev))

In [36]:
len(nodes)

274173