Skip to content

Commit

Permalink
add token drop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhtmike committed Feb 29, 2024
1 parent 3c1b0cf commit d398539
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 19 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def create_parser():
'Otherwise, repeated augmentation is enabled and the common choice is 3 (default=0)')
group.add_argument('--patch_size', type=int, default=32, help="Patch size in sequence packing.")
group.add_argument('--max_seq_length', type=int, default=2048, help="maximum sequence length in sequence packing.")
group.add_argument('--max_num_each_group', type=int, default=32, help="maximum number of images in each sequence")
group.add_argument('--max_num_each_group', type=int, default=40, help="maximum number of images in each sequence")

# Model parameters
group = parser.add_argument_group('Model parameters')
Expand Down
2 changes: 1 addition & 1 deletion configs/navit/navit_b16_384_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ batch_size: 12
drop_remainder: True
patch_size: 16
max_seq_length: 2048
max_num_each_group: 32
max_num_each_group: 40

# augmentation
image_resize: 384
Expand Down
101 changes: 85 additions & 16 deletions mindcv/data/token_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import logging
import os
import random
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy.stats
import tqdm
from PIL import Image

from mindspore.dataset import vision
Expand Down Expand Up @@ -33,12 +35,27 @@ def create_img_transform(is_train: bool = True):
return transform


def _cal_size_train(w: int, h: int, max_size: int, min_size: int) -> Tuple[int, int]:
def _sample_dropout_rate(
s: int, dmin: float = 0.1, dmax: float = 0.9, smin: int = 16, smax: int = 576
) -> Union[float, None]:
# A.4 resolution-dependent dropping rate
def mapping(x): # linear [smin, smax] -> [dmin, dmax]
return (dmax - dmin) / (smax - smin) * x + (smax * dmin - smin * dmax) / (smax - smin)

u = mapping(s)
d = np.random.normal(u, 0.02)

if d < u - 0.04 or d > u + 0.04:
d = None
return d


def _cal_size_train(w: int, h: int, max_size: int, min_size: int, u: Optional[float] = None) -> Tuple[int, int]:
def mapping(x): # linear [-1, 1] -> [min, max]
return (max_size - min_size) / 2 * x + (max_size + min_size) / 2

u = np.random.normal(-0.5, 1)
u = np.clip(u, -1, 1)
if u is None:
u = scipy.stats.truncnorm.rvs(-0.5, 1.5, -0.5, 1) # truncated normal (-0.5, 1) -> [-1, 1]
target = mapping(u)

