Skip to content

Commit

Permalink
change classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Embedding committed Jul 13, 2019
1 parent cb66159 commit d6c72b2
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions classifier.py
Expand Up @@ -19,11 +19,10 @@


class BertClassifier(nn.Module):
def __init__(self, args, bert_model):
def __init__(self, args, model):
super(BertClassifier, self).__init__()
self.embedding = bert_model.embedding
self.encoder = bert_model.encoder
self.target = bert_model.target
self.embedding = model.embedding
self.encoder = model.encoder
self.labels_num = args.labels_num
self.pooling = args.pooling
self.output_layer_1 = nn.Linear(args.hidden_size, args.hidden_size)
Expand Down Expand Up @@ -158,20 +157,20 @@ def main():
# Build bert model.
# A pseudo target is added.
args.target = "bert"
bert_model = build_model(args)
model = build_model(args)

# Load or initialize parameters.
if args.pretrained_model_path is not None:
# Initialize with pretrained model.
bert_model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)
model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)
else:
# Initialize with normal distribution.
for n, p in list(bert_model.named_parameters()):
for n, p in list(model.named_parameters()):
if 'gamma' not in n and 'beta' not in n:
p.data.normal_(0, 0.02)

# Build classification model.
model = BertClassifier(args, bert_model)
model = BertClassifier(args, model)

# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down

0 comments on commit d6c72b2

Please sign in to comment.