# Graph dataset preparation for PyTorch Geometric
---

## Imports

In [14]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url, Data

## Read the embedded data

In [9]:
nodes = pd.read_parquet('data/amazon_product_data_word2vec.parquet')
edges = pd.read_parquet('data/amazon_product_edges_filtered.parquet')

In [10]:
print(nodes.shape)
print(edges.shape)

(729819, 6)
(680548, 2)


In [11]:
# Convert ASINs to strings in both DataFrames
nodes['asin'] = nodes['asin'].astype(str)
edges['from_asin'] = edges['from_asin'].astype(str)
edges['to_asin'] = edges['to_asin'].astype(str)
# Identify missing ASINs in 'from_asin'
missing_from_asins = set(edges['from_asin']) - set(nodes['asin'])
print(f"Number of 'from_asin' ASINs not in nodes: {len(missing_from_asins)}")

# Identify missing ASINs in 'to_asin'
missing_to_asins = set(edges['to_asin']) - set(nodes['asin'])
print(f"Number of 'to_asin' ASINs not in nodes: {len(missing_to_asins)}")

Number of 'from_asin' ASINs not in nodes: 0
Number of 'to_asin' ASINs not in nodes: 0


In [12]:
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder
import pandas as pd

def create_node_features(nodes, method='concat'):
    node_features = []
    for i in range(len(nodes)):
        embeddings = [
            torch.tensor(nodes.iloc[i]['title_embedding'], dtype=torch.float),
            torch.tensor(nodes.iloc[i]['brand_embedding'], dtype=torch.float),
            torch.tensor(nodes.iloc[i]['description_embedding'], dtype=torch.float),
            torch.tensor(nodes.iloc[i]['categories_embedding'], dtype=torch.float)
        ]
        if method == 'concat':
            features = torch.cat(embeddings, dim=0)
        elif method == 'sum':
            features = sum(embeddings)
        else:
            raise ValueError(f"Unknown fusion method: {method}")
        node_features.append(features)
    node_features = torch.stack(node_features, dim=0)
    return node_features

def create_edge_index(nodes, edges):
    asin_to_index = {asin: i for i, asin in enumerate(nodes['asin'])}
    edge_index_list = []
    # Using iterrows() to iterate over edges
    for idx, row in edges.iterrows():
        source_asin = row['from_asin']
        target_asin = row['to_asin']
        source = asin_to_index[source_asin]
        target = asin_to_index[target_asin]
        edge_index_list.append([source, target])
    edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
    return edge_index

def create_labels(nodes):
    encoder = LabelEncoder()
    y = encoder.fit_transform(nodes['main_category'])
    labels = torch.tensor(y, dtype=torch.long)
    return labels

def create_and_save_graph_data(nodes, edges, fusion_method, filename):
    # Ensure ASINs are strings
    nodes['asin'] = nodes['asin'].astype(str)
    edges['from_asin'] = edges['from_asin'].astype(str)
    edges['to_asin'] = edges['to_asin'].astype(str)

    # Filter edges to only include ASINs present in nodes
    valid_edges = edges[
        edges['from_asin'].isin(nodes['asin']) & edges['to_asin'].isin(nodes['asin'])
    ].reset_index(drop=True)

    # Create node features
    node_features = create_node_features(nodes, method=fusion_method)
    # Create edge index
    edge_index = create_edge_index(nodes, valid_edges)
    # Create labels
    labels = create_labels(nodes)
    # Create the graph data object
    data = Data(x=node_features, edge_index=edge_index, y=labels)
    # Save the data to a file
    torch.save(data, filename)
    # Print information
    print(f"Data saved to {filename}")
    print(f"Node feature size for method '{fusion_method}':", data.x[0].size())

In [13]:
# Create and save data with concatenation
create_and_save_graph_data(
    nodes, edges, fusion_method='concat', filename='data/amazon_product_data_concat.pt'
)

# Create and save data with summing
# create_and_save_graph_data(
#     nodes, edges, fusion_method='sum', filename='data/amazon_product_data_sum.pt'
# )

RuntimeError: [enforce fail at inline_container.cc:337] . unexpected pos 832 vs 754