Skip to content

Commit

Permalink
Kim CNN OOP Refactoring (#124)
Browse files Browse the repository at this point in the history
* Nuke obsolete artifacts

* Refactor Kim CNN

* Make kim_cnn a module

* Fix bugs

* Update README

* Add choices to dataset arg

* update for sst2

update sst.py

update sst.py

* Add Kim CNN dataset choices to args.py

* Update tuned SST-1 accuracy
  • Loading branch information
tuzhucheng committed Jul 3, 2018
1 parent fae229e commit 8563ad5
Show file tree
Hide file tree
Showing 15 changed files with 347 additions and 301 deletions.
4 changes: 4 additions & 0 deletions common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from .evaluators.sick_evaluator import SICKEvaluator
from .evaluators.msrvid_evaluator import MSRVIDEvaluator
from .evaluators.sst_evaluator import SSTEvaluator
from .evaluators.trecqa_evaluator import TRECQAEvaluator
from .evaluators.wikiqa_evaluator import WikiQAEvaluator
from nce.nce_pairwise_mp.evaluators.trecqa_evaluator import TRECQAEvaluatorNCE
from nce.nce_pairwise_mp.evaluators.wikiqa_evaluator import WikiQAEvaluatorNCE


class EvaluatorFactory(object):
"""
Get the corresponding Evaluator class for a particular dataset.
"""
evaluator_map = {
'sick': SICKEvaluator,
'msrvid': MSRVIDEvaluator,
'SST-1': SSTEvaluator,
'SST-2': SSTEvaluator,
'trecqa': TRECQAEvaluator,
'wikiqa': WikiQAEvaluator
}
Expand Down
24 changes: 24 additions & 0 deletions common/evaluators/sst_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch.nn.functional as F

from .evaluator import Evaluator


class SSTEvaluator(Evaluator):

def get_scores(self):
self.model.eval()
self.data_loader.init_epoch()
n_dev_correct = 0
total_loss = 0

for batch_idx, batch in enumerate(self.data_loader):
scores = self.model(batch)
n_dev_correct += (
torch.max(scores, 1)[1].view(batch.label.size()).data == batch.label.data).sum().item()
total_loss += F.cross_entropy(scores, batch.label, size_average=False).item()

accuracy = 100. * n_dev_correct / len(self.data_loader.dataset.examples)
avg_loss = total_loss / len(self.data_loader.dataset.examples)

return [accuracy, avg_loss], ['accuracy', 'cross_entropy_loss']
3 changes: 3 additions & 0 deletions common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .trainers.msrvid_trainer import MSRVIDTrainer
from .trainers.trecqa_trainer import TRECQATrainer
from .trainers.wikiqa_trainer import WikiQATrainer
from .trainers.sst_trainer import SSTTrainer
from nce.nce_pairwise_mp.trainers.trecqa_trainer import TRECQATrainerNCE
from nce.nce_pairwise_mp.trainers.wikiqa_trainer import WikiQATrainerNCE

Expand All @@ -13,6 +14,8 @@ class TrainerFactory(object):
trainer_map = {
'sick': SICKTrainer,
'msrvid': MSRVIDTrainer,
'SST-1': SSTTrainer,
'SST-2': SSTTrainer,
'trecqa': TRECQATrainer,
'wikiqa': WikiQATrainer
}
Expand Down
81 changes: 81 additions & 0 deletions common/trainers/sst_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import time

import os
import torch
import torch.nn.functional as F

from .trainer import Trainer
from utils.serialization import save_checkpoint


class SSTTrainer(Trainer):

def __init__(self, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator):
super(SSTTrainer, self).__init__(model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)
self.early_stop = False
self.best_dev_acc = 0
self.iterations = 0
self.iters_not_improved = 0
self.start = None
self.log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
self.dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))

def train_epoch(self, epoch):
self.train_loader.init_epoch()
n_correct, n_total = 0, 0
for batch_idx, batch in enumerate(self.train_loader):
self.iterations += 1
self.model.train()
self.optimizer.zero_grad()
scores = self.model(batch)
n_correct += (torch.max(scores, 1)[1].view(batch.label.size()).data == batch.label.data).sum().item()
n_total += batch.batch_size
train_acc = 100. * n_correct / n_total

loss = F.cross_entropy(scores, batch.label)
loss.backward()

self.optimizer.step()

# Evaluate performance on validation set
if self.iterations % self.dev_log_interval == 1:
dev_acc, dev_loss = self.dev_evaluator.get_scores()[0]
print(self.dev_log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(),
dev_loss, train_acc, dev_acc))


# Update validation results
if dev_acc > self.best_dev_acc:
self.iters_not_improved = 0
self.best_dev_acc = dev_acc
snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, self.model.mode + '_best_model.pt')
torch.save(self.model, snapshot_path)
else:
self.iters_not_improved += 1
if self.iters_not_improved >= self.patience:
self.early_stop = True
break

if self.iterations % self.log_interval == 1:
# print progress message
print(self.log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(), ' ' * 8,
train_acc, ' ' * 12))

def train(self, epochs):
self.start = time.time()
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
# model_outfile is actually a directory, using model_outfile to conform to Trainer naming convention
os.makedirs(self.model_outfile, exist_ok=True)
os.makedirs(os.path.join(self.model_outfile, self.train_loader.dataset.NAME), exist_ok=True)
print(header)

for epoch in range(1, epochs + 1):
if self.early_stop:
print("Early Stopping. Epoch: {}, Best Dev Acc: {}".format(epoch, self.best_dev_acc))
break
self.train_epoch(epoch)
24 changes: 13 additions & 11 deletions common/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,32 @@ class Trainer(object):
def __init__(self, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator=None):
self.model = model
self.embedding = embedding
self.optimizer = trainer_config['optimizer']
self.optimizer = trainer_config.get('optimizer')
self.train_loader = train_loader
self.batch_size = trainer_config['batch_size']
self.log_interval = trainer_config['log_interval']
self.model_outfile = trainer_config['model_outfile']
self.lr_reduce_factor = trainer_config['lr_reduce_factor']
self.patience = trainer_config['patience']
self.use_tensorboard = trainer_config['tensorboard']
self.batch_size = trainer_config.get('batch_size')
self.log_interval = trainer_config.get('log_interval')
self.dev_log_interval = trainer_config.get('dev_log_interval')
self.model_outfile = trainer_config.get('model_outfile')
self.lr_reduce_factor = trainer_config.get('lr_reduce_factor')
self.patience = trainer_config.get('patience')
self.use_tensorboard = trainer_config.get('tensorboard')
self.clip_norm = trainer_config.get('clip_norm')

if self.use_tensorboard:
from tensorboardX import SummaryWriter
self.writer = SummaryWriter(log_dir=None, comment='' if trainer_config['run_label'] is None else trainer_config['run_label'])
self.logger = trainer_config['logger']
self.logger = trainer_config.get('logger')

self.train_evaluator = train_evaluator
self.test_evaluator = test_evaluator
self.dev_evaluator = dev_evaluator

def evaluate(self, evaluator, dataset_name):
scores, metric_names = evaluator.get_scores()
self.logger.info('Evaluation metrics for {}:'.format(dataset_name))
self.logger.info('\t'.join([' '] + metric_names))
self.logger.info('\t'.join([dataset_name] + list(map(str, scores))))
if self.logger is not None:
self.logger.info('Evaluation metrics for {}:'.format(dataset_name))
self.logger.info('\t'.join([' '] + metric_names))
self.logger.info('\t'.join([dataset_name] + list(map(str, scores))))
return scores

def get_sentence_embeddings(self, batch):
Expand Down
45 changes: 44 additions & 1 deletion datasets/sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def clean_str_sst(string):


class SST1(TabularDataset):
NAME = 'sst-1'
NAME = 'SST-1'
NUM_CLASSES = 5

TEXT_FIELD = Field(batch_first=True, tokenize=clean_str_sst)
Expand Down Expand Up @@ -55,3 +55,46 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d

return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)

class SST2(TabularDataset):
NAME = 'SST-2'
NUM_CLASSES = 5

TEXT_FIELD = Field(batch_first=True, tokenize=clean_str_sst)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True)

@staticmethod
def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train='stsa.binary.phrases.train', validation='stsa.binary.dev', test='stsa.binary.test', **kwargs):
return super(SST2, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param vectors_name: name of word vectors file
:param vectors_cache: path to directory containing word vectors file
:param batch_size: batch size
:param device: GPU device
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes
:param unk_init: function used to generate vector for OOV words
:return:
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)

train, val, test = cls.splits(path)

