<a href="https://colab.research.google.com/github/elenasofia98/PracticalNLP-2023-2024/blob/main/HoL03_2_ParaphraseGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Unlike recognizers, paraphrase or textual entailment generators are given a single language expression (or template) as input, and they are required to produce as many output language expressions
(or templates) as possible, such that the output expressions are paraphrases or they constitute, along
with the input, correct textual entailment pairs. Most generators assume that the input is a single
sentence (or sentence template), and we adopt this assumption in the remainder of this section.

https://web.archive.org/web/20171209091513/http://www.jair.org/media/2985/live-2985-5001-jair.pdf

Many generation methods borrow ideas from statistical machine translation (SMT).

SMT methods rely on very large bilingual or multilingual parallel corpora, for example the proceedings of the European parliament, without constructing meaning representations and often, at least until recently, without even constructing syntactic representations.

Let us assume that we wish to translate a sentence F, whose words are $f1$, $f2$,..., $f|F|$ in that order, from a foreign language to our native language.
Let us also denote by N any candidate translation, whose words are $a1$,$a2$,...,$a|N|$.

The best translation, denoted $N^*$
, is the N with the maximum probability of being a translation of F, i.e:
$N^*= argmax_N P(N|F) = argmax_N \frac{P(N)P(F|N)}{P(F)}= argmax_N
P(N)P(F|N)$
Since F is fixed, the denominator $P(F)$ above is constant and can be ignored when searching for $N^*$. $P(N)$ is called the language model and $P(F|N)$ the translation model.

For modeling purposes, it is common to assume that **F was in fact originally written in our native language** and it was transmitted to us via a **noisy channel**, which introduced various deformations.
The possible deformations may include, for example:
- replacing a native word with one or more
foreign ones
- removing or inserting words
- moving words to the left or right etc.

The foreign sentence $F$ can thus be seen as the
result of applying a sequence of transformations $D = < d1,d2,...,d|D|>$ to $N$, and it is common to search for the $N^∗$
that maximizes:
$N^*= argmax_N P(N|F) = argmax_N P(N)max_D P(F,D|N)$
; this search is called decoding.

**Deformations**

Assuming for simplicity that the individual deformations $di(·)$ of D are mutually independent,
$P(F,D|N)$ can be computed as the product of the probabilities of D’s individual deformations. In practice, however, parallel corpora do not indicate word alignment. Hence, it is common to find the most probable word alignment of the corpus given initial
estimates of individual deformation probabilities then re-estimate the deformation probabilities
given the resulting alignment, and iterate.
--> can we do that?
P(F,D|N) estimates the probability of obtaining F from N via D; we are
interested in Ns with high probabilities of leading to F.

**The translation model**

We also want, however, N to be grammatical, and we use the language model P(N) to check for grammaticality. P(N) is the probability
of encountering N in our native language; it is estimated from a large monolingual corpus of our
language, typically assuming that the probability of encountering word ai depends only on the preceding n−1 words.

For n = 3, P(N) becomes:
$P(N) = P(a1)·P(a2|a1)·P(a3|a1,a2)·P(a4|a2,a3)···P(a_{|N|}|a_{|N|−2},a_{|N|−1})$

In [1]:
!pip install stanza
import stanza
stanza.download('en')

nlp = stanza.Pipeline('en', processors='tokenize,lemma,pos,depparse')

