Skip to content

Commit

Permalink
add __getitem__ and get_unpadded_sent to Sentence (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
msperber authored and neubig committed Nov 14, 2018
1 parent 3977c4b commit 25a3ed3
Showing 1 changed file with 69 additions and 9 deletions.
78 changes: 69 additions & 9 deletions xnmt/sent.py
Expand Up @@ -20,6 +20,18 @@ def __init__(self, idx: Optional[int] = None, score: Optional[numbers.Real] = No
self.idx = idx
self.score = score

def __getitem__(self, key):
"""
Get an item or a slice of the sentence.
Args:
key: index or slice
Returns:
A single word or a Sentence object, depending on whether an index or a slice was given as key.
"""
raise NotImplementedError("must be implemented by subclasses")

def sent_len(self) -> int:
"""
Return length of input, included padded tokens.
Expand Down Expand Up @@ -57,6 +69,18 @@ def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'Sentence':
"""
raise NotImplementedError("must be implemented by subclasses")

def get_unpadded_sent(self) -> 'Sentence':
"""
Return the unpadded sentence.
If self is unpadded, return self, if not return reference to original unpadded sentence if possible, otherwise
create a new sentence.
"""
if self.sent_len() == self.len_unpadded():
return self
else:
return self[:self.len_unpadded()]

class ReadableSentence(Sentence):
"""
A base class for sentences based on readable strings.
Expand Down Expand Up @@ -121,6 +145,15 @@ def __init__(self, value: numbers.Integral, idx: Optional[numbers.Integral] = No
super().__init__(idx=idx, score=score)
self.value = value
self.vocab = vocab
def __getitem__(self, key):
if isinstance(key, numbers.Integral):
if key!=0: raise IndexError()
return self.value
else:
if not isinstance(key, slice):
raise TypeError()
if key.start!=0 and key.stop!=1: raise IndexError()
return self
def sent_len(self) -> int:
return 1
def len_unpadded(self) -> int:
Expand All @@ -133,6 +166,8 @@ def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'ScalarSentence'
if trunc_len != 0:
raise ValueError("ScalarSentence cannot be truncated")
return self
def get_unpadded_sent(self):
return self # scalar sentences are always unpadded
def str_tokens(self, **kwargs) -> List[str]:
if self.vocab: return [self.vocab[self.value]]
else: return [str(self.value)]
Expand All @@ -151,6 +186,8 @@ def __init__(self, sents: Sequence[Sentence]) -> None:
if s.idx != self.idx:
raise ValueError("CompoundSentence must contain sentences of consistent idx.")
self.sents = sents
def __getitem__(self, item):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")
def sent_len(self) -> int:
return sum(sent.sent_len() for sent in self.sents)
def len_unpadded(self) -> int:
Expand All @@ -159,6 +196,8 @@ def create_padded_sent(self, pad_len):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")
def create_truncated_sent(self, trunc_len):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")
def get_unpadded_sent(self):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")


class SimpleSentence(ReadableSentence):
Expand All @@ -172,24 +211,27 @@ class SimpleSentence(ReadableSentence):
score: a score given to this sentence by a model
output_procs: output processors to be applied when calling sent_str()
pad_token: special token used for padding
unpadded_sent: reference to original, unpadded sentence if available
"""
def __init__(self,
words: Sequence[numbers.Integral],
idx: Optional[numbers.Integral] = None,
vocab: Optional[Vocab] = None,
score: Optional[numbers.Real] = None,
output_procs: Union[OutputProcessor, Sequence[OutputProcessor]] = [],
pad_token: numbers.Integral = Vocab.ES) -> None:
pad_token: numbers.Integral = Vocab.ES,
unpadded_sent: 'SimpleSentence' = None) -> None:
super().__init__(idx=idx, score=score, output_procs=output_procs)
self.pad_token = pad_token
self.words = words
self.vocab = vocab
self.unpadded_sent = unpadded_sent

def __getitem__(self, key):
ret = self.words[key]
if isinstance(ret, list): # support for slicing
return SimpleSentence(words=ret, idx=self.idx, vocab=self.vocab, score=self.score, output_procs=self.output_procs,
pad_token=self.pad_token)
pad_token=self.pad_token, unpadded_sent=self.unpadded_sent)
return self.words[key]

def sent_len(self):
Expand All @@ -209,6 +251,10 @@ def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'SimpleSentence'
return self
return self.sent_with_words(self.words[:-trunc_len])

def get_unpadded_sent(self):
if self.unpadded_sent: return self.unpadded_sent
else: return super().get_unpadded_sent()

def str_tokens(self, exclude_ss_es=True, exclude_unk=False, exclude_padded=True, **kwargs) -> List[str]:
exclude_set = set()
if exclude_ss_es:
Expand All @@ -221,12 +267,16 @@ def str_tokens(self, exclude_ss_es=True, exclude_unk=False, exclude_padded=True,
else: return [str(w) for w in ret_toks]

def sent_with_new_words(self, new_words):
unpadded_sent = self.unpadded_sent
if not unpadded_sent:
if self.sent_len()==self.len_unpadded(): unpadded_sent = self
return SimpleSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token)
pad_token=self.pad_token,
unpadded_sent=unpadded_sent)

class SegmentedSentence(SimpleSentence):
def __init__(self, segment=[], **kwargs) -> None:
Expand All @@ -240,7 +290,8 @@ def sent_with_new_words(self, new_words):
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token,
segment=self.segment)
segment=self.segment,
unpadded_sent=self.unpadded_sent)


class ArraySentence(Sentence):
Expand All @@ -257,14 +308,16 @@ class ArraySentence(Sentence):
def __init__(self,
nparr: np.ndarray,
idx: Optional[numbers.Integral] = None,
padded_len: int = 0,
score: Optional[numbers.Real] = None) -> None:
padded_len: numbers.Integral= 0,
score: Optional[numbers.Real] = None,
unpadded_sent: 'ArraySentence' = None) -> None:
super().__init__(idx=idx, score=score)
self.nparr = nparr
self.padded_len = padded_len
self.unpadded_sent = unpadded_sent

def __getitem__(self, key):
assert isinstance(key, numbers.Integral)
if not isinstance(key, numbers.Integral): raise NotImplementedError()
return self.nparr.__getitem__(key)

def sent_len(self):
Expand All @@ -279,13 +332,20 @@ def create_padded_sent(self, pad_len: numbers.Integral) -> 'ArraySentence':
return self
new_nparr = np.append(self.nparr, np.broadcast_to(np.reshape(self.nparr[:, -1], (self.nparr.shape[0], 1)),
(self.nparr.shape[0], pad_len)), axis=1)
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=self.padded_len + pad_len)
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=self.padded_len + pad_len,
unpadded_sent=self if self.padded_len==0 else self.unpadded_sent)

def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'ArraySentence':
if trunc_len == 0:
return self
new_nparr = np.asarray(self.nparr[:-trunc_len])
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=max(0,self.padded_len - trunc_len))
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=max(0,self.padded_len - trunc_len),
unpadded_sent=self if self.padded_len == 0 else self.unpadded_sent)

def get_unpadded_sent(self):
if self.padded_len==0: return self
elif self.unpadded_sent: return self.unpadded_sent
else: return super().get_unpadded_sent()

def get_array(self):
return self.nparr
Expand Down

0 comments on commit 25a3ed3

Please sign in to comment.