Skip to content

Commit

Permalink
NameError: name 'device' is not defined in predict() method #68 (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmikaelian committed Mar 8, 2019
1 parent 0ce83d2 commit 05a49ea
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions cdqa/reader/bertqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,21 +855,21 @@ def __init__(self,
self.null_score_diff_threshold = null_score_diff_threshold
self.output_dir = output_dir

def fit(self, X, y=None):

train_examples, train_features = X

if self.local_rank == -1 or self.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not self.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
self.device = torch.device("cuda" if torch.cuda.is_available() and not self.no_cuda else "cpu")
self.n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(self.local_rank)
device = torch.device("cuda", self.local_rank)
n_gpu = 1
self.device = torch.device("cuda", self.local_rank)
self.n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(self.local_rank != -1), self.fp16))
self.device, self.n_gpu, bool(self.local_rank != -1), self.fp16))

def fit(self, X, y=None):

train_examples, train_features = X

if self.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
Expand All @@ -880,7 +880,7 @@ def fit(self, X, y=None):
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
if n_gpu > 0:
if self.n_gpu > 0:
torch.cuda.manual_seed_all(self.seed)

if os.path.exists(self.output_dir) and os.listdir(self.output_dir):
Expand All @@ -899,15 +899,15 @@ def fit(self, X, y=None):

if self.fp16:
model.half()
model.to(device)
model.to(self.device)
if self.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

model = DDP(model)
elif n_gpu > 1:
elif self.n_gpu > 1:
model = torch.nn.DataParallel(model)

# Prepare optimizer
Expand Down Expand Up @@ -967,11 +967,11 @@ def fit(self, X, y=None):
model.train()
for _ in trange(int(self.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
if self.n_gpu == 1:
batch = tuple(t.to(self.device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
if n_gpu > 1:
if self.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def fit(self, X, y=None):
model = BertForQuestionAnswering(config)
model.load_state_dict(torch.load(output_model_file))

model.to(device)
model.to(self.device)
self.model = model

return self
Expand Down Expand Up @@ -1036,9 +1036,9 @@ def predict(self, X):
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
input_ids = input_ids.to(self.device)
input_mask = input_mask.to(self.device)
segment_ids = segment_ids.to(self.device)
with torch.no_grad():
batch_start_logits, batch_end_logits = self.model(input_ids, segment_ids, input_mask)
for i, example_index in enumerate(example_indices):
Expand Down

0 comments on commit 05a49ea

Please sign in to comment.