Skip to content

Commit

Permalink
Merge pull request #120 from zalandoresearch/GH-48-word-dropout
Browse files Browse the repository at this point in the history
GH-48: Add word dropout to text classifier
  • Loading branch information
Alan Akbik committed Sep 21, 2018
2 parents 7b5449a + 64da98e commit 3643779
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
8 changes: 6 additions & 2 deletions flair/data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,16 @@ def read_conll_ud(path_to_conll_file: str) -> List[Sentence]:
return sentences

@staticmethod
def read_text_classification_file(path_to_file):
def read_text_classification_file(path_to_file, max_tokens_per_doc=-1):
"""
Reads a data file for text classification. The file should contain one document/text per line.
The line should have the following format:
__label__<class_name> <text>
If you have a multi class task, you can have as many labels as you want at the beginning of the line, e.g.,
__label__<class_name_1> __label__<class_name_2> <text>
:param path_to_file: the path to the data file
:param max_tokens_per_doc: Take only documents that contain number of tokens less or equal to this value. If
set to -1 all documents are taken.
:return: list of sentences
"""
label_prefix = '__label__'
Expand All @@ -346,7 +348,9 @@ def read_text_classification_file(path_to_file):
text = line[l_len:].strip()

if text and labels:
sentences.append(Sentence(text, labels=labels, use_tokenizer=True))
sentence = Sentence(text, labels=labels, use_tokenizer=True)
if max_tokens_per_doc == -1 or len(sentence.tokens) <= max_tokens_per_doc:
sentences.append(sentence)

return sentences

Expand Down
19 changes: 15 additions & 4 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pickle
import re
from abc import abstractmethod
from typing import List, Union
Expand All @@ -8,7 +7,8 @@
import numpy as np
import torch

from .data import Dictionary, Token, Sentence, TaggedCorpus
from .nn import LockedDropout, WordDropout
from .data import Dictionary, Token, Sentence
from .file_utils import cached_path


Expand Down Expand Up @@ -491,7 +491,7 @@ class DocumentLSTMEmbeddings(DocumentEmbeddings):

def __init__(self, token_embeddings: List[TokenEmbeddings], hidden_states=128, num_layers=1,
reproject_words: bool = True, reproject_words_dimension: int = None, bidirectional: bool = False,
use_first_representation: bool = False):
use_first_representation: bool = False, use_word_dropout: bool = True):
"""The constructor takes a list of embeddings to be combined.
:param token_embeddings: a list of token embeddings
:param hidden_states: the number of hidden states in the lstm
Expand All @@ -503,6 +503,7 @@ def __init__(self, token_embeddings: List[TokenEmbeddings], hidden_states=128, n
:param bidirectional: boolean value, indicating whether to use a bidirectional lstm or not
:param use_first_representation: boolean value, indicating whether to concatenate the first and last
representation of the lstm to be used as final document embedding.
:param use_word_dropout: boolean value, indicating whether to use word dropout or not.
"""
super().__init__()

Expand Down Expand Up @@ -534,7 +535,13 @@ def __init__(self, token_embeddings: List[TokenEmbeddings], hidden_states=128, n
self.embeddings_dimension)
self.rnn = torch.nn.GRU(self.embeddings_dimension, hidden_states, num_layers=num_layers,
bidirectional=self.bidirectional)
self.dropout = torch.nn.Dropout(0.5)

# dropouts
self.dropout: torch.nn.Module = LockedDropout(0.5)

self.use_word_dropout: bool = use_word_dropout
if self.use_word_dropout:
self.word_dropout = WordDropout(0.05)

torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)

Expand Down Expand Up @@ -598,6 +605,10 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
# --------------------------------------------------------------------
# FF PART
# --------------------------------------------------------------------
# use word dropout if set
if self.use_word_dropout:
sentence_tensor = self.word_dropout(sentence_tensor)

if self.reproject_words:
sentence_tensor = self.word_reprojection_map(sentence_tensor)

Expand Down

0 comments on commit 3643779

Please sign in to comment.