In [1]:
# This code works in Python 3.10.6
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.datasets import IMDB
from torch_geometric.nn import GCNConv
import time
from torch_geometric.logging import log
import os
from collections import Counter
import random

In [2]:
dataset = IMDB(root='./imdb_data')
hetero_data = dataset[0]

In [3]:
hetero_data

HeteroData(
  movie={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278],
  },
  director={ x=[2081, 3066] },
  actor={ x=[5257, 3066] },
  (movie, to, director)={ edge_index=[2, 4278] },
  (movie, to, actor)={ edge_index=[2, 12828] },
  (director, to, movie)={ edge_index=[2, 4278] },
  (actor, to, movie)={ edge_index=[2, 12828] }
)

In [4]:
# This code works in torch-geometric==2.6.0
data = hetero_data.to_homogeneous(add_edge_type=False)

In [5]:
data

Data(edge_index=[2, 34212], x=[11616, 3066], y=[11616], train_mask=[11616], val_mask=[11616], test_mask=[11616], node_type=[11616])

In [6]:
data.node_type

tensor([0, 0, 0,  ..., 2, 2, 2])

In [7]:
data.x = F.one_hot(data.node_type, num_classes=len(torch.unique(data.node_type))).float()

In [8]:
data.x

tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        ...,
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])

In [9]:
torch.save(data, '../data/imdb/processed/data.pt')

In [10]:
torch.unique(data.y)

tensor([-1,  0,  1,  2])

In [11]:
Counter(data.y.tolist())

Counter({-1: 7338, 1: 1584, 2: 1559, 0: 1135})

In [12]:
data.has_isolated_nodes()

False

In [13]:
data.has_self_loops()

False