In [1]:
from torchdrug import  layers, datasets,transforms,core
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F

import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0


# 数据集获取

In [2]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=1)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )


graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)

01:54:34   Extracting /home/cu/scratch/protein-datasets/EnzymeCommission.zip to /home/cu/scratch/protein-datasets


Loading /home/cu/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:22<00:00, 843.28it/s] 





PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3388, 8173], num_residues=[185, 415])


In [3]:
#graph = graph.to(device)

print(graph.num_nodes)
print(graph.batch_size)
print(graph.edge_list)

tensor([185, 415])
2
tensor([[155, 156,   4],
        [168, 169,   4],
        [167, 168,   4],
        ...,
        [426, 435,   0],
        [440, 435,   0],
        [487, 435,   0]])


# 关系卷积神经网络，获取多个不同的嵌入

In [4]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
    
        Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
        A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
        
        n_rel = graph.num_relation
        n = A_norm.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A_norm[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A_norm[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            edge_weight = torch.ones_like(node_out)
            degree_out = scatter_add(edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            degree_out = degree_out
            edge_weight = edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t(), input)
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input, new_edge_list)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update).view(graph.num_relation, input.size(0), -1)
        return output

#### 测试

In [5]:
input_dim = graph.node_feature.shape[-1]
output_dim = 128
num_relations = graph.num_relation

relational_output = relationalGraph(input_dim, output_dim, num_relations)(graph, graph.node_feature.float(), new_edge_list = None)
print(relational_output)
print("output: ", relational_output.shape)

