In [2]:
import numpy as np
import pandas as pd
import torch
import json
import itertools
import nltk
# nltk.download('punkt')
from collections import Counter

In [3]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained('roberta-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('roberta-large-mnli')

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [5]:
model.to(device)

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
         

In [6]:
data = json.load(open('data/combined.json', 'r'))
dev = data['dev']
test = data['test']
fulldata = dev + test

In [34]:
def contrast(reference):

  lists =[]
  agg =[]

  if reference == True:

    for data in fulldata:

      ref_a_sum = [nltk.sent_tokenize(i) for i in data['entity_a_summary']]
      ref_b_sum = [nltk.sent_tokenize(i) for i in data['entity_b_summary']]
      ref_comm_sum = [nltk.sent_tokenize(i) for i in data['common_summary']]
    
      agg_agg = []

      for i in range(len(ref_a_sum)):

        cont = list(itertools.product(ref_a_sum[i], ref_b_sum[i]))
        comm_a = list(itertools.product(ref_a_sum[i], ref_comm_sum[i]))
        comm_b = list(itertools.product(ref_b_sum[i], ref_comm_sum[i]))


        col1 = [j[0] for i in [cont,comm_a,comm_b] for j in i]
        col2 = [j[1] for i in [cont,comm_a,comm_b] for j in i]

        features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
        features.to(device)

        model.eval()
        with torch.no_grad():
            scores = model(**features).logits
            label_mapping = ['contradiction', 'entailment', 'neutral']
            labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

            lists.append(labels)
            # print(labels)
            aggs = Counter(labels)
            print(aggs)
            agg_agg.append(aggs)
        
        

  else:

    for data in fulldata:

      gen_a_sum = nltk.sent_tokenize(data['gen_cont_a'])
      gen_b_sum = nltk.sent_tokenize(data['gen_cont_b'])
      gen_comm_sum = nltk.sent_tokenize(data['gen_comm_a'])

      cont = list(itertools.product(gen_a_sum, gen_b_sum))
      comm_a = list(itertools.product(gen_a_sum, gen_comm_sum))
      comm_b = list(itertools.product(gen_b_sum, gen_comm_sum))

      col1 = [j[0] for i in [cont,comm_a,comm_b] for j in i]
      col2 = [j[1] for i in [cont,comm_a,comm_b] for j in i]

      features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
      features.to(device)
        
      # print(col1[2])
      # print(col2[2])

      model.eval()
      with torch.no_grad():
          scores = model(**features).logits
          # label_mapping = ['contradiction', 'entailment', 'neutral']
          label_mapping = ['contradiction', 'neutral','entailment']
          labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

          lists.append(labels)
          # print(labels)
          aggs = Counter(labels)
          print(aggs)
          agg.append(aggs)
            
            
      features = tokenizer(col2,col1,  padding=True, truncation=True, return_tensors="pt")
      features.to(device)

      model.eval()
      with torch.no_grad():
          scores = model(**features).logits
          # label_mapping = ['contradiction', 'entailment', 'neutral']
          label_mapping = ['contradiction', 'neutral','entailment']
          labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

          lists.append(labels)
          # print(labels)
          aggs = Counter(labels)
          print(aggs)
          agg.append(aggs)
            
      # break



In [27]:
def factual_opinion(reference):

  lists = []
  agg = []

  # factual opinons

  if reference == True:

    for data in fulldata:

      ref_a_sum = [nltk.sent_tokenize(i) for i in data['entity_a_summary']]
      ref_b_sum = [nltk.sent_tokenize(i) for i in data['entity_b_summary']]
      ref_comm_sum = [nltk.sent_tokenize(i) for i in data['common_summary']]

      source_a = [nltk.sent_tokenize(i) for i in data['entity_a_reviews']]
      source_b = [nltk.sent_tokenize(i) for i in data['entity_b_reviews']]

      for i in range(len(ref_a_sum)):

        rev_a = [k for j in source_a for k in itertools.product(ref_a_sum[i],j)]
        rev_b = [k for j in source_b for k in itertools.product(ref_b_sum[i],j)]
        rev_com_a = [k for j in source_a for k in itertools.product(ref_comm_sum[i],j)]
        rev_com_b = [k for j in source_b for k in itertools.product(ref_comm_sum[i],j)]


        col1 = [j[0] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]
        col2 = [j[1] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]


        features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
        features.to(device)

        model.eval()
        with torch.no_grad():
            scores = model(**features).logits
            label_mapping = ['contradiction', 'entailment', 'neutral']
            labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

            lists.append(labels)
            # print(labels)
            aggs = Counter(labels)
            print(aggs)
            agg.append(aggs)

  else:

    for data in fulldata:

      source_a = [nltk.sent_tokenize(i) for i in data['entity_a_reviews']]
      source_b = [nltk.sent_tokenize(i) for i in data['entity_b_reviews']]

      gen_a_sum = nltk.sent_tokenize(data['gen_cont_a'])
      gen_b_sum = nltk.sent_tokenize(data['gen_cont_b'])
      gen_comm_sum = nltk.sent_tokenize(data['gen_comm_a'])

      rev_gen_a = [j for i in source_a for j in itertools.product(i,gen_a_sum)]
      rev_gen_b = [j for i in source_b for j in itertools.product(i,gen_b_sum)]
      rev_com_a = [j for i in source_a for j in itertools.product(i,gen_comm_sum)]
      rev_com_b = [j for i in source_b for j in itertools.product(i,gen_comm_sum)]

      col1 = [j[0] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]
      col2 = [j[1] for i in [rev_gen_a,rev_gen_b, rev_com_a, rev_com_b] for j in i]

      features = tokenizer(col1,col2,  padding=True, truncation=True, return_tensors="pt")
      features.to(device)

      model.eval()
      with torch.no_grad():
          scores = model(**features).logits
          label_mapping = ['contradiction', 'entailment', 'neutral']
          # label_mapping = ['contradiction', 'neutral','entailment']
          labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1).detach().cpu().numpy()]

          lists.append(labels)
          # print(labels)
          aggs = Counter(labels)
          print(aggs)
          agg.append(aggs)



