Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support of AMP to all PyTorch training scripts #604

Merged
merged 9 commits into from
Nov 10, 2021
2 changes: 1 addition & 1 deletion doctr/models/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __call__(

for p_, bitmap_ in zip(proba_map, bitmap):
# Perform opening (erosion + dilatation)
bitmap_ = cv2.morphologyEx(bitmap_, cv2.MORPH_OPEN, kernel)
bitmap_ = cv2.morphologyEx(bitmap_.astype(np.float32), cv2.MORPH_OPEN, kernel)
# Rotate bitmap and proba_map
angle = get_bitmap_angle(bitmap_)
angles_batch.append(angle)
Expand Down
2 changes: 1 addition & 1 deletion doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def rotate_image(

height, width = exp_img.shape[:2]
rot_mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0)
rot_img = cv2.warpAffine(exp_img, rot_mat, (width, height))
rot_img = cv2.warpAffine(exp_img.astype(np.float32), rot_mat, (width, height))
if expand:
# Pad to get the same aspect ratio
if (image.shape[0] / image.shape[1]) != (rot_img.shape[0] / rot_img.shape[1]):
Expand Down
50 changes: 35 additions & 15 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from doctr.datasets import VOCABS, CharacterGenerator


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -36,28 +40,41 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, m

images = batch_transforms(images)

out = model(images)
train_loss = cross_entropy(out, targets)

optimizer.zero_grad()
train_loss.backward()
optimizer.step()
if amp:
with torch.cuda.amp.autocast():
out = model(images)
train_loss = cross_entropy(out, targets)
scaler.scale(train_loss).backward()
# Update the params
scaler.step(optimizer)
scaler.update()
else:
out = model(images)
train_loss = cross_entropy(out, targets)
train_loss.backward()
optimizer.step()
scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms):
def evaluate(model, val_loader, batch_transforms, amp=False):
# Model in eval mode
model.eval()
# Validation loop
val_loss, correct, samples, batch_cnt = 0, 0, 0, 0
val_iter = iter(val_loader)
for images, targets in val_iter:
images = batch_transforms(images)
out = model(images)
loss = cross_entropy(out, targets)
if amp:
with torch.cuda.amp.autocast():
out = model(images)
loss = cross_entropy(out, targets)
else:
out = model(images)
loss = cross_entropy(out, targets)
# Compute metric
correct += (out.argmax(dim=1) == targets).sum().item()

Expand Down Expand Up @@ -104,7 +121,7 @@ def main(args):
batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

# Load doctr model
model = models.__dict__[args.model](pretrained=args.pretrained, num_classes=len(vocab))
model = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=len(vocab))

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -163,7 +180,7 @@ def main(args):

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:
Expand All @@ -174,12 +191,15 @@ def main(args):
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.model,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": "adam",
"exp_type": "character-classification",
"framework": "pytorch",
"vocab": args.vocab,
"scheduler": args.sched,
"pretrained": args.pretrained,
}
)

Expand All @@ -200,7 +220,6 @@ def main(args):
# W&B
if args.wb:
wandb.log({
'epochs': epoch + 1,
'val_loss': val_loss,
'acc': acc,
})
Expand All @@ -214,7 +233,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for character classification (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('model', type=str, help='text-recognition model to train')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down Expand Up @@ -248,6 +267,7 @@ def parse_args():
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down
14 changes: 8 additions & 6 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(args):
f"{val_loader.num_batches} batches)")

# Load doctr model
model = backbones.__dict__[args.model](
model = backbones.__dict__[args.arch](
pretrained=args.pretrained,
input_shape=(args.input_size, args.input_size, 3),
num_classes=len(vocab),
Expand Down Expand Up @@ -170,7 +170,7 @@ def main(args):

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:
Expand All @@ -181,12 +181,15 @@ def main(args):
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.model,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": "adam",
"exp_type": "character-classification",
"framework": "tensorflow",
"vocab": args.vocab,
"scheduler": args.sched,
"pretrained": args.pretrained,
}
)

