Skip to content
Permalink
Browse files

Remove unwanted args from models/bert (#15)

  • Loading branch information...
achyudh committed Apr 29, 2019
1 parent 99a01c6 commit 97a3d2da291d594f4c6fb8a678967596ec5eab3a
Showing with 8 additions and 15 deletions.
  1. +0 −6 models/bert/__main__.py
  2. +8 −9 models/bert/args.py
@@ -57,12 +57,6 @@ def evaluate_split(model, processor, args, split='dev'):
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)

if args.server_ip and args.server_port:
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()

dataset_map = {
'SST-2': SST2Processor,
'Reuters': ReutersProcessor,
@@ -1,5 +1,4 @@
import os
from argparse import ArgumentParser

import models.args

@@ -11,6 +10,9 @@ def get_args():
parser.add_argument('--dataset', type=str, default='SST-2', choices=['SST-2', 'AGNews', 'Reuters', 'AAPD', 'IMDB', 'Yelp2014'])
parser.add_argument('--save-path', type=str, default=os.path.join('model_checkpoints', 'bert'))
parser.add_argument('--cache-dir', default='cache', type=str)
parser.add_argument('--trained-model', default=None, type=str)
parser.add_argument('--local-rank', type=int, default=-1, help='local rank for distributed training')
parser.add_argument('--fp16', action='store_true', help='use 16-bit floating point precision')

parser.add_argument('--max-seq-length',
default=128,
@@ -19,25 +21,22 @@ def get_args():
'Sequences longer than this will be truncated, and sequences shorter \n'
'than this will be padded.')

parser.add_argument('--trained-model', default=None, type=str)
parser.add_argument('--local-rank', type=int, default=-1, help='local rank for distributed training')
parser.add_argument('--fp16', action='store_true', help='use 16-bit floating point precision')

parser.add_argument('--warmup-proportion',
default=0.1,
type=float,
help='Proportion of training to perform linear learning rate warmup for')

parser.add_argument('--gradient-accumulation-steps', type=int, default=1,
parser.add_argument('--gradient-accumulation-steps',
type=int,
default=1,
help='Number of updates steps to accumulate before performing a backward/update pass')

parser.add_argument('--loss-scale',
type=float, default=0,
type=float,
default=0,
help='Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n'
'0 (default value): dynamic loss scaling.\n'
'Positive power of 2: static loss scaling value.\n')

parser.add_argument('--server-ip', type=str, default='', help='Can be used for distant debugging.')
parser.add_argument('--server-port', type=str, default='', help='Can be used for distant debugging.')
args = parser.parse_args()
return args

0 comments on commit 97a3d2d

Please sign in to comment.
You can’t perform that action at this time.