In [44]:
from torchdrug import  layers, datasets,transforms,core
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据集获取

In [45]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )


graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)

16:12:18   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:45<00:00, 412.44it/s]





PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3754, 8999], num_residues=[185, 415])


In [46]:
#graph = graph.to(device)

print(graph.num_nodes)
print(graph.batch_size)
print(graph.edge_list)

tensor([185, 415])
2
tensor([[ 95,  96,   5],
        [109, 110,   5],
        [108, 109,   5],
        ...,
        [438, 470,   0],
        [489, 470,   0],
        [493, 470,   0]])


# 关系卷积神经网络，获取多个不同的嵌入

In [47]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
    
        Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
        A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
        
        n_rel = graph.num_relation
        n = A_norm.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A_norm[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A_norm[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            edge_weight = torch.ones_like(node_out)
            degree_out = scatter_add(edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            degree_out = degree_out
            edge_weight = edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t(), input)
        
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

#### 测试

In [48]:
input_dim = graph.node_feature.shape[-1]
output_dim = 64
num_relations = graph.num_relation

relational_output = relationalGraph(input_dim, output_dim, num_relations)(graph, graph.node_feature.float(), new_edge_list = None)
print(relational_output)
print("output: ", relational_output.shape)

tensor([[0.0000, 0.1482, 0.0119,  ..., 0.1540, 0.0000, 0.3477],
        [0.1279, 0.0225, 0.0904,  ..., 0.2627, 0.0000, 0.2235],
        [0.1890, 0.2838, 0.0000,  ..., 0.4515, 0.0000, 0.1637],
        ...,
        [0.0632, 0.0000, 0.1168,  ..., 0.2792, 0.0000, 0.0075],
        [0.1890, 0.2838, 0.0000,  ..., 0.4515, 0.0000, 0.1637],
        [0.0000, 0.2916, 0.0000,  ..., 0.6574, 0.2065, 0.1689]],
       grad_fn=<ReluBackward0>)
output:  torch.Size([4200, 64])


# 重连接模块

### 只从点所在的图进行reconnection

In [6]:
import torch

def get_start_end(current, graph):
    """
    根据一维张量 a 生成新的张量 b。
    
    :param a: 输入的一维张量
    :return: 输出的一维张量 b
    """
    # 初始化 b，第一个元素是 0
    segment = graph.num_nodes.repeat(graph.num_relation)
    index = torch.zeros(segment.size(0) + 1, dtype=segment.dtype)
    
    # 计算 b 的每个元素
    for i in range(1, len(index)):
        index[i] = index[i - 1] + segment[i - 1]
    
    # 遍历张量以找到索引值的位置
    for i in range(len(index) - 1):
        if index[i] <= current < index[i + 1]:
            return (index[i].item(), index[i + 1].item())
        elif index[i] == current:
            return (index[i].item(), index[i + 1].item())
    
    # 如果索引值恰好等于张量的最后一个元素
    if current == index[-1]:
        return (index[-1].item(), index[-1].item())
    
    

# 示例使用
start, end = get_start_end(184, graph)
print(start, end)


0 185


In [11]:
def get_start_end2( current, graph):
    segment = graph.num_nodes.repeat(graph.num_relation)
    index = torch.cumsum(segment, dim=0)
    
    # Use torch.searchsorted to find the appropriate segment
    pos = torch.searchsorted(index, current, right=True)

    if pos == 0:
        return (0, index[0].item())
    elif pos >= len(index):
        return (index[-1].item(), index[-1].item())
    else:
        return (index[pos-1].item(), index[pos].item())
    
start, end = get_start_end2(600, graph)
print(start, end)


600 785


### Gumble-softmax采样

In [95]:
import torch
import torch.nn.functional as F

def gumbel_softmax_top_k(logits, tau=0.5, k=1, hard=False):
    """
    Gumbel-Softmax采样方法，每一步都是可微分的，并选择最大的k个元素。
    
    参数:
        logits (torch.Tensor): 输入logits张量，维度为 (batch_size, num_classes)。
        tau (float): Gumbel-Softmax的温度参数，控制平滑程度。
        k (int): 选择最大的k个元素。
        hard (bool): 是否返回硬分类结果。
    
    返回:
        torch.Tensor: Gumbel-Softmax采样结果，维度为 (batch_size, num_classes)。
    """
    # 获取Gumbel分布噪声
    gumbels = -torch.empty_like(logits).exponential_().log()  # 生成Gumbel(0,1)噪声
    gumbels = (logits + gumbels) / tau  # 添加噪声并除以温度参数

    # 计算softmax
    y_soft = F.softmax(gumbels, dim=-1)  # 维度为 (batch_size, num_classes)

    if hard:
        # 硬分类结果：选取原始logits最大的k个位置
        topk_indices = logits.topk(k, dim=-1)[1]  # 获取前k个元素的索引
        y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)  # 生成one-hot向量
        # 使用直通估计器
        y = (y_hard - y_soft).detach() + y_soft
    else:
        y = y_soft

    return y

# 示例用法
if __name__ == "__main__":
    logits = torch.randn(5, 10)  # 维度为 (batch_size=5, num_classes=10)
    tau = 0.5
    k = 3
    hard = True
    samples = gumbel_softmax_top_k(logits, tau, k, hard)
    print(samples)
    samples = gumbel_softmax_top_k(logits, tau, k)
    print(samples)


tensor([[0., 0., 0., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 1., 1.],
        [1., 0., 0., 1., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 1., 1.]])
tensor([[4.8048e-01, 3.7910e-04, 2.6873e-04, 4.0213e-03, 1.1619e-01, 5.7953e-04,
         3.6010e-01, 6.4904e-04, 1.2300e-02, 2.5035e-02],
        [5.5033e-04, 6.5851e-05, 1.9387e-05, 5.5963e-05, 9.5151e-02, 1.5899e-01,
         5.4683e-05, 1.9378e-01, 5.4991e-01, 1.4249e-03],
        [9.8238e-01, 2.3563e-05, 5.6371e-04, 5.5736e-03, 8.7534e-05, 5.6787e-05,
         8.2759e-03, 2.6395e-03, 7.3919e-05, 3.2825e-04],
        [6.5204e-05, 4.4989e-04, 9.1505e-03, 3.3549e-02, 8.4767e-03, 9.2767e-05,
         4.9709e-03, 1.1033e-03, 9.4183e-01, 3.1603e-04],
        [2.2411e-04, 3.5781e-06, 1.3032e-03, 9.5245e-07, 5.9685e-05, 1.1207e-04,
         5.7596e-05, 5.7783e-03, 9.1864e-01, 7.3822e-02]])


### window self attention + gumble softmax

In [49]:

class Rewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5, dropout=0.1):
        super(Rewirescorelayer, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        
        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.value = nn.Linear(in_features, out_features * num_heads)
        self.scale = torch.sqrt(torch.FloatTensor([out_features])).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # get the start and end indices for each window of nodes
    def get_start_end(self, current, graph):
        segment = graph.num_nodes.repeat(graph.num_relation)
        index = torch.cumsum(segment, dim=0)
        
        # Use torch.searchsorted to find the appropriate segment
        pos = torch.searchsorted(index, current, right=True)

        if pos == 0:
            return (0, index[0].item())
        elif pos >= len(index):
            return (index[-1].item(), index[-1].item())
        else:
            return (index[pos-1].item(), index[pos].item())

    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y

    def forward(self, graph, node_features):
        device = node_features.device
        num_nodes = node_features.size(0)
    
        
         # calculate the start and end indices for each node
        half_window = self.window_size // 2
        start_end_indices = [self.get_start_end(i, graph) for i in range(num_nodes)]
        start_indices = torch.tensor([max(start_end_indices[i][0], i - half_window) for i in range(num_nodes)], device=device)
        end_indices = torch.tensor([min(start_end_indices[i][1], i + half_window) for i in range(num_nodes)], device=device)
        window_sizes = end_indices - start_indices

        Q = self.query(node_features).view(num_nodes, self.num_heads, self.out_features)     # [num_nodes, num_heads, out_features]
        K = self.key(node_features).view(num_nodes, self.num_heads, self.out_features)       # [num_nodes, num_heads, out_features]
        
        output = torch.zeros(num_nodes, num_nodes, device=device)
        all_scores = torch.zeros(num_nodes, self.num_heads, self.window_size, device=device)
        
        # construct the windowed K matrix
        K_windows = torch.zeros((num_nodes, self.window_size, self.num_heads, self.out_features), device=device)
        for i in range(num_nodes):
            start = start_indices[i]
            end = end_indices[i]
            K_windows[i, :end-start] = K[start:end]


        Q_expanded = Q.unsqueeze(2)  # [num_nodes, num_heads, 1, out_features]
        K_expanded = K_windows.permute(0, 2, 1, 3)  # [num_nodes, num_heads, max_window_size, out_features]


        all_scores = torch.matmul(Q_expanded, K_expanded.transpose(-1, -2)) / self.scale  # [num_nodes, num_heads, 1, max_window_size]
        all_scores = all_scores.squeeze(2)  # [num_nodes, num_heads, max_window_size]
        

        mask = (torch.arange(self.window_size, device=device).expand(num_nodes, self.window_size) < window_sizes.unsqueeze(1))


        all_scores = all_scores.masked_fill(~mask.unsqueeze(1), float('-inf'))  # [num_nodes, num_heads, max_window_size]
        attention_weights = F.softmax(all_scores / self.temperature, dim=-1)  # [num_nodes, num_heads, max_window_size]
        attention_weights = attention_weights.mean(dim=1) 
        
        # sample edges
        for i in range(num_nodes):
            start = start_indices[i]
            end = end_indices[i]
            output[i, start:end] = attention_weights[i, :end-start]

        edge_list = self.gumbel_softmax_top_k(output, self.temperature, self.k)

        return edge_list

#### 测试

In [50]:
input_dim = relational_output.shape[-1]
output_dim = 256
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 3
num_nodes = relational_output.size(0)


relational_output = relational_output.to(device)
module = Rewirescorelayer(input_dim, output_dim, num_heads, window_size, k).to(device)

In [13]:
attn_output = module(graph, relational_output)
print(attn_output.shape)
print(attn_output)
#a = attn_output[1200:, 599:1199]
indices = torch.nonzero(attn_output, as_tuple=True)
print(indices)

torch.Size([4200, 4200])
tensor([[0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 0., 1., 1.]], device='cuda:0',
       grad_fn=<AddBackward0>)
(tensor([   0,    0,    0,  ..., 4199, 4199, 4199], device='cuda:0'), tensor([   2,    3,    4,  ..., 4196, 4198, 4199], device='cuda:0'))


### 测试不同degree进行采样

In [14]:
import torch
import torch.nn.functional as F

def gumbel_softmax_sample(logits, tau):
    # 从Gumbel(0, 1)分布中采样
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    y = logits + gumbel_noise
    return F.softmax(y / tau, dim=-1)

def bernoulli_sampling_with_different_thresholds(probs, thresholds, tau=1.0):
    # 对数概率
    logits = torch.log(probs) - torch.log(1 - probs)
    # 进行Gumbel-Softmax采样
    y = gumbel_softmax_sample(logits, tau)
    # 硬化处理，根据每行的不同阈值
    z = (y > thresholds.unsqueeze(1)).float()
    return z

# 示例矩阵
n = 10
P = torch.rand(n, n)

# 为每一行设置不同的阈值
thresholds = torch.tensor([0.2, 0.5, 0.7, 0.9, 0.1, 0.3, 0.6, 0.8, 0.4, 0.2])

# 进行可微分伯努利采样并硬化处理
tau = 0.1  # 温度参数
sampled_matrix = bernoulli_sampling_with_different_thresholds(P, thresholds, tau)

#print("Probability Matrix:\n", P)
print("Thresholds:\n", thresholds)
print("Sampled Matrix:\n", sampled_matrix)


Thresholds:
 tensor([0.2000, 0.5000, 0.7000, 0.9000, 0.1000, 0.3000, 0.6000, 0.8000, 0.4000,
        0.2000])
Sampled Matrix:
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])


# Diffusion模块

### 计算degree矩阵，变换adjacent matrix形式

In [9]:
def trans(A, graph):
    
    Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
    A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
    
    n_rel = graph.num_relation
    n = A_norm.size(0)
    n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
    assert n % n_rel == 0, "n must be divisible by n_rel"
    
    block_size = n // n_rel
    
    # 初始化一个张量来存储累加结果
    accumulated = torch.zeros_like(A_norm[:block_size])
    
    # 将后面的所有块累加到第一块
    for i in range(n_rel):
        accumulated += A_norm[i * block_size: (i + 1) * block_size]
    
    # 用累加后的第一块替换原始矩阵的第一块
    A_trans = accumulated
    
    
    
    return A_trans

A_norm = trans(attn_output, graph)
print(A_norm)
print(A_norm.shape)

indices = torch.nonzero(A_norm, as_tuple=True)
print(indices)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
update = torch.mm(A_norm.t(), graph.node_feature.to(device).to(torch.float))
print(update.shape)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([600, 4200])
(tensor([  0,   0,   0,  ..., 599, 599, 599], device='cuda:0'), tensor([   3,    6,    8,  ..., 4193, 4195, 4196], device='cuda:0'))
torch.Size([4200, 21])


### Rewired_gearnet 用于diffusion模块

In [15]:
class RewireGearnet(nn.Module):
    gradient_checkpoint = False

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(RewireGearnet, self).__init__()
        self.num_relation = num_relation
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        self.self_loop = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim) if batch_norm else None
        self.activation = getattr(F, activation) if activation else None
        self.edge_linear = nn.Linear(edge_input_dim, output_dim) if edge_input_dim else None

    def trans(self, A, graph):
        n_rel = graph.num_relation
        n = A.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list=None):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency

        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
            adjacency = torch.sparse_coo_tensor(
                torch.stack([node_in, node_out]),
                graph.edge_weight.to(device),
                (graph.num_node, graph.num_node * graph.num_relation),
                device=device
            )
            update = torch.sparse.mm(adjacency.t(), input)
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(graph.num_node, self.num_relation * self.input_dim).to(device)

    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None, new_edge_weight=None):
        """
        Perform message passing over the graph(s).

        Parameters:
            graph (Graph): graph(s)
            input (Tensor): node representations of shape :math:`(|V|, ...)`
        """
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

#### 测试

In [16]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


new_node_feature = RewireGearnet(input_dim, output_dim, num_relations)(graph, graph.node_feature.to(device).float(), attn_output).to(device)

print(new_node_feature.shape)
print(new_node_feature)

torch.Size([600, 512])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1172, 0.0000],
        [0.0000, 0.2223, 0.0000,  ..., 0.0000, 0.3193, 0.1439],
        [0.0903, 0.1699, 0.3812,  ..., 0.0723, 0.2250, 0.0000],
        ...,
        [0.2997, 0.1000, 0.0000,  ..., 0.4028, 0.8425, 0.0000],
        [0.3565, 0.4857, 0.1668,  ..., 0.0000, 0.1029, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.3299, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<ReluBackward0>)


# 最终模型

In [17]:
class DGMGearnet(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, num_relation, num_heads, window_size, k, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
        super(DGMGearnet, self).__init__()

        #if not isinstance(hidden_dims, Sequence):
            #hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm

        self.layers = nn.ModuleList()
        self.score_layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            
            self.score_layers.append(relationalGraph(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, 
                                            batch_norm, activation))
            
            self.score_layers.append(Rewirescorelayer(self.dims[i + 1], self.dims[i + 1], self.num_heads, self.window_size, 
                                            self.k, temperature=0.5, dropout=0.1))
            
            self.layers.append(RewireGearnet(self.dims[i], self.dims[i + 1], num_relation,
                                            edge_input_dim=None, batch_norm=False, activation="relu"))
        
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.dims) - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, edge_list=None, all_loss=None, metric=None):
        """
        Compute the node representations and the graph representation(s).

        Parameters:
            graph (Graph): :math:`n` graph(s)
            input (Tensor): input node representations
            all_loss (Tensor, optional): if specified, add loss to this tensor
            metric (dict, optional): if specified, output metrics to this dict

        Returns:
            dict with ``node_feature`` and ``graph_feature`` fields:
                node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)`
        """
        hiddens = []
        layer_input = input
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            
            
            relational_output = self.score_layers[2*i](graph, layer_input, edge_list)
            new_edge_list = self.score_layers[2*i+1](graph, relational_output)
            
            hidden = self.layers[i](graph, layer_input, new_edge_list)
            
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
                
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                if new_edge_list is None:
                    node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                else:
                    node_out = new_edge_list[:, 1] * self.num_relation + new_edge_list[:, 2]
                
                    update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                        dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
                
            hiddens.append(hidden)
            layer_input = hidden
            edge_list = new_edge_list

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)
        print("node_feature: ", node_feature.shape)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

### 测试

In [19]:
input_dim = graph.node_feature.shape[-1]
hidden_dims = [64, 128, 128, 128, 128]
num_relations = graph.num_relation
num_heads = 8
window_size = 50
k = 5


output = DGMGearnet(input_dim, hidden_dims, num_relations, num_heads, window_size, k).to(device)(graph.to(device), graph.node_feature.to(device).float())

node_feature:  torch.Size([600, 128])


In [29]:
print(output["node_feature"])
print(output["node_feature"].shape)
print("\n")

print(output["graph_feature"])
print(output["graph_feature"].shape)
torch.cuda.empty_cache()

tensor([[17.8948,  0.0000, 17.3570,  ...,  0.0000, 26.0307,  0.0000],
        [ 2.3947,  7.2990,  0.0000,  ...,  0.0000,  0.0000,  2.3281],
        [20.3003, 14.6250,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000, 23.7315,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.5055,  ...,  5.8059,  2.6199,  0.0000],
        [13.0536,  0.9474,  8.7509,  ...,  0.0000,  5.5958,  5.5812]],
       device='cuda:0', grad_fn=<ReluBackward0>)
torch.Size([600, 512])


tensor([[2161.1880, 1087.6064, 1541.9747,  ..., 2080.9690, 1970.3794,
         1159.7115],
        [5667.3657, 2416.1787, 4092.4651,  ..., 4103.5674, 4787.2236,
         1832.0172]], device='cuda:0', grad_fn=<ScatterAddBackward>)
torch.Size([2, 512])


# 测试

In [2]:
torch.cuda.empty_cache()

NameError: name 'torch' is not defined

In [34]:
import torch
from local_attention import LocalAttention

node = 512
dimm = 64
b = 5

q = torch.randn(b, 8, node, dimm)
k = torch.randn(b, 8, node, dimm)
v = torch.randn(b, 8, node, dimm)

attn = LocalAttention(
    dim = 64,                # dimension of each head (you need to pass this in for relative positional encoding)
    window_size = 16,       # window size. 512 is optimal, but 256 or 128 yields good enough results
    causal = True,           # auto-regressive or not
    look_backward = 1,       # each window looks at the window before
    look_forward = 0,        # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
    dropout = 0.1,           # post-attention dropout
    exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
)

mask = torch.ones(b, node).bool()
out = attn(q, k, v, mask = mask) # (2, 8, 2048, 64)

In [13]:
print(out.shape)

torch.Size([2, 8, 512, 64])
