diff --git a/vilbert/vilbert.py b/vilbert/vilbert.py index fb7cd4d..b00e24c 100644 --- a/vilbert/vilbert.py +++ b/vilbert/vilbert.py @@ -1608,10 +1608,10 @@ def __init__(self, config, num_labels, dropout_prob=0.1, default_gpu=True): config, self.bert.embeddings.word_embeddings.weight ) self.vil_prediction = SimpleClassifier( - config.bi_hidden_size, config.bi_hidden_size * 2, 3129, 0.5 + config.bi_hidden_size, config.bi_hidden_size * 2, self.num_labels, 0.5 ) self.vil_prediction_gqa = SimpleClassifier( - config.bi_hidden_size, config.bi_hidden_size * 2, 1533, 0.5 + config.bi_hidden_size, config.bi_hidden_size * 2, self.num_labels, 0.5 ) self.vil_binary_prediction = SimpleClassifier( config.bi_hidden_size * 2, config.bi_hidden_size * 2, 2, 0.5