From 418a7cbbd6942c3ac570246911ae10f8d87d19c1 Mon Sep 17 00:00:00 2001 From: Sourcery AI Date: Sun, 3 Apr 2022 17:49:56 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- config.py | 2 +- data/build.py | 26 ++++++++++++----- data/cached_image_folder.py | 20 ++++++++----- data/zipreader.py | 5 ++-- lr_scheduler.py | 23 ++++++--------- main.py | 13 ++++----- models/swin_mlp.py | 23 ++++++--------- models/swin_transformer.py | 33 +++++++++------------ optimizer.py | 10 ++----- utils.py | 58 ++++++++++++++++++++----------------- 10 files changed, 104 insertions(+), 109 deletions(-) diff --git a/config.py b/config.py index 2db498b5..f6f2b44f 100644 --- a/config.py +++ b/config.py @@ -196,7 +196,7 @@ def _update_config_from_file(config, cfg_file): _update_config_from_file( config, os.path.join(os.path.dirname(cfg_file), cfg) ) - print('=> merge config from {}'.format(cfg_file)) + print(f'=> merge config from {cfg_file}') config.merge_from_file(cfg_file) config.freeze() diff --git a/data/build.py b/data/build.py index 88c4e16b..e016d729 100644 --- a/data/build.py +++ b/data/build.py @@ -94,8 +94,8 @@ def build_dataset(is_train, config): if config.DATA.DATASET == 'imagenet': prefix = 'train' if is_train else 'val' if config.DATA.ZIP_MODE: - ann_file = prefix + "_map.txt" - prefix = prefix + ".zip@/" + ann_file = f'{prefix}_map.txt' + prefix += ".zip@/" dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, cache_mode=config.DATA.CACHE_MODE if is_train else 'part') else: @@ -134,17 +134,27 @@ def build_transform(is_train, config): if resize_im: if config.TEST.CROP: size = int((256 / 224) * config.DATA.IMG_SIZE) - t.append( - transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), - # to maintain same ratio w.r.t. 224 images + t.extend( + ( + transforms.Resize( + size, + interpolation=_pil_interp(config.DATA.INTERPOLATION), + ), + transforms.CenterCrop(config.DATA.IMG_SIZE), + ) ) - t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) + else: t.append( transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), interpolation=_pil_interp(config.DATA.INTERPOLATION)) ) - t.append(transforms.ToTensor()) - t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + t.extend( + ( + transforms.ToTensor(), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ) + ) + return transforms.Compose(t) diff --git a/data/cached_image_folder.py b/data/cached_image_folder.py index 7e1883b1..c2b6999d 100644 --- a/data/cached_image_folder.py +++ b/data/cached_image_folder.py @@ -56,7 +56,7 @@ def make_dataset_with_ann(ann_file, img_prefix, extensions): with open(ann_file, "r") as f: contents = f.readlines() for line_str in contents: - path_contents = [c for c in line_str.split('\t')] + path_contents = list(line_str.split('\t')) im_file_name = path_contents[0] class_index = int(path_contents[1]) @@ -102,8 +102,15 @@ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transfo extensions) if len(samples) == 0: - raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + - "Supported extensions are: " + ",".join(extensions))) + raise RuntimeError( + ( + f"Found 0 files in subfolders of: {root}" + + "\n" + + "Supported extensions are: " + ) + + ",".join(extensions) + ) + self.root = root self.loader = loader @@ -162,7 +169,7 @@ def __len__(self): return len(self.samples) def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str = f'Dataset {self.__class__.__name__}' + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' @@ -242,10 +249,7 @@ def __getitem__(self, index): """ path, target = self.samples[index] image = self.loader(path) - if self.transform is not None: - img = self.transform(image) - else: - img = image + img = self.transform(image) if self.transform is not None else image if self.target_transform is not None: target = self.target_transform(target) diff --git a/data/zipreader.py b/data/zipreader.py index 060bc46a..04c4163f 100644 --- a/data/zipreader.py +++ b/data/zipreader.py @@ -40,7 +40,7 @@ def split_zip_style_path(path): pos_at = path.index('@') assert pos_at != -1, "character '@' is not found from the given path '%s'" % path - zip_path = path[0: pos_at] + zip_path = path[:pos_at] folder_path = path[pos_at + 1:] folder_path = str.strip(folder_path, '/') return zip_path, folder_path @@ -86,8 +86,7 @@ def list_files(path, extension=None): def read(path): zip_path, path_img = ZipReader.split_zip_style_path(path) zfile = ZipReader.get_zipfile(zip_path) - data = zfile.read(path_img) - return data + return zfile.read(path_img) @staticmethod def imread(path): diff --git a/lr_scheduler.py b/lr_scheduler.py index 4d27289b..afbe5d59 100644 --- a/lr_scheduler.py +++ b/lr_scheduler.py @@ -82,21 +82,16 @@ def __init__(self, def _get_lr(self, t): if t < self.warmup_t: - lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] - else: - t = t - self.warmup_t - total_t = self.t_initial - self.warmup_t - lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] - return lrs + return [self.warmup_lr_init + t * s for s in self.warmup_steps] + t = t - self.warmup_t + total_t = self.t_initial - self.warmup_t + return [ + v - ((v - v * self.lr_min_rate) * (t / total_t)) + for v in self.base_values + ] def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None + return self._get_lr(epoch) if self.t_in_epochs else None def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None + return None if self.t_in_epochs else self._get_lr(num_updates) diff --git a/main.py b/main.py index ef7bdeee..f5538dea 100644 --- a/main.py +++ b/main.py @@ -109,8 +109,7 @@ def main(config): max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: - resume_file = auto_resume_helper(config.OUTPUT) - if resume_file: + if resume_file := auto_resume_helper(config.OUTPUT): if config.MODEL.RESUME: logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") config.defrost() @@ -152,7 +151,7 @@ def main(config): total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logger.info('Training time {}'.format(total_time_str)) + logger.info(f'Training time {total_time_str}') def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): @@ -287,15 +286,15 @@ def validate(config, data_loader, model): def throughput(data_loader, model, logger): model.eval() - for idx, (images, _) in enumerate(data_loader): + for images, _ in data_loader: images = images.cuda(non_blocking=True) batch_size = images.shape[0] - for i in range(50): + for _ in range(50): model(images) torch.cuda.synchronize() - logger.info(f"throughput averaged with 30 times") + logger.info("throughput averaged with 30 times") tic1 = time.time() - for i in range(30): + for _ in range(30): model(images) torch.cuda.synchronize() tic2 = time.time() diff --git a/models/swin_mlp.py b/models/swin_mlp.py index 115c43cd..2438df2d 100644 --- a/models/swin_mlp.py +++ b/models/swin_mlp.py @@ -42,8 +42,11 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows + return ( + x.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(-1, window_size, window_size, C) + ) def window_reverse(windows, window_size, H, W): @@ -277,10 +280,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, def forward(self, x): for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) + x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) if self.downsample is not None: x = self.downsample(x) return x @@ -289,9 +289,7 @@ def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() + flops = sum(blk.flops() for blk in self.blocks) if self.downsample is not None: flops += self.downsample.flops() return flops @@ -322,10 +320,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None + self.norm = norm_layer(embed_dim) if norm_layer is not None else None def forward(self, x): B, C, H, W = x.shape @@ -461,7 +456,7 @@ def forward(self, x): def flops(self): flops = 0 flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): + for layer in self.layers: flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes diff --git a/models/swin_transformer.py b/models/swin_transformer.py index cfeb0f22..13a16a43 100644 --- a/models/swin_transformer.py +++ b/models/swin_transformer.py @@ -41,8 +41,11 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows + return ( + x.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(-1, window_size, window_size, C) + ) def window_reverse(windows, window_size, H, W): @@ -132,10 +135,7 @@ def forward(self, x, mask=None): nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - + attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) @@ -224,7 +224,10 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill( + attn_mask == 0, 0.0 + ) + else: attn_mask = None @@ -387,10 +390,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, def forward(self, x): for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) + x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) if self.downsample is not None: x = self.downsample(x) return x @@ -399,9 +399,7 @@ def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() + flops = sum(blk.flops() for blk in self.blocks) if self.downsample is not None: flops += self.downsample.flops() return flops @@ -432,10 +430,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None + self.norm = norm_layer(embed_dim) if norm_layer is not None else None def forward(self, x): B, C, H, W = x.shape @@ -578,7 +573,7 @@ def forward(self, x): def flops(self): flops = 0 flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): + for layer in self.layers: flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes diff --git a/optimizer.py b/optimizer.py index 3c57ce0b..f00d9879 100644 --- a/optimizer.py +++ b/optimizer.py @@ -12,10 +12,8 @@ def build_optimizer(config, model): """ Build optimizer, set weight decay of normalization to 0 by default. """ - skip = {} skip_keywords = {} - if hasattr(model, 'no_weight_decay'): - skip = model.no_weight_decay() + skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else {} if hasattr(model, 'no_weight_decay_keywords'): skip_keywords = model.no_weight_decay_keywords() parameters = set_weight_decay(model, skip, skip_keywords) @@ -50,8 +48,4 @@ def set_weight_decay(model, skip_list=(), skip_keywords=()): def check_keywords_in_name(name, keywords=()): - isin = False - for keyword in keywords: - if keyword in name: - isin = True - return isin + return any(keyword in name for keyword in keywords) diff --git a/utils.py b/utils.py index 94890ca4..76641ec6 100644 --- a/utils.py +++ b/utils.py @@ -72,15 +72,14 @@ def load_pretrained(config, model, logger): L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {k}, passing......") - else: - if L1 != L2: - # bicubic interpolate relative_position_bias_table if not match - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), - mode='bicubic') - state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) + elif L1 != L2: + # bicubic interpolate relative_position_bias_table if not match + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), + mode='bicubic') + state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) # bicubic interpolate absolute_pos_embed if not match absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] @@ -92,17 +91,16 @@ def load_pretrained(config, model, logger): _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.warning(f"Error in loading {k}, passing......") - else: - if L1 != L2: - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) - absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( - absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') - absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) - absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) - state_dict[k] = absolute_pos_embed_pretrained_resized + elif L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) + absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( + absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') + absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) + absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) + state_dict[k] = absolute_pos_embed_pretrained_resized # check classifier, if not match, then re-init classifier to zero head_bias_pretrained = state_dict['head.bias'] @@ -111,7 +109,7 @@ def load_pretrained(config, model, logger): if (Nc1 != Nc2): if Nc1 == 21841 and Nc2 == 1000: logger.info("loading ImageNet-22K weight to ImageNet-1K ......") - map22kto1k_path = f'data/map22kto1k.txt' + map22kto1k_path = 'data/map22kto1k.txt' with open(map22kto1k_path) as f: map22kto1k = f.readlines() map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] @@ -122,7 +120,10 @@ def load_pretrained(config, model, logger): torch.nn.init.constant_(model.head.weight, 0.) del state_dict['head.weight'] del state_dict['head.bias'] - logger.warning(f"Error in loading classifier head, re-init classifier head to 0") + logger.warning( + "Error in loading classifier head, re-init classifier head to 0" + ) + msg = model.load_state_dict(state_dict, strict=False) logger.warning(msg) @@ -166,13 +167,16 @@ def auto_resume_helper(output_dir): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] print(f"All checkpoints founded in {output_dir}: {checkpoints}") - if len(checkpoints) > 0: - latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) + if checkpoints: + latest_checkpoint = max( + (os.path.join(output_dir, d) for d in checkpoints), + key=os.path.getmtime, + ) + print(f"The latest checkpoint founded: {latest_checkpoint}") - resume_file = latest_checkpoint + return latest_checkpoint else: - resume_file = None - return resume_file + return None def reduce_tensor(tensor):