Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 3, 2019
1 parent 243407e commit c4a6f72
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
version="0.0.1",
author="Alexander Rush",
author_email="arush@cornell.edu",
packages=["torch_struct, torch_struct.data"],
packages=["torch_struct", "torch_struct.data"],
package_data={"torch_struct": []},
url="https://github.com/harvardnlp/pytorch_struct",
install_requires=["torch"],
Expand Down
3 changes: 3 additions & 0 deletions torch_struct/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .data import SubTokenizedField, TokenBucket
__all__ = [
SubTokenizedField, TokenBucket]
71 changes: 71 additions & 0 deletions torch_struct/data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@

import torchtext

def token_pre(tokenizer, q):
st = " ".join(q)
s = tokenizer.tokenize(st)

out = [0]
cur = 0
expect = ""
first = True
for i, w in enumerate(s):
if len(expect) == 0:
cur += 1
expect = q[cur-1].lower()
first = True
if w.startswith("##"):
out.append(-1)
expect = expect[len(w) - 2:]
elif first:
out.append(cur)
expect = expect[len(w):]
first = False
else:
expect = expect[len(w):]
out.append(cur + 1)
#assert cur == len(q)-1, "%s %s \n%s\n%s"%(len(q), cur, q, s)
if cur != len(q):
print("error")
return [0] * (len(q)+2), [0] * (len(q) +2 )
return tokenizer.encode(st, add_special_tokens=True), out

def token_post(ls):
lengths = [len(l[0]) for l in ls]
positions = [len(l[1]) for l in ls]


length = max(lengths)
out = [l[0] + ([0] * (length - len(l[0]))) for l in ls]

lengths2 = [max(l[1]) + 1 for l in ls]
length2 = max(lengths2)
out2 = torch.zeros(len(ls), length, length2)
for b, l in enumerate(ls):
for i, w in enumerate(l[1]):
if w != -1:
out2[b, i, w] = 1
return torch.LongTensor(out), out2.long(), lengths

def SubTokenizedField(tokenizer):
"""
Field for use with pytorch-transformer
"""
FIELD = torchtext.data.RawField(preprocessing=lambda s: token_pre(tokenizer, s),
postprocessing=token_post)
FIELD.is_target = False
return FIELD

def TokenBucket(train):
def batch_size_fn(x, _, size):
return size + max(len(x.word[0]), 5)
return torchtext.data.BucketIterator(train,
train=True,
sort=False,
sort_within_batch=True,
shuffle=True,
batch_size=1500,
sort_key=lambda x: len(x.word[0]),
repeat=True,
batch_size_fn=batch_size_fn,
device="cuda:0")

0 comments on commit c4a6f72

Please sign in to comment.