In [3]:
import torch
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops,degree

In [4]:
# 메시지패싱 내부 구조
class GCNConv(MessagePassing):
    def __init__(self,dim_in,dim_h):
        super().__init__(aggr='add')
        self.linear1 = Linear(dim_in,dim_h,bias=False)
    
    def forward(self,x,edge_index):
        edge_index,_ = add_self_loops(edge_index,num_nodes=x.size(0))
        x = self.linear1(x)
        row, col = edge_index
        deg = degree(col,x.size(0),dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt==float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        out = self.propagate(edge_index,x=x,norm=norm)
        return out


In [5]:
conv = GCNConv(16,32)
print(conv)

GCNConv()


In [6]:
from torch_geometric.data import HeteroData
data = HeteroData()

In [7]:
data['user'].x = torch.Tensor([[1,1,1,1],[2,2,2,2],[3,3,3,3]])
data['game'].x = torch.Tensor([[1,1],[2,2]])
data['dev'].x = torch.Tensor([[1],[2]])

In [8]:
data['user', 'follows', 'user'].edge_index = torch.Tensor([[0, 1], [1, 2]])
data['user', 'plays', 'game'].edge_index = torch.Tensor([[0, 1, 1, 2], [0, 0, 1, 1]])
data['dev', 'develops', 'game'].edge_index = torch.Tensor([[0, 1], [0, 1]])

In [9]:
data

HeteroData(
  user={ x=[3, 4] },
  game={ x=[2, 2] },
  dev={ x=[2, 1] },
  (user, follows, user)={ edge_index=[2, 2] },
  (user, plays, game)={ edge_index=[2, 4] },
  (dev, develops, game)={ edge_index=[2, 2] }
)

GAT + MataPath

In [10]:
from torch import nn
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import GAT

In [11]:
metapaths = [[('author','paper'),('paper','author')]]

In [12]:
transform = T.AddMetaPaths(metapaths=metapaths,drop_orig_edge_types=True) # 기존 그래프에서 나머지 타입은 제거

In [13]:
dataset = DBLP(root='.',transform=transform)
data = dataset[0]
print(data)

HeteroData(
  metapath_dict={ (author, metapath_0, author)=[2] },
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057],
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={ num_nodes=20 },
  (author, metapath_0, author)={ edge_index=[2, 11113] }
)


  out = torch.matmul(sparse_input, other)


In [14]:
model = GAT(in_channels=-1, hidden_channels=64,out_channels=4,num_layers=1)

In [15]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

In [16]:
@torch.no_grad()
def test(mask):
    model.eval()
    pred = model(data.x_dict['author'],data.edge_index_dict[('author','metapath_0','author')]).argmax(dim=-1)
    acc = (pred[mask]==data['author'].y[mask]).sum() / mask.sum()
    
    return float(acc)

In [17]:
for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict['author'],data.edge_index_dict[('author','metapath_0','author')])
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask],data['author'].y[mask])
    loss.backward()
    optimizer.step()

    if epoch%20==0:
        train_acc = test(data['author'].train_mask)
        val_acc = test(data['author'].val_mask)
        print(f'Epoch : {epoch:>3} | Train_loss : {loss:.4f} | Train_acc : {train_acc:.4f} | Val_Acc : {val_acc:.4f}')

Epoch :   0 | Train_loss : 1.4032 | Train_acc : 0.2775 | Val_Acc : 0.2625
Epoch :  20 | Train_loss : 1.2387 | Train_acc : 0.5250 | Val_Acc : 0.4475
Epoch :  40 | Train_loss : 1.1152 | Train_acc : 0.6400 | Val_Acc : 0.6325
Epoch :  60 | Train_loss : 1.0196 | Train_acc : 0.7400 | Val_Acc : 0.6825
Epoch :  80 | Train_loss : 0.9406 | Train_acc : 0.8000 | Val_Acc : 0.7175
Epoch : 100 | Train_loss : 0.8725 | Train_acc : 0.8300 | Val_Acc : 0.7250


In [18]:
test_acc = test(data['author'].test_mask)
print(f'test_acc : {test_acc:.4f}')

test_acc : 0.7455


GAT 이종 버전

In [19]:
from torch_geometric.nn import GATConv,Linear,to_hetero

In [20]:
dataset= DBLP('.')
data = dataset[0]

In [21]:
print(data)

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057],
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={ num_nodes=20 },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)


In [22]:
data['conference'].x = torch.zeros(20,1) #conference 특성값이 존재하지 않으므로 0으로 설정.

In [23]:
class GAT(torch.nn.Module):
    def __init__(self,dim_h,dim_out):
        super().__init__()
        self.conv = GATConv((-1,-1),dim_h,add_self_loops=False)
        self.linear = nn.Linear(dim_h,dim_out)

    def forward(self,x,edge_index):
        h = self.conv(x,edge_index).relu()
        h = self.linear(h)
        return h

In [24]:
model = GAT(dim_h=64,dim_out=4)
model = to_hetero(model,data.metadata(),aggr='sum')

In [27]:
data.metadata()

