Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplifying some code #188

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _update_config_from_file(config, cfg_file):
_update_config_from_file(
config, os.path.join(os.path.dirname(cfg_file), cfg)
)
print('=> merge config from {}'.format(cfg_file))
print(f'=> merge config from {cfg_file}')
config.merge_from_file(cfg_file)
config.freeze()

Expand Down
26 changes: 18 additions & 8 deletions data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def build_dataset(is_train, config):
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
ann_file = f'{prefix}_map.txt'
prefix += ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
Expand Down Expand Up @@ -146,17 +146,27 @@ def build_transform(is_train, config):
if resize_im:
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
t.extend(
(
transforms.Resize(
size,
interpolation=_pil_interp(config.DATA.INTERPOLATION),
),
transforms.CenterCrop(config.DATA.IMG_SIZE),
)
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))

else:
t.append(
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION))
)

t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
t.extend(
(
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
)
)

return transforms.Compose(t)
20 changes: 12 additions & 8 deletions data/cached_image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def make_dataset_with_ann(ann_file, img_prefix, extensions):
with open(ann_file, "r") as f:
contents = f.readlines()
for line_str in contents:
path_contents = [c for c in line_str.split('\t')]
path_contents = list(line_str.split('\t'))
im_file_name = path_contents[0]
class_index = int(path_contents[1])

Expand Down Expand Up @@ -102,8 +102,15 @@ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transfo
extensions)

if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
"Supported extensions are: " + ",".join(extensions)))
raise RuntimeError(
(
f"Found 0 files in subfolders of: {root}"
+ "\n"
+ "Supported extensions are: "
)
+ ",".join(extensions)
)


self.root = root
self.loader = loader
Expand Down Expand Up @@ -162,7 +169,7 @@ def __len__(self):
return len(self.samples)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str = f'Dataset {self.__class__.__name__}' + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
Expand Down Expand Up @@ -242,10 +249,7 @@ def __getitem__(self, index):
"""
path, target = self.samples[index]
image = self.loader(path)
if self.transform is not None:
img = self.transform(image)
else:
img = image
img = self.transform(image) if self.transform is not None else image
if self.target_transform is not None:
target = self.target_transform(target)

Expand Down
5 changes: 2 additions & 3 deletions data/zipreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def split_zip_style_path(path):
pos_at = path.index('@')
assert pos_at != -1, "character '@' is not found from the given path '%s'" % path

zip_path = path[0: pos_at]
zip_path = path[:pos_at]
folder_path = path[pos_at + 1:]
folder_path = str.strip(folder_path, '/')
return zip_path, folder_path
Expand Down Expand Up @@ -86,8 +86,7 @@ def list_files(path, extension=None):
def read(path):
zip_path, path_img = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
data = zfile.read(path_img)
return data
return zfile.read(path_img)

@staticmethod
def imread(path):
Expand Down
23 changes: 9 additions & 14 deletions lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,16 @@ def __init__(self,

def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
return lrs
return [self.warmup_lr_init + t * s for s in self.warmup_steps]
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
return [
v - ((v - v * self.lr_min_rate) * (t / total_t))
for v in self.base_values
]

def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
return self._get_lr(epoch) if self.t_in_epochs else None

def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
return None if self.t_in_epochs else self._get_lr(num_updates)
13 changes: 6 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def main(config):
max_accuracy = 0.0

if config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT)
if resume_file:
if resume_file := auto_resume_helper(config.OUTPUT):
if config.MODEL.RESUME:
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
config.defrost()
Expand Down Expand Up @@ -162,7 +161,7 @@ def main(config):

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))
logger.info(f'Training time {total_time_str}')


def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
Expand Down Expand Up @@ -277,15 +276,15 @@ def validate(config, data_loader, model):
def throughput(data_loader, model, logger):
model.eval()

for idx, (images, _) in enumerate(data_loader):
for images, _ in data_loader:
images = images.cuda(non_blocking=True)
batch_size = images.shape[0]
for i in range(50):
for _ in range(50):
model(images)
torch.cuda.synchronize()
logger.info(f"throughput averaged with 30 times")
logger.info("throughput averaged with 30 times")
tic1 = time.time()
for i in range(30):
for _ in range(30):
model(images)
torch.cuda.synchronize()
tic2 = time.time()
Expand Down
23 changes: 9 additions & 14 deletions models/swin_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def window_partition(x, window_size):
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
return (
x.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, window_size, window_size, C)
)


def window_reverse(windows, window_size, H, W):
Expand Down Expand Up @@ -277,10 +280,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size,

def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
Expand All @@ -289,9 +289,7 @@ def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
flops = sum(blk.flops() for blk in self.blocks)
if self.downsample is not None:
flops += self.downsample.flops()
return flops
Expand Down Expand Up @@ -322,10 +320,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
self.norm = norm_layer(embed_dim) if norm_layer is not None else None

def forward(self, x):
B, C, H, W = x.shape
Expand Down Expand Up @@ -461,7 +456,7 @@ def forward(self, x):
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
for layer in self.layers:
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
Expand Down
33 changes: 14 additions & 19 deletions models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ def window_partition(x, window_size):
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
return (
x.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, window_size, window_size, C)
)


def window_reverse(windows, window_size, H, W):
Expand Down Expand Up @@ -144,10 +147,7 @@ def forward(self, x, mask=None):
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)

attn = self.softmax(attn)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
Expand Down Expand Up @@ -238,7 +238,10 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(
attn_mask == 0, 0.0
)

else:
attn_mask = None

Expand Down Expand Up @@ -414,10 +417,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size,

def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
Expand All @@ -426,9 +426,7 @@ def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
flops = sum(blk.flops() for blk in self.blocks)
if self.downsample is not None:
flops += self.downsample.flops()
return flops
Expand Down Expand Up @@ -459,10 +457,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
self.norm = norm_layer(embed_dim) if norm_layer is not None else None

def forward(self, x):
B, C, H, W = x.shape
Expand Down Expand Up @@ -607,7 +602,7 @@ def forward(self, x):
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
for layer in self.layers:
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
Expand Down
10 changes: 2 additions & 8 deletions optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ def build_optimizer(config, model):
"""
Build optimizer, set weight decay of normalization to 0 by default.
"""
skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else {}
if hasattr(model, 'no_weight_decay_keywords'):
skip_keywords = model.no_weight_decay_keywords()
parameters = set_weight_decay(model, skip, skip_keywords)
Expand Down Expand Up @@ -63,8 +61,4 @@ def set_weight_decay(model, skip_list=(), skip_keywords=()):


def check_keywords_in_name(name, keywords=()):
isin = False
for keyword in keywords:
if keyword in name:
isin = True
return isin
return any(keyword in name for keyword in keywords)
Loading