Skip to content

Commit

Permalink
feat: Added schedulers to Pytorch training scripts (#381)
Browse files Browse the repository at this point in the history
* feat: Added schedulers to Pytorch training scripts

* fix: Fixed typo
  • Loading branch information
fg-mindee committed Jul 13, 2021
1 parent 58a3f8d commit 330d52f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 10 additions & 2 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from torchvision.transforms import Compose, Lambda, Normalize, ColorJitter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR
from contiguous_params import ContiguousParams
import wandb

Expand All @@ -26,7 +27,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -42,6 +43,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'

Expand Down Expand Up @@ -172,6 +174,11 @@ def main(args):
model_params = ContiguousParams([p for p in model.parameters() if p.requires_grad]).contiguous()
optimizer = torch.optim.Adam(model_params, args.lr,
betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
# Scheduler
if args.sched == 'cosine':
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
elif args.sched == 'onecycle':
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -201,7 +208,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
if val_loss < min_loss:
Expand Down Expand Up @@ -253,6 +260,7 @@ def parse_args():
help='Load pretrained parameters before starting the training')
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
args = parser.parse_args()

return args
Expand Down
12 changes: 10 additions & 2 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from torchvision.transforms import Compose, Lambda, Normalize, ColorJitter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR
from contiguous_params import ContiguousParams
import wandb
from pathlib import Path
Expand All @@ -27,7 +28,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -41,6 +42,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'

Expand Down Expand Up @@ -162,6 +164,11 @@ def main(args):
model_params = ContiguousParams([p for p in model.parameters() if p.requires_grad]).contiguous()
optimizer = torch.optim.Adam(model_params, args.lr,
betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
# Scheduler
if args.sched == 'cosine':
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
elif args.sched == 'onecycle':
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -190,7 +197,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)

# Validation loop at the end of each epoch
val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric)
Expand Down Expand Up @@ -239,6 +246,7 @@ def parse_args():
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
args = parser.parse_args()

return args
Expand Down

0 comments on commit 330d52f

Please sign in to comment.