diff --git a/configs/navit/navit_b16_384_ascend.yaml b/configs/navit/navit_b16_384_ascend.yaml index 0f6920e2..62ebb1f8 100644 --- a/configs/navit/navit_b16_384_ascend.yaml +++ b/configs/navit/navit_b16_384_ascend.yaml @@ -11,11 +11,11 @@ dataset: "imagenet" data_dir: "/path/to/imagenet" shuffle: True dataset_download: False -batch_size: 12 +batch_size: 64 drop_remainder: True patch_size: 16 -max_seq_length: 2048 -max_num_each_group: 40 +max_seq_length: 768 +max_num_each_group: 16 # augmentation image_resize: 384 @@ -29,7 +29,7 @@ drop_path_rate: 0.1 num_classes: 1000 pretrained: False ckpt_path: "" -keep_checkpoint_max: 3 +keep_checkpoint_max: 1 ckpt_save_policy: "top_k" ckpt_save_dir: "./ckpt" epoch_size: 100 diff --git a/configs/navit/navit_b16_384_ascend_control.yaml b/configs/navit/navit_b16_384_ascend_control.yaml index 597a7101..454fcf1c 100644 --- a/configs/navit/navit_b16_384_ascend_control.yaml +++ b/configs/navit/navit_b16_384_ascend_control.yaml @@ -11,7 +11,7 @@ dataset: "imagenet" data_dir: "/path/to/imagenet" shuffle: True dataset_download: False -batch_size: 64 +batch_size: 116 drop_remainder: True # augmentation @@ -26,7 +26,7 @@ drop_path_rate: 0.1 num_classes: 1000 pretrained: False ckpt_path: "" -keep_checkpoint_max: 3 +keep_checkpoint_max: 1 ckpt_save_policy: "top_k" ckpt_save_dir: "./ckpt" epoch_size: 100 diff --git a/mindcv/models/navit.py b/mindcv/models/navit.py index 5c5aab8e..eebc3c29 100644 --- a/mindcv/models/navit.py +++ b/mindcv/models/navit.py @@ -139,7 +139,8 @@ def construct(self, x, context=None, token_mask=None): token_mask = ops.unsqueeze(token_mask, 1) attn = ops.masked_fill(attn, ~token_mask, -ms.numpy.inf) - attn = ops.softmax(attn, axis=-1) + dtype = attn.dtype + attn = ops.softmax(attn.to(ms.float32), axis=-1).to(dtype) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) diff --git a/mindcv/utils/top_k.py b/mindcv/utils/top_k.py index 76af23cf..05ed93d3 100644 --- a/mindcv/utils/top_k.py +++ b/mindcv/utils/top_k.py @@ -1,7 +1,11 @@ import numpy as np -from mindspore.train.metrics.metric import _check_onehot_data, rearrange_inputs -from mindspore.train.metrics.topk import TopKCategoricalAccuracy +try: + from mindspore.train.metrics.metric import rearrange_inputs + from mindspore.train.metrics.topk import TopKCategoricalAccuracy +except ImportError: # MS Version < 2.0 + from mindspore.nn.metrics.metric import rearrange_inputs + from mindspore.nn.metrics.topk import TopKCategoricalAccuracy class TopKCategoricalAccuracyForTokenData(TopKCategoricalAccuracy): @@ -23,7 +27,7 @@ def update(self, *inputs): y_pred = y_pred[inds] y = y[inds] - if y_pred.ndim == y.ndim and _check_onehot_data(y): + if y_pred.ndim == y.ndim: y = y.argmax(axis=1) indices = np.argsort(-y_pred, axis=1)[:, : self.k] repeated_y = y.reshape(-1, 1).repeat(self.k, axis=1)