Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

评估每个超边的重要性的代码 #8

Closed
13121283123 opened this issue Apr 8, 2021 · 4 comments
Closed

评估每个超边的重要性的代码 #8

13121283123 opened this issue Apr 8, 2021 · 4 comments

Comments

@13121283123
Copy link

对您的这篇论文非常感兴趣,想问您一下,您在定义layers1的时候 只用到了BatchedGraphSAGEDynamicRangeMean1 这个类,其他剩下的类您注释了BatchedGraphSAGEDynamicMean1,BatchedGraphSAGEMean1,BatchedGraphSAGEMean1Temporal,BatchedGAT_cat1 请问这几个类有用吗,还有想请问您一下 ,
for i in range(int(N/p)):
idx_start = max(0, i-t)
idx_end = min(i+t+1, int(N/p))
tmp_x = x[:,idx_startp:idx_endp,]
dis = NearestConvolution.cos_dis(tmp_x) 这个是计算特征之间的相似性
if i==0:
tk = min(dis.shape[2], self.kn)
#print(tk)
_, idx = torch.topk(dis, tk, dim=2) 是包含前K个近邻的邻居集
k_nearest = torch.stack([torch.stack([tmp_x[j, idx[j, i]] for i in range(p*(idx_end-idx_start))], dim=0) for j in range(b)], dim=0) #(b, xp, kn, d)
#print(k_nearest)
k_nearest_list.append(k_nearest[:,p
(i-idx_start):p*(i-idx_start+1),])
k_nearest = torch.cat(k_nearest_list, dim=1) #(b,N, kn, d)
x_neib = k_nearest[:,:,1:,].contiguous() 我们将除节点 v i 以外的超边中的所有节点特征进行平均操作,作为超边的特征.
x_neib = x_neib.mean(dim=2)
h_k = torch.cat((self.W_x(x), self.W_neib(x_neib)), 2)

    h_k = F.normalize(h_k, dim=2, p=2)
    h_k = F.relu(h_k)
    #print(h_k.shape)
    if self.use_bn:
        #self.bn = nn.BatchNorm1d(h_k.size(1))
        h_k = self.bn(h_k.permute(0,2,1).contiguous())
        #print(h_k.shape)
        h_k = h_k.permute(0, 2, 1)
        #print(h_k.shape)

    return h_k

请问评估每个超边的重要性的代码 在哪里啊 谢谢您的回复

@daodaofr
Copy link
Owner

daodaofr commented Apr 8, 2021

其他几个类是我们尝试过的不同的特征传播的方法,应该效果比现在用的稍微差点。
这里就是简单进行了节点的平均作为超边特征,我们发现attention对结果影响较小,这里就直接平均化了。

@13121283123
Copy link
Author

非常感谢您的回复,还想请教您一个问题, 当调用self.layers1的时候会跳到BatchedGraphSAGEDynamicRangeMean1类的forward方法,但是我看到这个forward方法中没有用到传入进来的adj邻阶矩阵,这个怎么解释啊, 大佬, 谢谢
class BatchedGraphSAGEDynamicRangeMean1(nn.Module):
def init(self, infeat, outfeat, use_bn=True, mean=False, add_self=False):
super(BatchedGraphSAGEDynamicRangeMean1, self).init()
self.add_self = add_self
self.use_bn = use_bn
self.mean = mean
self.aggregator = True
self.W_x = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_(self.W_x.weight, gain=nn.init.calculate_gain('relu'))
self.W_neib = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_(self.W_neib.weight, gain=nn.init.calculate_gain('relu'))
if self.use_bn:
self.bn = nn.BatchNorm1d(2*outfeat)
#self.bn = nn.BatchNorm1d(16)
self.kn = 3

def forward(self, x, adj, p, t):
b = x.size()[0]
N = x.size()[1]
k_nearest_list = []
tk = self.kn #tk=3
for i in range(int(N/p)):
idx_start = max(0, i-t)
idx_end = min(i+t+1, int(N/p))
tmp_x = x[:,idx_startp:idx_endp,]
dis = NearestConvolution.cos_dis(tmp_x)
if i==0:
tk = min(dis.shape[2], self.kn)
#print(tk)
_, idx = torch.topk(dis, tk, dim=2)

    k_nearest = torch.stack([torch.stack([tmp_x[j, idx[j, i]] for i in range(p*(idx_end-idx_start))], dim=0) for j in range(b)], dim=0) 
    k_nearest_list.append(k_nearest[:,p*(i-idx_start):p*(i-idx_start+1),])
k_nearest = torch.cat(k_nearest_list, dim=1) #(b,N, kn, d)
x_neib = k_nearest[:,:,1:,].contiguous()
#x_neib = x_neib.view(x.size(0), x.size(1), -1, x_neib.size(2))
x_neib = x_neib.mean(dim=2)
#print(k_nearest.shape)
#x_cmp = x - k_nearest[:,:,0]
#print(torch.sum(x_cmp)) 
h_k = torch.cat((self.W_x(x), self.W_neib(x_neib)), 2)
h_k = F.normalize(h_k, dim=2, p=2)
h_k = F.relu(h_k)
if self.use_bn:
    #self.bn = nn.BatchNorm1d(h_k.size(1))
    h_k = self.bn(h_k.permute(0,2,1).contiguous())
    #print(h_k.shape)
    h_k = h_k.permute(0, 2, 1)
    #print(h_k.shape)
return h_k

@daodaofr
Copy link
Owner

daodaofr commented Apr 8, 2021

这里是用p,t 这两个参数来控制他的邻居的选取,就没用邻接矩阵

@13121283123
Copy link
Author

13121283123 commented Apr 8, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants