## use it in PyTorch

In [1]:
# data_dir = "data"
data_dir = "lexicalized_data"

from dataset import syntax_token_type
syntax_vocabulary, blacklist = syntax_token_type(data_dir)

encountering ambiguous tag at lexicalized_data/sw2260#s283_500/syntax.txt, break-tied
encountering type II typo at lexicalized_data/sw2249#s67_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2789#s226_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2229#s140_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2589#s154_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2067#s50_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2154#s6_500/syntax.txt, blacklisted
encountering ambiguous tag at lexicalized_data/sw2015#s53_500/syntax.txt, break-tied
encountering type II typo at lexicalized_data/sw2079#s10_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2434#s87_500/syntax.txt, blacklisted
encountering a weird semicolon at lexicalized_data/sw3049#s167_500/syntax.txt, blacklisted
encountering type II typo at lexicalized_data/sw2229#s23

In [2]:
if data_dir == "data":
    print(syntax_vocabulary)
print(len(syntax_vocabulary))
print(blacklist)

13815
{'lexicalized_data/sw2249#s67_500', 'lexicalized_data/sw2789#s226_500', 'lexicalized_data/sw2434#s87_500', 'lexicalized_data/sw2719#s157_500', 'lexicalized_data/sw3294#s121_500', 'lexicalized_data/sw2154#s93_500', 'lexicalized_data/sw2589#s154_500', 'lexicalized_data/sw2229#s91_500', 'lexicalized_data/sw2641#s78_500', 'lexicalized_data/sw2784#s85_500', 'lexicalized_data/sw2229#s232_500', 'lexicalized_data/sw2078#s154_500', 'lexicalized_data/sw2229#s140_500', 'lexicalized_data/sw2079#s10_500', 'lexicalized_data/sw3561#s32_500', 'lexicalized_data/sw2067#s50_500', 'lexicalized_data/sw2154#s6_500', 'lexicalized_data/sw2249#s194_500', 'lexicalized_data/sw2249#s148_500', 'lexicalized_data/sw3049#s167_500', 'lexicalized_data/sw2249#s149_500'}


In [3]:
from dataset import SpeechSyntax, Phase, get_dataloader

dataloader = get_dataloader(
    root=data_dir, phase=Phase.TRAIN, syntax_vocabulary=syntax_vocabulary, blacklist=blacklist,
    batch_size=4, shuffle=True, num_workers=2
)

This is only for demo. You probably will never call ```pad_packed_sequence``` explicitly, which wastes a good deal of memory

In [12]:
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from dataset import restore_order
import numpy as np

for speech, syntax in dataloader:
    speech, inv_perm_index_speech = speech
    syntax, inv_perm_index_syntax = syntax
    assert type(speech) == PackedSequence
    assert type(syntax) == PackedSequence  # PackedSequence can be fed into a RNN directly
    
    # okay let's pretend there's a RNN and we are through it
    
    unpacked_speech, unpacked_speech_len = pad_packed_sequence(speech, batch_first=True)
    # restore order
    unpacked_speech = restore_order(unpacked_speech, inv_perm_index_speech)
    unpacked_speech_len = np.take(unpacked_speech_len, inv_perm_index_speech)
    print("speech batch")
    print(unpacked_speech.shape)
    print(unpacked_speech_len)
    
    unpacked_syntax, unpacked_syntax_len = pad_packed_sequence(syntax, batch_first=True)
    # restore order
    unpacked_syntax = restore_order(unpacked_syntax, inv_perm_index_syntax)
    unpacked_syntax_len = np.take(unpacked_syntax_len, inv_perm_index_syntax)
    print("syntax batch")
    print(unpacked_syntax.shape)
    print(unpacked_syntax_len)
    
    print("It's only a demo, people.")
    break

speech batch
torch.Size([4, 60238])
tensor([ 5486, 60238, 21330, 15998])
syntax batch
torch.Size([4, 284])
tensor([ 28, 284, 120,  80])
It's only a demo, people.


In [13]:
import IPython.display as ipd

def display_speech(unpacked_speech, unpacked_speech_len):
    auds = []
    for s, l in zip(unpacked_speech, unpacked_speech_len):
        speech = s[:l].numpy()
        auds.append(ipd.Audio(speech, rate=8000))
    return auds

def display_syntax(unpacked_syntax, unpacked_syntax_len):
    for s, l in zip(unpacked_syntax, unpacked_syntax_len):
        syntax = s[:l].numpy()
        tokens = []
        for syn in syntax:  # turn indices into human-readable strings
            for k, v in syntax_vocabulary.items():
                if v == syn:
                    tokens.append(k)
                    break
        syntax = "http://mshang.ca/syntree/?" + "".join(tokens)
        print(syntax)

In [14]:
auds = display_speech(unpacked_speech, unpacked_speech_len)
syns = display_syntax(unpacked_syntax, unpacked_syntax_len)

http://mshang.ca/syntree/?[S[NP[PRP[He]]][INTJ[UH[um]]]]
http://mshang.ca/syntree/?[S[NP[PRP[I]]][VP[VBP[think]][SBAR[TRACE[<TRACE>]][S[NP[PRP[it]]][VP[HVS['s]][VP[VBN[been]][EDITED[NP[DT[a]][JJ[good]]]][NP[NP[DT[a]][JJ[good]][JJ[positive]][NN[direction]]][PP[IN[for]][INTJ[UH[uh]]][NP[DT[the]][NNPS[Soviets]]]]][ADVP[ADVP[RB[as]][RB[far]]][SBAR[IN[as]][S[NP[NNP[Yeltsin]]][VP[VBZ[is]][VP[JJ[concerned]][NP[TRACE[<TRACE>]]]]]]]]]]]]][.[.]]]
http://mshang.ca/syntree/?[S[CC[but]][EDITED[NP[PRP$[my]]]][NP[PRP$[my]][NN[thought]]][VP[VBD[was]][SBAR[TRACE[<TRACE>]][S[NP[PRP[it]]][VP[BES['s]][NP[DT[a]][NN[shame]]]]]]]]
http://mshang.ca/syntree/?[S[CC[and]][NP[PRP[we]]][VP[VBD[took]][NP[DT[another]][NN[friend]]][PP[IN[with]][NP[PRP[us]]]]]]


In [15]:
auds[0]

In [8]:
auds[1]

In [9]:
auds[2]

In [16]:
auds[3]