You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Input In [26], in train_dalle_batch(vae, train_data, _, idx, __)
1 def train_dalle_batch(vae, train_data, _, idx, __):
2 text, image_codes, mask = train_data
----> 3 loss = dalle(text[idx, ...], image_codes[idx, ...], mask=mask[idx, ...], return_loss=True)
4 return loss
File c:\users\xx\xx\dall\venv\lib\site-packages\torch\nn\modules\module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'mask'
TypeError Traceback (most recent call last)
Input In [28], in <cell line: 2>()
1 dalle_model_file = "data/rainbow_dalle.model"
2 if not os.path.exists(dalle_model_file):
----> 3 dalle, loss_history = fit(dalle, opt, None, scheduler,
4 (captions_array[train_idx, ...], all_image_codes[train_idx, ...], captions_mask[train_idx, ...]), None, 200, 256,
5 dalle_model_file, train_dalle_batch,
6 n_train_samples=len(train_idx))
8 plt.plot(loss_history)
9 else:
Input In [14], in fit(model, opt, criterion, scheduler, train_x, train_y, epochs, batch_size, model_file, trainer, n_train_samples)
14 model.train()
15 opt.zero_grad()
---> 16 loss = trainer(model, train_x, train_y, rnd_idx[batch_idx:(batch_idx + batch_size)], criterion)
17 loss.backward()
18 losses.append(loss.item())
Input In [26], in train_dalle_batch(vae, train_data, _, idx, __)
1 def train_dalle_batch(vae, train_data, _, idx, __):
2 text, image_codes, mask = train_data
----> 3 loss = dalle(text[idx, ...], image_codes[idx, ...], mask=mask[idx, ...], return_loss=True)
4 return loss
File c:\users\xx\xx\dall\venv\lib\site-packages\torch\nn\modules\module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'mask'
torch 1.11.0
torch-fidelity 0.3.0
torchmetrics 0.9.1
torchvision 0.12.0
win10
torch runs on cpu
The text was updated successfully, but these errors were encountered: