Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added schedulers to Pytorch training scripts #381

Merged
merged 2 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from torchvision.transforms import Compose, Lambda, Normalize, ColorJitter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR
from contiguous_params import ContiguousParams
import wandb

Expand All @@ -26,7 +27,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -42,6 +43,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'

Expand Down Expand Up @@ -172,6 +174,11 @@ def main(args):
model_params = ContiguousParams([p for p in model.parameters() if p.requires_grad]).contiguous()
optimizer = torch.optim.Adam(model_params, args.lr,
betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
# Scheduler
if args.sched == 'cosine':
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
elif args.sched == 'onecycle':
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -201,7 +208,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
if val_loss < min_loss:
Expand Down Expand Up @@ -253,6 +260,7 @@ def parse_args():
help='Load pretrained parameters before starting the training')
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
args = parser.parse_args()

return args
Expand Down
12 changes: 10 additions & 2 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from torchvision.transforms import Compose, Lambda, Normalize, ColorJitter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR
from contiguous_params import ContiguousParams
import wandb
from pathlib import Path
Expand All @@ -27,7 +28,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -41,6 +42,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'

Expand Down Expand Up @@ -162,6 +164,11 @@ def main(args):
model_params = ContiguousParams([p for p in model.parameters() if p.requires_grad]).contiguous()
optimizer = torch.optim.Adam(model_params, args.lr,
betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
# Scheduler
if args.sched == 'cosine':
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
elif args.sched == 'onecycle':
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -190,7 +197,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)

# Validation loop at the end of each epoch
val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric)
Expand Down Expand Up @@ -239,6 +246,7 @@ def parse_args():
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
args = parser.parse_args()

return args
Expand Down