Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add fastText model * Add Reuters bag-of-words dataset class * Add input dropout for MLP * Remove duplicate README files
- Loading branch information
Showing
9 changed files
with
241 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import os | ||
import random | ||
from copy import deepcopy | ||
|
||
import numpy as np | ||
import torch | ||
import torch.onnx | ||
|
||
from common.evaluate import EvaluatorFactory | ||
from common.train import TrainerFactory | ||
from datasets.aapd import AAPD | ||
from datasets.imdb import IMDB | ||
from datasets.reuters import ReutersBOW | ||
from datasets.yelp2014 import Yelp2014 | ||
from models.fasttext.args import get_args | ||
from models.fasttext.model import FastText | ||
|
||
|
||
class UnknownWordVecCache(object): | ||
""" | ||
Caches the first randomly generated word vector for a certain size to make it is reused. | ||
""" | ||
cache = {} | ||
|
||
@classmethod | ||
def unk(cls, tensor): | ||
size_tup = tuple(tensor.size()) | ||
if size_tup not in cls.cache: | ||
cls.cache[size_tup] = torch.Tensor(tensor.size()) | ||
cls.cache[size_tup].uniform_(-0.25, 0.25) | ||
return cls.cache[size_tup] | ||
|
||
|
||
def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device, is_multilabel): | ||
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device) | ||
if hasattr(saved_model_evaluator, 'is_multilabel'): | ||
saved_model_evaluator.is_multilabel = is_multilabel | ||
|
||
scores, metric_names = saved_model_evaluator.get_scores() | ||
print('Evaluation metrics for', split_name) | ||
print(metric_names) | ||
print(scores) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Set default configuration in args.py | ||
args = get_args() | ||
|
||
# Set random seed for reproducibility | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
if not args.cuda: | ||
args.gpu = -1 | ||
|
||
if torch.cuda.is_available() and args.cuda: | ||
print('Note: You are using GPU for training') | ||
torch.cuda.set_device(args.gpu) | ||
torch.cuda.manual_seed(args.seed) | ||
args.gpu = torch.device('cuda:%d' % args.gpu) | ||
|
||
if torch.cuda.is_available() and not args.cuda: | ||
print('Warning: Using CPU for training') | ||
|
||
dataset_map = { | ||
'Reuters': ReutersBOW, | ||
'AAPD': AAPD, | ||
'IMDB': IMDB, | ||
'Yelp2014': Yelp2014 | ||
} | ||
|
||
if args.dataset not in dataset_map: | ||
raise ValueError('Unrecognized dataset') | ||
else: | ||
dataset_class = dataset_map[args.dataset] | ||
train_iter, dev_iter, test_iter = dataset_map[args.dataset].iters(args.data_dir, args.word_vectors_file, | ||
args.word_vectors_dir, | ||
batch_size=args.batch_size, device=args.gpu, | ||
unk_init=UnknownWordVecCache.unk) | ||
|
||
config = deepcopy(args) | ||
config.dataset = train_iter.dataset | ||
config.target_class = train_iter.dataset.NUM_CLASSES | ||
config.words_num = len(train_iter.dataset.TEXT_FIELD.vocab) | ||
|
||
print('Dataset:', args.dataset) | ||
print('No. of target classes:', train_iter.dataset.NUM_CLASSES) | ||
print('No. of train instances', len(train_iter.dataset)) | ||
print('No. of dev instances', len(dev_iter.dataset)) | ||
print('No. of test instances', len(test_iter.dataset)) | ||
|
||
if args.resume_snapshot: | ||
if args.cuda: | ||
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage.cuda(args.gpu)) | ||
else: | ||
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage) | ||
else: | ||
model = FastText(config) | ||
if args.cuda: | ||
model.cuda() | ||
|
||
if not args.trained_model: | ||
save_path = os.path.join(args.save_path, dataset_map[args.dataset].NAME) | ||
os.makedirs(save_path, exist_ok=True) | ||
|
||
parameter = filter(lambda p: p.requires_grad, model.parameters()) | ||
optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) | ||
|
||
train_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, train_iter, args.batch_size, args.gpu) | ||
test_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, test_iter, args.batch_size, args.gpu) | ||
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, dev_iter, args.batch_size, args.gpu) | ||
|
||
if hasattr(train_evaluator, 'is_multilabel'): | ||
train_evaluator.is_multilabel = dataset_class.IS_MULTILABEL | ||
if hasattr(test_evaluator, 'is_multilabel'): | ||
test_evaluator.is_multilabel = dataset_class.IS_MULTILABEL | ||
if hasattr(dev_evaluator, 'is_multilabel'): | ||
dev_evaluator.is_multilabel = dataset_class.IS_MULTILABEL | ||
|
||
trainer_config = { | ||
'optimizer': optimizer, | ||
'batch_size': args.batch_size, | ||
'log_interval': args.log_every, | ||
'patience': args.patience, | ||
'model_outfile': args.save_path, | ||
'is_multilabel': dataset_class.IS_MULTILABEL | ||
} | ||
|
||
trainer = TrainerFactory.get_trainer(args.dataset, model, None, train_iter, trainer_config, train_evaluator, test_evaluator, dev_evaluator) | ||
|
||
if not args.trained_model: | ||
trainer.train(args.epochs) | ||
else: | ||
if args.cuda: | ||
model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu)) | ||
else: | ||
model = torch.load(args.trained_model, map_location=lambda storage, location: storage) | ||
|
||
# Calculate dev and test metrics | ||
if hasattr(trainer, 'snapshot_path'): | ||
model = torch.load(trainer.snapshot_path) | ||
|
||
evaluate_dataset('dev', dataset_map[args.dataset], model, None, dev_iter, args.batch_size, | ||
is_multilabel=dataset_class.IS_MULTILABEL, | ||
device=args.gpu) | ||
evaluate_dataset('test', dataset_map[args.dataset], model, None, test_iter, args.batch_size, | ||
is_multilabel=dataset_class.IS_MULTILABEL, | ||
device=args.gpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
|
||
import models.args | ||
|
||
|
||
def get_args(): | ||
parser = models.args.get_args() | ||
|
||
parser.add_argument('--dataset', type=str, default='Reuters', choices=['Reuters', 'AAPD', 'IMDB', 'Yelp2014']) | ||
parser.add_argument('--mode', type=str, default='rand', choices=['rand', 'static', 'non-static']) | ||
parser.add_argument('--words-dim', type=int, default=300) | ||
parser.add_argument('--dropout', type=float, default=0.5) | ||
parser.add_argument('--epoch-decay', type=int, default=15) | ||
parser.add_argument('--weight-decay', type=float, default=0) | ||
|
||
parser.add_argument('--word-vectors-dir', default=os.path.join(os.pardir, 'hedwig-data', 'embeddings', 'word2vec')) | ||
parser.add_argument('--word-vectors-file', default='GoogleNews-vectors-negative300.txt') | ||
parser.add_argument('--save-path', type=str, default=os.path.join('model_checkpoints', 'kim_cnn')) | ||
parser.add_argument('--resume-snapshot', type=str) | ||
parser.add_argument('--trained-model', type=str) | ||
|
||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
import torch.nn.functional as F | ||
|
||
|
||
class FastText(nn.Module): | ||
|
||
def __init__(self, config): | ||
super().__init__() | ||
dataset = config.dataset | ||
target_class = config.target_class | ||
words_num = config.words_num | ||
words_dim = config.words_dim | ||
self.mode = config.mode | ||
|
||
if config.mode == 'rand': | ||
rand_embed_init = torch.Tensor(words_num, words_dim).uniform_(-0.25, 0.25) | ||
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False) | ||
elif config.mode == 'static': | ||
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True) | ||
elif config.mode == 'non-static': | ||
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False) | ||
else: | ||
print("Unsupported Mode") | ||
exit() | ||
|
||
self.dropout = nn.Dropout(config.dropout) | ||
self.fc1 = nn.Linear(words_dim, target_class) | ||
|
||
def forward(self, x, **kwargs): | ||
if self.mode == 'rand': | ||
x = self.embed(x) # (batch, sent_len, embed_dim) | ||
elif self.mode == 'static': | ||
x = self.static_embed(x) # (batch, sent_len, embed_dim) | ||
elif self.mode == 'non-static': | ||
x = self.non_static_embed(x) # (batch, sent_len, embed_dim) | ||
|
||
x = F.avg_pool2d(x, (x.shape[1], 1)).squeeze(1) # (batch, embed_dim) | ||
|
||
logit = self.fc1(x) # (batch, target_size) | ||
return logit | ||
|
||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters