Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kim CNN OOP Refactoring #124

Merged
merged 10 commits into from Jul 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/evaluation.py
@@ -1,17 +1,21 @@
from .evaluators.sick_evaluator import SICKEvaluator
from .evaluators.msrvid_evaluator import MSRVIDEvaluator
from .evaluators.sst_evaluator import SSTEvaluator
from .evaluators.trecqa_evaluator import TRECQAEvaluator
from .evaluators.wikiqa_evaluator import WikiQAEvaluator
from nce.nce_pairwise_mp.evaluators.trecqa_evaluator import TRECQAEvaluatorNCE
from nce.nce_pairwise_mp.evaluators.wikiqa_evaluator import WikiQAEvaluatorNCE


class EvaluatorFactory(object):
"""
Get the corresponding Evaluator class for a particular dataset.
"""
evaluator_map = {
'sick': SICKEvaluator,
'msrvid': MSRVIDEvaluator,
'SST-1': SSTEvaluator,
'SST-2': SSTEvaluator,
'trecqa': TRECQAEvaluator,
'wikiqa': WikiQAEvaluator
}
Expand Down
24 changes: 24 additions & 0 deletions common/evaluators/sst_evaluator.py
@@ -0,0 +1,24 @@
import torch
import torch.nn.functional as F

from .evaluator import Evaluator


class SSTEvaluator(Evaluator):

def get_scores(self):
self.model.eval()
self.data_loader.init_epoch()
n_dev_correct = 0
total_loss = 0

for batch_idx, batch in enumerate(self.data_loader):
scores = self.model(batch)
n_dev_correct += (
torch.max(scores, 1)[1].view(batch.label.size()).data == batch.label.data).sum().item()
total_loss += F.cross_entropy(scores, batch.label, size_average=False).item()

accuracy = 100. * n_dev_correct / len(self.data_loader.dataset.examples)
avg_loss = total_loss / len(self.data_loader.dataset.examples)

return [accuracy, avg_loss], ['accuracy', 'cross_entropy_loss']
3 changes: 3 additions & 0 deletions common/train.py
Expand Up @@ -2,6 +2,7 @@
from .trainers.msrvid_trainer import MSRVIDTrainer
from .trainers.trecqa_trainer import TRECQATrainer
from .trainers.wikiqa_trainer import WikiQATrainer
from .trainers.sst_trainer import SSTTrainer
from nce.nce_pairwise_mp.trainers.trecqa_trainer import TRECQATrainerNCE
from nce.nce_pairwise_mp.trainers.wikiqa_trainer import WikiQATrainerNCE

Expand All @@ -13,6 +14,8 @@ class TrainerFactory(object):
trainer_map = {
'sick': SICKTrainer,
'msrvid': MSRVIDTrainer,
'SST-1': SSTTrainer,
'SST-2': SSTTrainer,
'trecqa': TRECQATrainer,
'wikiqa': WikiQATrainer
}
Expand Down
81 changes: 81 additions & 0 deletions common/trainers/sst_trainer.py
@@ -0,0 +1,81 @@
import time

import os
import torch
import torch.nn.functional as F

from .trainer import Trainer
from utils.serialization import save_checkpoint


class SSTTrainer(Trainer):

def __init__(self, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator):
super(SSTTrainer, self).__init__(model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)
self.early_stop = False
self.best_dev_acc = 0
self.iterations = 0
self.iters_not_improved = 0
self.start = None
self.log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
self.dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))

def train_epoch(self, epoch):
self.train_loader.init_epoch()
n_correct, n_total = 0, 0
for batch_idx, batch in enumerate(self.train_loader):
self.iterations += 1
self.model.train()
self.optimizer.zero_grad()
scores = self.model(batch)
n_correct += (torch.max(scores, 1)[1].view(batch.label.size()).data == batch.label.data).sum().item()
n_total += batch.batch_size
train_acc = 100. * n_correct / n_total

loss = F.cross_entropy(scores, batch.label)
loss.backward()

self.optimizer.step()

# Evaluate performance on validation set
if self.iterations % self.dev_log_interval == 1:
dev_acc, dev_loss = self.dev_evaluator.get_scores()[0]
print(self.dev_log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(),
dev_loss, train_acc, dev_acc))


