Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HAN and XML_CNN for Doc Classification (#154)
* Add Reuters option in common.dataset * Add Reuters option in common.dataset * Add HAN model * Add XML-CNN * Add HAN * Add Hierarchical tokenization for Reuters * Add README for HAN * Add XML Readme * Update HAN Readme
- Loading branch information
1 parent
ed4f018
commit 650882f
Showing
14 changed files
with
759 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Hierarchical Attention Networks | ||
|
||
Implementation for Hierarchical Attention Networks for Documnet Classification of [HAN (2016)](https://www.cs.cmu.edu/~hovy/papers/16HLT-hierarchical-attention-networks.pdf) with PyTorch and Torchtext. | ||
|
||
## Model Type | ||
|
||
- rand: All words are randomly initialized and then modified during training. | ||
- static: A model with pre-trained vectors from [word2vec](https://code.google.com/archive/p/word2vec/). All words -- including the unknown ones that are initialized with zero -- are kept static and only the other parameters of the model are learned. | ||
- non-static: Same as above but the pretrained vectors are fine-tuned for each task. | ||
|
||
|
||
|
||
## Quick Start | ||
|
||
To run the model on Reuters dataset on static, just run the following from the Castor working directory. | ||
|
||
``` | ||
python -m han --dataset Reuters | ||
``` | ||
|
||
The file will be saved in | ||
|
||
``` | ||
han/saves/best_model.pt | ||
``` | ||
|
||
To test the model, you can use the following command. | ||
|
||
``` | ||
python -m han --trained_model han/saves/Reuters/static_best_model.pt | ||
``` | ||
|
||
## Dataset | ||
|
||
We experiment the model on the following datasets. | ||
|
||
- Reuters-21578: Split the data into sentences for the sentence level attention model and split the sentences into words for the word level attention. The word2vec pretrained embeddings were used for the task. | ||
|
||
## Settings | ||
|
||
Adam is used for training. | ||
|
||
## Training Time | ||
|
||
For training time, when | ||
|
||
``` | ||
torch.backends.cudnn.deterministic = True | ||
``` | ||
|
||
is specified, the training will be ~10 min. Reuters-21578 is a relatively small dataset and the implementation is a vectorized one, hence the speed. | ||
|
||
|
||
|
||
## TODO | ||
- a combined hyperparameter tuning on a few of the datasets and report results with the hyperparameters |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from copy import deepcopy | ||
import logging | ||
import random | ||
from sklearn import metrics | ||
import numpy as np | ||
import torch | ||
import torch.onnx | ||
|
||
from common.evaluation import EvaluatorFactory | ||
from common.train import TrainerFactory | ||
from datasets.sst import SST1 | ||
from datasets.sst import SST2 | ||
from datasets.reuters import Reuters_hierarchical as Reuters | ||
from han.args import get_args | ||
from han.model import HAN | ||
import torch.nn.functional as F | ||
|
||
class UnknownWordVecCache(object): | ||
""" | ||
Caches the first randomly generated word vector for a certain size to make it is reused. | ||
""" | ||
cache = {} | ||
|
||
@classmethod | ||
def unk(cls, tensor): | ||
size_tup = tuple(tensor.size()) | ||
if size_tup not in cls.cache: | ||
cls.cache[size_tup] = torch.Tensor(tensor.size()) | ||
# choose 0.25 so unknown vectors have approximately same variance as pre-trained ones | ||
# same as original implementation: https://github.com/yoonkim/CNN_sentence/blob/0a626a048757d5272a7e8ccede256a434a6529be/process_data.py#L95 | ||
cls.cache[size_tup].uniform_(-0.25, 0.25) | ||
return cls.cache[size_tup] | ||
|
||
|
||
def get_logger(): | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
ch = logging.StreamHandler() | ||
ch.setLevel(logging.DEBUG) | ||
formatter = logging.Formatter('%(levelname)s - %(message)s') | ||
ch.setFormatter(formatter) | ||
logger.addHandler(ch) | ||
|
||
return logger | ||
|
||
|
||
def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device): | ||
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device) | ||
scores, metric_names = saved_model_evaluator.get_scores() | ||
logger.info('Evaluation metrics for {}'.format(split_name)) | ||
logger.info('\t'.join([' '] + metric_names)) | ||
logger.info('\t'.join([split_name] + list(map(str, scores)))) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Set default configuration in : args.py | ||
args = get_args() | ||
|
||
# Set random seed for reproducibility | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
if not args.cuda: | ||
args.gpu = -1 | ||
if torch.cuda.is_available() and args.cuda: | ||
print('Note: You are using GPU for training') | ||
torch.cuda.set_device(args.gpu) | ||
torch.cuda.manual_seed(args.seed) | ||
if torch.cuda.is_available() and not args.cuda: | ||
print('Warning: You have Cuda but not use it. You are using CPU for training.') | ||
np.random.seed(args.seed) | ||
random.seed(args.seed) | ||
logger = get_logger() | ||
|
||
# Set up the data for training SST-1 | ||
if args.dataset == 'SST-1': | ||
train_iter, dev_iter, test_iter = SST1.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) | ||
# Set up the data for training SST-2 | ||
elif args.dataset == 'SST-2': | ||
train_iter, dev_iter, test_iter = SST2.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) | ||
elif args.dataset == 'Reuters': | ||
train_iter, dev_iter, test_iter = Reuters.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) | ||
else: | ||
raise ValueError('Unrecognized dataset') | ||
|
||
config = deepcopy(args) | ||
config.dataset = train_iter.dataset | ||
config.target_class = train_iter.dataset.NUM_CLASSES | ||
config.words_num = len(train_iter.dataset.TEXT_FIELD.vocab) | ||
|
||
print('Dataset {} Mode {}'.format(args.dataset, args.mode)) | ||
print('VOCAB num',len(train_iter.dataset.TEXT_FIELD.vocab)) | ||
print('LABEL.target_class:', train_iter.dataset.NUM_CLASSES) | ||
print('Train instance', len(train_iter.dataset)) | ||
print('Dev instance', len(dev_iter.dataset)) | ||
print('Test instance', len(test_iter.dataset)) | ||
|
||
if args.resume_snapshot: | ||
if args.cuda: | ||
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage.cuda(args.gpu)) | ||
else: | ||
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage) | ||
else: | ||
model = HAN(config) | ||
if args.cuda: | ||
model.cuda() | ||
print('Shift model to GPU') | ||
|
||
parameter = filter(lambda p: p.requires_grad, model.parameters()) | ||
print(parameter) | ||
#optimizer = torch.optim.Adadelta(parameter, lr=args.lr, weight_decay=args.weight_decay) | ||
#optimizer = torch.optim.SGD(parameter, lr = args.lr, momentum = 0.9) | ||
optimizer = torch.optim.Adam(parameter, lr = args.lr) | ||
if args.dataset == 'SST-1': | ||
train_evaluator = EvaluatorFactory.get_evaluator(SST1, model, None, train_iter, args.batch_size, args.gpu) | ||
test_evaluator = EvaluatorFactory.get_evaluator(SST1, model, None, test_iter, args.batch_size, args.gpu) | ||
dev_evaluator = EvaluatorFactory.get_evaluator(SST1, model, None, dev_iter, args.batch_size, args.gpu) | ||
elif args.dataset == 'SST-2': | ||
train_evaluator = EvaluatorFactory.get_evaluator(SST2, model, None, train_iter, args.batch_size, args.gpu) | ||
test_evaluator = EvaluatorFactory.get_evaluator(SST2, model, None, test_iter, args.batch_size, args.gpu) | ||
dev_evaluator = EvaluatorFactory.get_evaluator(SST2, model, None, dev_iter, args.batch_size, args.gpu) | ||
elif args.dataset == 'Reuters': | ||
train_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, train_iter, args.batch_size, args.gpu) | ||
test_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, test_iter, args.batch_size, args.gpu) | ||
dev_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, dev_iter, args.batch_size, args.gpu) | ||
else: | ||
raise ValueError('Unrecognized dataset') | ||
|
||
trainer_config = { | ||
'optimizer': optimizer, | ||
'batch_size': args.batch_size, | ||
'log_interval': args.log_every, | ||
'dev_log_interval': args.dev_every, | ||
'patience': args.patience, | ||
'model_outfile': args.save_path, # actually a directory, using model_outfile to conform to Trainer naming convention | ||
'logger': logger | ||
} | ||
trainer = TrainerFactory.get_trainer(args.dataset, model, None, train_iter, trainer_config, train_evaluator, test_evaluator, dev_evaluator) | ||
|
||
if not args.trained_model: | ||
trainer.train(args.epochs) | ||
else: | ||
if args.cuda: | ||
model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu)) | ||
else: | ||
model = torch.load(args.trained_model, map_location=lambda storage, location: storage) | ||
|
||
if args.dataset == 'SST-1': | ||
evaluate_dataset('dev', SST1, model, None, dev_iter, args.batch_size, args.gpu) | ||
evaluate_dataset('test', SST1, model, None, test_iter, args.batch_size, args.gpu) | ||
elif args.dataset == 'SST-2': | ||
evaluate_dataset('dev', SST2, model, None, dev_iter, args.batch_size, args.gpu) | ||
evaluate_dataset('test', SST2, model, None, test_iter, args.batch_size, args.gpu) | ||
elif args.dataset == 'Reuters': | ||
evaluate_dataset('dev', Reuters, model, None, dev_iter, args.batch_size, args.gpu) | ||
evaluate_dataset('test', Reuters, model, None, test_iter, args.batch_size, args.gpu) | ||
else: | ||
raise ValueError('Unrecognized dataset') | ||
|
||
|
||
|
||
# Calculate dev and test metrics | ||
for data_loader in [dev_iter, test_iter]: | ||
predicted_labels = list() | ||
target_labels = list() | ||
for batch_idx, batch in enumerate(data_loader): | ||
scores_rounded = F.sigmoid(model(batch.text)).round().long() | ||
predicted_labels.extend(scores_rounded.cpu().detach().numpy()) | ||
target_labels.extend(batch.label.cpu().detach().numpy()) | ||
predicted_labels = np.array(predicted_labels) | ||
target_labels = np.array(target_labels) | ||
accuracy = metrics.accuracy_score(target_labels, predicted_labels) | ||
precision = metrics.precision_score(target_labels, predicted_labels, average='micro') | ||
recall = metrics.recall_score(target_labels, predicted_labels, average='micro') | ||
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro') | ||
if data_loader == dev_iter: | ||
print("Dev metrics:") | ||
else: | ||
print("Test metrics:") | ||
print(accuracy, precision, recall, f1) | ||
|
||
|
||
|
||
if args.onnx: | ||
device = torch.device('cuda') if torch.cuda.is_available() and args.cuda else torch.device('cpu') | ||
dummy_input = torch.zeros(args.onnx_batch_size, args.onnx_sent_len, dtype=torch.long, device=device) | ||
onnx_filename = 'han_{}.onnx'.format(args.mode) | ||
torch.onnx.export(model, dummy_input, onnx_filename) | ||
print('Exported model in ONNX format as {}'.format(onnx_filename)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
from argparse import ArgumentParser | ||
|
||
|
||
def get_args(): | ||
parser = ArgumentParser(description="HAN") | ||
parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') | ||
parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU | ||
parser.add_argument('--epochs', type=int, default=30) | ||
|
||
|
||
parser.add_argument('--word_num_hidden', type = int, default = 50) | ||
parser.add_argument('--sentence_num_hidden', type = int, default = 50) | ||
|
||
|
||
parser.add_argument('--batch_size', type=int, default=64) | ||
parser.add_argument('--mode', type=str, default='static', choices=['rand', 'static', 'non-static']) | ||
parser.add_argument('--lr', type=float, default=1.0) | ||
parser.add_argument('--seed', type=int, default=3435) | ||
parser.add_argument('--dataset', type=str, default='SST-1', choices=['SST-1', 'SST-2', 'Reuters']) | ||
parser.add_argument('--resume_snapshot', type=str, default=None) | ||
parser.add_argument('--dev_every', type=int, default=30) | ||
parser.add_argument('--log_every', type=int, default=10) | ||
parser.add_argument('--patience', type=int, default=50) | ||
parser.add_argument('--save_path', type=str, default='han/saves') | ||
parser.add_argument('--output_channel', type=int, default=100) | ||
parser.add_argument('--words_dim', type=int, default=300) | ||
parser.add_argument('--embed_dim', type=int, default=300) | ||
parser.add_argument('--dropout', type=float, default=0.5) | ||
parser.add_argument('--epoch_decay', type=int, default=15) | ||
parser.add_argument('--data_dir', help='word vectors directory', | ||
default=os.path.join(os.pardir, 'Castor-data', 'datasets')) | ||
parser.add_argument('--word_vectors_dir', help='word vectors directory', | ||
default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'word2vec')) | ||
parser.add_argument('--word_vectors_file', help='word vectors filename', default='/data/GoogleNews-vectors-negative300.txt') | ||
parser.add_argument('--trained_model', type=str, default="") | ||
parser.add_argument('--weight_decay', type=float, default=0) | ||
parser.add_argument('--onnx', action='store_true', default=False, help='Export model in ONNX format') | ||
parser.add_argument('--onnx_batch_size', type=int, default=1024, help='Batch size for ONNX export') | ||
parser.add_argument('--onnx_sent_len', type=int, default=32, help='Sentence length for ONNX export') | ||
|
||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
#from utils import | ||
import torch.nn.functional as F | ||
from han.sent_level_rnn import SentLevelRNN | ||
from han.word_level_rnn import WordLevelRNN | ||
|
||
|
||
class HAN(nn.Module): | ||
def __init__(self, config): | ||
super(HAN, self).__init__() | ||
self.dataset = config.dataset | ||
self.mode = config.mode | ||
self.word_attention_rnn = WordLevelRNN(config) | ||
self.sentence_attention_rnn = SentLevelRNN(config) | ||
def forward(self,x): | ||
x = x.permute(1,2,0) ## Expected : #sentences, #words, batch size | ||
num_sentences = x.size()[0] | ||
word_attentions = None | ||
for i in range(num_sentences): | ||
_word_attention = self.word_attention_rnn(x[i,:,:]) | ||
if word_attentions is None: | ||
word_attentions = _word_attention | ||
else: | ||
word_attentions = torch.cat((word_attentions, _word_attention),0) | ||
return self.sentence_attention_rnn(word_attentions) | ||
|
Oops, something went wrong.