tensor([[[0.0000, 0.1173, 0.1634,  ..., 0.0293, 0.0000, 0.0000],
         [0.0000, 0.0007, 0.0000,  ..., 0.2899, 0.0000, 0.0000],
         [0.0000, 0.2027, 0.1033,  ..., 0.1484, 0.0000, 0.0000],
         ...,
         [0.0089, 0.1264, 0.0114,  ..., 0.0129, 0.0000, 0.0000],
         [0.1178, 0.1718, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.1201, 0.1198,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0114, 0.1481, 0.0909,  ..., 0.0315, 0.0000, 0.0000],
         [0.0272, 0.0318, 0.0000,  ..., 0.3047, 0.0000, 0.0000],
         [0.1022, 0.2793, 0.3373,  ..., 0.2042, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0690, 0.0000,  ..., 0.0980, 0.0000, 0.1223],
         [0.1178, 0.1718, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0097, 0.0788, 0.0947,  ..., 0.0000, 0.0072, 0.0000]],

        [[0.0168, 0.1657, 0.0866,  ..., 0.0085, 0.0000, 0.0000],
         [0.0000, 0.0385, 0.0000,  ..., 0.2266, 0.0000, 0.0000],
         [0.0483, 0.0305, 0.0754,  ..., 0.2700, 0.0000, 0.

# 重连接模块

### 只从点所在的图进行reconnection

### Gumble-softmax采样

In [6]:
import torch
import torch.nn.functional as F

def gumbel_softmax_top_k(logits, tau=0.5, k=1, hard=False):
    """
    Gumbel-Softmax采样方法，每一步都是可微分的，并选择最大的k个元素。
    
    参数:
        logits (torch.Tensor): 输入logits张量，维度为 (batch_size, num_classes)。
        tau (float): Gumbel-Softmax的温度参数，控制平滑程度。
        k (int): 选择最大的k个元素。
        hard (bool): 是否返回硬分类结果。
    
    返回:
        torch.Tensor: Gumbel-Softmax采样结果，维度为 (batch_size, num_classes)。
    """
    # 获取Gumbel分布噪声
    gumbels = -torch.empty_like(logits).exponential_().log()  # 生成Gumbel(0,1)噪声
    gumbels = (logits + gumbels) / tau  # 添加噪声并除以温度参数

    # 计算softmax
    y_soft = F.softmax(gumbels, dim=-1)  # 维度为 (batch_size, num_classes)

    if hard:
        # 硬分类结果：选取原始logits最大的k个位置
        topk_indices = logits.topk(k, dim=-1)[1]  # 获取前k个元素的索引
        y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)  # 生成one-hot向量
        # 使用直通估计器
        y = (y_hard - y_soft).detach() + y_soft
    else:
        y = y_soft

    return y

# 示例用法
if __name__ == "__main__":
    logits = torch.randn(5, 10)  # 维度为 (batch_size=5, num_classes=10)
    tau = 0.5
    k = 3
    hard = True
    samples = gumbel_softmax_top_k(logits, tau, k, hard)
    print(samples)
    samples = gumbel_softmax_top_k(logits, tau, k)
    print(samples)


tensor([[0., 0., 0., 1., 0., 0., 0., 1., 1., 0.],
        [1., 0., 0., 1., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 1., 0., 0.]])
tensor([[9.2066e-02, 4.4660e-06, 2.0334e-02, 1.2134e-02, 1.6932e-02, 4.3917e-03,
         2.0580e-03, 5.2264e-02, 7.8937e-01, 1.0444e-02],
        [7.2328e-01, 2.0644e-04, 2.0815e-02, 1.9107e-03, 9.1716e-05, 1.0527e-01,
         6.1835e-03, 1.1957e-02, 1.2846e-04, 1.3016e-01],
        [3.2986e-08, 8.7121e-08, 3.5935e-07, 1.0000e+00, 3.5876e-08, 8.5148e-09,
         2.7917e-07, 7.8181e-09, 7.1131e-08, 1.5726e-08],
        [2.4566e-03, 1.8842e-02, 1.0022e-03, 1.2845e-03, 9.4768e-01, 1.3542e-04,
         2.1959e-04, 1.5191e-02, 9.8077e-04, 1.2210e-02],
        [1.3476e-02, 1.7886e-04, 7.4551e-04, 2.7668e-03, 6.0881e-01, 3.5665e-02,
         3.0852e-01, 5.2925e-03, 2.3564e-02, 9.8649e-04]])


### window self attention + gumble softmax

In [7]:

class Rewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5, dropout=0.1):
        super(Rewirescorelayer, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        
        self.query = nn.Linear( in_features, out_features* num_heads)
        self.key = nn.Linear( in_features, out_features* num_heads)
        self.scale = 1 / (out_features ** 0.5)

    def split_windows(self, tensor, index, window_size):
        result = []
        index_list  = []
        # 初始的开始索引
        start = 0

        for idx in index:
            end = start + idx - 1
            while start < end:
                if start + window_size <= end:
                    result.append(tensor[:, start:start+window_size, :])
                    index_list.append([start, start+window_size])
                    start += window_size
                else:
                    # 计算还需要填充多少行
                    padding_rows = window_size - (end - start + 1)
                    restart = start - padding_rows
                    result.append(tensor[:, restart:restart + window_size , :])
                    index_list.append([restart, restart+window_size])
                    start = end + 1
                    
        # 转换结果列表为 tensor
        result_tensor = torch.stack(result, dim=1)
        return result_tensor, index_list

    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y
    
    def windows2adjacent(self, windows, index_list, output):
    
        for i, index in enumerate(index_list):
            start, end = index
            output[:, start:end, start:end] = windows[:, i, :, :]
            
        num_relations, num_nodes, _ = output.shape
        result = torch.zeros(num_relations* num_nodes, num_relations*num_nodes)
        for i in range(num_relations):
            result[i*num_nodes:(i+1)*num_nodes, i*num_nodes:(i+1)*num_nodes] = output[i]
        return result

    def forward(self, graph, node_features):
        device = node_features.device
        num_relations = node_features.size(0)
        num_nodes = node_features.size(1)
        index = graph.num_nodes.tolist()
        
        Q = self.query(node_features).view(num_relations, num_nodes, self.num_heads, self.out_features).permute(0, 2, 1, 3)
        K = self.key(node_features).view(num_relations, num_nodes, self.num_heads, self.out_features).permute(0, 2, 1, 3)   
        Q = Q.reshape(num_relations * self.num_heads, num_nodes, self.out_features)
        K = K.reshape(num_relations * self.num_heads, num_nodes, self.out_features)
        
        output = torch.zeros(num_relations, num_nodes, num_nodes)
       
        Q_windows, Q_index = self.split_windows(Q, index, self.window_size)
        K_windows, _ = self.split_windows(K, index, self.window_size)
        
        # 计算 attention
        scores  = torch.einsum('b h i e, b h j e -> b h i j', Q_windows, K_windows) / self.scale
        attn = scores.softmax(dim=-1).view(num_relations, self.num_heads, -1, self.window_size, self.window_size).mean(dim=1)
        attn = self.gumbel_softmax_top_k(attn, tau=self.temperature, hard=True)
        
        result = self.windows2adjacent(attn, Q_index, output)
        
        return result
        

#### 测试

In [8]:
input_dim = relational_output.shape[-1]
output_dim = 256
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 3
num_nodes = relational_output.size(0)

start = time.time()
relational_output = relational_output.to(device)
module = Rewirescorelayer(input_dim, output_dim, num_heads, window_size, k).to(device)
attn_output = module(graph, relational_output)
end = time.time()
print(f"运行时间: {end - start:.6f} 秒")
print(attn_output.shape)

print(attn_output)
#a = attn_output[1200:, 599:1199]
indices = torch.nonzero(attn_output, as_tuple=True)
print(indices)


运行时间: 0.042305 秒
torch.Size([3000, 3000])
tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<CopySlices>)
(tensor([   0,    0,    0,  ..., 2999, 2999, 2999]), tensor([   1,    7,    8,  ..., 2990, 2995, 2996]))


### 测试不同degree进行采样

In [10]:
import torch
import torch.nn.functional as F

def gumbel_softmax_sample(logits, tau):
    # 从Gumbel(0, 1)分布中采样
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    y = logits + gumbel_noise
    return F.softmax(y / tau, dim=-1)

def bernoulli_sampling_with_different_thresholds(probs, thresholds, tau=1.0):
    # 对数概率
    logits = torch.log(probs) - torch.log(1 - probs)
    # 进行Gumbel-Softmax采样
    y = gumbel_softmax_sample(logits, tau)
    # 硬化处理，根据每行的不同阈值
    z = (y > thresholds.unsqueeze(1)).float()
    return z

# 示例矩阵
n = 10
P = torch.rand(n, n)

# 为每一行设置不同的阈值
thresholds = torch.tensor([0.2, 0.5, 0.7, 0.9, 0.1, 0.3, 0.6, 0.8, 0.4, 0.2])

# 进行可微分伯努利采样并硬化处理
tau = 0.1  # 温度参数
sampled_matrix = bernoulli_sampling_with_different_thresholds(P, thresholds, tau)

#print("Probability Matrix:\n", P)
print("Thresholds:\n", thresholds)
print("Sampled Matrix:\n", sampled_matrix)


Thresholds:
 tensor([0.2000, 0.5000, 0.7000, 0.9000, 0.1000, 0.3000, 0.6000, 0.8000, 0.4000,
        0.2000])
Sampled Matrix:
 tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])


