diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 2c4208cee5..7569eb202e 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -15,6 +15,7 @@ import torch from torchvision.transforms import Compose, Lambda, Normalize, ColorJitter from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR from contiguous_params import ContiguousParams import wandb @@ -26,7 +27,7 @@ from utils import plot_samples -def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb): +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb): model.train() train_iter = iter(train_loader) # Iterate over the batches of the dataset @@ -42,6 +43,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb): optimizer.zero_grad() train_loss.backward() optimizer.step() + scheduler.step() mb.child.comment = f'Training loss: {train_loss.item():.6}' @@ -172,6 +174,11 @@ def main(args): model_params = ContiguousParams([p for p in model.parameters() if p.requires_grad]).contiguous() optimizer = torch.optim.Adam(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) + # Scheduler + if args.sched == 'cosine': + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == 'onecycle': + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -201,7 +208,7 @@ def main(args): # Training loop mb = master_bar(range(args.epochs)) for epoch in mb: - fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb) + fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb) # Validation loop at the end of each epoch val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: @@ -253,6 +260,7 @@ def parse_args(): help='Load pretrained parameters before starting the training') parser.add_argument('--rotation', dest='rotation', action='store_true', help='train with rotated bbox') + parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') args = parser.parse_args() return args diff --git a/references/recognition/train_pytorch.py b/references/recognition/train_pytorch.py index 63442b61bd..fcbf77355b 100644 --- a/references/recognition/train_pytorch.py +++ b/references/recognition/train_pytorch.py @@ -15,6 +15,7 @@ import torch from torchvision.transforms import Compose, Lambda, Normalize, ColorJitter from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR from contiguous_params import ContiguousParams import wandb from pathlib import Path @@ -27,7 +28,7 @@ from utils import plot_samples -def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb): +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb): model.train() train_iter = iter(train_loader) # Iterate over the batches of the dataset @@ -41,6 +42,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb): optimizer.zero_grad() train_loss.backward() optimizer.step() + scheduler.step() mb.child.comment = f'Training loss: {train_loss.item():.6}' @@ -162,6 +164,11 @@ def main(args): model_params = ContiguousParams([p for p in model.parameters() if p.requires_grad]).contiguous() optimizer = torch.optim.Adam(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) + # Scheduler + if args.sched == 'cosine': + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == 'onecycle': + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -190,7 +197,7 @@ def main(args): # Training loop mb = master_bar(range(args.epochs)) for epoch in mb: - fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb) + fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb) # Validation loop at the end of each epoch val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric) @@ -239,6 +246,7 @@ def parse_args(): help='Log to Weights & Biases') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='Load pretrained parameters before starting the training') + parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') args = parser.parse_args() return args