Expand All @@ -208,7 +211,6 @@ def main(args):
# W&B
if args.wb:
wandb.log({
'epochs': epoch + 1,
'val_loss': val_loss,
'acc': acc,
})
Expand All @@ -222,7 +224,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for character classification (TensorFlow)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('model', type=str, help='text-recognition model to train')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down
15 changes: 10 additions & 5 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
def plot_samples(images, targets):
# Unnormalize image
num_samples = min(len(images), 12)
num_rows = min(len(images), 3)
num_cols = int(math.ceil(num_samples / num_rows))
num_cols = min(len(images), 8)
num_rows = int(math.ceil(num_samples / num_cols))
_, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5))
for idx in range(num_samples):
img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8)
Expand All @@ -23,7 +23,12 @@ def plot_samples(images, targets):
row_idx = idx // num_cols
col_idx = idx % num_cols

axes[row_idx][col_idx].imshow(img)
axes[row_idx][col_idx].axis('off')
axes[row_idx][col_idx].set_title(targets[idx])
ax = axes[row_idx] if num_rows > 1 else axes
ax = ax[col_idx] if num_cols > 1 else ax

ax.imshow(img)
ax.set_title(targets[idx])
# Disable axis
for ax in axes.ravel():
ax.axis('off')
plt.show()
64 changes: 47 additions & 17 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
os.environ['USE_TORCH'] = '1'

import datetime
import hashlib
import logging
import multiprocessing as mp
import time
Expand All @@ -28,7 +29,11 @@
from doctr.utils.metrics import LocalizationConfusion


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -39,19 +44,30 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, m
images = images.cuda()
images = batch_transforms(images)

train_loss = model(images, targets)['loss']

optimizer.zero_grad()
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
if amp:
with torch.cuda.amp.autocast():
train_loss = model(images, targets)['loss']
scaler.scale(train_loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# Update the params
scaler.step(optimizer)
scaler.update()
else:
train_loss = model(images, targets)['loss']
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()

scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms, val_metric):
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
Expand All @@ -63,7 +79,11 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
out = model(images, targets, return_boxes=True)
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_boxes=True)
else:
out = model(images, targets, return_boxes=True)
# Compute metric
loc_preds, _ = out['preds']
for boxes_gt, boxes_pred in zip(targets, loc_preds):
Expand Down Expand Up @@ -105,11 +125,13 @@ def main(args):
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287))

# Load doctr model
model = detection.__dict__[args.model](pretrained=args.pretrained)
model = detection.__dict__[args.arch](pretrained=args.pretrained)

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -167,10 +189,12 @@ def main(args):
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")
with open(os.path.join(args.train_path, 'labels.json'), 'rb') as f:
train_hash = hashlib.sha256(f.read()).hexdigest()

if args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target, rotation=args.rotation)
plot_samples(x, target)
return

# Backbone freezing
Expand All @@ -190,7 +214,7 @@ def main(args):

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:
Expand All @@ -201,12 +225,18 @@ def main(args):
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.model,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": "adam",
"exp_type": "text-detection",
"framework": "pytorch",
"scheduler": args.sched,
"train_hash": train_hash,
"val_hash": val_hash,
"pretrained": args.pretrained,
"rotation": args.rotation,
"amp": args.amp,
}
)

Expand All @@ -216,9 +246,9 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
Expand All @@ -231,7 +261,6 @@ def main(args):
# W&B
if args.wb:
wandb.log({
'epochs': epoch + 1,
'val_loss': val_loss,
'recall': recall,
'precision': precision,
Expand All @@ -251,7 +280,7 @@ def parse_args():

parser.add_argument('train_path', type=str, help='path to training data folder')
parser.add_argument('val_path', type=str, help='path to validation data folder')
parser.add_argument('model', type=str, help='text-detection model to train')
parser.add_argument('arch', type=str, help='text-detection model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training')
Expand All @@ -273,6 +302,7 @@ def parse_args():
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down
Loading