This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
/
text_classification.py
126 lines (107 loc) · 4.57 KB
/
text_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import numpy as np
from datasets import list_datasets, load_dataset, list_metrics, load_metric
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import argparse
import os
from src.dataset import load_data
from src.utils import bool_flag
# function for computing accuracy
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if type(predictions) == tuple:
predictions = predictions[0]
predictions = np.argmax(predictions, axis=1)
acc = np.mean(predictions == labels)
return {
'accuracy': acc
}
def main(args):
dataset, num_labels = load_data(args)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(args.model, num_labels=num_labels)
if args.model == 'gpt2':
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
if args.dataset == "mnli":
# only evaluate on matched validation set
testset_key = "validation_matched"
preprocess_function = lambda examples: tokenizer(
examples["premise"], examples["hypothesis"], max_length=256, truncation=True)
else:
text_key = 'text' if (args.dataset in ["ag_news", "imdb", "yelp"]) else 'sentence'
testset_key = 'test' if (args.dataset in ["ag_news", "imdb", "yelp"]) else 'validation'
preprocess_function = lambda examples: tokenizer(examples[text_key], max_length=256, truncation=True)
encoded_dataset = dataset.map(preprocess_function, batched=True)
train_args = TrainingArguments(
args.checkpoint_folder,
disable_tqdm=not args.tqdm,
evaluation_strategy = "epoch",
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=args.weight_decay,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
)
trainer = Trainer(
model,
train_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset[testset_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
if not args.finetune:
# freeze parameters of transformer
transformer = list(model.children())[0]
for param in transformer.parameters():
param.requires_grad = False
trainer.train()
trainer.evaluate()
suffix = ''
if args.finetune:
suffix += '_finetune'
torch.save(model.state_dict(),
os.path.join(args.result_folder, "%s_%s%s.pth" % (args.model.replace('/', '-'), args.dataset, suffix)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Text classification model training.")
# Bookkeeping
parser.add_argument("--checkpoint_folder", default="checkpoint/", type=str,
help="folder in which to store temporary model checkpoints")
parser.add_argument("--result_folder", default="result/", type=str,
help="folder in which to store trained models")
parser.add_argument("--tqdm", default=True, type=bool_flag,
help="Use tqdm in output")
# Data
parser.add_argument("--data_folder", required=True, type=str,
help="folder in which to store data")
parser.add_argument("--dataset", default="dbpedia14", type=str,
choices=["dbpedia14", "ag_news", "imdb", "yelp", "mnli"],
help="classification dataset to use")
# Model
parser.add_argument("--model", default="gpt2", type=str,
help="type of model")
# Optimization
parser.add_argument("--batch_size", default=16, type=int,
help="batch size for training and evaluation")
parser.add_argument("--epochs", default=5, type=int,
help="number of epochs to train for")
parser.add_argument("--lr", default=2e-5, type=float,
help="learning rate")
parser.add_argument("--weight_decay", default=0.01, type=float,
help="weight decay")
parser.add_argument("--finetune", default=False, type=bool_flag,
help="finetune the transformer; if False, only train linear layer")
args = parser.parse_args()
if args.result_folder == 'none':
args.result_folder = args.checkpoint_folder
main(args)