In [2]:
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F

import torch_scatter
import torch_sparse
import torch_cluster

import numpy as np
import torch.optim as optim
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import HeteroData

from myutils import constructPyGHeteroData,LabelBalancedSampler

from torch_geometric.loader import DataLoader

from typing import List
from tqdm import tqdm
from collections import defaultdict
from operator import itemgetter



数据集信息：

https://jmcauley.ucsd.edu/data/amazon/

In [2]:
%load_ext autoreload
%autoreload 2

In [26]:
class GraphSage(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize=True,
                 bias=False, **kwargs):
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = None
        self.lin_r = None

        ############################################################################
        # TODO: Your code here!
        # Define the layers needed for the message and update functions below.
        # self.lin_l is the linear transformation that you apply to embedding
        #            for central node.
        # self.lin_r is the linear transformation that you apply to aggregated
        #            message from neighbors.
        # Don't forget the bias!
        # Our implementation is ~2 lines, but don't worry if you deviate from this.

        ############################################################################

        self.lin_l = nn.Linear(in_channels, out_channels, bias=bias)
        self.lin_r = nn.Linear(in_channels, out_channels, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size=None):
        """"""
        out = self.propagate(edge_index=edge_index, x=(x, x), size=size)
        out += self.lin_l(x)
        if self.normalize: out = F.normalize(out, p=2)
        return out

    def message(self, x_j):
        return x_j

    def aggregate(self, inputs, index, dim_size=None):
        # The axis along which to index number of nodes.
        node_dim = self.node_dim
        out = self.lin_r(torch_scatter.scatter(
            inputs, index, dim=node_dim, reduce='mean'))
        return out


Graph(num_nodes=716847, num_edges=13954819,
      ndata_schemes={'feat': Scheme(shape=(300,), dtype=torch.float32), 'label': Scheme(shape=(100,), dtype=torch.int64), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [3]:
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d
# 使用ObjectView类可以将一个字典的key视作其属性来访问


关系内部交互层

In [5]:
class IntraAgg(nn.Module):
    """
    在某一关系下进行message aggregate
    """
    def __init__(
        self, 
        feature_dim: int, 
        output_dim: int, 
        # features: torch.Tensor, # 怎么可能在你刚初始化的时候就把features传进来呢……你又未必是第一层
        # rho: float,
        # avg_half_pos_neigh : int, # 用于决定oversample时选多少个同类节点
        # train_pos_mask: list,
        device: torch.device
        ) -> None:
        """
        原文太无赖，居然在pclayer外面就把intraAgg声明好，数据也传进去了 \\
        :param feature_dim: 原数据每点的特征维数
        :param output_dim: 本层的嵌入维度，也就是输出维度
        """
        super(IntraAgg, self).__init__()
        self.feature_dim = feature_dim
        self.output_dim = output_dim
        # self.features = features # (|N|,feat_dim)
        # self.rho = rho # 用于距离函数判断的？
        self.device = device
        # self.train_pos_mask = train_pos_mask # 这是个列表啊
        # TODO 为什么这个线性层维度设置怪怪的
        self.proj = nn.Linear(2*feature_dim,output_dim)

        # 在train阶段，这些rho不会被用，但会被更新
        self.rho_neg = 0.5
        self.rho_pos = 0.5

        # self.avg_half_pos_neigh = avg_half_pos_neigh

    def forward(
        self,
        features: torch.Tensor,
        batch_center_mask: list,
        batch_center_labels: list,
        train_pos_mask: list,
        rx_list: List[list],
        batch_center_logits: torch.Tensor, # (|B|,2)
        batch_all_logits: torch.Tensor, # (|BatchAll|,2)
        train_pos_logits: torch.Tensor, # (|Pos|,2)
        trainIdx2OrderIdx: dict,
        orderIdx2trainIdx: dict,
        avg_half_pos_neigh: int,
        train_flag = True # 如果是test，没法靠标签信息来choose
    ):
        """
        在一层关系内进行message passing
        :param batch_center_mask: 本batch内的中心点
        :param batch_center_labels: 本batch中心点的label
        :param rx_list: 第x关系的所有中心点的邻居情况 列表套列表
        :param batch_center_logits:
        :param batch_all_logits:
        :param train_pos_logits:
        :param trainIdx2orderIdx: 从真正的node id投射到rxlist中索引的词典
        """
        # 此时只是train
        self.avg_half_pos_neigh = avg_half_pos_neigh
        self.train_pos_mask = train_pos_mask
        # 首先，肯定要对邻居进行undersample
        # A(v,u)>0 且 D(v,u) < rho-
        rx_list_undersampled = []
        out_feats = []
        for idx, one_center_logits in enumerate(batch_center_logits):
            # 先把这个中心点的邻居点的logits提取出来
            certain_neighbor_logits = batch_all_logits[rx_list[idx]]
            # 计算distance
            distance = torch.abs(certain_neighbor_logits - one_center_logits)[:,0]
            howManyNeighbors = distance.shape[0]
            sampledNeighbor = (distance.argsort()[0:int(howManyNeighbors / 2) + 1]).tolist()
            # 对rho-进行更新 这个更新有什么意义吗？
            self.rho_neg = distance(distance.argsort()[int(howManyNeighbors / 2)])
            # rx_list_undersampled.append(nearest50Idx)
            # 这就是我们降采样之后的邻居样本，这里是orderIdx
        
            # label=1的时候是小样本！
            choosedSameClassNode = []
            if batch_center_labels[idx] == 1:
                # TODO 这里的维度肯定有点问题
                distance2 = torch.abs(
                    train_pos_logits - one_center_logits)[:, 0]  # 这个时候已经flatten了
                choosedSameClassNode = (distance2.argsort()[
                    0:self.avg_half_pos_neigh + 1]).tolist()

            # undersample之后的orderIdx在undersampledNeighbor里
            # oversample之后的orderIdx在choosedSameClassNode里

            # 进行aggregate！
            neighbor_feats = features[itemgetter(*sampledNeighbor)(orderIdx2trainIdx)]
            if not choosedSameClassNode == []:
                minor_feats = features[np.array(self.train_pos_mask)[choosedSameClassNode]]
                neighbor_feats = torch.cat([neighbor_feats,minor_feats],axis=0)

            agg_feats = torch.mean(neighbor_feats,axis=0) 

            # 把和中心节点的feat进行contact
            # 注意有一个reshape！
            contacted_feat = torch.cat([features[trainIdx2OrderIdx[idx]],agg_feats],axis=0).reshape(1,-1)
            # shape: (1,2*h_{l-1})

            # 进行线性映射
            out_feats.append(F.relu(self.proj(contacted_feat)))
        
        rx_out_feats = torch.cat(out_feats,axis=0)
        return rx_out_feats


    def NeighborhoodSamplerForTraining(
        self,
        batch_center_logits: torch.Tensor,
        batch_center_labels: torch.Tensor,
        batch_all_logits: torch.Tensor,
        train_pos_logits: torch.Tensor,
        rx_list: List[list]
        ):
        """
        这里是training阶段，我们将会根据邻居的情况来adaptively决定rho！
        """
        # 首先，肯定要对邻居进行undersample
        # A(v,u)>0 且 D(v,u) < rho-
        rx_list_undersampled = []
        for idx, one_center_logits in enumerate(batch_center_logits):
            # 先把这个中心点的邻居点的logits提取出来
            certain_neighbor_logits = batch_all_logits[rx_list[idx]]
            # 进行相减！
            distance = torch.abs(certain_neighbor_logits - one_center_logits)[:,0]
            howManyNeighbors = distance.shape[0]
            undersampledNeighbor = (distance.argsort()[0:int(howManyNeighbors / 2) + 1]).tolist()
            # 对rho-进行更新 这个更新有什么意义吗？
            self.rho_neg = distance(distance.argsort()[int(howManyNeighbors / 2)])
            # rx_list_undersampled.append(nearest50Idx)
            # 这就是我们降采样之后的邻居样本，这里是orderIdx
        
            # label=1的时候是小样本！
            if batch_center_labels[idx] == 1:
                # TODO 这里的维度肯定有点问题
                distance2 = torch.abs(train_pos_logits - one_center_logits)[:,0] # 这个时候已经flatten了
                choosedSameClassNode = distance2.argsort()[0:self.avg_half_pos_neigh + 1]

            # undersample之后的orderIdx在undersampledNeighbor里
            # oversample之后的orderIdx在choosedSameClassNode里

        

    def NeighborhoodSamplerForTest(
        self,
    ):
        pass

    

三层关系交互层

In [6]:


class InterAgg(nn.Module):
    """ 
    对三层关系进行message aggregate
    """
    def __init__(
        self,
        # features: torch.Tensor,
        feature_dim: int,
        output_dim: int,
        # adj_lists: defaultdict,
        # train_pos_mask: list,
        device: torch.device,
        num_classes:int = 2,
        num_relations:int =3
        ) -> None:
        super(InterAgg, self).__init__()
        # self.features = features
        self.feature_dim = feature_dim
        self.output_dim = output_dim
        # self.adj_lists = adj_lists # 3个关系的defaultdict
        # self.train_pos_mask = train_pos_mask
        self.device = device
        self.num_classes =  num_classes

        # 三个关系的embedding要糅合成一个，需要进行一个线性层转换
        self.proj = nn.Linear(
            in_features=num_relations*self.output_dim + self.feature_dim,
            # 入维度是三个realtion的embedding和原特征contact在一起的
            out_features=output_dim
        )

        # 可是距离函数不是每个关系的都不一样吗？？
        self.label_linear = nn.Linear(self.feature_dim,self.num_classes)
        # 这个距离函数甚至只输出logits，没有进行sigmoid

        

        

        # 准备intraAgg层
        self.intra1 = IntraAgg(
            feature_dim=self.feature_dim,
            output_dim=self.output_dim,
            # avg_half_pos_neigh=self.avg_half_neigh_size[0],
            # train_pos_mask=self.train_pos_mask,
            device=self.device
        )
        self.intra2 = IntraAgg(
            self.feature_dim,self.output_dim,self.device
        )
        self.intra3 = IntraAgg(
            self.feature_dim,self.output_dim,self.device
        )


    def forward(
        self,
        features: torch.Tensor,
        batch_center_mask,
        batch_center_label,
        train_pos_mask,
        adj_lists,
        train_flag = True
        ):
        """ 
        :param batch_mask: 本批次中要训练的点
        :param batch_label: 本批次中训练点的label
        """
        self.features = features
        self.adj_lists = adj_lists
        self.train_pos_mask = train_pos_mask

        # 计算minority class的average neighborhood size
        avg_half_neigh_size = []
        for relationIdx in range(len(self.adj_lists)):
            total = 0
            for trainIdx in self.train_pos_mask:
                total += len(self.adj_lists[relationIdx][trainIdx])
            avg_half_neigh_size.append(total / len(self.train_pos_mask))
        self.avg_half_neigh_size = avg_half_neigh_size


        # batch_mask是什么，是本batch中所要考察的中心点
        # 我们后续要用到的信息包括本batch的中心点及它们的1-hop neighbor
        # 所以搞一个batch_all_mask，就包括了上述这些需要的点    
        to_neighs = []  # to_neighs里面最终将会是三个列表，每个列表是某一个关系的adjlist转成列表
        for adj_list in self.adj_lists:
            to_neighs.append([set(adj_list[int(node)]) for node in batch_center_mask])
        # to_neighs be like: [[{某点的所有邻居},{},...,{}],[],[]]
        batch_all_nodes = set.union(set.union(*to_neighs[0]), set.union(*to_neighs[1]),
                                 set.union(*to_neighs[2], set(batch_center_mask)))
        # batch_all_nodes be like: {0,1,3,4,5,...}
        batch_all_mask = list(batch_all_nodes)
        # batch_all_mask be like: [0,1,3,4,5,...]
        # batch_all_mask 内承载了本batch训练所需的所有点的index -> TODO 这个index是针对谁来说的？
        
        # 提取出本batch all点的feature
        # batch_features = self.features[batch_all_mask]
        # 我去，features是Embedding层……
        batch_all_features = self.features[torch.LongTensor(batch_all_mask).to(self.device)]
        # -> shape (|BatchAll|,feat_dim)
        # postive nodes features
        pos_features = self.feautres[torch.LongTensor(self.train_pos_mask).to(self.device)]
        # -> shape (|Pos|,feat_dim)

        # 出于加快访问速度的考虑（应该是），defaultdict涉及到查找过程——这太慢啦！
        # 但是nodes的trainIdx和在batch_all_mask中的orderIdx需要进行相互转化，对吧
        trainIdx2orderIdx = {trainIdx : orderIdx for trainIdx, orderIdx in zip(batch_all_nodes, range(len(batch_all_nodes)))}
        orderIdx2trainIdx = (lambda d: dict(zip(d.itervalues(),d.iterkeys())))(trainIdx2orderIdx)

        # TODO 可是score不是在每个intra层里单独算的吗？？？
        # 先把batch all的logits都算完
        batch_all_logits = self.label_linear(batch_all_features)
        # 注意到，pos mask中的点很明显可能不在batch all中
        pos_logits = self.label_linear(pos_features)

        # 提取一些特定点的logits
        # 提取本batch center点的logits
        batch_center_logits = batch_all_logits[itemgetter(*batch_center_mask)(trainIdx2orderIdx)]

        r1_list = [list(to_neigh) for to_neigh in to_neighs[0]]
        r2_list = [list(to_neigh) for to_neigh in to_neighs[1]]
        r3_list = [list(to_neigh) for to_neigh in to_neighs[2]]
        # rx_list: [[此关系下某个点的所有邻居],[],[]...,[]]

        r1_embeds = self.intra1.forward(
            batch_center_mask,
            batch_center_label,
            self.train_pos_mask,
            r1_list,
            batch_center_logits,
            batch_all_logits,
            pos_logits,
            trainIdx2orderIdx,
            orderIdx2trainIdx,
            self.avg_half_neigh_size[0]
        )
        r2_embeds = self.intra1.forward(
            batch_center_mask,
            batch_center_label,            
            self.train_pos_mask,            
            r2_list,
            batch_center_logits,
            batch_all_logits,
            pos_logits,
            trainIdx2orderIdx,
            orderIdx2trainIdx,
            self.avg_half_neigh_size[1]
        )
        r3_embeds = self.intra1.forward(
            batch_center_mask,
            batch_center_label,
            self.train_pos_mask,
            r3_list,
            batch_center_logits,
            batch_all_logits,
            pos_logits,
            trainIdx2orderIdx,
            orderIdx2trainIdx,
            self.avg_half_neigh_size[2]
        )
        # rx_embeds -> (|B|,out_dim)
        all_relation_and_self_embeds = torch.cat([self.features[batch_center_mask],r1_embeds,r2_embeds,r3_embeds],dim=1)
        proj_all_embeds = F.relu(self.proj(all_relation_and_self_embeds))

        return proj_all_embeds, batch_center_logits



PC-GNN消息传递

In [12]:
class PCGNN(nn.Module):
    """ 
    一层PC-GNN用以message passing -> 核心特点是，居然需要label=。= \\
    论文源代码居然直接把一层当整个模型了……好🐕啊
    """
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int,
        # adj_lists: List[defaultdict],
        # train_pos_mask: list,
        device: torch.device,
        normalize = True,
        num_classes: int = 2,
        bias = False, 
    ):  
        """
        一层Pick&Choose 的 message passager \\
        难道inter层不应该在这里进行声明吗…… \\
        :param in_channels: 输入的特征维数
        :param out_channels: 输出的特征维数
        :param num_classes: 最后需要做节点分类的类数
        """
        super(PCGNN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.interAgg = InterAgg(
            feature_dim=in_channels,
            output_dim=out_channels,
            # adj_lists=adj_lists,
            # train_pos_mask=train_pos_mask,
            device=device
        )

        self.reset_parameters()

    def reset_parameters(self):
        # self.lin_l.reset_parameters()
        # self.lin_r.reset_parameters()
        pass

    def forward(self, features ,labels, batch_mask, train_pos_mask, adj_lists, train_flag = True):
        """
        :param features: (|N|, input_channels)
        :param labels: (|N|,)
        :param batch_mask: (|B|,)在此次过程中需要考察的中心点的mask
        :param train_pos_mask:
        :param adj_lists:
        :return output_embeds: (|B|,out_dim)
        :return label_scores: (|B|,2)
        """

        embeds, logits = self.interAgg(
            features=features,
            batch_center_mask=batch_mask,
            batch_center_label=labels[batch_mask],
            train_pos_mask=train_pos_mask,
            adj_lists=adj_lists
        )

        return embeds, logits




一个args实例：

{'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32,
        'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01},

### GNNStack 范式

In [13]:
class GNNStack(torch.nn.Module):
    def __init__(
        self, 
        input_dim : int, 
        hidden_dim: int,
        output_dim : int,
        device: torch.device,
        num_classes : int = 2,
        num_layers : int = 1,
        dropout : float = 0.5, 
        heads : int = 1,
        model_type : str = 'PCGNN',
        emb : bool = False
        ) -> None:
        super(GNNStack, self).__init__()

        conv_model = self.build_conv_model(model_type)
        # self.convs = nn.ModuleList()
        # self.convs.append(conv_model(
        #     in_channels=input_dim,
        #     out_channels=hidden_dim
        # ))
        assert (num_layers >= 1), 'Number of layers is not >=1'
        if num_layers == 1:
            self.convs = conv_model(
                in_channels=input_dim,
                out_channels=output_dim,
                device=device
            )
        else: 
            # for l in range(num_layers):
            #     if l == 0:
            #         self.convs.append(conv_model(input_dim,hidden_dim))
            #     elif l == num_layers-1:
            #         self.convs.append(conv_model(heads*hidden_dim,output_dim))
            #     else:
            #         self.convs.append(conv_model(heads*hidden_dim, hidden_dim))
            raise NotImplementedError

        # # post-message-passing
        # self.post_mp = nn.Sequential(
        #     nn.Linear(heads * hidden_dim, hidden_dim), nn.Dropout(self.dropout), 
        #     nn.Linear(hidden_dim, output_dim))
        # # 嗯……其实我一共设了num_layers+1层

        # PCGNN最后还有一个线性层用于node分类！
        self.final_proj = nn.Linear(output_dim,num_classes)

        self.dropout = dropout
        self.num_layers = num_layers

        self.emb = emb
        self.criterion = nn.CrossEntropyLoss()

    def build_conv_model(self, model_type):
        if model_type == 'PCGNN':
            return PCGNN
        else: 
            raise NotImplementedError

    def forward(self, features, labels, batch_mask, train_pos_mask, adj_lists):
        """ 
        :param features: 所有点的features
        :param labels: 所有点的label
        :param batch_mask: 本批次点的mask 或者是之后test/valid的时候的mask
        :param train_pos_mask:
        :param adj_lists:
        """
        # x, edge_index, batch = data.x, data.edge_index, data.batch        

        for i in range(self.num_layers):
            embeds, logits = self.convs(features,labels,batch_mask,train_pos_mask,adj_lists)
            # x = F.relu(x)
            # x = F.dropout(x, p=self.dropout,training=self.training)

        # x = self.post_mp(x)
        if self.emb == True:
            return embeds

        return embeds, logits

    def loss(self, features, labels, batch_mask, train_pos_mask, adj_lists):
        """ 
        PCGNN的loss包括两个:loss_{gnn}和loss_{dist}
        """
        embeds,logits = self.forward(
            features=features,
            labels=labels,
            batch_mask=batch_mask,
            train_pos_mask=train_pos_mask,
            adj_lists=adj_lists            
        )

        # PCGNN有两个loss
        # loss_{gnn}
        gnn_pred = self.final_proj(embeds)
        gnn_loss = self.criterion(gnn_pred,labels[batch_mask].squeeze())

        # loss_{dist}
        dist_loss = self.criterion(logits,labels[batch_mask].squeeze())

        return gnn_loss + dist_loss


### Optimizer构建

In [11]:
def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p: p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr,
                               weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr,
                              momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(
            filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(
            filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.opt_restart)
    return scheduler, optimizer


一个args实例：

{'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32,
        'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01},

In [3]:
def show_data(dataset_name:str = 'Amazon'):
    # print(f"Node Classification. test set size: {graph.ndata['train_mask'].sum().item()}")
    print()

    data = constructPyGHeteroData(dataset_name)
    # data是一个heterogeneous图，一种节点，三种关系
    print(data)


train()


HeteroData(
  homo=[11944],
  [1mreview[0m={
    x=[11944, 25],
    y=[11944],
    train_mask=[3455],
    valid_mask=[1710],
    test_mask=[3474]
  },
  [1m(review, r1, review)[0m={
    adj=[11944, 11944],
    adj_list=[11944],
    edge_index=[2, 351216]
  },
  [1m(review, r2, review)[0m={
    adj=[11944, 11944],
    adj_list=[11944],
    edge_index=[2, 7132958]
  },
  [1m(review, r3, review)[0m={
    adj=[11944, 11944],
    adj_list=[11944],
    edge_index=[2, 2073474]
  }
)


### PCGNN模型

In [14]:
class ModelHandler():

    def __init__(
        self, 
        data: HeteroData,
        data_name:str = 'Amazon',
        random_seed:int = 42,
        use_cuda:bool = True,
        opt:str = 'adam',
        weight_decay:float = 5e-3,
        lr:float = 0.01,
        dropout:float = 0.5,
        num_layers:int = 2,
        num_epochs:int = 50,
        batch_size:int = 256
        ) -> None:
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

        # preprare data
        self.data = data
        # 关于此数据集有个很重要的事，即，node数据都是ndarray格式，relation的数据都是tensor！
        # TODO normlize feature??

        if use_cuda and torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        else:
            self.device = torch.device('cpu')

        self.opt, self.weight_decay,self.lr = opt,weight_decay,lr
        self.num_epochs = num_epochs
        self.batch_size = batch_size

    def train(self):
        feat_data,label_data = self.data['review'].x,self.data['review'].y
        train_mask,valid_mask,test_mask = self.data['review'].train_mask,self.data['review'].valid_mask,self.data['review'].test_mask
        train_pos_mask = self.data['review'].train_pos_mask
        adj_lists = [
            self.data['review','r1','review'].adj_list[0],
            self.data['review','r2','review'].adj_list[0],
            self.data['review','r3','review'].adj_list[0],
        ]

        features = nn.Embedding(feat_data.shape[0], feat_data.shape[1])
        features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False).to(self.device)

        # 我们只使用PCGNN
        GNN = GNNStack(
            input_dim=feat_data.shape[1],
            hidden_dim=64,
            output_dim=64
        )

        # optimizer
        if(self.opt == 'adam'):
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, GNN.parameters()), lr=self.lr, weight_decay=self.weight_decay)
        else:
            raise NotImplementedError("This optimizer is not implemented yet.")
        
        for epoch in tqdm(range(self.num_epochs)):
            # Pick阶段，借助labelBalancedSampler进行一个降采样
            train_mask = LabelBalancedSampler(train_mask,label_data[train_mask],self.data['homo_adj_list'][0],)
            # 这里的train mask是经过概率pick过的！

            # 我们在准备数据集时已经准备好了train_pos/neg_mask
            num_batches = int(len(train_mask) / self.batch_size) + 1

            loss = 0.

            # 开始batch训练
            for batch in range(num_batches):
                ind_start = batch*self.batch_size
                ind_end = min(batch*self.batch_size,len(train_mask))
                batch_nodes_mask = train_mask[ind_start:ind_end]
                # batch_label = label_data[batch_nodes_mask]
                # TODO 这里的类型可能存在问题

                optimizer.zero_grad()
                loss = GNN.loss(
                    features=feat_data,
                    labels=label_data,
                    batch_mask=batch_nodes_mask,
                    train_pos_mask=train_pos_mask,
                    adj_lists=adj_lists
                )

                loss.backward()
                optimizer.step()
                loss += loss.item()

                print(f'Epoch: {epoch}, loss: {loss.item() / num_batches}')
                


开始漫漫debug……

In [3]:
data = constructPyGHeteroData()


In [19]:

model = ModelHandler(data)
model.train()