Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

labels should be int in train.py, line 45 #1

Closed
wanyao1992 opened this issue Mar 2, 2018 · 4 comments
Closed

labels should be int in train.py, line 45 #1

wanyao1992 opened this issue Mar 2, 2018 · 4 comments

Comments

@wanyao1992
Copy link

Dear authors,

Thank you for sharing you code. When I run the source code, I encountered the following problem.

Traceback (most recent call last):
File "train.py", line 45, in
model = GAT(nfeat=features.shape[1], nhid=args.hidden, nclass=labels.max() + 1, dropout=args.dropout, nheads=args.nb_heads, alpha=args.alpha)
File "/home/wanyao/www/Dropbox/ghproj/pyGAT/models.py", line 16, in init
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
File "/home/wanyao/www/Dropbox/ghproj/pyGAT/layers.py", line 22, in init
self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
TypeError: new() received an invalid combination of arguments - got (Variable, int), but expected one of:

  • (int device)
  • (tuple of ints size, int device)
    didn't match because some of the arguments have invalid types: (Variable, int)
  • (torch.Storage storage)
  • (Tensor other)
  • (object data, int device)
    didn't match because some of the arguments have invalid types: (Variable, int)

This is caused by that, in line 45 of train.py, labels should be int type, not torch.LongTensor.
model = GAT(nfeat=features.shape[1], nhid=args.hidden, nclass=labels.max() + 1, dropout=args.dropout, nheads=args.nb_heads, alpha=args.alpha)

After I modified "labels.max()" to be "int(labels.max())", the problem is solved.

@Diego999
Copy link
Owner

Diego999 commented Mar 2, 2018

Hi wanyao1992,

I'm not one of the authors despite I would like ;)

For your error, what version of pytorch/numpy do you use ? Maybe it's a matter of auto-cast

@Diego999
Copy link
Owner

Diego999 commented Mar 2, 2018

(I fixed it just in case in 9f6afbe)

@wanyao1992
Copy link
Author

Hi Diego,

Acutually, I use the Python 3.6.4, pytorch 0.4 and numpy 1.14.1.

@Diego999
Copy link
Owner

Diego999 commented Mar 2, 2018

Hum I'm using python 3.5.2 and numpy 1.13.3. I guess this should be the difference. Anyway, it's fixed :)

@Diego999 Diego999 closed this as completed Mar 5, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants