In [1]:
import torch
from torch_geometric.data import HeteroData
import numpy as np
from torch_geometric.loader import HGTLoader, NeighborLoader

In [2]:
cat_props = torch.rand([100, 4])
num_props = torch.rand([100, 5])
des_props = torch.rand([100, 768])

label = torch.rand(100).round()

tweet_tensor = torch.rand([200, 768])

follow_src = torch.tensor(np.random.choice(100, 200, replace=True))
follow_dst = torch.tensor(np.random.choice(100, 200, replace=True))
follow = torch.concat((follow_src, follow_dst)).reshape(-1, 200).long()
friend = torch.concat((follow_dst, follow_src)).reshape(-1, 200).long()

post_src = torch.tensor(np.random.choice(100, 500, replace=True))
post_dst = torch.tensor(np.random.choice(200, 500, replace=True))
post = torch.concat((post_src, post_dst)).reshape(-1, 500).long()

In [3]:
train_idx = torch.zeros(100)
train_idx[range(0, 60)] = 1
train_idx = train_idx.bool()
val_idx = torch.zeros(100)
val_idx[range(60, 80)] = 1
val_idx = val_idx.bool()
test_idx = torch.zeros(100)
test_idx[range(80, 100)] = 1
test_idx = test_idx.bool()
# train_idx = torch.tensor(range(0, 60)).long()
# val_idx = torch.tensor(range(60, 80)).long()
# test_idx = torch.tensor(range(80, 100)).long()

In [4]:
# torch.concat([cat_props, num_props, des_props], dim=1)

In [4]:
device = "cpu"

In [5]:
hetero_twi = HeteroData(
    {
        'user': {
            'x': torch.concat([cat_props, num_props, des_props], dim=1),
            'y': label,
            'train_mask': train_idx,
            'val_mask': val_idx,
            'test_mask': test_idx
        },
        'tweet': {'x': tweet_tensor}
    },
    user__follow__user={'edge_index': follow},
    user__friend__user={'edge_index': friend},
    user__post__tweet={'edge_index': post}
)

In [6]:
hetero_twi.metadata()

(['user', 'tweet'],
 [('user', 'follow', 'user'),
  ('user', 'friend', 'user'),
  ('user', 'post', 'tweet')])

In [8]:
# 存在不明原因报错
# train_loader = HGTLoader(hetero_twi, num_samples={key: [2] for key in hetero_twi.node_types}, input_nodes=('user', hetero_twi['user'].train_mask), batch_size=2, num_workers=1)
# val_loader = HGTLoader(hetero_twi, num_samples={key: [2] for key in hetero_twi.node_types}, input_nodes=('user', hetero_twi['user'].val_mask), batch_size=2, num_workers=1)

In [8]:
train_loader = NeighborLoader(hetero_twi, num_neighbors={key: [2] for key in hetero_twi.edge_types}, input_nodes=('user', hetero_twi['user'].train_mask), batch_size=2, num_workers=1)
val_loader = NeighborLoader(hetero_twi, num_neighbors={key: [2] for key in hetero_twi.edge_types}, input_nodes=('user', hetero_twi['user'].val_mask), batch_size=2, num_workers=1)

In [9]:
from torch_geometric.transforms import ToUndirected
undirected_transform = ToUndirected(merge=True)
undirected_hetero_twi = undirected_transform(hetero_twi)
nei_train_loader = NeighborLoader(undirected_hetero_twi, num_neighbors={key: [10] for key in hetero_twi.edge_types}, shuffle=True, input_nodes=('user', hetero_twi['user'].train_mask), batch_size=32, num_workers=0, persistent_workers=False)
nei_val_loader = NeighborLoader(hetero_twi, num_neighbors={key: [10] for key in hetero_twi.edge_types}, input_nodes=('user', hetero_twi['user'].val_mask), batch_size=32, num_workers=0)

In [10]:
for sampled_hetero_data in nei_train_loader:
    print(sampled_hetero_data)
hetero_twi.edge_types

