Skip to content

Commit

Permalink
Removing type-hint for support python3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
codertimo committed Oct 19, 2018
1 parent 120ead8 commit 17132a2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions bert_pytorch/trainer/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class BERTTrainer:
"""

def __init__(self, bert: BERT, vocab_size,
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
def __init__(self, bert, vocab_size,
train_dataloader, test_dataloader=None,
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
with_cuda: bool = True, log_freq: int = 10):
"""
Expand All @@ -40,18 +40,18 @@ def __init__(self, bert: BERT, vocab_size,
self.device = torch.device("cuda:0" if cuda_condition else "cpu")

# This BERT model will be saved every epoch
self.bert: BERT = bert
self.bert = bert
# Initialize the BERT Language Model, with BERT model
self.model: BERTLM = BERTLM(bert, vocab_size).to(self.device)
self.model = BERTLM(bert, vocab_size).to(self.device)

# Distributed GPU training if CUDA can detect more than 1 GPU
if torch.cuda.device_count() > 1:
print("Using %d GPUS for BERT" % torch.cuda.device_count())
self.model = nn.DataParallel(self.model)

# Setting the train and test data loader
self.train_data: DataLoader = train_dataloader
self.test_data: DataLoader = test_dataloader
self.train_data = train_dataloader
self.test_data = test_dataloader

# Setting the Adam optimizer with hyper-param
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
Expand Down

0 comments on commit 17132a2

Please sign in to comment.