Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

移植loss问题 #40

Closed
chapmancpp opened this issue Jun 14, 2018 · 7 comments
Closed

移植loss问题 #40

chapmancpp opened this issue Jun 14, 2018 · 7 comments

Comments

@chapmancpp
Copy link

大佬好!
我把代码移植到python3.6+pytorch0.4下面。在训练的时候,loss.py里面报错。
当一个batch块里面出现相同的ID的图片的时候这边is_pos 计算的就不是对角矩阵了。
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
然后导致后面代码:
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos_test].contiguous().view(N, -1), 1, keepdim=True)#报错
错误提示:
RuntimeError: invalid argument 2: size '[16 x -1]' is invalid for input with 18 elements at ..\src\TH\THStorage.c:37

请问这里我是否可以直接定义一个batch大小的对角阵给is_pos 。还是is_pos 代码里面就可以出现非对角矩阵。
当我将is_pos 改成一直是对角矩阵的时候。只有全局loss,loss的值开始就非常的低,请问可能是什么问题啊。

@huanghoujing
Copy link
Owner

多谢你的关注!
这个is_pos的第i行表示所有样本和第i个样本之间是否是同一个id,对角线上肯定是True,如果第j, k, l个样本和第i个样本同一个id,那么第i行的第j, k, l应该是Trueis_pos一般情况下不是对角阵。这个错误我没能一眼看出来是哪里的不兼容。

@chapmancpp
Copy link
Author

N = dist_mat.size(0) #这是您N的定义
您之前定义了N的大小,为矩阵大小,
然而当队列中出现相同的ID的时候,即为非对角矩阵时,那么
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) ,
dist_mat[is_pos]的维数就会不等于N。
那么,下面很多用N定义的就会报错了,比如下面这行:
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos_test].contiguous().view(N, -1), 1, keepdim=True)#报错

dist_mat[is_pos_test].contiguous()的维数不等于N,就不能用view(N,1).您看我理解的对吗,是不是我哪边理解错了。

@huanghoujing
Copy link
Owner

举个例子,labels[1, 2, 3, 4, 2, 1, 3, 4, 4, 2, 1, 3],也即4个id,每个id有3张图片,那么is_pos应该是(为了简化,下面的1表示True, 0表示False):

1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0
0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0
0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0
1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0
0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0
1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1

对角线都是1,每一行总共有3个1,所以dist_mat[is_pos].contiguous().view(N, -1)里边的-1就相当于3dist_mat[is_pos].contiguous().view(N, -1)的结果是一个12*3的数组。

我发现上面你的代码中好像有点问题,dist_ap, relative_p_inds = torch.max( dist_mat[is_pos_test].contiguous().view(N, -1), 1, keepdim=True)这里边is_pos_test不对吧,应该是is_pos

@chapmancpp
Copy link
Author

后来发现是我对triplet loss的理解问题。哈哈,谢谢大佬。大佬在国外的吗?羡慕。

@huanghoujing
Copy link
Owner

大佬这个还是不敢当。。没在国外啊。。。

@Kang9779
Copy link

Kang9779 commented Jun 17, 2020

后来发现是我对triplet loss的理解问题。哈哈,谢谢大佬。大佬在国外的吗?羡慕。

N = dist_mat.size(0) #这是您N的定义
您之前定义了N的大小,为矩阵大小,
然而当队列中出现相同的ID的时候,即为非对角矩阵时,那么
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) ,
dist_mat[is_pos]的维数就会不等于N。
那么,下面很多用N定义的就会报错了,比如下面这行:
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos_test].contiguous().view(N, -1), 1, keepdim=True)#报错

dist_mat[is_pos_test].contiguous()的维数不等于N,就不能用view(N,1).您看我理解的对吗,是不是我哪边理解错了。

我也遇到了这个问题,你解决了吗?

@Kang9779
Copy link

举个例子,labels[1, 2, 3, 4, 2, 1, 3, 4, 4, 2, 1, 3],也即4个id,每个id有3张图片,那么is_pos应该是(为了简化,下面的1表示True, 0表示False):

1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0
0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0
0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0
1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0
0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0
0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0
1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1

对角线都是1,每一行总共有3个1,所以dist_mat[is_pos].contiguous().view(N, -1)里边的-1就相当于3dist_mat[is_pos].contiguous().view(N, -1)的结果是一个12*3的数组。

我发现上面你的代码中好像有点问题,dist_ap, relative_p_inds = torch.max( dist_mat[is_pos_test].contiguous().view(N, -1), 1, keepdim=True)这里边is_pos_test不对吧,应该是is_pos

要是labels = [1, 2, 3, 4, 2, 1, 3, 4, 4, 2, 1, 1]这样的话,不就没法搞了吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants