# Graph dataset preparation for PyTorch Geometric
---

## Imports

In [9]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url, Data

## Read the embedded data

In [10]:
nodes = pd.read_parquet('data/amazon_product_data_word2vec.parquet')
edges = pd.read_parquet('data/amazon_product_edges_filtered.parquet')

In [11]:
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder

def create_node_features(nodes, method='concat'):
    """
    Creates node features by combining embeddings using the specified method.

    :param nodes: A dictionary containing node embeddings.
    :param method: The method to combine embeddings ('concat' or 'sum').
    :return: A tensor of node features.
    """
    node_features = []
    for i in range(len(nodes['asin'])):
        # Extract embeddings for the current node
        embeddings = [
            torch.tensor(nodes['title_embedding'][i], dtype=torch.float),
            torch.tensor(nodes['brand_embedding'][i], dtype=torch.float),
            torch.tensor(nodes['description_embedding'][i], dtype=torch.float),
            torch.tensor(nodes['categories_embedding'][i], dtype=torch.float)
        ]
        if method == 'concat':
            # Concatenate embeddings
            features = torch.cat(embeddings, dim=0)
        elif method == 'sum':
            # Sum embeddings element-wise
            features = sum(embeddings)
        else:
            raise ValueError(f"Unknown fusion method: {method}")
        node_features.append(features)
    # Stack the features into a tensor
    node_features = torch.stack(node_features, dim=0)
    return node_features

def create_edge_index(nodes, edges):
    """
    Creates the edge index tensor for the graph.

    :param nodes: A dictionary containing node information.
    :param edges: A dictionary containing edge information.
    :return: An edge index tensor.
    """
    # Create node index mapping
    asin_to_index = {asin: i for i, asin in enumerate(nodes['asin'])}
    # Create edge index
    edge_index_list = []
    for i in range(len(edges['from_asin'])):
        source = asin_to_index[edges['from_asin'][i]]
        target = asin_to_index[edges['to_asin'][i]]
        edge_index_list.append([source, target])
    edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
    return edge_index

def create_labels(nodes):
    """
    Encodes the node labels using LabelEncoder.

    :param nodes: A dictionary containing node information.
    :return: A tensor of labels.
    """
    encoder = LabelEncoder()
    y = encoder.fit_transform(nodes['main_category'])
    labels = torch.tensor(y, dtype=torch.long)
    return labels

def create_and_save_graph_data(nodes, edges, fusion_method, filename):
    """
    Creates the graph data using the specified fusion method and saves it to a file.

    :param nodes: A dictionary containing node information.
    :param edges: A dictionary containing edge information.
    :param fusion_method: The method to combine embeddings ('concat' or 'sum').
    :param filename: The filename to save the graph data.
    """
    # Create node features
    node_features = create_node_features(nodes, method=fusion_method)
    # Create edge index
    edge_index = create_edge_index(nodes, edges)
    # Create labels
    labels = create_labels(nodes)
    # Create the graph data object
    data = Data(x=node_features, edge_index=edge_index, y=labels)
    # Save the data to a file
    torch.save(data, filename)
    # Print information
    print(f"Data saved to {filename}")
    print(f"Node feature size for method '{fusion_method}':", data.x[0].size())

  node_features = torch.tensor(node_features, dtype=torch.float)


Data(x=[863130, 100], edge_index=[2, 815222])


In [17]:
# Create and save data with concatenation
create_and_save_graph_data(
    nodes, edges, fusion_method='concat', filename='data/amazon_product_data_concat.pt'
)

# Create and save data with summing
create_and_save_graph_data(
    nodes, edges, fusion_method='sum', filename='data/amazon_product_data_sum.pt'
)