# sample the a random side to [min_size, maxsize],
Expand Down Expand Up @@ -76,12 +93,17 @@ def _cal_size_infer(w: int, h: int, max_size: int, min_size: int) -> Tuple[int,


def _cal_size(
shape: Tuple[int, int], patch_size: int = 32, max_size: int = 384, min_size: int = 64, is_train: bool = False
shape: Tuple[int, int],
patch_size: int = 32,
max_size: int = 384,
min_size: int = 64,
is_train: bool = False,
u: Optional[float] = None,
):
w, h = shape

if is_train: # resolution sampling
new_w, new_h = _cal_size_train(w, h, max_size, min_size)
new_w, new_h = _cal_size_train(w, h, max_size, min_size, u=u)
else: # keep in range without change aspect ratio
new_w, new_h = _cal_size_infer(w, h, max_size, min_size)

Expand All @@ -102,14 +124,23 @@ def __init__(
interpolation: str = "bilinear",
image_resize: int = 384,
image_resize_min: int = 64,
max_num_each_group: int = 32,
max_num_each_group: int = 40,
token_dropout_min: float = 0.1,
token_dropout_max: float = 0.9,
apply_token_drop_prob: float = 0.8,
) -> None:
self.is_train = split == "train"
self.patch_size = patch_size
self.max_seq_length = max_seq_length
self.image_resize = image_resize
self.image_resize_min = image_resize_min
self.max_num_each_group = max_num_each_group
self.token_dropout_min = token_dropout_min
self.token_dropout_max = token_dropout_max
self.apply_token_drop_prob = apply_token_drop_prob

self.max_length = image_resize * image_resize // patch_size // patch_size
self.min_length = image_resize_min * image_resize_min // patch_size // patch_size

self.transform = create_img_transform(is_train=self.is_train)
if interpolation == "bilinear":
Expand All @@ -125,26 +156,27 @@ def __init__(
# and target image size (which is divisible by patch size).
self.images_info = self._inspect_images(root, enable_cache=enable_cache, cache_path=cache_path)

# step 2: group the images by the maximum sequence length
# here we use greedy method for simplicty, and meanwhile we do the resolution sampling in the same time
# step 2: group the images by the maximum sequence length, here we use greedy method for simplicty,
# and meanwhile we do the resolution sampling / token dropout sampling in the same time
self.update_groups()

# step 3 (optional): form the label mapping for train/val, for imagenet format
label = sorted(list(set([x["label"] for x in self.images_info])))
self.label_mapping = dict(zip(label, range(len(label))))

def update_groups(self):
_logger.info("Packing groups, it may take some time...")
# call this once after each epoch allowing different resolution sampling
self.images_group = self._group_by_max_seq_length(self.images_info)
max_num_each_group_real = max([len(x) for x in self.images_group])
if max_num_each_group_real > self.max_num_each_group:
_logger.warning(
f"The Maximum number of images in the grouping ({max_num_each_group_real}) "
f"The maximum number of images in the group ({max_num_each_group_real}) "
f"is higher than the allowed value ({self.max_num_each_group}), "
"you may need to adjust the `max_num_each_group` in the model and dataloader."
"you may need to adjust the `max_num_each_group` in the configuration."
)

_logger.info(f"Packing group is updated. The total number of groups now is {len(self.images_group)}.")
_logger.info(f"Group is updated. The total number of groups now is {len(self.images_group)}.")

def __len__(self):
return len(self.images_group)
Expand Down Expand Up @@ -174,18 +206,23 @@ def __getitem__(self, index):
"Perhaps the cache is corrupted, please remove the cache first and rerun again.",
)

# step 1: resize to the shape which is divisible by patch size
# step 1: image resolution sampling
img = img.resize(img_info["target_shape"], resample=self.resample)

# stes 2: normalization and other imagewise transform
img = self.transform(img)[0]

# step 3: patchify
img_patch, pos = self._patchify(img)

# step 4: token dropout
if self.is_train:
img_patch, pos = self._token_dropout(img_patch, pos, p=img_info["token_dropout"])

img_patch_seq.append(img_patch)
pos_seq.append(pos)

# step 4 (optinal): add label
# step 5: add label
label = self.label_mapping[img_info["label"]]
label_seq.append(label)

Expand Down Expand Up @@ -243,18 +280,39 @@ def _group_by_max_seq_length(self, images_info: List[Dict[str, Any]]) -> List[Li
if self.is_train:
random.shuffle(images_info)

for image_info in images_info:
# resolution sampling values, sampled here for speed concern
u = scipy.stats.truncnorm.rvs(-0.5, 1.5, -0.5, 1, size=len(images_info))

for i, image_info in tqdm.tqdm(
enumerate(images_info), desc="packing group", total=len(images_info), miniters=len(images_info) // 10
):
w, h = _cal_size(
image_info["shape"],
patch_size=self.patch_size,
max_size=self.image_resize,
min_size=self.image_resize_min,
is_train=self.is_train,
u=u[i],
)
nw, nh = w // self.patch_size, h // self.patch_size
img_seq_len = nw * nh

image_info["target_shape"] = (w, h)
if self.is_train:
if random.random() < self.apply_token_drop_prob:
image_info["token_dropout"] = _sample_dropout_rate(
img_seq_len,
dmin=self.token_dropout_min,
dmax=self.token_dropout_max,
smin=self.min_length,
smax=self.max_length,
)
if image_info["token_dropout"] is None:
continue # reject sample
else:
image_info["token_dropout"] = 0.0

img_seq_len = max(round(img_seq_len * (1 - image_info["token_dropout"])), 1)

if img_seq_len > self.max_seq_length:
_logger.warning(
Expand All @@ -273,6 +331,7 @@ def _group_by_max_seq_length(self, images_info: List[Dict[str, Any]]) -> List[Li

# the last
groups.append(group)
print("") # make log clear
return groups

def _patchify(self, img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -286,5 +345,15 @@ def _patchify(self, img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
# get 2d abs posiition
pos = np.meshgrid(np.arange(nh), np.arange(nw), indexing="ij")
pos = np.stack(pos, axis=-1)
pos = np.reshape(pos, (-1, 2))
pos = np.reshape(pos, (-1, 2)) # nh * nw, 2
return img, pos

def _token_dropout(self, img_patch: np.ndarray, pos: np.ndarray, p: float = 0) -> Tuple[np.ndarray, np.ndarray]:
if p == 0:
return img_patch, pos
seq_len = img_patch.shape[0]
num_keep = max(round(seq_len * (1 - p)), 1)
inds = np.random.permutation(np.arange(seq_len))[:num_keep]
img_patch = img_patch[inds]
pos = pos[inds]
return img_patch, pos
2 changes: 1 addition & 1 deletion mindcv/models/navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
block_fn: Callable = Block,
num_classes: int = 1000,
pool_type: str = "attn",
max_num_each_group: int = 32,
max_num_each_group: int = 40,
):
super().__init__()

Expand Down

0 comments on commit d398539

Please sign in to comment.