HeteroData(
  [1muser[0m={
    x=[81, 777],
    y=[81],
    train_mask=[81],
    val_mask=[81],
    test_mask=[81],
    input_id=[32],
    batch_size=32
  },
  [1mtweet[0m={ x=[103, 768] },
  [1m(user, follow, user)[0m={ edge_index=[2, 117] },
  [1m(user, friend, user)[0m={ edge_index=[2, 117] },
  [1m(user, post, tweet)[0m={ edge_index=[2, 0] },
  [1m(tweet, rev_post, user)[0m={ edge_index=[2, 152] }
)
HeteroData(
  [1muser[0m={
    x=[77, 777],
    y=[77],
    train_mask=[77],
    val_mask=[77],
    test_mask=[77],
    input_id=[28],
    batch_size=28
  },
  [1mtweet[0m={ x=[113, 768] },
  [1m(user, follow, user)[0m={ edge_index=[2, 98] },
  [1m(user, friend, user)[0m={ edge_index=[2, 98] },
  [1m(user, post, tweet)[0m={ edge_index=[2, 0] },
  [1m(tweet, rev_post, user)[0m={ edge_index=[2, 163] }
)


[('user', 'follow', 'user'),
 ('user', 'friend', 'user'),
 ('user', 'post', 'tweet'),
 ('tweet', 'rev_post', 'user')]

In [11]:
import torch
from torch import nn
from torch_geometric.nn import HGTConv


class PropertyVector(nn.Module):
    def __init__(self, n_cat_prop=4, n_num_prop=5, des_size=768, embedding_dimension=128, dropout=0.3):
        super(PropertyVector, self).__init__()

        self.n_cat_prop = n_cat_prop
        self.n_num_prop = n_num_prop
        self.des_size = des_size

        self.cat_prop_module = nn.Sequential(
            nn.Linear(n_cat_prop, int(embedding_dimension / 4)),
            nn.LeakyReLU()
        )
        self.num_prop_module = nn.Sequential(
            nn.Linear(n_num_prop, int(embedding_dimension / 4)),
            nn.LeakyReLU()
        )
        self.prop_module = nn.Sequential(
            nn.Linear(int(embedding_dimension / 2), int(embedding_dimension / 2)),
            nn.LeakyReLU()
        )
        self.des_module = nn.Sequential(
            nn.Linear(des_size, int(embedding_dimension / 2)),
            nn.LeakyReLU()
        )
        self.out_layer = nn.Sequential(
            nn.Linear(embedding_dimension, embedding_dimension),
            nn.LeakyReLU()
        )

    def forward(self, user_tensor):
        cat_prop, num_prop, des = torch.split_with_sizes(user_tensor, [self.n_cat_prop, self.n_num_prop, self.des_size], dim=1)
        cat_prop_vec = self.cat_prop_module(cat_prop)
        num_prop_vec = self.num_prop_module(num_prop)
        des_vec = self.des_module(des)
        prop_vec = torch.concat((cat_prop_vec, num_prop_vec, des_vec), dim=1)
        prop_vec = self.out_layer(prop_vec)
        return prop_vec


class TweetVector(nn.Module):
    def __init__(self, tweet_size=768, embedding_dimension=128, dropout=0.3):
        super(TweetVector, self).__init__()
        self.tweet_module = nn.Sequential(
            nn.Linear(tweet_size, embedding_dimension),
            nn.LeakyReLU()
        )

    def forward(self, tweet_tensor):
        tweet_vec = self.tweet_module(tweet_tensor)
        return tweet_vec


class HGTDetector(nn.Module):
    def __init__(self, n_cat_prop=4, n_num_prop=5, des_size=768, tweet_size=768, embedding_dimension=128, dropout=0.3):
        super(HGTDetector, self).__init__()

        meta_node = ["user", "tweet"]
        meta_edge = [("user", "follow", "user"), ("user", "friend", "user"), ("user", "post", "tweet"), ("tweet", "rev_post", "user")]

        self.module_dict = nn.ModuleDict()
        self.module_dict["user"] = PropertyVector(n_cat_prop, n_num_prop, des_size, embedding_dimension, dropout)
        self.module_dict["tweet"] = TweetVector(tweet_size, embedding_dimension, dropout)

        self.HGT_layer1 = HGTConv(in_channels=embedding_dimension, out_channels=embedding_dimension, metadata=(meta_node, meta_edge))
        self.HGT_layer2 = HGTConv(in_channels=embedding_dimension, out_channels=embedding_dimension, metadata=(meta_node, meta_edge))

        self.classify_layer = nn.Sequential(
            nn.Linear(embedding_dimension, embedding_dimension),
            nn.LeakyReLU(),
            nn.Linear(embedding_dimension, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.module_dict[node_type](x)
            for node_type, x in x_dict.items()
        }

        x_dict = self.HGT_layer1(x_dict, edge_index_dict)
        x_dict = self.HGT_layer2(x_dict, edge_index_dict)

        out = self.classify_layer(x_dict["user"])

        return out



In [43]:
a = next(iter(nei_train_loader))
b = next(iter(nei_train_loader))
c = next(iter(nei_train_loader))

In [12]:
from tqdm.notebook import tqdm

model = HGTDetector(n_cat_prop=4, n_num_prop=5, des_size=768, tweet_size=768, embedding_dimension=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

@torch.no_grad()
def init_params():
    batch = next(iter(nei_train_loader))
    batch = batch.to(device, "edge_index")
    model(batch.x_dict, batch.edge_index_dict)

def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(nei_train_loader):
        optimizer.zero_grad()
        batch = batch.to(device, 'edge_index')
        batch_size = batch['user'].batch_size
        mask = batch['user'].train_mask
        out = model(batch.x_dict, batch.edge_index_dict)[mask]
        # print(f"out: {out}")
        # print(f"out.argmax(-1): {out.argmax(dim=-1)}")
        # print(f"batch['user'].y[mask]: {batch['user'].y[mask]}")
        loss = nn.functional.cross_entropy(out, batch['user'].y[mask].long())
        loss.backward()
        optimizer.step()

        total_examples += mask.sum()
        total_loss += float(loss) * mask.sum()

    return total_loss / total_examples

@torch.no_grad()
def val(loader):
    model.eval()

    total_examples = total_correct = 0
    for batch in tqdm(loader):
        batch = batch.to(device, 'edge_index')
        batch_size = batch['user'].batch_size
        mask = batch['user'].val_mask
        out = model(batch.x_dict, batch.edge_index_dict)
        pred = out.argmax(dim=-1)[mask]
        # print(f"batch_size: {batch_size}")
        # print(f"mask: {mask}")
        # print(f"pred: {pred}")
        # print(f"batch['user'].y: {batch['user'].y}")
        # print(f"pred[mask]: {pred[mask]}")
        # print(f"batch['user'].y[mask]: {batch['user'].y[mask]}")
        total_examples += mask.sum()
        total_correct += int((pred == batch['user'].y[mask]).sum())

    return total_correct / total_examples

init_params()

for epoch in range(1, 41):
    loss = train()
    val_acc = val(nei_val_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 01, Loss: 0.6891, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 02, Loss: 0.6868, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 03, Loss: 0.6905, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 04, Loss: 0.6860, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 05, Loss: 0.6839, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 06, Loss: 0.6844, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 07, Loss: 0.6799, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 08, Loss: 0.6813, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 09, Loss: 0.6788, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 10, Loss: 0.6791, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 11, Loss: 0.6768, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 12, Loss: 0.6706, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 13, Loss: 0.6523, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 14, Loss: 0.6270, Val: 0.5500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 15, Loss: 0.5859, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 16, Loss: 0.6137, Val: 0.5000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 17, Loss: 0.5727, Val: 0.2500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 18, Loss: 0.6119, Val: 0.2500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 19, Loss: 0.5486, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 20, Loss: 0.5446, Val: 0.6000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 21, Loss: 0.5212, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 22, Loss: 0.5219, Val: 0.5000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 23, Loss: 0.5012, Val: 0.3500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 24, Loss: 0.4517, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 25, Loss: 0.4711, Val: 0.3000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 26, Loss: 0.4301, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 27, Loss: 0.4283, Val: 0.5000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 28, Loss: 0.4368, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 29, Loss: 0.3984, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 30, Loss: 0.4131, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 31, Loss: 0.4077, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 32, Loss: 0.3925, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 33, Loss: 0.4003, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 34, Loss: 0.3879, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 35, Loss: 0.3854, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 36, Loss: 0.3814, Val: 0.4000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 37, Loss: 0.3654, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 38, Loss: 0.3595, Val: 0.4500


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 39, Loss: 0.3512, Val: 0.5000


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 40, Loss: 0.3508, Val: 0.5000


In [13]:
test_data = hetero_twi.subgraph({'user': hetero_twi['user'].test_mask, 'tweet': torch.ones(hetero_twi['tweet'].num_nodes).bool()})

In [15]:
out = model(test_data.x_dict, test_data.edge_index_dict).argmax(dim=-1)
mask = test_data['user'].test_mask
test_res = (out[mask] == test_data['user'].y[mask]).sum()
print(f"Test: {test_res / mask.sum():.4f}")

Test: 0.4500


In [51]:
torch.save(model, rf'./saved_models/acc{test_res / len(out):.4f}.pickle')


In [13]:
loaded_model = torch.load(rf'./saved_models/acc0.5500.pickle')
out1 = loaded_model(test_data.x_dict, test_data.edge_index_dict).argmax(dim=-1)
test_res1 = (out1 == test_data['user'].y).sum()
print(f"Test: {test_res1 / len(out1):.4f}")

Test: 0.5000


In [16]:
# print(loaded_model)
test_idx.sum()

tensor(20)

In [14]:
from torchviz import make_dot

netvis = make_dot(out1, params=dict(list(loaded_model.named_parameters()) + [('x', test_data.x_dict)]))
netvis.format = 'png'
netvis.directory = "saved_models"
netvis.view()

AssertionError: 