Skip to content

Commit

Permalink
update params and fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
zhtmike committed Mar 1, 2024
1 parent d398539 commit 8bb78df
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
8 changes: 4 additions & 4 deletions configs/navit/navit_b16_384_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 12
batch_size: 64
drop_remainder: True
patch_size: 16
max_seq_length: 2048
max_num_each_group: 40
max_seq_length: 768
max_num_each_group: 16

# augmentation
image_resize: 384
Expand All @@ -29,7 +29,7 @@ drop_path_rate: 0.1
num_classes: 1000
pretrained: False
ckpt_path: ""
keep_checkpoint_max: 3
keep_checkpoint_max: 1
ckpt_save_policy: "top_k"
ckpt_save_dir: "./ckpt"
epoch_size: 100
Expand Down
4 changes: 2 additions & 2 deletions configs/navit/navit_b16_384_ascend_control.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 64
batch_size: 116
drop_remainder: True

# augmentation
Expand All @@ -26,7 +26,7 @@ drop_path_rate: 0.1
num_classes: 1000
pretrained: False
ckpt_path: ""
keep_checkpoint_max: 3
keep_checkpoint_max: 1
ckpt_save_policy: "top_k"
ckpt_save_dir: "./ckpt"
epoch_size: 100
Expand Down
3 changes: 2 additions & 1 deletion mindcv/models/navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def construct(self, x, context=None, token_mask=None):
token_mask = ops.unsqueeze(token_mask, 1)
attn = ops.masked_fill(attn, ~token_mask, -ms.numpy.inf)

attn = ops.softmax(attn, axis=-1)
dtype = attn.dtype
attn = ops.softmax(attn.to(ms.float32), axis=-1).to(dtype)
attn = self.attn_drop(attn)

out = self.attn_matmul_v(attn, v)
Expand Down
10 changes: 7 additions & 3 deletions mindcv/utils/top_k.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np

from mindspore.train.metrics.metric import _check_onehot_data, rearrange_inputs
from mindspore.train.metrics.topk import TopKCategoricalAccuracy
try:
from mindspore.train.metrics.metric import rearrange_inputs
from mindspore.train.metrics.topk import TopKCategoricalAccuracy
except ImportError: # MS Version < 2.0
from mindspore.nn.metrics.metric import rearrange_inputs
from mindspore.nn.metrics.topk import TopKCategoricalAccuracy


class TopKCategoricalAccuracyForTokenData(TopKCategoricalAccuracy):
Expand All @@ -23,7 +27,7 @@ def update(self, *inputs):
y_pred = y_pred[inds]
y = y[inds]

if y_pred.ndim == y.ndim and _check_onehot_data(y):
if y_pred.ndim == y.ndim:
y = y.argmax(axis=1)
indices = np.argsort(-y_pred, axis=1)[:, : self.k]
repeated_y = y.reshape(-1, 1).repeat(self.k, axis=1)
Expand Down

0 comments on commit 8bb78df

Please sign in to comment.