# Update validation results
if dev_acc > self.best_dev_acc:
self.iters_not_improved = 0
self.best_dev_acc = dev_acc
snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, self.model.mode + '_best_model.pt')
torch.save(self.model, snapshot_path)
else:
self.iters_not_improved += 1
if self.iters_not_improved >= self.patience:
self.early_stop = True
break

if self.iterations % self.log_interval == 1:
# print progress message
print(self.log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(), ' ' * 8,
train_acc, ' ' * 12))

def train(self, epochs):
self.start = time.time()
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
# model_outfile is actually a directory, using model_outfile to conform to Trainer naming convention
os.makedirs(self.model_outfile, exist_ok=True)
os.makedirs(os.path.join(self.model_outfile, self.train_loader.dataset.NAME), exist_ok=True)
print(header)

for epoch in range(1, epochs + 1):
if self.early_stop:
print("Early Stopping. Epoch: {}, Best Dev Acc: {}".format(epoch, self.best_dev_acc))
break
self.train_epoch(epoch)
24 changes: 13 additions & 11 deletions common/trainers/trainer.py
Expand Up @@ -7,30 +7,32 @@ class Trainer(object):
def __init__(self, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator=None):
self.model = model
self.embedding = embedding
self.optimizer = trainer_config['optimizer']
self.optimizer = trainer_config.get('optimizer')
self.train_loader = train_loader
self.batch_size = trainer_config['batch_size']
self.log_interval = trainer_config['log_interval']
self.model_outfile = trainer_config['model_outfile']
self.lr_reduce_factor = trainer_config['lr_reduce_factor']
self.patience = trainer_config['patience']
self.use_tensorboard = trainer_config['tensorboard']
self.batch_size = trainer_config.get('batch_size')
self.log_interval = trainer_config.get('log_interval')
self.dev_log_interval = trainer_config.get('dev_log_interval')
self.model_outfile = trainer_config.get('model_outfile')
self.lr_reduce_factor = trainer_config.get('lr_reduce_factor')
self.patience = trainer_config.get('patience')
self.use_tensorboard = trainer_config.get('tensorboard')
self.clip_norm = trainer_config.get('clip_norm')

if self.use_tensorboard:
from tensorboardX import SummaryWriter
self.writer = SummaryWriter(log_dir=None, comment='' if trainer_config['run_label'] is None else trainer_config['run_label'])
self.logger = trainer_config['logger']
self.logger = trainer_config.get('logger')

self.train_evaluator = train_evaluator
self.test_evaluator = test_evaluator
self.dev_evaluator = dev_evaluator

def evaluate(self, evaluator, dataset_name):
scores, metric_names = evaluator.get_scores()
self.logger.info('Evaluation metrics for {}:'.format(dataset_name))
self.logger.info('\t'.join([' '] + metric_names))
self.logger.info('\t'.join([dataset_name] + list(map(str, scores))))
if self.logger is not None:
self.logger.info('Evaluation metrics for {}:'.format(dataset_name))
self.logger.info('\t'.join([' '] + metric_names))
self.logger.info('\t'.join([dataset_name] + list(map(str, scores))))
return scores

def get_sentence_embeddings(self, batch):
Expand Down
45 changes: 44 additions & 1 deletion datasets/sst.py
Expand Up @@ -16,7 +16,7 @@ def clean_str_sst(string):


class SST1(TabularDataset):
NAME = 'sst-1'
NAME = 'SST-1'
NUM_CLASSES = 5

TEXT_FIELD = Field(batch_first=True, tokenize=clean_str_sst)
Expand Down Expand Up @@ -55,3 +55,46 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d

return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)

class SST2(TabularDataset):
NAME = 'SST-2'
NUM_CLASSES = 5

TEXT_FIELD = Field(batch_first=True, tokenize=clean_str_sst)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True)

@staticmethod
def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train='stsa.binary.phrases.train', validation='stsa.binary.dev', test='stsa.binary.test', **kwargs):
return super(SST2, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param vectors_name: name of word vectors file
:param vectors_cache: path to directory containing word vectors file
:param batch_size: batch size
:param device: GPU device
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes
:param unk_init: function used to generate vector for OOV words
:return:
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)

train, val, test = cls.splits(path)

cls.TEXT_FIELD.build_vocab(train, val, test, min_freq=2, vectors=vectors)