(['author', 'paper', 'term', 'conference'],
 [('author', 'to', 'paper'),
  ('paper', 'to', 'author'),
  ('paper', 'to', 'term'),
  ('paper', 'to', 'conference'),
  ('term', 'to', 'paper'),
  ('conference', 'to', 'paper')])

In [None]:
model #각 메타패스에 대한 레이어

GraphModule(
  (conv): ModuleDict(
    (author__to__paper): GATConv((-1, -1), 64, heads=1)
    (paper__to__author): GATConv((-1, -1), 64, heads=1)
    (paper__to__term): GATConv((-1, -1), 64, heads=1)
    (paper__to__conference): GATConv((-1, -1), 64, heads=1)
    (term__to__paper): GATConv((-1, -1), 64, heads=1)
    (conference__to__paper): GATConv((-1, -1), 64, heads=1)
  )
  (linear): ModuleDict(
    (author): Linear(in_features=64, out_features=4, bias=True)
    (paper): Linear(in_features=64, out_features=4, bias=True)
    (term): Linear(in_features=64, out_features=4, bias=True)
    (conference): Linear(in_features=64, out_features=4, bias=True)
  )
)

In [29]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data,model = data.to(device),model.to(device)

In [35]:
@torch.no_grad()
def test(mask):
    model.eval()
    pred = model(data.x_dict,data.edge_index_dict)['author'].argmax(dim=-1)
    acc = (pred[mask]==data['author'].y[mask]).sum() / mask.sum()
    return float(acc)

In [36]:
for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict,data.edge_index_dict)['author']
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask],data['author'].y[mask])
    loss.backward()
    optimizer.step()

    if epoch%20==0:
        train_acc = test(data['author'].train_mask)
        val_acc = test(data['author'].val_mask)
        print(f'Epoch : {epoch:>3} | Train_loss : {loss:.4f} | Train_acc : {train_acc*100:.4f} | Val_acc : {val_acc*100:.4f}')


Epoch :   0 | Train_loss : 1.3790 | Train_acc : 37.5000 | Val_acc : 30.5000
Epoch :  20 | Train_loss : 1.1876 | Train_acc : 92.0000 | Val_acc : 67.2500
Epoch :  40 | Train_loss : 0.8396 | Train_acc : 97.7500 | Val_acc : 71.5000
Epoch :  60 | Train_loss : 0.4799 | Train_acc : 98.5000 | Val_acc : 75.2500
Epoch :  80 | Train_loss : 0.2456 | Train_acc : 99.5000 | Val_acc : 75.5000
Epoch : 100 | Train_loss : 0.1381 | Train_acc : 100.0000 | Val_acc : 75.0000


In [38]:
test_acc = test(data['author'].test_mask)
print(f'Test_acc : {test_acc*100:.4f}')

Test_acc : 77.5560


HAN 구현

In [46]:
data

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057],
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1],
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)

In [47]:
from torch_geometric.nn import HANConv

dataset = DBLP(root = '.')
data = dataset[0]
data['conference'].x = torch.zeros(20,1)

In [50]:
class HAN(torch.nn.Module):
    def __init__(self,dim_in,dim_out,dim_h=128,heads=8):
        super().__init__()
        self.han = HANConv(dim_in,dim_h,heads=heads,dropout=0.6,metadata=data.metadata())
        self.linear = nn.Linear(dim_h,dim_out)

    def forward(self,x_dict,edge_index_dict):
        out = self.han(x_dict,edge_index_dict)
        out = self.linear(out['author'])
        return out

In [54]:
model = HAN(dim_in = -1, dim_out = 4)
print(model)

HAN(
  (han): HANConv(128, heads=8)
  (linear): Linear(in_features=128, out_features=4, bias=True)
)


In [55]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

In [60]:
@torch.no_grad()
def test(mask):
    model.eval()
    pred = model(data.x_dict,data.edge_index_dict).argmax(dim=-1) # 클래스에서 author에 대해서 이미 반환됨
    acc = (pred[mask]== data['author'].y[mask]).sum() / mask.sum()
    return float(acc)

In [61]:
for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict,data.edge_index_dict)
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask],data['author'].y[mask])
    loss.backward()
    optimizer.step()

    if epoch%20==0:
        train_acc = test(data['author'].train_mask)
        val_acc = test(data['author'].val_mask)
        print(f'Epoch : {epoch:>3} | train_acc : {train_acc*100:.4f} | val_acc : {val_acc*100:.4f}')

Epoch :   0 | train_acc : 30.2500 | val_acc : 28.2500
Epoch :  20 | train_acc : 93.7500 | val_acc : 69.2500
Epoch :  40 | train_acc : 96.0000 | val_acc : 72.0000
Epoch :  60 | train_acc : 98.2500 | val_acc : 77.5000
Epoch :  80 | train_acc : 99.5000 | val_acc : 78.2500
Epoch : 100 | train_acc : 99.7500 | val_acc : 79.2500


In [64]:
test_acc = test(data['author'].test_mask)
print(f'Test_acc : {test_acc*100:.4f}%')

Test_acc : 81.8545%
