--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipykernel_34/3668832236.py in 1 #hide_output ----> 2 emotions_hidden = emotions_encoded.map(extract_hidden_states, batched=True) /opt/conda/lib/python3.7/site-packages/datasets/dataset_dict.py in map(self, function, with_indices, input_columns, batched, batch_size, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc) 502 desc=desc, 503 ) --> 504 for k, dataset in self.items() 505 } 506 ) /opt/conda/lib/python3.7/site-packages/datasets/dataset_dict.py in (.0) 502 desc=desc, 503 ) --> 504 for k, dataset in self.items() 505 } 506 ) /opt/conda/lib/python3.7/site-packages/datasets/arrow_dataset.py in map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc) 2034 new_fingerprint=new_fingerprint, 2035 disable_tqdm=disable_tqdm, -> 2036 desc=desc, 2037 ) 2038 else: /opt/conda/lib/python3.7/site-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs) 516 self: "Dataset" = kwargs.pop("self") 517 # apply actual function --> 518 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs) 519 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out] 520 for dataset in datasets: /opt/conda/lib/python3.7/site-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs) 483 } 484 # apply actual function --> 485 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs) 486 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out] 487 # re-apply format to the output /opt/conda/lib/python3.7/site-packages/datasets/fingerprint.py in wrapper(*args, **kwargs) 409 # Call actual function 410 --> 411 out = func(self, *args, **kwargs) 412 413 # Update fingerprint of in-place transforms + update in-place history of transforms /opt/conda/lib/python3.7/site-packages/datasets/arrow_dataset.py in _map_single(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, disable_tqdm, desc, cache_only) 2391 indices, 2392 check_same_num_examples=len(input_dataset.list_indexes()) > 0, -> 2393 offset=offset, 2394 ) 2395 except NumExamplesMismatchError: /opt/conda/lib/python3.7/site-packages/datasets/arrow_dataset.py in apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples, offset) 2275 if with_rank: 2276 additional_args += (rank,) -> 2277 processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) 2278 if update_data is None: 2279 # Check if the function returns updated examples /tmp/ipykernel_34/3098911508.py in extract_hidden_states(batch) 5 # Extract last hidden states 6 with torch.no_grad(): ----> 7 last_hidden_state = model(**inputs).last_hidden_state 8 # Return vector for [CLS] token 9 return {"hidden_state": last_hidden_state[:,0].cpu().numpy()} /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], [] /opt/conda/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict) 548 549 if inputs_embeds is None: --> 550 inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim) 551 return self.transformer( 552 x=inputs_embeds, /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], [] /opt/conda/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, input_ids) 128 position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) 129 --> 130 word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) 131 position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) 132 /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], [] /opt/conda/lib/python3.7/site-packages/torch/nn/modules/sparse.py in forward(self, input) 158 return F.embedding( 159 input, self.weight, self.padding_idx, self.max_norm, --> 160 self.norm_type, self.scale_grad_by_freq, self.sparse) 161 162 def extra_repr(self) -> str: /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) 2041 # remove once script supports set_grad_enabled 2042 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) -> 2043 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) 2044 2045 RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 15.90 GiB total capacity; 512.05 MiB already allocated; 167.75 MiB free; 530.00 MiB reserved in total by PyTorch)