return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)


78 changes: 38 additions & 40 deletions kim_cnn/README.md
Expand Up @@ -10,56 +10,35 @@ Implementation for Convolutional Neural Networks for Sentence Classification of
- multichannel: A model with two sets of word vectors. Each set of vectors is treated as a 'channel' and each filter is applied to both channels, but gradients are back-propagated only through one of the channels. Hence the model is able to fine-tune one set of vectors while keeping the other static. Both channels are initialized with word2vec.# text-classification-cnn
Implementation for Convolutional Neural Networks for Sentence Classification of [Kim (2014)](https://arxiv.org/abs/1408.5882) with PyTorch.

## Requirement

Assuming you already have PyTorch, just install torchtext (`pip install torchtext==0.2.0`)

## Quick Start

To get the dataset, you can run this.
```
cd kim_cnn
bash get_data.sh
```

To run the model on SST-1 dataset on multichannel, just run the following code.
To run the model on SST-1 dataset on multichannel, just run the following from the Castor working directory.

```
python train.py --mode multichannel
python -m kim_cnn --mode multichannel
```

The file will be saved in
The file will be saved in

```
saves/best_model.pt
kim_cnn/saves/best_model.pt
```

To test the model, you can use the following command.

```
python main.py --trained_model saves/best_model.pt --mode multichannel
python -m kim_cnn --trained_model kim_cnn/saves/SST-1/multichannel_best_model.pt --mode multichannel
```

## Dataset


## Dataset and Embeddings

We experiment the model on the following three datasets.
We experiment the model on the following datasets.

- SST-1: Keep the original splits and train with phrase level dataset and test on sentence level dataset.

**word2vec.sst-1.pt** is a subset of word2vector. We just select the word appearing in the SST-1 dataset and generate this file with the **vector_preprocess.py**(you will get this after you run get_data.sh or you can download [here](https://raw.githubusercontent.com/Impavidity/kim_cnn/master/vector_preprocess.py)) You can select these from any kind of word embedding text file and generate in following format.
```
word vector_in_one_line
```
and then run
```
python vector_preprocess.py file_in embed.pt
```
Here you can get *embed.pt* for the embedding file. Remember change the argument in *args.py* file with your own embedding.

## Settings
Adadelta is used for training.

Adadelta is used for training.

## Training Time

Expand All @@ -78,21 +57,40 @@ torch.backends.cudnn.enabled = False
```
but this will take ~6-7x training time.

## Results
## SST-1 Dataset Results

**Random**

```
python -m kim_cnn --mode rand --lr 0.8337 --weight_decay 0.0008987 --dropout 0.4
```

**Static**

Deterministic Algorithm for CNN.
```
python -m kim_cnn --mode static --lr 0.8641 --weight_decay 1.44e-05 --dropout 0.3
```

| Dev Accuracy on SST-1 | rand | static | non-static | multichannel |
|:--------------------------:|:-----------:|:-----------:|:-------------:|:---------------:|
| My-Implementation | 42.597639| 48.773842| 48.864668 | 49.046322 |
**Non-static**

```
python -m kim_cnn --mode non-static --lr 0.371 --weight_decay 1.84e-05 --dropout 0.4
```

| Test Accuracy on SST-1| rand | static | non-static | multichannel |
|:--------------------------:|:-----------:|:-----------:|:-------------:|:---------------:|
| Kim-Implementation | 45.0 | 45.5 | 48.0 | 47.4 |
| My- Implementation | 39.683258 | 45.972851| 48.914027| 47.330317 |
**Multichannel**

```
python -m kim_cnn --mode multichannel --lr 0.2532 --weight_decay 3.95e-05 --dropout 0.1
```

Using deterministic algorithm for cuDNN.

| Test Accuracy on SST-1 | rand | static | non-static | multichannel |
|:------------------------------:|:----------:|:------------:|:--------------:|:---------------:|
| Paper | 45.0 | 45.5 | 48.0 | 47.4 |
| PyTorch using above configs | 41.5 | 44.7 | 47.4 | 47.5 |

## TODO

- More experiments on SST-2 and subjectivity
- Parameters tuning

15 changes: 0 additions & 15 deletions kim_cnn/SST1.py

This file was deleted.