## 创建torch_geometric中的图

In [5]:
import torch
from torch_geometric.data import Data

In [6]:
x = torch.tensor([[2,1], [5,6],[3,7],[12,0]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2, 0, 3],
                           [1, 0, 1, 3, 2]], dtype=torch.long)
# 边的顺序无所谓

data = Data(x=x, y=y, edge_index=edge_index)
data

Data(x=[4, 2], edge_index=[2, 5], y=[4])

## 电商转化率预测
+ yoochoose-clicks: 用户的浏览行为，其中一个session_id就表示一次登录都浏览了啥东西
+ item_id就是他所浏览的商品，其中yoochoose-buys描述了他最终是否购买，也就是标签

In [10]:
from sklearn.preprocessing import LabelEncoder
import pandas as pd


In [None]:
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm

class MyDataset(InMemoryDataset):
    def __init__(self, root, transform=True, pre_transform=None):
        super(MyDataset, self).__init__(root, transform, pre_transform) # transform就是数据增强，对每一个数据都执行
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    # 检查self.raw_dir目录下是否存在raw_file_names()属性方法返回的每个文件
    # 如果文件不存在，则调用download()方法执行原始文件下载
    @property
    def raw_file_names(self):
        return []
    
    # 检查self.processed_dir目录下是否存在self.processed_file_names属性方法返回的所有文件，没有就会走process
    @property
    def processed_file_names(self):
        return ['MyDataset.dataset']
    
    def download(self):
        pass
    
    def process(self):
        # 自己定义数据处理方法，处理为torch_geometric可用的Data类型
        pass

## 构建网络模型
+ 与CNN的卷积核池化类似

In [16]:
embed_dim = 128
from torch_geometric.nn import TopKPooling, SAGEConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = SAGEConv(embed_dim, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        self.item_embedding = torch.nn.Embedding(num_embeddings=df.item_id.max() + 10, embedding_dim=embed_dim)
        self.lin1 = torch.nn.Linear(128, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, 1)
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()
        
    def forward(self, data):
        # x： n * 1, 其中每个图中点的个数是不同的
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.item_embedding(x)
        x = x.squeeze(1) # n*128
        x = F.relu(self.conv1(x, edge_index))
        x, edeg_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) # pool 之后得到 n*128个点
        x1 = gap(x, batch)
        x = F.relu(self.conv2(x, edeg_index))
        x, edeg_index, _, batch, _, _ = self.pool2(x, edeg_index, None, batch)
        x2 = gap(x, batch)
        x = F.relu(self.conv3(x, edge_index))
        x, edeg_index, _, batch, _, _ = self.pool3(x, edeg_index, None, batch)
        x3 = gap(x, batch)
        x = x1 + x2 + x3 # 获取不同尺度的全局特征
        
        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)
        x = F.dropout(x, p=0.5, training=self.training)
        
        
        x = torch.sigmoid(self.lin3(x).squeeze(1)) # batch个结果
        
        return x        

In [17]:
from torch_geometric.loader import DataLoader

def train():
    model.train()
    
    loss_all = 0
    for data in train_loader:
        data = data
        opt.zero_grad()
        output = model(data)
        label = data.y
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        opt.step()
    return loss_all / len(dataset)

model = Net()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
crit = torch.nn.BCELoss()
train_laader = DataLoader(dataset, batch_size=64)
for epoch in range(10):
    print("epoch:", epoch)
    loss = train()
    print(loss)

NameError: name 'df' is not defined