Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chapter 8 > Gated Recurrent Units (GRUs) > Visualizing the Model > The Journey of a Gated Hidden State: figure22 error #31

Closed
scmanjarrez opened this issue Dec 30, 2022 · 1 comment

Comments

@scmanjarrez
Copy link
Contributor

Hi,
Figure22 is throwing the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[100], line 1
----> 1 fig = figure22(model.basic_rnn)

File ~/side_projects/brain_auth/psychopy/data_alcoholism/PyTorchStepByStep/plots/chapter8.py:880, in figure22(rnn)
    878 square = torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]).float().view(1, 4, 2)
    879 n_linear, r_linear, z_linear = disassemble_gru(rnn, layer='_l0')
--> 880 gcell, mstates, hstates, gates = generate_gru_states(n_linear, r_linear, z_linear, square)
    881 gcell(hstates[-1])
    882 titles = [r'$hidden\ state\ (h)$',
    883           r'$transformed\ state\ (t_h)$',
    884           r'$reset\ gate\ (r*t_h)$' + '\n' + r'$r=$',
   (...)
    888           r'$adding\ z*h$' + '\n' + r'h=$(1-z)*n+z*h$', 
    889          ]

File ~/side_projects/brain_auth/psychopy/data_alcoholism/PyTorchStepByStep/plots/chapter8.py:787, in generate_gru_states(n_linear, r_linear, z_linear, X)
    785     gcell = add_h(gcell, z*hidden)
    786     model_states.append(deepcopy(gcell.state_dict()))
--> 787     hidden = gcell(hidden)
    789 return gcell, model_states, hidden_states, {'rmult': rs, 'zmult': zs}

File ~/side_projects/brain_auth/psychopy/data_alcoholism/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/side_projects/brain_auth/psychopy/data_alcoholism/venv/lib/python3.8/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input)
    202 def forward(self, input):
    203     for module in self:
--> 204         input = module(input)
    205     return input

File ~/side_projects/brain_auth/psychopy/data_alcoholism/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/side_projects/brain_auth/psychopy/data_alcoholism/venv/lib/python3.8/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: expand(torch.FloatTensor{[1, 1, 2]}, size=[1, 2]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

I have tested in local and colab, but the same error happens.

@dvgodoy
Copy link
Owner

dvgodoy commented Dec 31, 2022

Hi @scmanjarrez

Thank you for pointing out this issue.
I've just pushed a fix for it. It was a problem with the shape of the tensor representing the bias in one of the layers of the manually assembled GRU cell.
It turns out, former versions of PyTorch (like former LTS version, 1.8.2) were more "forgiving" if you will, but the latest one surely isn't.

In PyTorch 1.8, we could do this:

>>> lin = nn.Linear(2, 2)
>>> lin.weight = nn.Parameter(torch.eye(2))
>>> lin.bias = nn.Parameter(torch.ones(1, 1, 2))
>>> x = torch.ones(1, 1, 2)
>>> lin(x)
tensor([[[2., 2.]]], grad_fn=<AddBackward0>)

But, in PyTorch 1.13 (which Colab is already using), the same code raises the exception you saw:

>>> lin = nn.Linear(2, 2)
>>> lin.weight = nn.Parameter(torch.eye(2))
>>> lin.bias = nn.Parameter(torch.ones(1, 1, 2))
>>> x = torch.ones(1, 1, 2)
>>> lin(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/dvgodoy/anaconda3/envs/pyt13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/dvgodoy/anaconda3/envs/pyt13/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expand(torch.FloatTensor{[1, 1, 2]}, size=[1, 2]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)
>>> lin.bias = nn.Parameter(torch.ones(1, 2))
>>> lin(x)
tensor([[[2., 2.]]], grad_fn=<ViewBackward0>)

In the manually assembled GRU cell, the hidden state was multiplied by z and added to the output (and I used it the bias of the linear layer to do that manually). The hidden state was a sequence of one following the (N=1,L=1,F=2) shape, so I simply used the first (and only) element of sequence instead, resulting in the necessary (1, 2) shape for the bias.

Best,
Daniel

@dvgodoy dvgodoy closed this as completed Dec 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants