From 63a8dba99e4853b9be5d3e6c14909a30685c7532 Mon Sep 17 00:00:00 2001 From: Optimox Date: Thu, 9 Feb 2023 12:09:43 +0100 Subject: [PATCH] fix: 424 allow any np.intX as training target --- pytorch_tabnet/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 93e360a4..6302378c 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -422,8 +422,8 @@ def define_device(device_name): class ComplexEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, np.int64): - return int(obj) + if isinstance(obj, (np.generic, np.ndarray)): + return obj.tolist() # Let the base class default method raise the TypeError return json.JSONEncoder.default(self, obj)