In [1]:
from torchdrug import  layers, datasets,transforms,core
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据集获取

In [2]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )


graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)

11:30:29   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:45<00:00, 413.11it/s]





PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3754, 8999], num_residues=[185, 415])


In [3]:
#graph = graph.to(device)

print(graph.num_nodes)
print(graph.batch_size)
print(graph.edge_list)

tensor([185, 415])
2
tensor([[ 95,  96,   5],
        [109, 110,   5],
        [108, 109,   5],
        ...,
        [438, 470,   0],
        [489, 470,   0],
        [493, 470,   0]])


# 关系卷积神经网络，获取多个不同的嵌入

In [4]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
    
        Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
        A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
        
        n_rel = graph.num_relation
        n = A_norm.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A_norm[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A_norm[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            edge_weight = torch.ones_like(node_out)
            degree_out = scatter_add(edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            degree_out = degree_out
            edge_weight = edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t(), input)
        
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

#### 测试

In [59]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation

relational_output = relationalGraph(input_dim, output_dim, num_relations)(graph, graph.node_feature.float(), new_edge_list = None)
print(relational_output)
print("output: ", relational_output.shape)

tensor([[0.0000, 0.1461, 0.0647,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2744, 0.0895,  ..., 0.1369, 0.0000, 0.1432],
        [0.0000, 0.1414, 0.3575,  ..., 0.0457, 0.0996, 0.0000],
        ...,
        [0.0000, 0.0000, 0.2763,  ..., 0.0000, 0.0000, 0.0613],
        [0.0000, 0.1414, 0.3575,  ..., 0.0457, 0.0996, 0.0000],
        [0.0000, 0.2805, 0.0271,  ..., 0.0976, 0.0494, 0.0000]],
       grad_fn=<ReluBackward0>)
output:  torch.Size([4200, 512])


# 重连接模块

### 只从点所在的图进行reconnection

In [6]:
import torch

def get_start_end(current, graph):
    """
    根据一维张量 a 生成新的张量 b。
    
    :param a: 输入的一维张量
    :return: 输出的一维张量 b
    """
    # 初始化 b，第一个元素是 0
    segment = graph.num_nodes.repeat(graph.num_relation)
    index = torch.zeros(segment.size(0) + 1, dtype=segment.dtype)
    
    # 计算 b 的每个元素
    for i in range(1, len(index)):
        index[i] = index[i - 1] + segment[i - 1]
    
    # 遍历张量以找到索引值的位置
    for i in range(len(index) - 1):
        if index[i] <= current < index[i + 1]:
            return (index[i].item(), index[i + 1].item())
        elif index[i] == current:
            return (index[i].item(), index[i + 1].item())
    
    # 如果索引值恰好等于张量的最后一个元素
    if current == index[-1]:
        return (index[-1].item(), index[-1].item())
    
    

# 示例使用
start, end = get_start_end(184, graph)
print(start, end)


0 185


In [11]:
def get_start_end2( current, graph):
    segment = graph.num_nodes.repeat(graph.num_relation)
    index = torch.cumsum(segment, dim=0)
    
    # Use torch.searchsorted to find the appropriate segment
    pos = torch.searchsorted(index, current, right=True)

    if pos == 0:
        return (0, index[0].item())
    elif pos >= len(index):
        return (index[-1].item(), index[-1].item())
    else:
        return (index[pos-1].item(), index[pos].item())
    
start, end = get_start_end2(600, graph)
print(start, end)


600 785


### Gumble-softmax采样

In [95]:
import torch
import torch.nn.functional as F

def gumbel_softmax_top_k(logits, tau=0.5, k=1, hard=False):
    """
    Gumbel-Softmax采样方法，每一步都是可微分的，并选择最大的k个元素。
    
    参数:
        logits (torch.Tensor): 输入logits张量，维度为 (batch_size, num_classes)。
        tau (float): Gumbel-Softmax的温度参数，控制平滑程度。
        k (int): 选择最大的k个元素。
        hard (bool): 是否返回硬分类结果。
    
    返回:
        torch.Tensor: Gumbel-Softmax采样结果，维度为 (batch_size, num_classes)。
    """
    # 获取Gumbel分布噪声
    gumbels = -torch.empty_like(logits).exponential_().log()  # 生成Gumbel(0,1)噪声
    gumbels = (logits + gumbels) / tau  # 添加噪声并除以温度参数

    # 计算softmax
    y_soft = F.softmax(gumbels, dim=-1)  # 维度为 (batch_size, num_classes)

    if hard:
        # 硬分类结果：选取原始logits最大的k个位置
        topk_indices = logits.topk(k, dim=-1)[1]  # 获取前k个元素的索引
        y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)  # 生成one-hot向量
        # 使用直通估计器
        y = (y_hard - y_soft).detach() + y_soft
    else:
        y = y_soft

    return y

# 示例用法
if __name__ == "__main__":
    logits = torch.randn(5, 10)  # 维度为 (batch_size=5, num_classes=10)
    tau = 0.5
    k = 3
    hard = True
    samples = gumbel_softmax_top_k(logits, tau, k, hard)
    print(samples)
    samples = gumbel_softmax_top_k(logits, tau, k)
    print(samples)


tensor([[0., 0., 0., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 1., 1.],
        [1., 0., 0., 1., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 1., 1.]])
tensor([[4.8048e-01, 3.7910e-04, 2.6873e-04, 4.0213e-03, 1.1619e-01, 5.7953e-04,
         3.6010e-01, 6.4904e-04, 1.2300e-02, 2.5035e-02],
        [5.5033e-04, 6.5851e-05, 1.9387e-05, 5.5963e-05, 9.5151e-02, 1.5899e-01,
         5.4683e-05, 1.9378e-01, 5.4991e-01, 1.4249e-03],
        [9.8238e-01, 2.3563e-05, 5.6371e-04, 5.5736e-03, 8.7534e-05, 5.6787e-05,
         8.2759e-03, 2.6395e-03, 7.3919e-05, 3.2825e-04],
        [6.5204e-05, 4.4989e-04, 9.1505e-03, 3.3549e-02, 8.4767e-03, 9.2767e-05,
         4.9709e-03, 1.1033e-03, 9.4183e-01, 3.1603e-04],
        [2.2411e-04, 3.5781e-06, 1.3032e-03, 9.5245e-07, 5.9685e-05, 1.1207e-04,
         5.7596e-05, 5.7783e-03, 9.1864e-01, 7.3822e-02]])


### window self attention + gumble softmax

In [78]:

class Rewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, num_nodes, temperature=0.5, dropout=0.1):
        super(Rewirescorelayer, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        self.num_nodes = num_nodes

        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.value = nn.Linear(in_features, out_features * num_heads)
        self.scale = torch.sqrt(torch.FloatTensor([out_features])).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        
         # calculate the start and end indices for each node
        self.half_window = self.window_size // 2
        start_end_indices = [self.get_start_end(i, graph) for i in range(num_nodes)]
        self.start_indices = [max(start_end_indices[i][0], i - self.half_window) for i in range(num_nodes)]
        self.end_indices = [min(start_end_indices[i][1], i + self.half_window) for i in range(num_nodes)]

    # get the start and end indices for each window of nodes
    def get_start_end(self, current, graph):
        segment = graph.num_nodes.repeat(graph.num_relation)
        index = torch.cumsum(segment, dim=0)
        
        # Use torch.searchsorted to find the appropriate segment
        pos = torch.searchsorted(index, current, right=True)

        if pos == 0:
            return (0, index[0].item())
        elif pos >= len(index):
            return (index[-1].item(), index[-1].item())
        else:
            return (index[pos-1].item(), index[pos].item())
    
    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y

    def forward(self, graph, node_features):
        device = node_features.device
        num_nodes = self.num_nodes

        Q = self.query(node_features).view(num_nodes, self.num_heads, self.out_features)
        K = self.key(node_features).view(num_nodes, self.num_heads, self.out_features)
        
        output = torch.zeros(num_nodes, num_nodes, device=device)
        all_scores = torch.zeros(num_nodes, self.num_heads, self.window_size, device=device)

        # calculate scores
        for i in range(num_nodes):
            start = self.start_indices[i]
            end = self.end_indices[i]
            K_window = K[start:end]  # [window_size, num_heads, out_features]
            
            Q_i = Q[i].unsqueeze(0)  # [1, num_heads, out_features]
            all_scores[i, :, :end-start] = torch.einsum("nhd,mhd->nhm", Q_i, K_window) / self.scale  # [1, num_heads, window_size]

        # calculate attention weights
        all_scores = all_scores / self.temperature
        attention_weights = F.softmax(all_scores, dim=-1)  # [num_nodes, num_heads, max_window_size]
        attention_weights = attention_weights.mean(dim=1)  # [num_nodes, max_window_size]

        # sample edges
        for i in range(num_nodes):
            start = self.start_indices[i]
            end = self.end_indices[i]
            output[i, start:end] = attention_weights[i, :end-start]

        edge_list = self.gumbel_softmax_top_k(output, self.temperature, self.k)

        return edge_list, output 

#### 测试

In [79]:
input_dim = relational_output.shape[-1]
output_dim = 1024
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 5
num_nodes = relational_output.size(0)


relational_output = relational_output.to(device)
module = Rewirescorelayer(input_dim, output_dim, num_heads, window_size, k, num_nodes).to(device)

In [81]:
attn_output, output = module(graph, relational_output)
print(attn_output.shape)
print(attn_output)
#a = attn_output[1200:, 599:1199]
indices = torch.nonzero(attn_output, as_tuple=True)
print(indices)

torch.Size([4200, 4200])
tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.]], device='cuda:0',
       grad_fn=<AddBackward0>)
(tensor([   0,    0,    0,  ..., 4199, 4199, 4199], device='cuda:0'), tensor([   0,    1,    2,  ..., 4197, 4198, 4199], device='cuda:0'))


### 测试不同degree进行采样

In [40]:
import torch
import torch.nn.functional as F

def gumbel_softmax_sample(logits, tau):
    # 从Gumbel(0, 1)分布中采样
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    y = logits + gumbel_noise
    return F.softmax(y / tau, dim=-1)

def bernoulli_sampling_with_different_thresholds(probs, thresholds, tau=1.0):
    # 对数概率
    logits = torch.log(probs) - torch.log(1 - probs)
    # 进行Gumbel-Softmax采样
    y = gumbel_softmax_sample(logits, tau)
    # 硬化处理，根据每行的不同阈值
    z = (y > thresholds.unsqueeze(1)).float()
    return z

# 示例矩阵
n = 10
P = torch.rand(n, n)

# 为每一行设置不同的阈值
thresholds = torch.tensor([0.2, 0.5, 0.7, 0.9, 0.1, 0.3, 0.6, 0.8, 0.4, 0.2])

# 进行可微分伯努利采样并硬化处理
tau = 0.1  # 温度参数
sampled_matrix = bernoulli_sampling_with_different_thresholds(P, thresholds, tau)

#print("Probability Matrix:\n", P)
print("Thresholds:\n", thresholds)
print("Sampled Matrix:\n", sampled_matrix)


Thresholds:
 tensor([0.2000, 0.5000, 0.7000, 0.9000, 0.1000, 0.3000, 0.6000, 0.8000, 0.4000,
        0.2000])
Sampled Matrix:
 tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


# Diffusion模块

### 计算degree矩阵，变换adjacent matrix形式

In [9]:
def trans(A, graph):
    
    Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
    A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
    
    n_rel = graph.num_relation
    n = A_norm.size(0)
    n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
    assert n % n_rel == 0, "n must be divisible by n_rel"
    
    block_size = n // n_rel
    
    # 初始化一个张量来存储累加结果
    accumulated = torch.zeros_like(A_norm[:block_size])
    
    # 将后面的所有块累加到第一块
    for i in range(n_rel):
        accumulated += A_norm[i * block_size: (i + 1) * block_size]
    
    # 用累加后的第一块替换原始矩阵的第一块
    A_trans = accumulated
    
    
    
    return A_trans

A_norm = trans(attn_output, graph)
print(A_norm)
print(A_norm.shape)

indices = torch.nonzero(A_norm, as_tuple=True)
print(indices)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
update = torch.mm(A_norm.t(), graph.node_feature.to(device).to(torch.float))
print(update.shape)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([600, 4200])
(tensor([  0,   0,   0,  ..., 599, 599, 599], device='cuda:0'), tensor([   3,    6,    8,  ..., 4193, 4195, 4196], device='cuda:0'))
torch.Size([4200, 21])


### Rewired_gearnet 用于diffusion模块

In [11]:
class RewireGearnet(nn.Module):
    gradient_checkpoint = False

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(RewireGearnet, self).__init__()
        self.num_relation = num_relation
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        self.self_loop = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim) if batch_norm else None
        self.activation = getattr(F, activation) if activation else None
        self.edge_linear = nn.Linear(edge_input_dim, output_dim) if edge_input_dim else None

    def trans(self, A, graph):
        n_rel = graph.num_relation
        n = A.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list=None):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency

        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
            adjacency = torch.sparse_coo_tensor(
                torch.stack([node_in, node_out]),
                graph.edge_weight.to(device),
                (graph.num_node, graph.num_node * graph.num_relation),
                device=device
            )
            update = torch.sparse.mm(adjacency.t(), input)
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(graph.num_node, self.num_relation * self.input_dim).to(device)

    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None, new_edge_weight=None):
        """
        Perform message passing over the graph(s).

        Parameters:
            graph (Graph): graph(s)
            input (Tensor): node representations of shape :math:`(|V|, ...)`
        """
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

