Skip to content

Commit

Permalink
feat: improve loading time when test only
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesmindee committed Jun 3, 2021
1 parent 6933091 commit 97b6ff9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 50 deletions.
52 changes: 26 additions & 26 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,6 @@ def main(args):

print(args)

st = time.time()
# Load both train and val data generators
train_set = DetectionDataset(
img_folder=os.path.join(args.data_path, 'train'),
label_folder=os.path.join(args.data_path, 'train_labels'),
sample_transforms=T.Compose([
T.LambdaTransformation(lambda x: x / 255),
T.Resize((args.input_size, args.input_size)),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
rotated_bbox=args.rotation
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)")

if args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target, rotation=args.rotation)
return

st = time.time()
val_set = DetectionDataset(
img_folder=os.path.join(args.data_path, 'val'),
Expand Down Expand Up @@ -179,6 +153,32 @@ def main(args):
f"Mean IoU: {mean_iou:.2%})")
return

st = time.time()
# Load both train and val data generators
train_set = DetectionDataset(
img_folder=os.path.join(args.data_path, 'train'),
label_folder=os.path.join(args.data_path, 'train_labels'),
sample_transforms=T.Compose([
T.LambdaTransformation(lambda x: x / 255),
T.Resize((args.input_size, args.input_size)),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
rotated_bbox=args.rotation
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)")

if args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target, rotation=args.rotation)
return

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
Expand Down
50 changes: 26 additions & 24 deletions references/recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric):

def main(args):

st = time.time()
# Load both train and val data generators
train_set = RecognitionDataset(
img_folder=os.path.join(args.data_path, 'train'),
labels_path=os.path.join(args.data_path, 'train_labels.json'),
sample_transforms=T.Compose([
T.LambdaTransformation(lambda x: x / 255),
T.RandomApply(T.ColorInversion(), .1),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)")

if args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target)
return
print(args)

st = time.time()
val_set = RecognitionDataset(
Expand Down Expand Up @@ -168,6 +145,31 @@ def main(args):
print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
return

st = time.time()
# Load both train and val data generators
train_set = RecognitionDataset(
img_folder=os.path.join(args.data_path, 'train'),
labels_path=os.path.join(args.data_path, 'train_labels.json'),
sample_transforms=T.Compose([
T.LambdaTransformation(lambda x: x / 255),
T.RandomApply(T.ColorInversion(), .1),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)")

if args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target)
return

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
Expand Down

0 comments on commit 97b6ff9

Please sign in to comment.