In [9]:
for data in fulldata:

      gen_a_sum = nltk.sent_tokenize(data['gen_cont_a'])
      gen_b_sum = nltk.sent_tokenize(data['gen_cont_b'])
      gen_comm_sum = nltk.sent_tokenize(data['gen_comm_a'])

      cont = list(itertools.product(gen_a_sum, gen_b_sum))
      comm_a = list(itertools.product(gen_a_sum, gen_comm_sum))
      comm_b = list(itertools.product(gen_b_sum, gen_comm_sum))

      col1 = [j[0] for i in [cont,comm_a,comm_b] for j in i]
      col2 = [j[1] for i in [cont,comm_a,comm_b] for j in i]

      break

In [22]:
col1[2]

"Le Place D'Armes is an elegant historic building with a fantastic location."

In [21]:
col2[2]

'The room was slightly smaller than average although it was clean and with a slightly more modern decor than a typical Fairmont.'

In [26]:
contrast(False)

Le Place D'Armes is an elegant historic building with a fantastic location.
The room was slightly smaller than average although it was clean and with a slightly more modern decor than a typical Fairmont.
['entailment']
Counter({'entailment': 1})


In [35]:
contrast(False)

Counter({'neutral': 54, 'contradiction': 4, 'entailment': 1})
Counter({'neutral': 55, 'contradiction': 4})
Counter({'neutral': 62, 'contradiction': 9})
Counter({'neutral': 60, 'contradiction': 11})
Counter({'neutral': 54, 'contradiction': 3, 'entailment': 3})
Counter({'neutral': 55, 'contradiction': 5})
Counter({'neutral': 68, 'contradiction': 7, 'entailment': 2})
Counter({'neutral': 74, 'contradiction': 3})
Counter({'neutral': 57, 'contradiction': 2})
Counter({'neutral': 53, 'contradiction': 5, 'entailment': 1})
Counter({'neutral': 30, 'contradiction': 6, 'entailment': 2})
Counter({'neutral': 27, 'contradiction': 7, 'entailment': 4})
Counter({'neutral': 49, 'contradiction': 8, 'entailment': 2})
Counter({'neutral': 52, 'contradiction': 7})
Counter({'neutral': 45, 'contradiction': 2})
Counter({'neutral': 44, 'contradiction': 3})
Counter({'neutral': 57, 'contradiction': 2, 'entailment': 1})
Counter({'neutral': 54, 'contradiction': 5, 'entailment': 1})
Counter({'neutral': 66, 'contradicti

In [25]:
contrast(False)

Counter({'entailment': 54, 'contradiction': 4, 'neutral': 1})
Counter({'entailment': 55, 'contradiction': 4})
Counter({'entailment': 62, 'contradiction': 9})
Counter({'entailment': 60, 'contradiction': 11})
Counter({'entailment': 54, 'contradiction': 3, 'neutral': 3})
Counter({'entailment': 55, 'contradiction': 5})
Counter({'entailment': 68, 'contradiction': 7, 'neutral': 2})
Counter({'entailment': 74, 'contradiction': 3})
Counter({'entailment': 57, 'contradiction': 2})
Counter({'entailment': 53, 'contradiction': 5, 'neutral': 1})
Counter({'entailment': 30, 'contradiction': 6, 'neutral': 2})
Counter({'entailment': 27, 'contradiction': 7, 'neutral': 4})
Counter({'entailment': 49, 'contradiction': 8, 'neutral': 2})
Counter({'entailment': 52, 'contradiction': 7})
Counter({'entailment': 45, 'contradiction': 2})
Counter({'entailment': 44, 'contradiction': 3})
Counter({'entailment': 57, 'contradiction': 2, 'neutral': 1})
Counter({'entailment': 54, 'contradiction': 5, 'neutral': 1})
Counter({

In [23]:
factual_opinion(False)

Counter({'entailment': 629, 'contradiction': 28, 'neutral': 17})
Counter({'entailment': 705, 'contradiction': 79, 'neutral': 6})
Counter({'entailment': 535, 'contradiction': 73, 'neutral': 16})
Counter({'entailment': 701, 'contradiction': 55, 'neutral': 27})
Counter({'entailment': 714, 'contradiction': 30, 'neutral': 17})
Counter({'entailment': 468, 'contradiction': 49, 'neutral': 29})
Counter({'entailment': 547, 'contradiction': 118, 'neutral': 17})
Counter({'entailment': 572, 'contradiction': 34, 'neutral': 6})
Counter({'entailment': 560, 'contradiction': 18, 'neutral': 14})
Counter({'entailment': 622, 'contradiction': 114, 'neutral': 20})
Counter({'entailment': 529, 'contradiction': 64, 'neutral': 10})
Counter({'entailment': 505, 'contradiction': 20, 'neutral': 16})
Counter({'entailment': 614, 'contradiction': 92, 'neutral': 18})
Counter({'entailment': 744, 'contradiction': 43, 'neutral': 22})
Counter({'entailment': 574, 'contradiction': 50, 'neutral': 18})
Counter({'entailment': 61