Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split run squad.py in processing/train/predict #66

Merged
merged 3 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions cdqa/pipeline/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@

# train document reader
train_processor = BertProcessor(bert_model='bert-base-uncased', is_training=True)
train_examples, train_features = train_processor.fit_transform(X='data/bnpp_newsroom_v1.0/bnpp_newsroom-v1.0.csv')
model = BertQA(bert_model='models/bert_qa_squad_v1.1')
model.fit(X_y=train_features)
train_examples, train_features = train_processor.fit_transform(X='data/train-v1.1.json')

model = BertQA(bert_model='bert-base-uncased',
custom_weights=False,
train_batch_size=12,
learning_rate=3e-5,
num_train_epochs=2,
output_dir='logs/bert_qa_squad_v1.1_sklearn')

model.fit(X=(train_examples, train_features))

dump(model, 'model.joblib')
4 changes: 2 additions & 2 deletions cdqa/reader/bertqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from tqdm.auto import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
Expand Down Expand Up @@ -883,7 +883,7 @@ def fit(self, X, y=None):
if n_gpu > 0:
torch.cuda.manual_seed_all(self.seed)

if os.path.exists(self.output_dir) and os.listdir(self.output_dir) and self.do_train:
if os.path.exists(self.output_dir) and os.listdir(self.output_dir):
raise ValueError("Output directory () already exists and is not empty.")
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
Expand Down
4 changes: 2 additions & 2 deletions cdqa/reader/run_squad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -894,7 +894,7 @@ def main():

# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))

if args.fp16:
model.half()
Expand Down