Skip to content

Commit

Permalink
Eliminate bias by retraining only with train set
Browse files Browse the repository at this point in the history
See issue coastalcph#32 for more info about the bias
  • Loading branch information
danigoju committed Nov 3, 2022
1 parent 5109aeb commit 60c482a
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions models/tfidf_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,34 @@ def add_zero_class(labels):
# Fixate Validation Split
split_index = [-1] * len(dataset['train']) + [0] * len(dataset['validation'])
val_split = PredefinedSplit(test_fold=split_index)
gs_clf = GridSearchCV(text_clf, parameters, cv=val_split, n_jobs=32, verbose=4)
gs_clf = GridSearchCV(text_clf, parameters, cv=val_split, n_jobs=32, verbose=4, refit = False)

# Pre-process inputs, outputs
x_train = get_text(dataset['train']) + get_text(dataset['validation'])
x_train = get_text(dataset['train'])
x_val = get_text(dataset['validation'])
x_train_val = x_train + x_val

if config.task_type == 'multi_label':
mlb = MultiLabelBinarizer(classes=range(config.n_classes))
mlb.fit(dataset['train']['labels'])
else:
mlb = None
y_train = get_labels(dataset['train'], mlb) + get_labels(dataset['validation'], mlb)
y_train = get_labels(dataset['train'], mlb)
y_val = get_labels(dataset['validation'], mlb)
y_train_val = y_train + y_val

# Train classifier
gs_clf = gs_clf.fit(x_train, y_train)
gs_clf = gs_clf.fit(x_train_val, y_train_val)

# Print best hyper-parameters
logging.info('Best Parameters:')
for param_name in sorted(parameters.keys()):
logging.info("%s: %r" % (param_name, gs_clf.best_params_[param_name]))


# Retrain model with best CV parameters only with train data
text_clf.set_params(**gs_clf.best_params_)
gs_clf = text_clf.fit(x_train, y_train)

# Report results
logging.info('VALIDATION RESULTS:')
y_pred = gs_clf.predict(get_text(dataset['validation']))
Expand Down

0 comments on commit 60c482a

Please sign in to comment.