Skip to content

Commit

Permalink
Merge pull request #4 from achyudh/master
Browse files Browse the repository at this point in the history
Make is_multilabel an attribute of Dataset
  • Loading branch information
achyudh committed Mar 19, 2019
2 parents 5c146fc + ecda4c0 commit 5e52615
Show file tree
Hide file tree
Showing 21 changed files with 130 additions and 357 deletions.
13 changes: 5 additions & 8 deletions common/evaluation.py → common/evaluate.py
@@ -1,18 +1,15 @@
from .evaluators.sst_evaluator import SSTEvaluator
from .evaluators.reuters_evaluator import ReutersEvaluator
from common.evaluators.classification_evaluator import ClassificationEvaluator


class EvaluatorFactory(object):
"""
Get the corresponding Evaluator class for a particular dataset.
"""
evaluator_map = {
'SST-1': SSTEvaluator,
'SST-2': SSTEvaluator,
'Reuters': ReutersEvaluator,
'AAPD': ReutersEvaluator,
'IMDB': ReutersEvaluator,
'Yelp2014': ReutersEvaluator
'Reuters': ClassificationEvaluator,
'AAPD': ClassificationEvaluator,
'IMDB': ClassificationEvaluator,
'Yelp2014': ClassificationEvaluator
}

@staticmethod
Expand Down
Expand Up @@ -6,17 +6,16 @@
from .evaluator import Evaluator


class ReutersEvaluator(Evaluator):
class ClassificationEvaluator(Evaluator):

def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False):
super().__init__(dataset_cls, model, embedding, data_loader, batch_size, device, keep_results)
self.ignore_lengths = False
self.single_label = False
self.is_multilabel = False

def get_scores(self):
self.model.eval()
self.data_loader.init_epoch()
n_dev_correct = 0
total_loss = 0

# Temp Ave
Expand All @@ -37,15 +36,15 @@ def get_scores(self):
else:
scores = self.model(batch.text[0], lengths=batch.text[1])

if self.single_label:
predicted_labels.extend(torch.argmax(scores, dim=1).cpu().detach().numpy())
target_labels.extend(torch.argmax(batch.label, dim=1).cpu().detach().numpy())
total_loss += F.cross_entropy(scores, torch.argmax(batch.label, dim=1), size_average=False).item()
else:
if self.is_multilabel:
scores_rounded = F.sigmoid(scores).round().long()
predicted_labels.extend(scores_rounded.cpu().detach().numpy())
target_labels.extend(batch.label.cpu().detach().numpy())
total_loss += F.binary_cross_entropy_with_logits(scores, batch.label.float(), size_average=False).item()
else:
predicted_labels.extend(torch.argmax(scores, dim=1).cpu().detach().numpy())
target_labels.extend(torch.argmax(batch.label, dim=1).cpu().detach().numpy())
total_loss += F.cross_entropy(scores, torch.argmax(batch.label, dim=1), size_average=False).item()

if hasattr(self.model, 'TAR') and self.model.TAR: # TAR condition
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()
Expand Down
24 changes: 0 additions & 24 deletions common/evaluators/sst_evaluator.py

This file was deleted.

13 changes: 5 additions & 8 deletions common/train.py
@@ -1,18 +1,15 @@
from .trainers.sst_trainer import SSTTrainer
from .trainers.reuters_trainer import ReutersTrainer
from common.trainers.classification_trainer import ClassificationTrainer


class TrainerFactory(object):
"""
Get the corresponding Trainer class for a particular dataset.
"""
trainer_map = {
'SST-1': SSTTrainer,
'SST-2': SSTTrainer,
'Reuters': ReutersTrainer,
'AAPD': ReutersTrainer,
'IMDB': ReutersTrainer,
'Yelp2014': ReutersTrainer
'Reuters': ClassificationTrainer,
'AAPD': ClassificationTrainer,
'IMDB': ClassificationTrainer,
'Yelp2014': ClassificationTrainer
}

@staticmethod
Expand Down
Expand Up @@ -10,10 +10,10 @@
from .trainer import Trainer


class ReutersTrainer(Trainer):
class ClassificationTrainer(Trainer):

def __init__(self, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator):
super(ReutersTrainer, self).__init__(model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)
super().__init__(model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)
self.config = trainer_config
self.early_stop = False
self.best_dev_f1 = 0
Expand All @@ -22,7 +22,8 @@ def __init__(self, model, embedding, train_loader, trainer_config, train_evaluat
self.start = None
self.log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f}'.split(','))
self.dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.4f},{:>8.4f},{:8.4f},{:12.4f},{:12.4f}'.split(','))
self.dev_log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.4f},{:>8.4f},{:8.4f},{:12.4f},{:12.4f}'.split(','))
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
self.snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, 'best_model.pt')

Expand All @@ -44,18 +45,17 @@ def train_epoch(self, epoch):
else:
scores = self.model(batch.text[0], lengths=batch.text[1])

if 'single_label' in self.config and self.config['single_label']:
for tensor1, tensor2 in zip(torch.argmax(scores, dim=1), torch.argmax(batch.label.data, dim=1)):
if np.array_equal(tensor1, tensor2):
n_correct += 1
loss = F.cross_entropy(scores, torch.argmax(batch.label.data, dim=1))
else:
if 'is_multilabel' in self.config and self.config['is_multilabel']:
predictions = F.sigmoid(scores).round().long()
# Computing binary accuracy
for tensor1, tensor2 in zip(predictions, batch.label):
if np.array_equal(tensor1, tensor2):
n_correct += 1
loss = F.binary_cross_entropy_with_logits(scores, batch.label.float())
else:
for tensor1, tensor2 in zip(torch.argmax(scores, dim=1), torch.argmax(batch.label.data, dim=1)):
if np.array_equal(tensor1, tensor2):
n_correct += 1
loss = F.cross_entropy(scores, torch.argmax(batch.label.data, dim=1))

if hasattr(self.model, 'TAR') and self.model.TAR:
loss = loss + self.model.TAR*(rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()
Expand All @@ -75,10 +75,9 @@ def train_epoch(self, epoch):
niter = epoch * len(self.train_loader) + batch_idx
self.writer.add_scalar('Train/Loss', loss.data.item(), niter)
self.writer.add_scalar('Train/Accuracy', train_acc, niter)
print(self.log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(),
train_acc))
print(self.log_template.format(time.time() - self.start, epoch, self.iterations, 1 + batch_idx,
len(self.train_loader), 100.0 * (1 + batch_idx) / len(self.train_loader),
loss.item(), train_acc))

def train(self, epochs):
self.start = time.time()
Expand Down
81 changes: 0 additions & 81 deletions common/trainers/sst_trainer.py

This file was deleted.

2 changes: 2 additions & 0 deletions datasets/aapd.py
Expand Up @@ -17,6 +17,7 @@ def char_quantize(string, max_length=1000):
else:
return np.concatenate((quantized_string, np.zeros((max_length - len(quantized_string), len(AAPDCharQuantized.ALPHABET)), dtype=np.float32)))


def process_labels(string):
"""
Returns the label string as a list of integers
Expand All @@ -29,6 +30,7 @@ def process_labels(string):
class AAPD(TabularDataset):
NAME = 'AAPD'
NUM_CLASSES = 54
IS_MULTILABEL = True

TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)
Expand Down
2 changes: 2 additions & 0 deletions datasets/imdb.py
Expand Up @@ -31,6 +31,8 @@ def process_labels(string):
class IMDB(TabularDataset):
NAME = 'IMDB'
NUM_CLASSES = 10
IS_MULTILABEL = False

TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)

Expand Down
2 changes: 2 additions & 0 deletions datasets/reuters.py
Expand Up @@ -54,6 +54,8 @@ def process_labels(string):
class Reuters(TabularDataset):
NAME = 'Reuters'
NUM_CLASSES = 90
IS_MULTILABEL = True

TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)

Expand Down
104 changes: 0 additions & 104 deletions datasets/sst.py

This file was deleted.

0 comments on commit 5e52615

Please sign in to comment.