Collecting stanza
  Downloading stanza-1.6.1-py3-none-any.whl (881 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m881.2/881.2 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting emoji (from stanza)
  Downloading emoji-2.8.0-py2.py3-none-any.whl (358 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m358.9/358.9 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: emoji, stanza
Successfully installed emoji-2.8.0 stanza-1.6.1


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.6.0.json:   0%|   …

INFO:stanza:Downloading default packages for language: en (English) ...


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.6.0/models/default.zip:   0%|          | 0…

INFO:stanza:Finished downloading models and saved to /root/stanza_resources.
INFO:stanza:Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.6.0.json:   0%|   …

INFO:stanza:Loading these models for language: en (English):
| Processor | Package           |
---------------------------------
| tokenize  | combined          |
| pos       | combined_charlm   |
| lemma     | combined_nocharlm |
| depparse  | combined_charlm   |

INFO:stanza:Using device: cpu
INFO:stanza:Loading: tokenize
INFO:stanza:Loading: pos
INFO:stanza:Loading: lemma
INFO:stanza:Loading: depparse
INFO:stanza:Done loading processors!


In [2]:
!pip install datasets
from datasets import load_dataset
import os
# https://huggingface.co/docs/datasets/index

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [7]:
import pandas as pd
from ast import literal_eval

for split in ['train', 'valid']:
    df_train = pd.read_csv(f'{split}.csv')
    for column in df_train.columns:
        if column in ['sentence1', 'sentence2', 'label', 'idx']:
            continue

        vs = []
        for v in df_train[column].values:
            v = v.split()
            if v[0] == '[':
                v = v[0]+', '.join(v[1:])
            else:
                v = ', '.join(v)
            v = literal_eval(v)
            vs.append(v)

        df_train[column] = vs
    df_train.to_csv(f'{split}_pd.csv')

def parse(example, nlp):
    parsed_example = {}
    for i in [1,2]:
        parsed_sent = nlp(example[f'sentence{i}'])
        for feature in [k+str(i) for k in ['words', 'pos','deprel', 'heads', 'headsidx']]:
            parsed_example[feature] = []

        for sent in parsed_sent.sentences:
            for word in sent.words:
                parsed_example[f'words{i}'].append(word.text)

                parsed_example[f'pos{i}'].append(word.upos)

                parsed_example[f'deprel{i}'].append(word.deprel)

                head_idx = word.head-1
                parsed_example[f"headsidx{i}"].append(head_idx)

                head = sent.words[head_idx].text if head_idx > 0 else "root"
                parsed_example[f"heads{i}"].append(head)

    return parsed_example

In [8]:
def literal_parse(example):
    literal_ex ={}
    for column in example.keys():
        if column in ['sentence1', 'sentence2', 'label', 'idx']:
            literal_ex[column] = example[column]
        elif column == 'Unnamed: 0':
            continue
        else:
            literal_ex[column] = literal_eval(example[column])
    return literal_ex

import os
if not os.path.exists('train_pd.csv') or not os.path.exists('valid_pd.csv'):
    dataset = load_dataset("glue", "mrpc",  split={"train":'train[:20%]','validation':'validation[:20%]'})
    for split in ['train', 'validation']:
        dataset[split] = dataset[split].map(lambda x: parse(x, nlp=nlp), batched=False, num_proc=128)
else:
    dataset = load_dataset("csv",  data_files={"train":'train_pd.csv','validation':'train_pd.csv'})
    for split in ['train', 'validation']:
        dataset[split] = dataset[split].map(lambda x: literal_parse(x))
display(dataset)
display(dataset['train'][0])

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/734 [00:00<?, ? examples/s]

Map:   0%|          | 0/734 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'sentence1', 'sentence2', 'label', 'idx', 'words1', 'pos1', 'deprel1', 'heads1', 'headsidx1', 'words2', 'pos2', 'deprel2', 'heads2', 'headsidx2'],
        num_rows: 734
    })
    validation: Dataset({
        features: ['Unnamed: 0', 'sentence1', 'sentence2', 'label', 'idx', 'words1', 'pos1', 'deprel1', 'heads1', 'headsidx1', 'words2', 'pos2', 'deprel2', 'heads2', 'headsidx2'],
        num_rows: 734
    })
})

