In [1]:
import os, h5py
para_data = h5py.File('data.h5', 'r')

In [2]:
para_data

<HDF5 file "data.h5" (mode r)>

In [3]:
from utils import deleaf
num = 100000
synts1 = list(para_data['train_synts1'][:num])

synt1 = ['<s>'] + deleaf(synts1[2]) + ['</s>'] 

In [4]:
print(synt1)

['<s>', "b'(", '(', 'FRAG', '(', 'SBAR', '(', 'IN', ')', '(', 'NP', '(', 'NP', '(', 'QP', '(', 'VBZ', ')', '(', 'IN', ')', '(', 'CD', ')', ')', '(', 'NN', ')', ')', '(', 'PP', '(', 'IN', ')', '(', 'PP', '(', 'IN', ')', ')', ')', ')', ')', '(', '.', ')', ')', ")'", '</s>']


In [1]:
import pickle
with open('synt_vocab.pkl', 'rb') as f:
    synt_vocab = pickle.load(f)

In [6]:
import torch
from utils import deleaf
from tqdm import tqdm

def prepare_dataset(para_data, tokenizer, num):

    max_sent_len = 40
    max_synt_len = 160
    
    sents1 = list(para_data['train_sents1'][:num])
    synts1 = list(para_data['train_synts1'][:num])
    sents2 = list(para_data['train_sents2'][:num])
    synts2 = list(para_data['train_synts2'][:num])

    sent1_token_ids = torch.ones((num, max_sent_len+2), dtype=torch.long) 
    sent2_token_ids = torch.ones((num, max_sent_len+2), dtype=torch.long)    		
    synt1_token_ids = torch.ones((num, max_synt_len+2), dtype=torch.long) 
    synt2_token_ids = torch.ones((num, max_synt_len+2), dtype=torch.long)
    synt1_bow = torch.ones((num, 74))
    synt2_bow = torch.ones((num, 74))
        
    bsz = 64
    
    for i in tqdm(range(0, num, bsz)):
        sent1_inputs = tokenizer(sents1[i:i+bsz], padding='max_length', truncation=True, max_length=max_sent_len+2, return_tensors="pt")
        sent2_inputs = tokenizer(sents2[i:i+bsz], padding='max_length', truncation=True, max_length=max_sent_len+2, return_tensors="pt")
        sent1_token_ids[i:i+bsz] = sent1_inputs['input_ids']
        sent2_token_ids[i:i+bsz] = sent2_inputs['input_ids']

    for i in tqdm(range(num)):
        synt1 = ['<s>'] + deleaf(synts1[i]) + ['</s>']
        synt1_token_ids[i, :len(synt1)] = torch.tensor([synt_vocab[tag] for tag in synt1])[:max_synt_len+2]
        synt2 = ['<s>'] + deleaf(synts2[i]) + ['</s>']
        synt2_token_ids[i, :len(synt2)] = torch.tensor([synt_vocab[tag] for tag in synt2])[:max_synt_len+2]
        
        for tag in synt1:
            if tag != '<s>' and tag != '</s>':
                synt1_bow[i][synt_vocab[tag]-3] += 1
        for tag in synt2:
            if tag != '<s>' and tag != '</s>':
                synt2_bow[i][synt_vocab[tag]-3] += 1

    synt1_bow /= synt1_bow.sum(1, keepdim=True)
    synt2_bow /= synt2_bow.sum(1, keepdim=True)
    
    sum = 0
    for i in range(num):
        if torch.equal(synt1_bow[i], synt2_bow[i]):
            sum += 1

    return {'sent1':sent1_token_ids, 'sent2':sent2_token_ids, 'synt1': synt1_token_ids, 'synt2': synt2_token_ids,
            'synt1bow': synt1_bow, 'synt2bow': synt2_bow}

In [7]:
from transformers import BartTokenizer
# tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir="./bart-base/")
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [8]:
sents1 = list(para_data['train_sents1'][:num])

In [34]:
dataset = prepare_dataset(para_data, tokenizer, num)

  0%|          | 0/1563 [00:00<?, ?it/s]


ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [None]:
import torch
check = torch.ones((1000000, 74))

In [12]:
def deleaf(tree):
    nonleaves = ''
    for w in str(tree).replace('\n', '').split():
        w = w.replace('(', '( ').replace(')', ' )')
        nonleaves += w + ' '

    arr = nonleaves.split()
    for n, i in enumerate(arr):
        if n + 1 < len(arr):
            tok1 = arr[n]
            tok2 = arr[n + 1]
            if not is_paren(tok1) and not is_paren(tok2):
                arr[n + 1] = ""

    nonleaves = " ".join(arr)
    return nonleaves.split()

def is_paren(tok):
    return tok == ")" or tok == "("

In [13]:
from nltk import ParentedTree

# parse syntax and convert to tensor
synt_ = '(ROOT (SQ (MD can) (NP (PRP you)) (VP (VB adjust) (NP (DT the) (NNS cameras))) (. ?)))'
synt_ = ParentedTree.fromstring(synt_)
print(synt_)
synt_ = deleaf(synt_)
print(synt_)

(ROOT
  (SQ
    (MD can)
    (NP (PRP you))
    (VP (VB adjust) (NP (DT the) (NNS cameras)))
    (. ?)))