cls.TEXT_FIELD.build_vocab(train, val, test, min_freq=2, vectors=vectors)

return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)


78 changes: 38 additions & 40 deletions kim_cnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,56 +10,35 @@ Implementation for Convolutional Neural Networks for Sentence Classification of
- multichannel: A model with two sets of word vectors. Each set of vectors is treated as a 'channel' and each filter is applied to both channels, but gradients are back-propagated only through one of the channels. Hence the model is able to fine-tune one set of vectors while keeping the other static. Both channels are initialized with word2vec.# text-classification-cnn
Implementation for Convolutional Neural Networks for Sentence Classification of [Kim (2014)](https://arxiv.org/abs/1408.5882) with PyTorch.

## Requirement

Assuming you already have PyTorch, just install torchtext (`pip install torchtext==0.2.0`)

## Quick Start

To get the dataset, you can run this.
```
cd kim_cnn
bash get_data.sh
```

To run the model on SST-1 dataset on multichannel, just run the following code.
To run the model on SST-1 dataset on multichannel, just run the following from the Castor working directory.

```
python train.py --mode multichannel
python -m kim_cnn --mode multichannel
```

The file will be saved in
The file will be saved in

```
saves/best_model.pt
kim_cnn/saves/best_model.pt
```

To test the model, you can use the following command.

```
python main.py --trained_model saves/best_model.pt --mode multichannel
python -m kim_cnn --trained_model kim_cnn/saves/SST-1/multichannel_best_model.pt --mode multichannel
```

## Dataset


## Dataset and Embeddings

We experiment the model on the following three datasets.
We experiment the model on the following datasets.

- SST-1: Keep the original splits and train with phrase level dataset and test on sentence level dataset.

**word2vec.sst-1.pt** is a subset of word2vector. We just select the word appearing in the SST-1 dataset and generate this file with the **vector_preprocess.py**(you will get this after you run get_data.sh or you can download [here](https://raw.githubusercontent.com/Impavidity/kim_cnn/master/vector_preprocess.py)) You can select these from any kind of word embedding text file and generate in following format.
```
word vector_in_one_line
```
and then run
```
python vector_preprocess.py file_in embed.pt
```
Here you can get *embed.pt* for the embedding file. Remember change the argument in *args.py* file with your own embedding.

## Settings
Adadelta is used for training.

Adadelta is used for training.

## Training Time

Expand All @@ -78,21 +57,40 @@ torch.backends.cudnn.enabled = False
```
but this will take ~6-7x training time.

## Results
## SST-1 Dataset Results

**Random**

```
python -m kim_cnn --mode rand --lr 0.8337 --weight_decay 0.0008987 --dropout 0.4
```

**Static**

Deterministic Algorithm for CNN.
```
python -m kim_cnn --mode static --lr 0.8641 --weight_decay 1.44e-05 --dropout 0.3
```

| Dev Accuracy on SST-1 | rand | static | non-static | multichannel |
|:--------------------------:|:-----------:|:-----------:|:-------------:|:---------------:|
| My-Implementation | 42.597639| 48.773842| 48.864668 | 49.046322 |
**Non-static**

```
python -m kim_cnn --mode non-static --lr 0.371 --weight_decay 1.84e-05 --dropout 0.4
```

| Test Accuracy on SST-1| rand | static | non-static | multichannel |
|:--------------------------:|:-----------:|:-----------:|:-------------:|:---------------:|
| Kim-Implementation | 45.0 | 45.5 | 48.0 | 47.4 |
| My- Implementation | 39.683258 | 45.972851| 48.914027| 47.330317 |
**Multichannel**

```
python -m kim_cnn --mode multichannel --lr 0.2532 --weight_decay 3.95e-05 --dropout 0.1
```

Using deterministic algorithm for cuDNN.

| Test Accuracy on SST-1 | rand | static | non-static | multichannel |
|:------------------------------:|:----------:|:------------:|:--------------:|:---------------:|
| Paper | 45.0 | 45.5 | 48.0 | 47.4 |
| PyTorch using above configs | 41.5 | 44.7 | 47.4 | 47.5 |

## TODO

- More experiments on SST-2 and subjectivity
- Parameters tuning

15 changes: 0 additions & 15 deletions kim_cnn/SST1.py

This file was deleted.

Loading

0 comments on commit 8563ad5

Please sign in to comment.