In [1]:
import numpy as np
import pandas as pd
import torch
import json
import itertools
import nltk
# nltk.download('punkt')
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained('roberta-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('roberta-large-mnli')

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [4]:
model = model.to(device)

In [5]:
data = json.load(open('data/combined.json', 'r'))
dev = data['dev']
test = data['test']
fulldata = dev + test

In [6]:
def contrast(reference):

    pairs = []
    label_lists =[]
    stats =[]

    if reference == True:

        for data in fulldata:

            ref_a_sum = [nltk.sent_tokenize(i) for i in data['entity_a_summary']]
            ref_b_sum = [nltk.sent_tokenize(i) for i in data['entity_b_summary']]
            ref_comm_sum = [nltk.sent_tokenize(i) for i in data['common_summary']]
        
            pair = []
            label = []
            label_rev = []
            stat = []
            stat_rev = []

            for i in range(len(ref_a_sum)):

                cont = list(itertools.product(ref_a_sum[i], ref_b_sum[i]))
                comm_a = list(itertools.product(ref_a_sum[i], ref_comm_sum[i]))
                comm_b = list(itertools.product(ref_b_sum[i], ref_comm_sum[i]))

                pair.append(cont+comm_a+comm_b)

                col1 = [j[0] for i in [cont,comm_a,comm_b] for j in i]
                col2 = [j[1] for i in [cont,comm_a,comm_b] for j in i]

                features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
                features.to(device)

                features_rev = tokenizer(col2,col1,  padding=True, truncation=True, return_tensors="pt")
                features_rev.to(device)

                model.eval()
                with torch.no_grad():
                    scores = model(**features).logits
                    label_mapping = ['contradiction', 'neutral','entailment']
                    labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

                    label.append(labels)
                    print(Counter(labels))
                    stat.append(Counter(labels))

                    scores = model(**features_rev).logits
                    label_mapping = ['contradiction', 'neutral','entailment']
                    labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

                    label_rev.append(labels)
                    print(Counter(labels))
                    stat_rev.append(Counter(labels))
        

            label_lists.append(label) 
            label_lists.append(label_rev)
            stats.append(stat) 
            stats.append(stat_rev) 
            pairs.append(pair)

    else:

        for data in fulldata:

            gen_a_sum = nltk.sent_tokenize(data['gen_cont_a'])
            gen_b_sum = nltk.sent_tokenize(data['gen_cont_b'])
            gen_comm_sum = nltk.sent_tokenize(data['gen_comm_a'])

            cont = list(itertools.product(gen_a_sum, gen_b_sum))
            comm_a = list(itertools.product(gen_a_sum, gen_comm_sum))
            comm_b = list(itertools.product(gen_b_sum, gen_comm_sum))

            pairs.append(cont+comm_a+comm_b)

            col1 = [j[0] for i in [cont,comm_a,comm_b] for j in i]
            col2 = [j[1] for i in [cont,comm_a,comm_b] for j in i]

            features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
            features.to(device)

            features_rev = tokenizer(col2,col1,  padding=True, truncation=True, return_tensors="pt")
            features_rev.to(device)
        

            model.eval()
            with torch.no_grad():
                scores = model(**features).logits
                label_mapping = ['contradiction', 'neutral','entailment']
                labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

                label_lists.append(labels)
                print(Counter(labels))
                stats.append(Counter(labels))


                scores = model(**features_rev).logits
                label_mapping = ['contradiction', 'neutral','entailment']
                labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

                label_lists.append(labels)
                print(Counter(labels))
                stats.append(Counter(labels))
            
    return pairs,label_lists, stats    

In [43]:
def factual_opinion(reference):

    pairs = []
    label_lists = []
    stats = []

    # factual opinons

    if reference == True:

        for data in fulldata:

            ref_a_sum = [nltk.sent_tokenize(i) for i in data['entity_a_summary']]
            ref_b_sum = [nltk.sent_tokenize(i) for i in data['entity_b_summary']]
            ref_comm_sum = [nltk.sent_tokenize(i) for i in data['common_summary']]

            source_a = [nltk.sent_tokenize(i) for i in data['entity_a_reviews']]
            source_b = [nltk.sent_tokenize(i) for i in data['entity_b_reviews']]
            
            
            pair = []
            label = []
            stat = []

            for i in range(len(ref_a_sum)):

                rev_a = [k for j in source_a for k in itertools.product(ref_a_sum[i],j)]
                rev_b = [k for j in source_b for k in itertools.product(ref_b_sum[i],j)]
                rev_com_a = [k for j in source_a for k in itertools.product(ref_comm_sum[i],j)]
                rev_com_b = [k for j in source_b for k in itertools.product(ref_comm_sum[i],j)]
                
                pair.append(rev_a,rev_b,rev_com_a,rev_com_b)


                col1 = [j[0] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]
                col2 = [j[1] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]


                features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
                features.to(device)

                model.eval()
                with torch.no_grad():
                    scores = model(**features).logits
                    label_mapping = ['contradiction', 'neutral','entailment']
                    labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

                    label.append(labels)
                    print(Counter(labels))
                    stat.append(Counter(labels))
                    
            label_lists.append(label) 
            stats.append(stat) 
            pairs.append(pair)
                    
            

    else:

        for data in fulldata:

            source_a = [nltk.sent_tokenize(i) for i in data['entity_a_reviews']]
            source_b = [nltk.sent_tokenize(i) for i in data['entity_b_reviews']]

            gen_a_sum = nltk.sent_tokenize(data['gen_cont_a'])
            gen_b_sum = nltk.sent_tokenize(data['gen_cont_b'])
            gen_comm_sum = nltk.sent_tokenize(data['gen_comm_a'])

            rev_gen_a = [j for i in source_a for j in itertools.product(i,gen_a_sum)]
            rev_gen_b = [j for i in source_b for j in itertools.product(i,gen_b_sum)]
            rev_com_a = [j for i in source_a for j in itertools.product(i,gen_comm_sum)]
            rev_com_b = [j for i in source_b for j in itertools.product(i,gen_comm_sum)]
            
            pairs.append(rev_gen_a,rev_gen_b,rev_com_a,rev_com_b)

            col1 = [j[0] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]
            col2 = [j[1] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]

            features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
            features.to(device)

            model.eval()
            with torch.no_grad():
                scores = model(**features).logits
                label_mapping = ['contradiction', 'neutral','entailment']
                labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

                label_lists.append(labels)
                print(Counter(labels))
                stats.append(Counter(labels))
                
                
    return pairs, label_lists, stats
                




In [27]:
pairs, label_lists, stats = contrast(False)

Counter({'neutral': 54, 'contradiction': 4, 'entailment': 1})
Counter({'neutral': 55, 'contradiction': 4})
Counter({'neutral': 62, 'contradiction': 9})
Counter({'neutral': 60, 'contradiction': 11})
Counter({'neutral': 54, 'contradiction': 3, 'entailment': 3})
Counter({'neutral': 55, 'contradiction': 5})
Counter({'neutral': 68, 'contradiction': 7, 'entailment': 2})
Counter({'neutral': 74, 'contradiction': 3})
Counter({'neutral': 57, 'contradiction': 2})
Counter({'neutral': 53, 'contradiction': 5, 'entailment': 1})
Counter({'neutral': 30, 'contradiction': 6, 'entailment': 2})
Counter({'neutral': 27, 'contradiction': 7, 'entailment': 4})
Counter({'neutral': 49, 'contradiction': 8, 'entailment': 2})
Counter({'neutral': 52, 'contradiction': 7})
Counter({'neutral': 45, 'contradiction': 2})
Counter({'neutral': 44, 'contradiction': 3})
Counter({'neutral': 57, 'contradiction': 2, 'entailment': 1})
Counter({'neutral': 54, 'contradiction': 5, 'entailment': 1})
Counter({'neutral': 66, 'contradicti

In [40]:
label_for = label_lists[1::2]
label_rev = label_lists[::2]
stats_for = stats[1::2]
stats_rev = stats[::2]

mean_stats_both = [{k:round(sum([i]+[j], Counter())[k]/2,3) for k in sum([i]+[j], Counter())} for i,j in zip(stats_for, stats_rev)]

ratio_stats_both = [{k:round(sum([i]+[j], Counter())[k]/sum(sum([i]+[j], Counter()).values()),3) for k in sum([i]+[j], Counter())} for i,j in zip(stats_for, stats_rev)]

In [41]:
gen_df = pd.DataFrame(
    {'Sentence_Pairs': pairs,
     'Labels': label_for,
     'Labels_Reverse': label_rev,
     'Aggregate_Labels': stats_for,
     'Aggregate_Labels_Rev': stats_rev,
     'Mean_Label_Both': mean_stats_both,
     'Ratio_Label_Both': ratio_stats_both
    })

In [42]:
gen_df.head()

Unnamed: 0,Sentence_Pairs,Labels,Labels_Reverse,Aggregate_Labels,Aggregate_Labels_Rev,Mean_Label_Both,Ratio_Label_Both
0,[(Le Place D'Armes is an elegant historic buil...,"[neutral, neutral, neutral, neutral, neutral, ...","[contradiction, neutral, neutral, neutral, neu...","{'neutral': 55, 'contradiction': 4}","{'contradiction': 4, 'neutral': 54, 'entailmen...","{'neutral': 54.5, 'contradiction': 4.0, 'entai...","{'neutral': 0.924, 'contradiction': 0.068, 'en..."
1,[(The Residence Inn Toronto is in an excellent...,"[neutral, neutral, neutral, neutral, contradic...","[neutral, neutral, neutral, neutral, neutral, ...","{'neutral': 60, 'contradiction': 11}","{'neutral': 62, 'contradiction': 9}","{'neutral': 61.0, 'contradiction': 10.0}","{'neutral': 0.859, 'contradiction': 0.141}"
2,[(This is a great hotel that deserves its repu...,"[neutral, neutral, contradiction, neutral, neu...","[neutral, neutral, neutral, neutral, neutral, ...","{'neutral': 55, 'contradiction': 5}","{'neutral': 54, 'contradiction': 3, 'entailmen...","{'neutral': 54.5, 'contradiction': 4.0, 'entai...","{'neutral': 0.908, 'contradiction': 0.067, 'en..."
3,[(The Sutton Place hotel is great for business...,"[contradiction, neutral, neutral, neutral, neu...","[contradiction, neutral, neutral, neutral, neu...","{'contradiction': 3, 'neutral': 74}","{'contradiction': 7, 'neutral': 68, 'entailmen...","{'contradiction': 5.0, 'neutral': 71.0, 'entai...","{'contradiction': 0.065, 'neutral': 0.922, 'en..."
4,[(This hotel is hands down one of the nicest a...,"[neutral, contradiction, neutral, neutral, neu...","[neutral, neutral, neutral, neutral, neutral, ...","{'neutral': 53, 'contradiction': 5, 'entailmen...","{'neutral': 57, 'contradiction': 2}","{'neutral': 55.0, 'contradiction': 3.5, 'entai...","{'neutral': 0.932, 'contradiction': 0.059, 'en..."


In [8]:
pairs, label_lists, stats = contrast(True)

Counter({'neutral': 53, 'contradiction': 2})
Counter({'neutral': 47, 'contradiction': 7, 'entailment': 1})
Counter({'neutral': 69, 'contradiction': 7, 'entailment': 1})
Counter({'neutral': 74, 'contradiction': 2, 'entailment': 1})
Counter({'neutral': 104, 'contradiction': 6, 'entailment': 1})
Counter({'neutral': 98, 'contradiction': 9, 'entailment': 4})
Counter({'neutral': 42, 'contradiction': 2, 'entailment': 1})
Counter({'neutral': 40, 'contradiction': 5})
Counter({'neutral': 43, 'contradiction': 9})
Counter({'neutral': 41, 'contradiction': 11})
Counter({'neutral': 35, 'contradiction': 3})
Counter({'neutral': 34, 'contradiction': 4})
Counter({'neutral': 23, 'contradiction': 11})
Counter({'neutral': 29, 'contradiction': 5})
Counter({'neutral': 47, 'contradiction': 8})
Counter({'neutral': 49, 'contradiction': 6})
Counter({'neutral': 56, 'contradiction': 7})
Counter({'neutral': 52, 'contradiction': 10, 'entailment': 1})
Counter({'neutral': 75, 'contradiction': 8, 'entailment': 1})
Count

In [24]:
label_for = label_lists[1::2]
label_rev = label_lists[::2]
stats_for = stats[1::2]
stats_rev = stats[::2]

mean_stats_for = [{i:round(sum(j, Counter())[i]/3,3) for i in sum(j, Counter())}  for j in stats_for ]
mean_stats_rev = [{i:round(sum(j, Counter())[i]/3,3) for i in sum(j, Counter())}  for j in stats_rev ]

mean_stats_both = [{k:round(sum(i+j, Counter())[k]/6,3) for k in sum(i+j, Counter())} for i,j in zip(stats_for, stats_rev)]

ratio_stats_both = [{k:round(sum(i+j, Counter())[k]/sum(sum(i+j, Counter()).values()),3) for k in sum(i+j, Counter())} for i,j in zip(stats_for, stats_rev)]

In [25]:
ref_df = pd.DataFrame(
    {'Sentence_Pairs': pairs,
     'Labels': label_for,
     'Labels_Reverse': label_rev,
     'Aggregate_Labels': stats_for,
     'Aggregate_Labels_Rev': stats_rev,
     'Mean_Labels': mean_stats_for,
     'Mean_Labels_Rev': mean_stats_rev,
     'Mean_Label_Both': mean_stats_both,
     'Ratio_Label_Both': ratio_stats_both
    })

In [26]:
ref_df.head()

Unnamed: 0,Sentence_Pairs,Labels,Labels_Reverse,Aggregate_Labels,Aggregate_Labels_Rev,Mean_Labels,Mean_Labels_Rev,Mean_Label_Both,Ratio_Label_Both
0,[[(The hotel has conference rooms available to...,"[[neutral, neutral, neutral, neutral, neutral,...","[[neutral, neutral, neutral, neutral, neutral,...","[{'neutral': 47, 'contradiction': 7, 'entailme...","[{'neutral': 53, 'contradiction': 2}, {'neutra...","{'neutral': 73.0, 'contradiction': 6.0, 'entai...","{'neutral': 75.333, 'contradiction': 5.0, 'ent...","{'neutral': 74.167, 'contradiction': 5.5, 'ent...","{'neutral': 0.916, 'contradiction': 0.068, 'en..."
1,[[(This is a hotel located in an excellent pla...,"[[neutral, neutral, neutral, neutral, neutral,...","[[neutral, neutral, neutral, neutral, neutral,...","[{'neutral': 40, 'contradiction': 5}, {'neutra...","[{'neutral': 42, 'contradiction': 2, 'entailme...","{'neutral': 38.333, 'contradiction': 6.667}","{'neutral': 40.0, 'contradiction': 4.667, 'ent...","{'neutral': 39.167, 'contradiction': 5.667, 'e...","{'neutral': 0.87, 'contradiction': 0.126, 'ent..."
2,"[[(The hotel has a level of homely comfort, bu...","[[neutral, neutral, neutral, neutral, neutral,...","[[neutral, neutral, contradiction, neutral, ne...","[{'neutral': 29, 'contradiction': 5}, {'neutra...","[{'neutral': 23, 'contradiction': 11}, {'neutr...","{'neutral': 43.333, 'contradiction': 7.0, 'ent...","{'neutral': 42.0, 'contradiction': 8.667}","{'neutral': 42.667, 'contradiction': 7.833, 'e...","{'neutral': 0.842, 'contradiction': 0.155, 'en..."
3,"[[(The hotel was clean and reasonably priced.,...","[[neutral, neutral, neutral, neutral, neutral,...","[[neutral, neutral, neutral, neutral, neutral,...","[{'neutral': 84}, {'neutral': 76, 'entailment'...","[{'neutral': 75, 'contradiction': 8, 'entailme...","{'neutral': 78.667, 'entailment': 0.333, 'cont...","{'neutral': 73.667, 'contradiction': 4.333, 'e...","{'neutral': 76.167, 'entailment': 0.833, 'cont...","{'neutral': 0.96, 'entailment': 0.011, 'contra..."
4,[[(This greatly loved hotel is well kept and c...,"[[neutral, contradiction, neutral, neutral, ne...","[[neutral, contradiction, neutral, neutral, ne...","[{'neutral': 62, 'contradiction': 4, 'entailme...","[{'neutral': 57, 'contradiction': 9, 'entailme...","{'neutral': 65.333, 'contradiction': 5.0, 'ent...","{'neutral': 61.0, 'contradiction': 9.667, 'ent...","{'neutral': 63.167, 'contradiction': 7.333, 'e...","{'neutral': 0.881, 'contradiction': 0.102, 'en..."


In [21]:
[{k:round(sum(i+j, Counter())[k]/sum(sum(i+j, Counter()).values()),3) for k in sum(i+j, Counter())} for i,j in zip(stats_for, stats_rev)]

[{'neutral': 0.916, 'contradiction': 0.068, 'entailment': 0.016},
 {'neutral': 0.87, 'contradiction': 0.126, 'entailment': 0.004},
 {'neutral': 0.842, 'contradiction': 0.155, 'entailment': 0.003},
 {'neutral': 0.96, 'entailment': 0.011, 'contradiction': 0.029},
 {'neutral': 0.881, 'contradiction': 0.102, 'entailment': 0.016},
 {'neutral': 0.863, 'contradiction': 0.128, 'entailment': 0.009},
 {'neutral': 0.88, 'contradiction': 0.117, 'entailment': 0.003},
 {'neutral': 0.947, 'contradiction': 0.046, 'entailment': 0.007},
 {'neutral': 0.914, 'contradiction': 0.06, 'entailment': 0.026},
 {'neutral': 0.886, 'contradiction': 0.103, 'entailment': 0.011},
 {'neutral': 0.94, 'contradiction': 0.04, 'entailment': 0.02},
 {'neutral': 0.906, 'contradiction': 0.094},
 {'entailment': 0.011, 'neutral': 0.934, 'contradiction': 0.056},
 {'neutral': 0.919, 'contradiction': 0.068, 'entailment': 0.013},
 {'contradiction': 0.19, 'neutral': 0.798, 'entailment': 0.012},
 {'neutral': 0.87, 'contradiction': 0.1

In [22]:
[{k:round(sum(i, Counter())[k]/sum(sum(i, Counter()).values()),3) for k in sum(i, Counter())} for i in stats_for]

[{'neutral': 0.901, 'contradiction': 0.074, 'entailment': 0.025},
 {'neutral': 0.852, 'contradiction': 0.148},
 {'neutral': 0.855, 'contradiction': 0.138, 'entailment': 0.007},
 {'neutral': 0.992, 'entailment': 0.004, 'contradiction': 0.004},
 {'neutral': 0.912, 'contradiction': 0.07, 'entailment': 0.019},
 {'neutral': 0.873, 'contradiction': 0.114, 'entailment': 0.012},
 {'neutral': 0.94, 'contradiction': 0.057, 'entailment': 0.003},
 {'neutral': 0.955, 'contradiction': 0.041, 'entailment': 0.003},
 {'neutral': 0.916, 'contradiction': 0.052, 'entailment': 0.031},
 {'neutral': 0.897, 'contradiction': 0.096, 'entailment': 0.007},
 {'neutral': 0.925, 'contradiction': 0.044, 'entailment': 0.032},
 {'neutral': 0.886, 'contradiction': 0.114},
 {'entailment': 0.014, 'neutral': 0.917, 'contradiction': 0.069},
 {'neutral': 0.924, 'contradiction': 0.06, 'entailment': 0.017},
 {'contradiction': 0.187, 'neutral': 0.807, 'entailment': 0.006},
 {'neutral': 0.868, 'contradiction': 0.093, 'entailment

In [23]:
[{k:round(sum(i, Counter())[k]/sum(sum(i, Counter()).values()),3) for k in sum(i, Counter())} for i in stats_rev]

[{'neutral': 0.93, 'contradiction': 0.062, 'entailment': 0.008},
 {'neutral': 0.889, 'contradiction': 0.104, 'entailment': 0.007},
 {'neutral': 0.829, 'contradiction': 0.171},
 {'neutral': 0.929, 'contradiction': 0.055, 'entailment': 0.017},
 {'neutral': 0.851, 'contradiction': 0.135, 'entailment': 0.014},
 {'neutral': 0.852, 'entailment': 0.006, 'contradiction': 0.142},
 {'neutral': 0.82, 'contradiction': 0.177, 'entailment': 0.003},
 {'neutral': 0.938, 'contradiction': 0.052, 'entailment': 0.01},
 {'entailment': 0.021, 'neutral': 0.911, 'contradiction': 0.068},
 {'neutral': 0.875, 'contradiction': 0.11, 'entailment': 0.015},
 {'neutral': 0.956, 'contradiction': 0.036, 'entailment': 0.008},
 {'neutral': 0.926, 'contradiction': 0.074},
 {'neutral': 0.95, 'contradiction': 0.043, 'entailment': 0.007},
 {'neutral': 0.914, 'contradiction': 0.076, 'entailment': 0.01},
 {'contradiction': 0.193, 'neutral': 0.789, 'entailment': 0.018},
 {'neutral': 0.871, 'contradiction': 0.114, 'entailment': 