In [143]:
import pandas as pd
import numpy as np
import torch
from torch.nn import Embedding
from torch_geometric.data import HeteroData
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import LabelEncoder

In [144]:
EMBEDDING_SIZE = 16
incident_path = 'data/incidents.csv'
support_org_path = 'data/support_orgs.csv'
customer_path = 'data/customers.csv'
vendor_path = 'data/vendors.csv'
target_path = 'data/target.pt'

In [145]:
df_incident = pd.read_csv(incident_path)

`SimpleEncoder` numeričke vrednosti pretvara u vektore (tenzore) paketa PyTorch (npr. 11 -> `tensor([11])`). Ova transformacija je nephodna za kasnije formiranje grafa.

In [146]:
class SimpleEncoder:
    def __init__(self, dtype=torch.float):
        self.dtype = dtype

    def __call__(self, df):
        return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)

Kategorički atributi se mogu tranformisati u vektore dužine EMBEDDING_SIZE pomoću [Embedding](https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html) klase iz PyTorch paketa. Klasi Embedding se osim dužine ciljnog vektora prosleđuje i veličina rečnika, tj. broj mogućih vrednosti kategoričkog atributa.


In [147]:
class WordEncoder:
    def __init__(self, vocab_size, embedding_dim=EMBEDDING_SIZE):
        self.embeddings = Embedding(vocab_size, embedding_dim)

    @torch.no_grad()
    def __call__(self, df):
        label_encoder = LabelEncoder()
        df = label_encoder.fit_transform(df)
        embeds = self.embeddings(torch.tensor(df, dtype=torch.long)).view((-1, EMBEDDING_SIZE)).to(torch.float)
        return embeds


In [148]:
def load_node_csv(path, index_col, encoders=None):
    df = pd.read_csv(path, index_col=index_col)
    mapping = {index: i for i, index in enumerate(df.index.unique())}

    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    return x, mapping

In [149]:
def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,
                  encoders=None, **kwargs):
    df = pd.read_csv(path, **kwargs)

    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])

    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)
    return edge_index, edge_attr

In [150]:
support_org_x, support_org_map = load_node_csv(support_org_path, 
                                               'assignment_group', 
                                               encoders={'assigned_to': SimpleEncoder(),
                                                         'number': WordEncoder(len(df_incident['number']), EMBEDDING_SIZE)})

In [151]:
vendor_x, vendor_map = load_node_csv(vendor_path, 'vendor', encoders={'number': WordEncoder(len(df_incident['number']), EMBEDDING_SIZE)})

In [152]:
customer_x, customer_map = load_node_csv(customer_path, 'opened_by', encoders={'number': WordEncoder(len(df_incident['number']), EMBEDDING_SIZE)})

In [153]:
incident_encoders = {}
for col in df_incident.columns[2:]:
    if col == 'incident_state' or col == 'contact_type':
        incident_encoders[col] = WordEncoder(len(df_incident[col].unique()), EMBEDDING_SIZE)
    else:
        incident_encoders[col] = SimpleEncoder()

In [154]:
incident_x, incident_map = load_node_csv(incident_path, 'number', encoders=incident_encoders)

In [155]:
incident_support_org, incident_support_org_label = load_edge_csv(support_org_path, 
                                                                 src_index_col='assignment_group',
                                                                 src_mapping=support_org_map,
                                                                 dst_index_col='number',
                                                                 dst_mapping=incident_map
                                                                )

In [156]:
incident_customer, incident_customer_label = load_edge_csv(customer_path, 
                                                           src_index_col='opened_by',
                                                           src_mapping=customer_map,
                                                           dst_index_col='number',
                                                           dst_mapping=incident_map
                                                           )

In [157]:
incident_vendor, incident_vendor_label = load_edge_csv(vendor_path, 
                                                       src_index_col='vendor',
                                                       src_mapping=vendor_map,
                                                       dst_index_col='number',
                                                       dst_mapping=incident_map
                                                       )

In [169]:
data = HeteroData()

In [170]:
target = torch.load(target_path)

In [171]:
data['incident'].x = incident_x
data['incident'].y = target
data['support_org'].x = support_org_x
data['customer'].x = customer_x
data['vendor'].x = vendor_x

In [172]:
data['incident', 'assigned', 'support_org'].edge_index = incident_support_org
data['incident', 'assigned', 'support_org'].edge_label = incident_support_org_label

data['incident', 'assigned', 'vendor'].edge_index = incident_vendor
data['incident', 'assigned', 'vendor'].edge_label = incident_vendor_label

data['incident', 'reported', 'customer'].edge_index = incident_customer
data['incident', 'reported', 'customer'].edge_label = incident_customer_label

In [173]:
data

HeteroData(
  incident={
    x=[141707, 51],
    y=[141707],
  },
  support_org={ x=[141707, 17] },
  customer={ x=[141707, 16] },
  vendor={ x=[141707, 16] },
  (incident, assigned, support_org)={ edge_index=[2, 141707] },
  (incident, assigned, vendor)={ edge_index=[2, 141707] },
  (incident, reported, customer)={ edge_index=[2, 141707] }
)

In [174]:
torch.save(data, 'data/hetero_data.pt')