Skip to content

Commit

Permalink
feat: Added option to select vocab for training (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
fg-mindee authored Sep 29, 2021
1 parent c6c771c commit 3df804b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main(args):

torch.backends.cudnn.benchmark = True

vocab = VOCABS['french']
vocab = VOCABS[args.vocab]

# Load val data generator
st = time.time()
Expand Down Expand Up @@ -225,6 +225,7 @@ def parse_args():
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used')
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training')
parser.add_argument(
'--train-samples',
dest='train_samples',
Expand Down
3 changes: 2 additions & 1 deletion references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args):

print(args)

vocab = VOCABS['french']
vocab = VOCABS[args.vocab]

# Load val data generator
st = time.time()
Expand Down Expand Up @@ -229,6 +229,7 @@ def parse_args():
parser.add_argument('-j', '--workers', type=int, default=4, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used')
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training')
parser.add_argument(
'--train-samples',
dest='train_samples',
Expand Down
3 changes: 2 additions & 1 deletion references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def main(args):
batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

# Load doctr model
model = recognition.__dict__[args.model](pretrained=args.pretrained, vocab=VOCABS['french'])
model = recognition.__dict__[args.model](pretrained=args.pretrained, vocab=VOCABS[args.vocab])

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -260,6 +260,7 @@ def parse_args():
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--show-samples', dest='show_samples', action='store_true',
help='Display unormalized training samples')
Expand Down
3 changes: 2 additions & 1 deletion references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main(args):
model = recognition.__dict__[args.model](
pretrained=args.pretrained,
input_shape=(args.input_size, 4 * args.input_size, 3),
vocab=VOCABS['french']
vocab=VOCABS[args.vocab]
)
# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -251,6 +251,7 @@ def parse_args():
parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)')
parser.add_argument('-j', '--workers', type=int, default=4, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--show-samples', dest='show_samples', action='store_true',
help='Display unormalized training samples')
Expand Down

0 comments on commit 3df804b

Please sign in to comment.