{'Unnamed: 0': 0,
 'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .',
 'label': 1,
 'idx': 0,
 'words1': ['Amrozi',
  'accused',
  'his',
  'brother',
  ',',
  'whom',
  'he',
  'called',
  '"',
  'the',
  'witness',
  '"',
  ',',
  'of',
  'deliberately',
  'distorting',
  'his',
  'evidence',
  '.'],
 'pos1': ['PROPN',
  'VERB',
  'PRON',
  'NOUN',
  'PUNCT',
  'PRON',
  'PRON',
  'VERB',
  'PUNCT',
  'DET',
  'NOUN',
  'PUNCT',
  'PUNCT',
  'SCONJ',
  'ADV',
  'VERB',
  'PRON',
  'NOUN',
  'PUNCT'],
 'deprel1': ['nsubj',
  'root',
  'nmod:poss',
  'obj',
  'punct',
  'obj',
  'nsubj',
  'acl:relcl',
  'punct',
  'det',
  'xcomp',
  'punct',
  'punct',
  'mark',
  'advmod',
  'advcl',
  'nmod:poss',
  'obj',
  'punct'],
 'heads1': ['accused',
  'root',
  'brother',
  'accused',
  'brother'

In [9]:
training_lm = []
for example in dataset['train']:
    for sentence_i in [1,2]:
        poss = []
        for i in range(len(example[f'words{sentence_i}'])):
            poss.append((example[f'words{sentence_i}'][i], example[f'pos{sentence_i}'][i]))
        training_lm.append(poss)

In [10]:
# Import HMM module - https://www.nltk.org/api/nltk.tag.hmm.html
from nltk.tag import hmm

# Setup a trainer with default(None) values
# And train with the data
trainer = hmm.HiddenMarkovModelTrainer()
tagger = trainer.train_supervised(training_lm)

# Prints the basic data about the tagger
print(tagger)

print(tagger.tag("Chicago is the birthplace of Ginny".split()))
print(tagger.log_probability("Chicago is the birthplace of Ginny".split()))

<HiddenMarkovModelTagger 17 states and 5774 output symbols>
[('Chicago', 'PROPN'), ('is', 'AUX'), ('the', 'DET'), ('birthplace', 'PROPN'), ('of', 'PROPN'), ('Ginny', 'PROPN')]
-1.2e+301


In [11]:
import numpy as np

def logsumexp2(arr):
    max_ = arr.max()
    return np.log2(np.sum(2 ** (arr - max_))) + max_

class LM:
    def __init__(self, tagger):
        self.tagger = tagger

    def prob(self, example, sentence_i):
        return 2 ** (self.logprob(example, sentence_i))

    def logprob(self, example, sentence_i):
        T = len(example[f'words{sentence_i}'])
        alpha = self.tagger._forward_probability(example[f'words{sentence_i}'])
        p = logsumexp2(alpha[T - 1])
        return p

lm = LM(tagger)
lm.logprob(example=dataset['train'][0], sentence_i=1)

-1.2000000000000003e+301

In [15]:
s1 = "I feel great"
s2 = "Great I feel"

example1={}
example1['words1'] = s1.split()

example2={}
example2['words1'] = s2.split()



lm.logprob(example1,1), lm.logprob(example2,1)

(-2e+300, -1e+300)

In [37]:
! pip install zss
from zss import simple_distance, Node

Collecting zss
  Downloading zss-1.2.0.tar.gz (9.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: zss
  Building wheel for zss (setup.py) ... [?25l[?25hdone
  Created wheel for zss: filename=zss-1.2.0-py3-none-any.whl size=6725 sha256=ab96de3ba840265f2322b34339fafd4cace8dbc2bbf82c6db83146c25044c941
  Stored in directory: /root/.cache/pip/wheels/f6/61/2a/cf33ab7301cc318a13418d9a805c1832be561b46e7d9337625
Successfully built zss
Installing collected packages: zss
Successfully installed zss-1.2.0


In [38]:
def get_dtree(example, sentence_index):
    dtree = {}
    for i in range(len(example[f'words{sentence_index}'])):

        head = example[f'heads{sentence_index}'][i]
        head_idx = example[f'headsidx{sentence_index}'][i]
        if f"{head}_{head_idx}" not in dtree:
            dtree[f"{head}_{head_idx}"] = []

        dtree[f"{head}_{head_idx}"].append(f"{example[f'words{sentence_index}'][i]}_{i}")
    return dtree

def construct_tree(root, dtree):
    tree = Node(root)

    # is leaf
    if root not in dtree:
        return tree

    # has children
    children = dtree[root]
    for child in children:
        subtree = construct_tree(root=child, dtree=dtree) # recursevely, build the subtree rooted in each children
        tree = tree.addkid(subtree)

    return tree

example = dataset['train'][22]
root = 'root_-1'
print(example['words1'])
print(example['heads1'])
print(example['headsidx1'])
dtree = get_dtree(example, sentence_index=2)
display(dtree)
tree = construct_tree(root, dtree)
print("\nRoot word")
display(Node.get_children(tree))
print("\nChildren of root word")
display(Node.get_children(Node.get_children(tree)[0]))

['A', 'BMI', 'of', '25', 'or', 'above', 'is', 'considered', 'overweight', ';', '30', 'or', 'above', 'is', 'considered', 'obese', '.']
['BMI', 'considered', '25', 'BMI', 'above', '25', 'considered', 'root', 'considered', 'considered', 'considered', 'above', '30', 'considered', 'considered', 'considered', 'considered']
[1, 7, 3, 1, 5, 3, 7, -1, 7, 14, 14, 12, 10, 14, 7, 14, 7]


{'BMI_1': ['A_0', '18.5_3'],
 'considered_7': ['BMI_1', 'is_6', 'normal_8', ',_9', 'considered_13', '._23'],
 '18.5_3': ['between_2', '24.9_5'],
 '24.9_5': ['and_4'],
 'root_-1': ['considered_7'],
 '25_11': ['over_10'],
 'considered_13': ['25_11', 'is_12', 'overweight_14', 'defined_20'],
 'defined_20': ['and_15', '30_16', 'is_19', 'obese_22'],
 'greater_18': ['or_17'],
 '30_16': ['greater_18'],
 'obese_22': ['as_21']}


Root word


[<zss.simple_tree.Node object at 0x7965b63c8e80 considered_7>]


Children of root word


[<zss.simple_tree.Node object at 0x7965b63cab60 BMI_1>,
 <zss.simple_tree.Node object at 0x7965b63cbb20 is_6>,
 <zss.simple_tree.Node object at 0x7965b63cb220 normal_8>,
 <zss.simple_tree.Node object at 0x7965b63c8a30 ,_9>,
 <zss.simple_tree.Node object at 0x7965b63c86a0 considered_13>,
 <zss.simple_tree.Node object at 0x7965b63caa40 ._23>]

In [45]:
root = 'root_-1'

dtree1 = get_dtree(dataset['train'][0], sentence_index=1)
dtree2 = get_dtree(dataset['train'][0], sentence_index=2)

display(dtree1)
display(dtree2)

tree1 = construct_tree(root, dtree=dtree1)
tree2 = construct_tree(root, dtree=dtree2)

d, operations = simple_distance(tree1, tree2, return_operations=True)

{'accused_1': ['Amrozi_0', 'brother_3', 'distorting_15', '._18'],
 'root_-1': ['accused_1'],
 'brother_3': ['his_2', ',_4', 'called_7'],
 'called_7': ['whom_5', 'he_6', 'witness_10'],
 'witness_10': ['"_8', 'the_9', '"_11'],
 'distorting_15': [',_12', 'of_13', 'deliberately_14', 'evidence_17'],
 'evidence_17': ['his_16']}

{'accused_11': ['Referring_0',
  'Amrozi_10',
  'brother_13',
  'distorting_16',
  '._19'],
 'him_2': ['to_1'],
 'root_0': ['him_2', 'witness_7', ',_9'],
 'witness_7': ['as_3', 'only_4', '"_5', 'the_6', '"_8'],
 'root_-1': ['accused_11'],
 'brother_13': ['his_12'],
 'distorting_16': ['of_14', 'deliberately_15', 'evidence_18'],
 'evidence_18': ['his_17']}

In [40]:
remove = 0
insert = 1
update = 2
match = 3

def parse(op):
    if hasattr(op.arg1, "label") or hasattr(op.arg2, "label"):
        if op.type == remove:
            return ('remove', op.arg1.label.split('_')[0], '__X__')
        elif op.type == insert:
            return ('insert', '__X__', op.arg2.label.split('_')[0])
        elif op.type == update:
            return ('update', op.arg1.label.split('_')[0], op.arg2.label.split('_')[0])
        else:
            return ('match', op.arg1.label.split('_')[0], op.arg2.label.split('_')[0])
    else:
        if op.type == remove:
            return ('remove')
        elif op.type == insert:
            return ('insert')
        elif op.type == update:
            return ('update')
        else:
            return ('match')

def eq(op, other):
    if other is None: return False

    return op.type == other.type and op.arg1 == other.arg1 and \
        op.arg2 == other.arg2


In [69]:
from tqdm import tqdm

transformations = {}
total = 0

root = 'root_-1'

for example in tqdm(dataset['train']):
    if example['label'] == 1:
        dtree1 = get_dtree(example, sentence_index=1)
        dtree2 = get_dtree(example, sentence_index=2)

        tree1 = construct_tree(root, dtree=dtree1)
        tree2 = construct_tree(root, dtree=dtree2)

        d, operations = simple_distance(tree1, tree2, return_operations=True)
        for op in operations:
            op = parse(op)
            if len(op)>1:
                label, arg1, arg2 = op
                if label not in transformations:
                    transformations[label] = {}
                if arg1 not in transformations[label]:
                    transformations[label][arg1] = {}
                if arg2 not in transformations[label][arg1]:
                    transformations[label][arg1][arg2] = 0

                transformations[label][arg1][arg2]+=1
                total += 1

100%|██████████| 734/734 [00:31<00:00, 23.31it/s]


In [93]:
dataset['train'][0]['headsidx1']

[1, -1, 3, 1, 3, 7, 7, 3, 10, 10, 7, 10, 15, 15, 15, 1, 17, 15, 1]

In [None]:
transformations['remove']

In [96]:
# Lets try with non-lexicalized trees
def get_dtree_not_lexicalized(example, sentence_index):
    dtree = {}
    for i in range(len(example[f'words{sentence_index}'])):
        head = example[f'heads{sentence_index}'][i]

        head_idx = example[f'headsidx{sentence_index}'][i]
        if head_idx !=-1:
            pos_tag_head = example[f'pos{sentence_index}'][head_idx]
        else:
            pos_tag_head = head

        if f"{pos_tag_head}_{head_idx}" not in dtree:
            dtree[f"{pos_tag_head}_{head_idx}"] = []

        pos_tag_child = example[f'pos{sentence_index}'][i]

        dtree[f"{pos_tag_head}_{head_idx}"].append(f"{pos_tag_child}_{i}")
    return dtree


In [97]:
get_dtree_not_lexicalized(example=dataset['train'][0], sentence_index=1),get_dtree_not_lexicalized(example=dataset['train'][0], sentence_index=2)

({'VERB_1': ['PROPN_0', 'NOUN_3', 'VERB_15', 'PUNCT_18'],
  'root_-1': ['VERB_1'],
  'NOUN_3': ['PRON_2', 'PUNCT_4', 'VERB_7'],
  'VERB_7': ['PRON_5', 'PRON_6', 'NOUN_10'],
  'NOUN_10': ['PUNCT_8', 'DET_9', 'PUNCT_11'],
  'VERB_15': ['PUNCT_12', 'SCONJ_13', 'ADV_14', 'NOUN_17'],
  'NOUN_17': ['PRON_16']},
 {'VERB_11': ['VERB_0', 'PROPN_10', 'NOUN_13', 'VERB_16', 'PUNCT_19'],
  'PRON_2': ['ADP_1'],
  'VERB_0': ['PRON_2', 'NOUN_7', 'PUNCT_9'],
  'NOUN_7': ['ADP_3', 'ADV_4', 'PUNCT_5', 'DET_6', 'PUNCT_8'],
  'root_-1': ['VERB_11'],
  'NOUN_13': ['PRON_12'],
  'VERB_16': ['SCONJ_14', 'ADV_15', 'NOUN_18'],
  'NOUN_18': ['PRON_17']})

In [100]:
from tqdm import tqdm

transformations = {}
total = 0

root = 'root_-1'

for example in tqdm(dataset['train']):
    if example['label'] == 1:
        dtree1 = get_dtree_not_lexicalized(example, sentence_index=1)
        dtree2 = get_dtree_not_lexicalized(example, sentence_index=2)

        tree1 = construct_tree(root, dtree=dtree1)
        tree2 = construct_tree(root, dtree=dtree2)

        d, operations = simple_distance(tree1, tree2, return_operations=True)
        for op in operations:
            op = parse(op)
            if len(op)>1:
                label, arg1, arg2 = op
                if label not in transformations:
                    transformations[label] = {}
                if arg1 not in transformations[label]:
                    transformations[label][arg1] = {}
                if arg2 not in transformations[label][arg1]:
                    transformations[label][arg1][arg2] = 0

                transformations[label][arg1][arg2]+=1
                total += 1

100%|██████████| 734/734 [00:26<00:00, 27.28it/s]


In [109]:
transformations['remove'],total

({'DET': {'__X__': 0.01226588321704003},
  'VERB': {'__X__': 0.016232096951891296},
  'NOUN': {'__X__': 0.031876606683804626},
  'PUNCT': {'__X__': 0.016819684171869263},
  'PRON': {'__X__': 0.0044069041498347415},
  'ADV': {'__X__': 0.005141388174807198},
  'PART': {'__X__': 0.002864487697392582},
  'PROPN': {'__X__': 0.017480719794344474},
  'ADP': {'__X__': 0.011972089607051047},
  'SCONJ': {'__X__': 0.002130003672420125},
  'NUM': {'__X__': 0.004700697759823724},
  'CCONJ': {'__X__': 0.002791039294895336},
  'ADJ': {'__X__': 0.011604847594564819},
  'AUX': {'__X__': 0.005728975394785164},
  'SYM': {'__X__': 0.0010282776349614395}},
 13615)

In [102]:
for label in transformations:
    for arg1 in transformations[label]:
        for arg2 in transformations[label][arg1]:
            transformations[label][arg1][arg2] = transformations[label][arg1][arg2]/total

In [106]:
class PD:
    def __init__(self, transformations):
        self.transformations = transformations

    def op_prob(self, op):
        op = parse(op)
        label, arg1, arg2 = op
        return self.transformations[label][arg1][arg2]

    def op_logprob(self, op):
        p = self.op_prob(op)
        return np.log2(p)

    def deformations_logprob(self, operations):
        logp = 0
        for op in operations:
            logp += self.op_logprob(op)
        return logp

    def deformations_prob(self, operations):
        return 2 ** (self.deformations_prob(operations))

    def logprob(self, example, F=1, N=2):
        root = 'root_-1'

        dtreeF = get_dtree_not_lexicalized(dataset['train'][0], sentence_index=F)
        dtreeN = get_dtree_not_lexicalized(dataset['train'][0], sentence_index=N)

        treeF = construct_tree(root, dtree=dtreeF)
        treeN = construct_tree(root, dtree=dtreeN)

        d, operations = simple_distance(treeF, treeN, return_operations=True)
        return self.deformations_logprob(operations)

    def prob(self, example, F=1, N=2):
        return 2 ** (self.logprob(example, F, N))

pd = PD(transformations)

In [116]:
def generation_prob(example, F, N):
    log_p = {}

    log_p['N'] = lm.logprob(example=example, sentence_i=N)
    log_p['F,D|N'] = pd.logprob(example=example, F=F, N=N)
    print(log_p['N'], log_p['F,D|N'])

    return log_p['N'] + log_p['F,D|N']

In [117]:
generation_prob(dataset['train'][0], F=1, N=2), generation_prob(dataset['train'][1], F=1, N=2)

-1.3000000000000004e+301 -158.92060106181742
-1.1000000000000002e+301 -158.92060106181742


(-1.3000000000000004e+301, -1.1000000000000002e+301)

In [119]:
dataset['train'][0]['sentence1'],dataset['train'][0]['sentence2']

('Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .')

In [120]:
dataset['train'][1]['sentence1'],dataset['train'][1]['sentence2']

("Yucaipa owned Dominick 's before selling the chain to Safeway in 1998 for $ 2.5 billion .",
 "Yucaipa bought Dominick 's in 1995 for $ 693 million and sold it to Safeway for $ 1.8 billion in 1998 .")