#### 测试

In [12]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


new_node_feature = RewireGearnet(input_dim, output_dim, num_relations)(graph, graph.node_feature.to(device).float(), attn_output).to(device)

print(new_node_feature.shape)
print(new_node_feature)

torch.Size([600, 512])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5687, 0.0408],
        [0.0000, 0.0668, 0.0000,  ..., 0.5021, 0.0878, 0.0000],
        [0.0000, 0.0000, 0.3658,  ..., 0.5964, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.1374,  ..., 0.9314, 0.4847, 0.2744],
        [0.2658, 0.0000, 0.4445,  ..., 0.0664, 0.4838, 0.2672],
        [0.0000, 0.0000, 0.1118,  ..., 0.6846, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<ReluBackward0>)


# 最终模型

In [13]:
class DGMGearnet(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, num_relation, num_heads, window_size, k, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
        super(DGMGearnet, self).__init__()

        #if not isinstance(hidden_dims, Sequence):
            #hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm

        self.layers = nn.ModuleList()
        self.score_layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            
            self.score_layers.append(relationalGraph(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, 
                                            batch_norm, activation))
            
            self.score_layers.append(Rewirescorelayer(self.dims[i + 1], self.dims[i + 1], self.num_heads, self.window_size, 
                                            self.k, temperature=0.5, dropout=0.1))
            
            self.layers.append(RewireGearnet(self.dims[i], self.dims[i + 1], num_relation,
                                            edge_input_dim=None, batch_norm=False, activation="relu"))
        
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.dims) - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, edge_list=None, all_loss=None, metric=None):
        """
        Compute the node representations and the graph representation(s).

        Parameters:
            graph (Graph): :math:`n` graph(s)
            input (Tensor): input node representations
            all_loss (Tensor, optional): if specified, add loss to this tensor
            metric (dict, optional): if specified, output metrics to this dict

        Returns:
            dict with ``node_feature`` and ``graph_feature`` fields:
                node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)`
        """
        hiddens = []
        layer_input = input
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            
            
            relational_output = self.score_layers[2*i](graph, layer_input, edge_list)
            new_edge_list = self.score_layers[2*i+1](graph, relational_output)
            
            hidden = self.layers[i](graph, layer_input, new_edge_list)
            
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
                
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                if new_edge_list is None:
                    node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                else:
                    node_out = new_edge_list[:, 1] * self.num_relation + new_edge_list[:, 2]
                
                    update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                        dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
                
            hiddens.append(hidden)
            layer_input = hidden
            edge_list = new_edge_list

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)
        print("node_feature: ", node_feature.shape)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

### 测试

In [16]:
input_dim = graph.node_feature.shape[-1]
hidden_dims = [512, 512, 512, 512, 512]
num_relations = graph.num_relation
num_heads = 8
window_size = 50
k = 5


output = DGMGearnet(input_dim, hidden_dims, num_relations, num_heads, window_size, k).to(device)(graph.to(device), graph.node_feature.to(device).float())

node_feature:  torch.Size([600, 512])


In [17]:
print(output["node_feature"])
print(output["node_feature"].shape)
print("\n")

print(output["graph_feature"])
print(output["graph_feature"].shape)
torch.cuda.empty_cache()

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  8.6070, 16.7746,  0.0000],
        [ 0.0000,  4.5035,  0.0000,  ..., 19.8174, 32.3328,  0.0000],
        [ 0.0000, 13.5381,  0.0000,  ...,  0.0000,  0.0000, 62.8631],
        ...,
        [ 0.0000,  5.6229,  0.0000,  ..., 29.6259, 22.7300,  0.0000],
        [ 0.0000,  0.0000,  8.9290,  ...,  0.0000, 13.3476,  0.1188],
        [ 0.0000,  0.0000,  0.0000,  ..., 27.0346,  6.5614,  1.8956]],
       device='cuda:0', grad_fn=<ReluBackward0>)
torch.Size([600, 512])


tensor([[ 233.5066, 1417.5868,  931.7796,  ..., 3142.3210, 3161.3845,
         2296.9148],
        [ 648.2252, 3275.2917, 1469.4263,  ..., 7588.7271, 7191.9448,
         4825.6797]], device='cuda:0', grad_fn=<ScatterAddBackward>)
torch.Size([2, 512])


# 测试

In [43]:
torch.cuda.empty_cache()

In [56]:
class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """ Forward function.
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        return x

In [54]:
a = torch.randn(7, 1024, 1024)
model = WindowAttention(4096, (32, 32), 8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.).to(device)
output = model(a.to(device))

RuntimeError: mat1 dim 1 must match mat2 dim 0