Skip to content
This repository has been archived by the owner on Oct 31, 2022. It is now read-only.

Commit

Permalink
Fix models_dir issue #76
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Mar 6, 2021
1 parent fdd5ecf commit 9741323
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
description='Pre-encode text files into tokenized training set.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name')
parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')
parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).')
parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path')

def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name)
enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
print('Reading files')
chunks = load_dataset(enc, args.in_text, args.combine, encoding=args.encoding)
print('Writing', args.out_npz)
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).')
parser.add_argument('--model_name', metavar='MODEL', type=str, default='124M', help='Pretrained model name')
parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')

Expand Down Expand Up @@ -71,7 +72,7 @@ def randomize(context, hparams, p):

def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name)
enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
hparams = model.default_hparams()
with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
Expand Down

0 comments on commit 9741323

Please sign in to comment.