Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object of type 'int64' is not JSON serializable in Trainer.save_checkpoint #10299

Closed
arthurbra opened this issue Feb 20, 2021 · 4 comments · Fixed by #10632
Closed

Object of type 'int64' is not JSON serializable in Trainer.save_checkpoint #10299

arthurbra opened this issue Feb 20, 2021 · 4 comments · Fixed by #10632
Assignees

Comments

@arthurbra
Copy link

arthurbra commented Feb 20, 2021

I am using the recent run_ner.py example script to train an NER model. I want to evaluate the performance of the model during training and use the following command for training:

python3 run_ner.py 
--model_name_or_path bert-base-uncased                                                                                                                                     
--dataset_name conll2003                                                                                                                         
--return_entity_level_metrics                                                                                                                                               
--output_dir conll-tmp                                                                                                                                                    
--overwrite_output_dir                                                                                                                                                      
--do_train                                                                                                                                                                 
--do_eval                                                                                                                                                                   
--do_predict                                                                                                                                                                
--evaluation_strategy steps                                                                                                                                                
--logging_steps 10                                                                                                                                                         
--eval_steps 10                                                                                                                                                            
--load_best_model_at_end

I run the command in the current docker image huggingface/transformers-pytorch-gpu
However, I get the following error:

Traceback (most recent call last):                                                                                                                                             
File "run_ner.py", line 470, in main()                                                                                                                                                                     
File "run_ner.py", line 404, in main                                                                                                                              
train_result = trainer.train(resume_from_checkpoint=checkpoint)                                                                                                            File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 983, in train                                                                                      self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)                                                                                                                File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 1062, in _maybe_log_save_evaluate                                                                  self._save_checkpoint(model, trial, metrics=metrics)                                                                                                                       
File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 1126, in _save_checkpoint                                                                          self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))                                                                                                    File "/usr/local/lib/python3.6/dist-packages/transformers/trainer_callback.py", line 95, in save_to_json                                                                       json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"                                                                                        File "/usr/lib/python3.6/json/__init__.py", line 238, in dumps                                                                                                                 **kw).encode(obj)                                                                                                                                                          
File "/usr/lib/python3.6/json/encoder.py", line 201, in encode                                                                                                                 chunks = list(chunks)                                                                                                                                                      
File "/usr/lib/python3.6/json/encoder.py", line 430, in _iterencode                                                                                                            yield from _iterencode_dict(o, _current_indent_level)                                                                                                                      
File "/usr/lib/python3.6/json/encoder.py", line 404, in _iterencode_dict                                                                                                       yield from chunks                                                                                                                                                          
File "/usr/lib/python3.6/json/encoder.py", line 325, in _iterencode_list                                                                                                       yield from chunks                                                                                                                                                          
File "/usr/lib/python3.6/json/encoder.py", line 404, in _iterencode_dict                                                                                                       yield from chunks                                                                                                                                                          
File "/usr/lib/python3.6/json/encoder.py", line 437, in _iterencode                                                                                                            o = _default(o)                                                                                                                                                            
File "/usr/lib/python3.6/json/encoder.py", line 180, in default                                                                                                                o.__class__.__name__)                                                                                                                                                    
TypeError: Object of type 'int64' is not JSON serializable
--
@antonyscerri
Copy link

I too ran into this problem and its caused by turning on evaluation strategy which then adds metrics in the log_history of the models state, which is using numpy data types and causes the JSON encoder issue. That was the case with 4.3.3. There appear to be a bunch of changes in the trainer in the works, whether this has been fixed as a result of those i've not checked.

@antonyscerri
Copy link

As a temporary work around you can modify trainer.py at line 1260 "output = {**logs, **{"step": self.state.global_step}}" and add the following three lines after. If the metrics are being calculated the same in the latest code as in 4.3.3 then something like this may also be needed going forward, or things calling the log method will need to ensure they safely cast data points beforehand if its going to be added to the trainer state still.

        for k,v in output.items():
            if isinstance(v, np.generic):
                output[k]=v.item()

@sgugger
Copy link
Collaborator

sgugger commented Mar 10, 2021

I confirm I can reproduce in master. Will investigate more tomorrow.

@antonyscerri
Copy link

My only comment on the fix submitted is that it targets the metrics output, but will not stop others putting things into the log history in the model state which later on cause the same problem if serializing the state to json.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants