In [4]:
# Install required packages.
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

# Anomaly Detection in Graphs using Self-Supervised Learning

![CentraleSupelec Logo](https://www.centralesupelec.fr/sites/all/themes/cs_theme/medias/common/images/intro/logo_nouveau.jpg)

This project represents my end-of-studies project, that was developped for a big company and in association with French Engineering school CentraleSupélec.

Having signed an NDA, we do not have the right to share the company's data. The data we will use is from a re-adapted Kaggle dataset :
```
https://www.kaggle.com/datasets/mkechinov/ecommerce-events-history-in-cosmetics-shop/
```


## Preprocessing the tabular data

In our sample data, we have e-commerce events, and we would like to convert those events into a **heterogeneous graph** with three types of nodes, *product*, *customer* and *user*, and with that also three types of edges that add some logic to these nodes.

In [1]:
import pandas as pd

In [2]:
database_path = "data_out_head_head.csv"

dataframe = pd.read_csv(database_path, index_col=0).fillna("")

In [3]:
dataframe

Unnamed: 0,event_time,event_type,product_id,category_id,category_code,brand,price,user_id,Session_id,Customer_id,Location,License_id,Session_start_datetime,Session_end_datetime,duration,License_start_date,License_end_date
0,2020-01-25 23:46:12,cart,5921712,2115334439910245200,,,5.16,388018099,843d560b-2069-4a0d-68af-f767f5341312,480374496_65446,"[-7.1208, -34.5019]",5921712,2020-01-25 18:23:12,2020-01-26 05:19:12,656.0,2020-01-22 07:56:13,2020-02-27 05:38:51
1,2020-02-15 14:43:37,remove_from_cart,5921712,2115334439910245200,,,5.16,459659126,457cee31-cfd9-4f75-909d-64f17021da9d,552795963_171732,"[55.0342, 6.547499999999999]",5921712,2020-02-15 07:57:37,2020-02-15 15:52:37,475.0,2020-01-22 07:56:13,2020-02-27 05:38:51
2,2020-02-09 20:57:57,remove_from_cart,5921712,2115334439910245200,,,5.16,405986628,a4354a0c-f44a-484c-96b7-b319f81e99de,405986628_283400,"[20.6167, -96.1167]",5921712,2020-02-09 14:17:57,2020-02-10 02:48:57,751.0,2020-01-22 07:56:13,2020-02-27 05:38:51
3,2020-02-05 05:30:46,view,5921712,2115334439910245200,,,5.16,571731968,10ba57c9-187e-454a-b57c-cdc71388cbe5,610461263_109078,"[25.7206, 76.8472]",5921712,2020-02-04 21:41:46,2020-02-05 10:44:46,783.0,2020-01-22 07:56:13,2020-02-27 05:38:51
4,2020-01-28 07:17:14,cart,5921712,2115334439910245200,,,5.16,601508456,201af163-9d3f-45ae-9511-7f64d8e168c1,530951720_21726,"[55.268, 1.476]",5921712,2020-01-27 21:33:14,2020-01-28 14:48:14,1035.0,2020-01-22 07:56:13,2020-02-27 05:38:51
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,2020-02-12 10:30:12,view,4810,1487580006317032337,,,6.67,612193632,a469ba53-4047-4a0a-a22e-05e7c922cb65,561323751_958571,"[51.8844, -94.1464]",4810,2020-02-12 02:29:12,2020-02-12 17:27:12,898.0,2019-09-30 21:07:40,2020-03-01 03:55:29
996,2020-02-15 18:39:37,cart,4810,1487580006317032337,,,6.67,514701786,67748dce-b81f-46a4-b943-2667fc5edb15,514701786_629210,"[13.9833, 125.9]",4810,2020-02-15 16:25:37,2020-02-16 01:54:37,569.0,2019-09-30 21:07:40,2020-03-01 03:55:29
997,2019-10-09 11:13:26,view,4810,1487580006317032337,,,6.67,551061566,41af1dc9-7c2f-4222-9919-18a449341d1b,,"[-11.72, -56.3278]",4810,2019-10-09 10:26:26,2019-10-09 14:38:26,252.0,2019-09-30 21:07:40,2020-03-01 03:55:29
998,2020-02-13 12:17:35,cart,4810,1487580006317032337,,,6.67,465231019,b65bbb4c-769a-45e0-b7db-99cbd6808883,465231019_93796,"[18.822699999999998, 123.3295]",4810,2020-02-13 04:02:35,2020-02-13 16:53:35,771.0,2019-09-30 21:07:40,2020-03-01 03:55:29


In [4]:
print(dataframe.info())

<class 'pandas.core.frame.DataFrame'>
Index: 1000 entries, 0 to 999
Data columns (total 17 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   event_time              1000 non-null   object 
 1   event_type              1000 non-null   object 
 2   product_id              1000 non-null   int64  
 3   category_id             1000 non-null   int64  
 4   category_code           1000 non-null   object 
 5   brand                   1000 non-null   object 
 6   price                   1000 non-null   float64
 7   user_id                 1000 non-null   int64  
 8   Session_id              1000 non-null   object 
 9   Customer_id             1000 non-null   object 
 10  Location                1000 non-null   object 
 11  License_id              1000 non-null   int64  
 12  Session_start_datetime  1000 non-null   object 
 13  Session_end_datetime    1000 non-null   object 
 14  duration                1000 non-null   float6

In [5]:
dataframe['latitude']=dataframe['Location'].apply(lambda x : float(x.split(',')[0][1:]))
dataframe['longitude']=dataframe['Location'].apply(lambda x : float(x.split(',')[1][:-1]))

def hex_to_dec(id):
  if id == "" or id is None: return 0
  else: return int(id, 16)


In [6]:
dataframe.head()

Unnamed: 0,event_time,event_type,product_id,category_id,category_code,brand,price,user_id,Session_id,Customer_id,Location,License_id,Session_start_datetime,Session_end_datetime,duration,License_start_date,License_end_date,latitude,longitude
0,2020-01-25 23:46:12,cart,5921712,2115334439910245200,,,5.16,388018099,843d560b-2069-4a0d-68af-f767f5341312,480374496_65446,"[-7.1208, -34.5019]",5921712,2020-01-25 18:23:12,2020-01-26 05:19:12,656.0,2020-01-22 07:56:13,2020-02-27 05:38:51,-7.1208,-34.5019
1,2020-02-15 14:43:37,remove_from_cart,5921712,2115334439910245200,,,5.16,459659126,457cee31-cfd9-4f75-909d-64f17021da9d,552795963_171732,"[55.0342, 6.547499999999999]",5921712,2020-02-15 07:57:37,2020-02-15 15:52:37,475.0,2020-01-22 07:56:13,2020-02-27 05:38:51,55.0342,6.5475
2,2020-02-09 20:57:57,remove_from_cart,5921712,2115334439910245200,,,5.16,405986628,a4354a0c-f44a-484c-96b7-b319f81e99de,405986628_283400,"[20.6167, -96.1167]",5921712,2020-02-09 14:17:57,2020-02-10 02:48:57,751.0,2020-01-22 07:56:13,2020-02-27 05:38:51,20.6167,-96.1167
3,2020-02-05 05:30:46,view,5921712,2115334439910245200,,,5.16,571731968,10ba57c9-187e-454a-b57c-cdc71388cbe5,610461263_109078,"[25.7206, 76.8472]",5921712,2020-02-04 21:41:46,2020-02-05 10:44:46,783.0,2020-01-22 07:56:13,2020-02-27 05:38:51,25.7206,76.8472
4,2020-01-28 07:17:14,cart,5921712,2115334439910245200,,,5.16,601508456,201af163-9d3f-45ae-9511-7f64d8e168c1,530951720_21726,"[55.268, 1.476]",5921712,2020-01-27 21:33:14,2020-01-28 14:48:14,1035.0,2020-01-22 07:56:13,2020-02-27 05:38:51,55.268,1.476


In [6]:
dataframe

Unnamed: 0.1,Unnamed: 0,event_time,event_type,product_id,category_id,category_code,brand,price,user_id,Session_id,...,duration,License_start_date,License_end_date,latitude,longitude,session_id1,session_id2,session_id3,session_id4,session_id5
0,0,2020-01-25 23:46:12,cart,5921712,2115334439910245200,,,5.16,388018099,843d560b-2069-4a0d-68af-f767f5341312,...,656.0,2020-01-22 07:56:13,2020-02-27 05:38:51,-7.1208,-34.5019,2218612235,8297,18957,26799,272025867522834
1,1,2020-02-15 14:43:37,remove_from_cart,5921712,2115334439910245200,,,5.16,459659126,457cee31-cfd9-4f75-909d-64f17021da9d,...,475.0,2020-01-22 07:56:13,2020-02-27 05:38:51,55.0342,6.5475,1165815345,53209,20341,37021,110988131162781
2,2,2020-02-09 20:57:57,remove_from_cart,5921712,2115334439910245200,,,5.16,405986628,a4354a0c-f44a-484c-96b7-b319f81e99de,...,751.0,2020-01-22 07:56:13,2020-02-27 05:38:51,20.6167,-96.1167,2754955788,62538,18508,38583,196924118309342
3,3,2020-02-05 05:30:46,view,5921712,2115334439910245200,,,5.16,571731968,10ba57c9-187e-454a-b57c-cdc71388cbe5,...,783.0,2020-01-22 07:56:13,2020-02-27 05:38:51,25.7206,76.8472,280647625,6270,17738,46460,226254909918181
4,4,2020-01-28 07:17:14,cart,5921712,2115334439910245200,,,5.16,601508456,201af163-9d3f-45ae-9511-7f64d8e168c1,...,1035.0,2020-01-22 07:56:13,2020-02-27 05:38:51,55.2680,1.4760,538636643,40255,17838,38161,140071112108225
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,995,2020-02-12 10:30:12,view,4810,1487580006317032337,,,6.67,612193632,a469ba53-4047-4a0a-a22e-05e7c922cb65,...,898.0,2019-09-30 21:07:40,2020-03-01 03:55:29,51.8844,-94.1464,2758392403,16455,18954,41518,6493070084965
996,996,2020-02-15 18:39:37,cart,4810,1487580006317032337,,,6.67,514701786,67748dce-b81f-46a4-b943-2667fc5edb15,...,569.0,2019-09-30 21:07:40,2020-03-01 03:55:29,13.9833,125.9000,1735691726,47135,18084,47427,42228057561877
997,997,2019-10-09 11:13:26,view,4810,1487580006317032337,,,6.67,551061566,41af1dc9-7c2f-4222-9919-18a449341d1b,...,252.0,2019-09-30 21:07:40,2020-03-01 03:55:29,-11.7200,-56.3278,1101995465,31791,16930,39193,27093881855259
998,998,2020-02-13 12:17:35,cart,4810,1487580006317032337,,,6.67,465231019,b65bbb4c-769a-45e0-b7db-99cbd6808883,...,771.0,2019-09-30 21:07:40,2020-03-01 03:55:29,18.8227,123.3295,3059465036,30362,17888,47067,169100756158595


# Création du graphe

In [16]:
import torch
from datetime import datetime

def load_node_csv(dataframe, index_col, encoders=None):
    df = dataframe.set_index(index_col)
    mapping = {index: i for i, index in enumerate(df.index.unique())}

    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    return x, mapping

def load_edge_csv(dataframe, src_index_col, src_mapping, dst_index_col, dst_mapping,
                  encoders=None):
    df = dataframe

    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])
    
    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr


class IdentityEncoder(object):
    def __init__(self, dtype=None):
        self.dtype = dtype

    def __call__(self, df):
        return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)

class DateTimeEncoder(object):
    def __init__(self, dtype=None):
        self.dtype = dtype

    def __call__(self, df):
        return torch.from_numpy(pd.to_datetime(df).apply(lambda x: datetime.timestamp(x)).values).view(-1, 1).to(self.dtype)

class SessionIdEncoder(object):
    def __init__(self, dtype=None):
        self.dtype = dtype

    def __call__(self, df):
        return torch.from_numpy(df.str.split('-',expand=True)[0].apply(hex_to_dec).values).view(-1, 1).to(self.dtype)

In [17]:
customer_encodings = {
        "latitude": IdentityEncoder(dtype=torch.float),
        "longitude":IdentityEncoder(dtype=torch.float)
}
product_encodings = {
        "price": IdentityEncoder(dtype=torch.float),
        "category_id": IdentityEncoder(dtype=torch.float),
}
session_encodings = {
        "Session_id": SessionIdEncoder(dtype=torch.float),
        "Session_start_datetime": DateTimeEncoder(dtype=torch.float),
        "Session_end_datetime": DateTimeEncoder(dtype=torch.float),
        "user_id": IdentityEncoder(dtype=torch.float),
}
licence_encodings = {
        "License_id": IdentityEncoder(dtype=torch.float),
        "License_start_date": DateTimeEncoder(dtype=torch.float),
        "License_end_date": DateTimeEncoder(dtype=torch.float),
}

In [18]:
from torch_geometric.data import HeteroData

def generate_graph(dataframe):
    data = HeteroData()

    # Loading nodes into graph
    data['customer'].x, customer_mapping = load_node_csv(dataframe, "Customer_id",customer_encodings)
    data['product'].x, product_mapping = load_node_csv(dataframe, "product_id", product_encodings)
    _, user_mapping = load_node_csv(dataframe, "user_id")
    data['user'].num_nodes = len(user_mapping)  # user has no features

    # Loading edges into graph
    data['customer', 'has', 'user'].edge_index, _ = load_edge_csv(
        dataframe,
        src_index_col='Customer_id',
        src_mapping=customer_mapping,
        dst_index_col='user_id',
        dst_mapping=user_mapping,
    )

    data['product', 'license','customer'].edge_index, data[
        'product', 'license','customer'].edge_attr = load_edge_csv(
        dataframe,
        src_index_col='product_id',
        src_mapping=product_mapping,
        dst_index_col='Customer_id',
        dst_mapping=customer_mapping,
        encoders=licence_encodings
    )
    data['user', 'opened_session', 'product'].edge_index, data[
        'user', 'opened_session', 'product'].edge_attr = load_edge_csv(
        dataframe,
        src_index_col='user_id',
        src_mapping=user_mapping,
        dst_index_col='product_id',
        dst_mapping=product_mapping,
        encoders=session_encodings
    )
    return data

In [19]:
data = generate_graph(dataframe)
print(data)

HeteroData(
  [1mcustomer[0m={ x=[1000, 2] },
  [1mproduct[0m={ x=[1000, 2] },
  [1muser[0m={ num_nodes=581 },
  [1m(customer, has, user)[0m={ edge_index=[2, 1000] },
  [1m(product, license, customer)[0m={
    edge_index=[2, 1000],
    edge_attr=[1000, 3]
  },
  [1m(user, opened_session, product)[0m={
    edge_index=[2, 1000],
    edge_attr=[1000, 4]
  }
)


In [20]:
data.num_nodes_dict

{'customer': 1000, 'product': 1000, 'user': 581}

# Preprocessing du graphe


In [21]:
from torch_geometric.nn import MetaPath2Vec
from torch_sparse import SparseTensor

metapath = [
    ('customer', 'has', 'user'),
    ('user', 'opened_session', 'product'),
    ('product', 'license','customer')  
]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MetaPath2Vec(data.edge_index_dict, embedding_dim=2,
                     metapath=metapath, walk_length=5, context_size=3,
                     walks_per_node=3, num_negative_samples=1,
                     sparse=True).to(device)

loader = model.loader(batch_size=32, shuffle=True, num_workers=3)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)


def train(epoch, log_steps=100, eval_steps=2000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

        



In [22]:
for epoch in range(1, 10):
  train(epoch)

In [23]:
data['user'].x=model('user')

print(data)

HeteroData(
  [1mcustomer[0m={ x=[1000, 2] },
  [1mproduct[0m={ x=[1000, 2] },
  [1muser[0m={
    num_nodes=581,
    x=[581, 2]
  },
  [1m(customer, has, user)[0m={ edge_index=[2, 1000] },
  [1m(product, license, customer)[0m={
    edge_index=[2, 1000],
    edge_attr=[1000, 3]
  },
  [1m(user, opened_session, product)[0m={
    edge_index=[2, 1000],
    edge_attr=[1000, 4]
  }
)


Saving the graph for further testing

In [40]:
torch.save(data, "graph.pt")

# DOMINANT

![The proposed framework](framework.png)

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [25]:
data.num_nodes_dict

{'customer': 1000, 'product': 1000, 'user': 581}

In [26]:
import numpy as np

def dense_adj(data):
  adj_dict = {}
  for node_i in data.num_nodes_dict.keys():
    adj_dict[node_i] = {}
    for node_j in data.num_nodes_dict.keys():
      adj_dict[node_i][node_j] = torch.from_numpy(np.zeros((data.num_nodes_dict[node_i], data.num_nodes_dict[node_j])))
  for key in data.edge_index_dict.keys():
    a,_,b = key
    for (i,j) in data.edge_index_dict[key].numpy().transpose():
      adj_dict[a][b][i][j] = 1
    return adj_dict
  

On rajoute des transformations à notre graphe, propre à torch_geometric, pour que les performances d'apprentissage de nos réseaux de neurones soient plus hautes.

In [27]:
# data = torch.load("graph.pt")
#
print(data)

HeteroData(
  [1mcustomer[0m={ x=[1000, 2] },
  [1mproduct[0m={ x=[1000, 2] },
  [1muser[0m={
    num_nodes=581,
    x=[581, 2]
  },
  [1m(customer, has, user)[0m={ edge_index=[2, 1000] },
  [1m(product, license, customer)[0m={
    edge_index=[2, 1000],
    edge_attr=[1000, 3]
  },
  [1m(user, opened_session, product)[0m={
    edge_index=[2, 1000],
    edge_attr=[1000, 4]
  }
)


In [28]:
import torch_geometric.transforms as T

data = T.ToUndirected()(data)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)
data = T.ToDevice(device)(data)

On introduit notre modèle DOMINANT, basé sur des couches de convolutions SAGEConv

In [29]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch_geometric.nn import SAGEConv, to_hetero

class Encoder(nn.Module):
    def __init__(self, num_features, hidden_channels, dropout):
        super().__init__()
        self.conv1 = SAGEConv((num_features, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), hidden_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.conv2(x, edge_index).relu()
        return x

class Attribute_Decoder(nn.Module):
    def __init__(self, num_features, hidden_channels, dropout):
        super().__init__()

        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), num_features)
        self.dropout = dropout

    def forward(self, x, adj):

        x = F.relu(self.conv1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.conv2(x, adj))

        return x

class Structure_Decoder(nn.Module):
    def __init__(self, num_nodes, hidden_channels, dropout):
        super().__init__()

        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), num_nodes)
        self.dropout = dropout

    def forward(self, x, adj):

        x = F.relu(self.conv1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.conv2(x, adj))

        return x.T


class Dominant(nn.Module):
    def __init__(self, feat_size, hidden_size, num_nodes_dict, dropout, metadata):
        super().__init__()
        
        self.shared_encoder = to_hetero(Encoder(feat_size, hidden_size, dropout), metadata, aggr='sum')
        self.attr_decoder = to_hetero(Attribute_Decoder(feat_size, hidden_size, dropout), metadata, aggr='sum')
        self.struct_decoder_dict = {}
        for key in num_nodes_dict.keys():
          self.struct_decoder_dict[key] = to_hetero(Structure_Decoder(num_nodes_dict[key], hidden_size, dropout), metadata, aggr='sum')
    
    def forward(self, x_dict, adj_dict):

        # encode
        x_dict = self.shared_encoder(x_dict, adj_dict)
        # decode feature matrix
        x_hat_dict = self.attr_decoder(x_dict, adj_dict)
        # decode adjacency matrix
        struct_reconstructed_dict={}
        for key in self.struct_decoder_dict.keys():
          struct_reconstructed_dict[key] = self.struct_decoder_dict[key](x_dict, adj_dict)
        # return reconstructed matrices
        return struct_reconstructed_dict, x_hat_dict

We train our model now

In [30]:
def loss_func_train(attrs, X_hat, adj, A_hat, alpha=0.8):
    # Attribute reconstruction loss
    diff_attribute = []
    for key in attrs.keys():
        diff_attribute.append(torch.pow(X_hat[key] - attrs[key], 2))
    diff_attribute = torch.cat(tuple(diff_attribute), 0)
    attribute_reconstruction_errors = torch.sqrt(torch.sum(diff_attribute, 1))
    attribute_cost = torch.mean(attribute_reconstruction_errors)

    # structure reconstruction loss
    diff_structure_all = []
    for key1 in adj.keys():
        structure = []
        for key2 in adj.keys():
            structure.append(torch.pow(A_hat[key1][key2] - adj[key1][key2], 2))
        diff_structure_all.append(torch.cat(tuple(structure), 1))
    diff_structure = torch.cat(tuple(diff_structure_all), 0)
    structure_reconstruction_errors = torch.sqrt(torch.sum(diff_structure, 1))
    structure_cost = torch.mean(structure_reconstruction_errors)

    cost =  alpha * attribute_reconstruction_errors + (1-alpha) * structure_reconstruction_errors

    return cost, structure_cost, attribute_cost

def loss_func_test(attrs, X_hat):
    # Attribute reconstruction loss
    diff_attribute = torch.pow(X_hat - attrs, 2)
    attribute_reconstruction_errors = torch.sqrt(torch.sum(diff_attribute, 1))
    attribute_cost = torch.mean(attribute_reconstruction_errors)

    # structure reconstruction loss
    # diff_structure = torch.pow(A_hat - adj, 2)
    # structure_reconstruction_errors = torch.sqrt(torch.sum(diff_structure, 1))
    # structure_cost = torch.mean(structure_reconstruction_errors)
    structure_cost = 0

    cost =  attribute_reconstruction_errors


    return cost, structure_cost, attribute_cost

model = Dominant(feat_size=2, hidden_size=64, num_nodes_dict = data.num_nodes_dict, dropout=0.3, metadata=data.metadata()).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = 5e-3)

In [31]:
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    data,
    # Sample 15 neighbors for each node and each edge type for 2 iterations:
    num_neighbors=[15] * 2,
    # Use a batch size of 128 for sampling training nodes of type "paper":
    batch_size=32,
    input_nodes=('customer'),
)

In [32]:
epochs = 1
X = data.x_dict
adj = dense_adj(data)

for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        A_hat, X_hat = model(data.x_dict, data.edge_index_dict)
        loss, struct_loss, feat_loss = loss_func_train(X, X_hat, adj, A_hat)
        l = torch.mean(loss)
        l.backward(retain_graph=True)
        optimizer.step()        
        print("Epoch:", '%04d' % (epoch), "train_loss=", "{:.5f}".format(l.item()),"train/feat_loss=", "{:.5f}".format(feat_loss.item()))

        # if epoch == epochs - 1:
        #     model.eval()
        #     A_hat, X_hat = model(data.x_dict, data.edge_index_dict)
        #     loss, struct_loss, feat_loss = loss_func(X['customer'], X_hat['customer'])
        #     score = loss.detach().cpu().numpy()
        #     print("Score = ", score)

Epoch: 0000 train_loss= 4.18770 train/feat_loss= 0.94360


In [33]:
model.eval()
A_hat, X_hat = model(data.x_dict, data.edge_index_dict)
loss, struct_loss, feat_loss = loss_func_test(X['customer'], X_hat['customer'])
score = loss.detach().cpu().numpy()
print("Score = ", score)

Score =  [0.5659822  0.28648496 0.7314963  0.09177812 0.590038   0.1597658
 0.20797169 0.48248687 0.46836755 0.40577725 0.33270958 0.88923615
 0.5588084  0.41940075 0.7800286  0.53098845 0.5683027  0.45636603
 0.32742548 0.5085881  0.45910108 0.5542878  0.49751875 0.44268945
 0.3712099  0.37015218 0.29598954 0.4802004  0.3242657  0.20246767
 0.467867   0.41449004 0.6109383  0.4811692  0.5683265  0.64990586
 0.6987262  0.6489911  0.14058189 0.22675322 0.63819367 0.47391045
 0.10255742 0.62589633 0.46262652 0.3240517  0.5314261  0.5410406
 0.25101015 0.41099566 0.5332941  0.7324476  0.5142531  0.5273354
 0.31936792 0.5731665  0.18088584 0.1663604  0.24202348 0.53941065
 0.45194614 0.78933185 0.5386242  0.37200925 0.5217056  0.5366045
 0.37198684 0.60543674 0.1610556  0.73887086 0.4865959  0.68965745
 0.20644377 0.3374082  0.45536947 0.7339434  0.6301728  0.62935024
 0.39496055 0.29494184 0.3106367  0.13364904 0.7052696  0.12636732
 0.3948563  0.36973676 0.22259101 0.2604181  0.34809792 0

In [26]:
len(score)

1000