# Diffusion模块

### 计算degree矩阵，变换adjacent matrix形式

In [11]:
def trans(A, graph):
    
    Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
    A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
    
    n_rel = graph.num_relation
    n = A_norm.size(0)
    n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
    assert n % n_rel == 0, "n must be divisible by n_rel"
    
    block_size = n // n_rel
    
    # 初始化一个张量来存储累加结果
    accumulated = torch.zeros_like(A_norm[:block_size])
    
    # 将后面的所有块累加到第一块
    for i in range(n_rel):
        accumulated += A_norm[i * block_size: (i + 1) * block_size]
    
    # 用累加后的第一块替换原始矩阵的第一块
    A_trans = accumulated
    
    
    
    return A_trans

A_norm = trans(attn_output, graph)
print(A_norm)
print(A_norm.shape)

indices = torch.nonzero(A_norm, as_tuple=True)
print(indices)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
update = torch.mm(A_norm.t(), graph.node_feature.to(device).to(torch.float))
print(update.shape)

tensor([[0.0000, 0.3333, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3333, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<AddBackward0>)
torch.Size([600, 3000])
(tensor([  0,   0,   0,  ..., 599, 599, 599]), tensor([   1,    7,    8,  ..., 2990, 2995, 2996]))
torch.Size([3000, 21])


### Rewired_gearnet 用于diffusion模块

In [12]:
class RewireGearnet(nn.Module):
    gradient_checkpoint = False

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(RewireGearnet, self).__init__()
        self.num_relation = num_relation
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        self.self_loop = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim) if batch_norm else None
        self.activation = getattr(F, activation) if activation else None
        self.edge_linear = nn.Linear(edge_input_dim, output_dim) if edge_input_dim else None

    def trans(self, A, graph):
        n_rel = graph.num_relation
        n = A.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list=None):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency

        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
            adjacency = torch.sparse_coo_tensor(
                torch.stack([node_in, node_out]),
                graph.edge_weight.to(device),
                (graph.num_node, graph.num_node * graph.num_relation),
                device=device
            )
            update = torch.sparse.mm(adjacency.t(), input)
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(graph.num_node, self.num_relation * self.input_dim).to(device)

    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None, new_edge_weight=None):
        """
        Perform message passing over the graph(s).

        Parameters:
            graph (Graph): graph(s)
            input (Tensor): node representations of shape :math:`(|V|, ...)`
        """
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

#### 测试

In [13]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


new_node_feature = RewireGearnet(input_dim, output_dim, num_relations)(graph, graph.node_feature.to(device).float(), attn_output).to(device)

print(new_node_feature.shape)
print(new_node_feature)

torch.Size([600, 512])
tensor([[0.1449, 0.0628, 0.0000,  ..., 0.3983, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.1705, 0.0874, 0.6224],
        [0.0000, 0.0000, 0.2159,  ..., 0.0000, 0.0000, 0.2213],
        ...,
        [0.1342, 0.0000, 0.3268,  ..., 0.0116, 0.0000, 0.3118],
        [0.0000, 0.2666, 0.0000,  ..., 0.4255, 0.1149, 0.3375],
        [0.2765, 0.0000, 0.0000,  ..., 0.0175, 0.4929, 0.5345]],
       grad_fn=<ReluBackward0>)


# 最终模型

In [20]:
class DGMGearnet(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, score_dim, num_relation, num_heads, window_size, k, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
        super(DGMGearnet, self).__init__()

        #if not isinstance(hidden_dims, Sequence):
            #hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.score_dim = score_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm

        self.layers = nn.ModuleList()
        self.score_layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            
            self.score_layers.append(relationalGraph(self.dims[i], self.score_dim, num_relation, None, 
                                            batch_norm, activation))
            
            self.score_layers.append(Rewirescorelayer(self.score_dim, self.dims[i+1], self.num_heads, self.window_size, 
                                            self.k, temperature=0.5, dropout=0.1))
            
            self.layers.append(RewireGearnet(self.dims[i], self.dims[i + 1], num_relation,
                                            edge_input_dim=None, batch_norm=False, activation="relu"))
        
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.dims) - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, edge_list=None, all_loss=None, metric=None):
        """
        Compute the node representations and the graph representation(s).

        Parameters:
            graph (Graph): :math:`n` graph(s)
            input (Tensor): input node representations
            all_loss (Tensor, optional): if specified, add loss to this tensor
            metric (dict, optional): if specified, output metrics to this dict

        Returns:
            dict with ``node_feature`` and ``graph_feature`` fields:
                node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)`
        """
        hiddens = []
        layer_input = input
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            
            
            relational_output = self.score_layers[2*i](graph, layer_input, edge_list)
            new_edge_list = self.score_layers[2*i+1](graph, relational_output)
            
            hidden = self.layers[i](graph, layer_input, new_edge_list)
            
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                if new_edge_list is None:
                    node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                else:
                    node_out = new_edge_list[:, 1] * self.num_relation + new_edge_list[:, 2]
                
                    update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                        dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
                
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
                
            hiddens.append(hidden)
            layer_input = hidden
            edge_list = new_edge_list

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)
        print("node_feature: ", node_feature.shape)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

### 测试

In [27]:
input_dim = graph.node_feature.shape[-1]
hidden_dims = [512, 512, 512, 512, 512]
score_dim = 128
num_relations = graph.num_relation
num_heads = 8
window_size = 10
k = 5


output = DGMGearnet(input_dim, hidden_dims, score_dim, num_relations, num_heads, window_size, k).to(device)(graph.to(device), graph.node_feature.to(device).float())

node_feature:  torch.Size([600, 512])


In [28]:
print(output["node_feature"])
print(output["node_feature"].shape)
print("\n")

print(output["graph_feature"])
print(output["graph_feature"].shape)
torch.cuda.empty_cache()

tensor([[31.8002,  0.0000, 18.8589,  ...,  0.0000,  8.2624,  0.0000],
        [20.0393,  0.0000, 17.0460,  ...,  0.0000,  6.8150,  0.0000],
        [ 6.4855,  0.0000,  4.4976,  ...,  0.0000,  7.6434,  0.0000],
        ...,
        [13.9503,  0.0000, 12.1124,  ...,  0.0000,  7.2026,  0.0000],
        [44.9440,  0.0000, 27.2555,  ...,  0.0000, 20.4179,  0.0000],
        [26.0715,  0.0000, 11.8351,  ...,  0.0000, 10.0670,  0.0000]],
       grad_fn=<ReluBackward0>)
torch.Size([600, 512])


tensor([[4341.7954,   26.5939, 2779.7217,  ...,   23.5720, 1896.7634,
          137.0831],
        [9712.8516,   84.3868, 6230.3101,  ...,   53.5257, 4284.4419,
          260.6467]], grad_fn=<ScatterAddBackward>)
torch.Size([2, 512])


# 测试

In [2]:
torch.cuda.empty_cache()

NameError: name 'torch' is not defined

In [66]:
import torch
from local_attention import LocalAttention

node = 1200
dimm = 64
b = 10

q = torch.randn(b, 8, node, dimm)
k = torch.randn(b, 8, node, dimm)
v = torch.randn(b, 8, node, dimm)

attn = LocalAttention(
    dim =64,                # dimension of each head (you need to pass this in for relative positional encoding)
    window_size = 10,       # window size. 512 is optimal, but 256 or 128 yields good enough results
    causal = None,           # auto-regressive or not
    look_backward = 1,       # each window looks at the window before
    look_forward = 1,        # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
    dropout = 0.1,           # post-attention dropout
    exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
)

mask = torch.ones(b, node).bool()
out = attn(q, k, v, mask = mask) # (2, 8, 2048, 64)

In [64]:
print(out.shape)

torch.Size([10, 8, 1200, 64])


In [90]:
import math

import torch
from torch import nn, einsum
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

# constant

TOKEN_SELF_ATTN_VALUE = -5e4

# helper functions

def exists(val):
    return val is not None

# 如果value不存在，返回d
def default(value, d):
    return d if not exists(value) else value

def to(t):
    return {'device': t.device, 'dtype': t.dtype}

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max  #返回给定张量数据类型的所能表示的最大负值

def l2norm(tensor):
    dtype = tensor.dtype
    normed = F.normalize(tensor, dim = -1)
    return normed.type(dtype)

def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple
    if m.is_integer():
        return False, tensor
    remainder = math.ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value)

def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):  #x = bk: (40, 32, 16, 64)
    t = x.shape[1]    #获取一共有多少个窗口，这里是32
    dims = (len(x.shape) - dim) * (0, 0)   #一个长度为 len(x.shape) - dim 的元组，每个元素为 (0, 0)；其中len(x.shape) = 4
    padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)   #在第二维度上，前面加backward个元素，后面加forward个元素 -> (40, 33, 16, 64)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] #一个张量列表，每个张量的维度为(40, 32, 16, 64), len = 2
    return torch.cat(tensors, dim = dim) #在第二维度上拼接 -> (40, 32, 32, 64)

# main class

class Attention(Module):
    def __init__(
        self,
        window_size,
        causal = False,
        look_backward = 1,
        look_forward = None,
        dropout = 0.,
        shared_qk = False,
        rel_pos_emb_config = None,
        dim = None,
        autopad = False,
        exact_windowsize = False,
        scale = None,
        use_rotary_pos_emb = True,
        use_xpos = False,
        xpos_scale_base = None
    ):
        super().__init__()
        look_forward = default(look_forward, 0 if causal else 1)
        assert not (causal and look_forward > 0), 'you cannot look forward if causal'

        self.scale = scale

        self.window_size = window_size
        self.autopad = autopad
        self.exact_windowsize = exact_windowsize

        self.causal = causal

        self.look_backward = look_backward
        self.look_forward = look_forward

        self.dropout = nn.Dropout(dropout)

        self.shared_qk = shared_qk

        # relative positions

        self.rel_pos = None
        self.use_xpos = use_xpos

    def forward(
        self,
        q, k,
        mask = None,
        input_mask = None,
        attn_bias = None,
        window_size = None
    ):

        mask = default(mask, input_mask)

        assert not (exists(window_size) and not self.use_xpos), 'cannot perform window size extrapolation if xpos is not turned on'

        shape, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk

        # https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb
        (q, packed_shape), (k, _)= map(lambda t: pack([t], '* n d'), (q, k))  #打包成[5, 8, 512, 64] -> [40, 512, 64] 

        # auto padding

        if autopad:
            orig_seq_len = q.shape[1]
            (needed_pad, q), (_, k) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k))

        b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype   # 40, 512, 64

        scale = default(self.scale, dim_head ** -0.5)

        assert (n % window_size) == 0, f'sequence length {n} must be divisible by window size {window_size} for local attention'

        windows = n // window_size  # 512 / 16 = 32

        if shared_qk:
            k = l2norm(k)

        seq = torch.arange(n, device = device)                  # 0, 1, 2, 3, ..., 511
        b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)    # (1, 32, 16) 排序序列变形后的矩阵

        # bucketing

        bq, bk = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k)) #重构：（40，512，64）->（40, 32, 16, 64）

        bq = bq * scale    # (40, 32, 16, 64)

        look_around_kwargs = dict(
            backward =  look_backward,
            forward =  look_forward,
            pad_value = pad_value
        )

        bk = look_around(bk, **look_around_kwargs)      # (40, 32, 32, 64)

        # rotary embeddings

        if exists(self.rel_pos):
            pos_emb, xpos_scale = self.rel_pos(bk)
            bq, bk = self.rel_pos.apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)

        # calculate positions for masking

        bq_t = b_t
        bq_k = look_around(b_t, **look_around_kwargs) # (1, 32, 32)

        bq_t = rearrange(bq_t, '... i -> ... i 1')      # (1, 32, 16, 1)
        bq_k = rearrange(bq_k, '... j -> ... 1 j')      # (1, 32, 1, 16)

        pad_mask = bq_k == pad_value

        sim = einsum('b h i e, b h j e -> b h i j', bq, bk)  # (40, 32, 16, 64) * (40, 32, 32, 64) -> (40, 32, 16, 32)

        if exists(attn_bias):
            heads = attn_bias.shape[0]
            assert (b % heads) == 0

            attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
            sim = sim + attn_bias

        mask_value = max_neg_value(sim)

        if shared_qk:
            self_mask = bq_t == bq_k
            sim = sim.masked_fill(self_mask, TOKEN_SELF_ATTN_VALUE)
            del self_mask

        if causal:
            causal_mask = bq_t < bq_k

            if self.exact_windowsize:
                max_causal_window_size = (self.window_size * self.look_backward)
                causal_mask = causal_mask | (bq_t > (bq_k + max_causal_window_size))

            sim = sim.masked_fill(causal_mask, mask_value)
            del causal_mask

        # masking out for exact window size for non-causal
        # as well as masking out for padding value

        if not causal and self.exact_windowsize:
            max_backward_window_size = (self.window_size * self.look_backward)
            max_forward_window_size = (self.window_size * self.look_forward)
            window_mask = ((bq_k - max_forward_window_size) > bq_t) | (bq_t > (bq_k + max_backward_window_size)) | pad_mask
            sim = sim.masked_fill(window_mask, mask_value)
        else:
            sim = sim.masked_fill(pad_mask, mask_value)

        # take care of key padding mask passed in

        if exists(mask):
            batch = mask.shape[0]    # 5
            assert (b % batch) == 0

            h = b // mask.shape[0]  # 8

            if autopad:
                _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)

            mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size)
            mask = look_around(mask, **{**look_around_kwargs, 'pad_value': False})
            mask = rearrange(mask, '... j -> ... 1 j')
            mask = repeat(mask, 'b ... -> (b h) ...', h = h)
            sim = sim.masked_fill(~mask, mask_value)
            del mask

        # attention

        attn = sim.softmax(dim = -1)

        return attn

In [104]:

class localRewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5, dropout=0.1):
        super(localRewirescorelayer, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        self.Attention = Attention(
                                    dim =128,                # dimension of each head (you need to pass this in for relative positional encoding)
                                    window_size = 10,       # window size. 512 is optimal, but 256 or 128 yields good enough results
                                    causal = None,           # auto-regressive or not
                                    look_backward = 1,       # each window looks at the window before
                                    look_forward = 1,        # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
                                    dropout = 0.1,           # post-attention dropout
                                    exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
                                )
        
        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.value = nn.Linear(in_features, out_features * num_heads)
        self.scale = torch.sqrt(torch.FloatTensor([out_features])).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y

    def forward(self, graph, node_features):
        device = node_features.device
        num_relations = graph.num_relation
        num_nodes = node_features.size(0)//num_relations
    
        Q = self.query(node_features).view(num_relations, self.num_heads, num_nodes, self.out_features)     
        K = self.key(node_features).view(num_relations,  self.num_heads, num_nodes, self.out_features)      
        
        output = self.Attention(Q, K)  # [num_nodes, num_heads, 1, max_window_size]

        edge_list = self.gumbel_softmax_top_k(output, self.temperature, self.k)

        return edge_list

In [191]:
input_dim = relational_output.shape[-1]
output_dim = 128
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 3
num_nodes = relational_output.size(0)


relational_output = relational_output.to(device)
module = localRewirescorelayer(input_dim, output_dim, num_heads, window_size, k).to(device)
start = time.time()
attn_output = module(graph, relational_output)
end = time.time()

In [194]:
print(attn_output.shape)
print(end-start)
indices = torch.nonzero(attn_output, as_tuple=True)
print(indices)

torch.Size([40, 60, 10, 30])
0.05623292922973633
(tensor([ 0,  0,  0,  ..., 39, 39, 39]), tensor([ 0,  0,  0,  ..., 59, 59, 59]), tensor([0, 0, 0,  ..., 9, 9, 9]), tensor([11, 16, 17,  ...,  3, 11, 13]))


In [113]:

class test(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5, dropout=0.1):
        super(test, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        
        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.value = nn.Linear(in_features, out_features * num_heads)
        self.scale = torch.sqrt(torch.FloatTensor([out_features])).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # get the start and end indices for each window of nodes
    def get_start_end(self, current, graph):
        segment = graph.num_nodes.repeat(graph.num_relation)
        index = torch.cumsum(segment, dim=0)
        
        # Use torch.searchsorted to find the appropriate segment
        pos = torch.searchsorted(index, current, right=True)

        if pos == 0:
            return (0, index[0].item())
        elif pos >= len(index):
            return (index[-1].item(), index[-1].item())
        else:
            return (index[pos-1].item(), index[pos].item())

    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y

    def forward(self, graph, node_features):
        device = node_features.device
        num_nodes = node_features.size(0)
    
        # calculate the start and end indices for each node
        half_window = self.window_size // 2
        start_end_indices = [self.get_start_end(i, graph) for i in range(num_nodes)]
        start_indices = torch.tensor([max(start_end_indices[i][0], i - half_window) for i in range(num_nodes)], device=device)
        end_indices = torch.tensor([min(start_end_indices[i][1], i + half_window) for i in range(num_nodes)], device=device)
        window_sizes = end_indices - start_indices

        Q = self.query(node_features).view(self.num_heads, num_nodes, self.out_features)     # [num_heads, num_nodes, out_features]
        K = self.key(node_features).view(self.num_heads, num_nodes, self.out_features)       # [num_heads, num_nodes, out_features]
        
        output = torch.zeros(num_nodes, num_nodes, device=device)
        all_scores = torch.zeros(num_nodes, self.num_heads, self.window_size, device=device)
        
        # construct the windowed K matrix
        K_windows = torch.zeros((num_nodes, self.window_size, self.num_heads, self.out_features), device=device)
        for i in range(num_nodes):
            start = start_indices[i]
            end = end_indices[i]
            K_windows[i, :end-start] = K[:, start:end]
            


        Q_expanded = Q.unsqueeze(2)  # [num_nodes, num_heads, 1, out_features]
        K_expanded = K_windows.permute(0, 2, 1, 3)  # [num_nodes, num_heads, max_window_size, out_features]


        all_scores = torch.matmul(Q_expanded, K_expanded.transpose(-1, -2)) / self.scale  # [num_nodes, num_heads, 1, max_window_size]
        all_scores = all_scores.squeeze(2)  # [num_nodes, num_heads, max_window_size]
        

        mask = (torch.arange(self.window_size, device=device).expand(num_nodes, self.window_size) < window_sizes.unsqueeze(1))


        all_scores = all_scores.masked_fill(~mask.unsqueeze(1), float('-inf'))  # [num_nodes, num_heads, max_window_size]
        attention_weights = F.softmax(all_scores / self.temperature, dim=-1)  # [num_nodes, num_heads, max_window_size]
        attention_weights = attention_weights.mean(dim=1) 
        
        # sample edges
        for i in range(num_nodes):
            start = start_indices[i]
            end = end_indices[i]
            output[i, start:end] = attention_weights[i, :end-start]

        edge_list = self.gumbel_softmax_top_k(output, self.temperature, self.k)

        return edge_list

In [117]:
input_dim = relational_output.shape[-1]
output_dim = 128
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 3
num_nodes = relational_output.size(0)


relational_output = relational_output.to(device)
module = test(input_dim, output_dim, num_heads, window_size, k).to(device)
attn_output = module(graph, relational_output)
print(attn_output.shape)

torch.Size([3000, 3000])


In [122]:
input= torch.randn(56, 600, 512)

segment = graph.num_nodes
index = torch.cumsum(segment, dim=0)
window_size = 10




tensor([185, 600])


In [4]:
import torch
import time

def split_tensor(tensor, index, window_size):
    result = []
    index_list  = []
    # 初始的开始索引
    start = 0

    for idx in index:
        end = start + idx - 1
        while start < end:
            if start + window_size <= end:
                result.append(tensor[:, start:start+window_size, :])
                index_list.append([start, start+window_size])
                start += window_size
            else:
                # 计算还需要填充多少行
                padding_rows = window_size - (end - start + 1)
                restart = start - padding_rows
                result.append(tensor[:, restart:restart + window_size , :])
                index_list.append([restart, restart+window_size])
                start = end + 1
                

    # 转换结果列表为 tensor
    result_tensor = torch.stack(result, dim=1)
    return result_tensor, index_list

In [26]:
# 示例输入
tensor = torch.randn(56, 600, 512, requires_grad=True)
index = [185, 415]
window_size = 10

start_time = time.time()
result, index_list = split_tensor(tensor, index, window_size)
end_time = time.time()

# 打印运行时间
print(f"运行时间: {end_time - start_time:.6f} 秒")
print(result.shape)  # 应输出 (56, n, 10, 512)

# 打印索引列表
print(index_list)
print(len(index_list))

运行时间: 0.014563 秒
torch.Size([56, 61, 10, 512])
[[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80], [80, 90], [90, 100], [100, 110], [110, 120], [120, 130], [130, 140], [140, 150], [150, 160], [160, 170], [170, 180], [175, 185], [185, 195], [195, 205], [205, 215], [215, 225], [225, 235], [235, 245], [245, 255], [255, 265], [265, 275], [275, 285], [285, 295], [295, 305], [305, 315], [315, 325], [325, 335], [335, 345], [345, 355], [355, 365], [365, 375], [375, 385], [385, 395], [395, 405], [405, 415], [415, 425], [425, 435], [435, 445], [445, 455], [455, 465], [465, 475], [475, 485], [485, 495], [495, 505], [505, 515], [515, 525], [525, 535], [535, 545], [545, 555], [555, 565], [565, 575], [575, 585], [585, 595], [590, 600]]
61


In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
atten_output = torch.randn(56, 61, 10, 10, device=device)
output = torch.zeros(56, 600, 600, device=device)


In [35]:

start_time = time.time()
for i, index in enumerate(index_list):
    start, end = index
    output[:, start:end, start:end] = atten_output[:, i, :, :]
end_time = time.time()
output = output.view(8, 4200, 4200)
# 示例输出
print(f"运行时间: {end_time - start_time:.6f} 秒")
print(output.shape)
print(output)



RuntimeError: shape '[8, 4200, 4200]' is invalid for input of size 20160000