['(', 'ROOT', '(', 'SQ', '(', 'MD', ')', '(', 'NP', '(', 'PRP', ')', ')', '(', 'VP', '(', 'VB', ')', '(', 'NP', '(', 'DT', ')', '(', 'NNS', ')', ')', ')', '(', '.', ')', ')', ')']


In [4]:
#parse syntax and get template
from nltk import ParentedTree

def tree2tmpl(tree, level, mlevel):
    if level == mlevel:
        for idx, n in enumerate(tree):
            if isinstance(n, ParentedTree):
                tree[idx] = "(" + n.label() + ")"
    else:
        for n in tree:
            tree2tmpl(n, level + 1, mlevel)


tmpl_ = '(ROOT (SQ (MD can) (NP (PRP you)) (VP (VB adjust) (NP (DT the) (NNS cameras))) (. ?)))'
tmpl_ = ParentedTree.fromstring(tmpl_)
print(tmpl_)
tree2tmpl(tmpl_, 1, 2)
print(tmpl_)

(ROOT
  (SQ
    (MD can)
    (NP (PRP you))
    (VP (VB adjust) (NP (DT the) (NNS cameras)))
    (. ?)))
(ROOT (SQ (MD) (NP) (VP) (.)))


In [9]:
def is_paren(tok):
    return tok == ")" or tok == "("

def getleaf(tree):
    nonleaves = ''
    for w in str(tree).replace('\n', '').split():
        w = w.replace('(', '( ').replace(')', ' )')
        nonleaves += w + ' '
    
    leaves = []
    arr = nonleaves.split()
    for n, i in enumerate(arr):
        if n + 1 < len(arr):
            tok1 = arr[n]
            tok2 = arr[n + 1]
            if not is_paren(tok1) and not is_paren(tok2):
                leaves.append(arr[n])

    return leaves

In [10]:
#tag Sequence
sent_  = '(ROOT (SQ (MD can) (NP (PRP you)) (VP (VB adjust) (NP (DT the) (NNS cameras))) (. ?)))'
sent_ = ParentedTree.fromstring(sent_)
print(sent_)
sent_ = getleaf(sent_)
print(sent_)

(ROOT
  (SQ
    (MD can)
    (NP (PRP you))
    (VP (VB adjust) (NP (DT the) (NNS cameras)))
    (. ?)))
['MD', 'PRP', 'VB', 'DT', 'NNS', '.']


In [16]:
print(sent_) #tag1 is the tag sequence of x1
print(tmpl_) #t2 is the template of p2
print(synt_) #target X2

['MD', 'PRP', 'VB', 'DT', 'NNS', '.']
(ROOT (SQ (MD) (NP) (VP) (.)))
['(', 'ROOT', '(', 'SQ', '(', 'MD', ')', '(', 'NP', '(', 'PRP', ')', ')', '(', 'VP', '(', 'VB', ')', '(', 'NP', '(', 'DT', ')', '(', 'NNS', ')', ')', ')', '(', '.', ')', ')', ')']


In [4]:
from transformers import BartTokenizer, BartConfig, BartModel
config = BartConfig.from_pretrained('facebook/bart-base', cache_dir="./bart-base/")
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir="./bart-base/")

In [3]:
config.vocab_size

50265

In [8]:
tokenizer.get_vocab()

{'<s>': 0,
 '<pad>': 1,
 '</s>': 2,
 '<unk>': 3,
 '.': 4,
 'Ġthe': 5,
 ',': 6,
 'Ġto': 7,
 'Ġand': 8,
 'Ġof': 9,
 'Ġa': 10,
 'Ġin': 11,
 '-': 12,
 'Ġfor': 13,
 'Ġthat': 14,
 'Ġon': 15,
 'Ġis': 16,
 'âĢ': 17,
 "'s": 18,
 'Ġwith': 19,
 'ĠThe': 20,
 'Ġwas': 21,
 'Ġ"': 22,
 'Ġat': 23,
 'Ġit': 24,
 'Ġas': 25,
 'Ġsaid': 26,
 'Ļ': 27,
 'Ġbe': 28,
 's': 29,
 'Ġby': 30,
 'Ġfrom': 31,
 'Ġare': 32,
 'Ġhave': 33,
 'Ġhas': 34,
 ':': 35,
 'Ġ(': 36,
 'Ġhe': 37,
 'ĠI': 38,
 'Ġhis': 39,
 'Ġwill': 40,
 'Ġan': 41,
 'Ġthis': 42,
 ')': 43,
 'ĠâĢ': 44,
 'Ġnot': 45,
 'Ŀ': 46,
 'Ġyou': 47,
 'ľ': 48,
 'Ġtheir': 49,
 'Ġor': 50,
 'Ġthey': 51,
 'Ġwe': 52,
 'Ġbut': 53,
 'Ġwho': 54,
 'Ġmore': 55,
 'Ġhad': 56,
 'Ġbeen': 57,
 'Ġwere': 58,
 'Ġabout': 59,
 ',"': 60,
 'Ġwhich': 61,
 'Ġup': 62,
 'Ġits': 63,
 'Ġcan': 64,
 'Ġone': 65,
 'Ġout': 66,
 'Ġalso': 67,
 'Ġ$': 68,
 'Ġher': 69,
 'Ġall': 70,
 'Ġafter': 71,
 '."': 72,
 '/': 73,
 'Ġwould': 74,
 "'t": 75,
 'Ġyear': 76,
 'Ġwhen': 77,
 'Ġfirst': 78,
 'Ġshe': 79,
 'Ġtwo': 