Skip to content

Commit

Permalink
[20230424 v0.4.0] Fix: drop_last=True
Browse files Browse the repository at this point in the history
  • Loading branch information
horrible-dong committed Apr 24, 2023
1 parent 5b17f6a commit eca94b2
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from termcolor import cprint

from engine import evaluate, train_one_epoch
from qtcls import build_criterion, build_dataset, build_model, build_optimizer, build_scheduler
from qtcls import __version__, build_criterion, build_dataset, build_model, build_optimizer, build_scheduler
from qtcls.utils.io import checkpoint_saver, checkpoint_loader, variables_loader, variables_saver
from qtcls.utils.misc import makedirs, init_distributed_mode, init_seeds, is_main_process

Expand Down Expand Up @@ -44,7 +44,7 @@ def get_args_parser():
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--need_targets', action='store_true', help='need targets for training')
parser.add_argument('--drop_lr_now', action='store_true')
parser.add_argument('--drop_last', action='store_true')
parser.add_argument('--drop_last', type=bool, default=True)
parser.add_argument('--amp', action='store_true', help='automatic mixed precision training')
parser.add_argument('--no_dist', action='store_true', help='forcibly disable distributed mode')

Expand Down Expand Up @@ -108,6 +108,7 @@ def get_args_parser():
def main(args):
init_seeds(args.seed)
init_distributed_mode(args)
cprint(f'QTClassification v{__version__}', 'light_green', attrs=['bold'])
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
if device.type == 'cpu' or args.eval:
args.amp = False
Expand Down

0 comments on commit eca94b2

Please sign in to comment.