# SNLI 데이터에 대해서 locate을 실행해본 코드
locate한 부분을 mask해서 그러고 났을 때 contradiction probability가 어떻게 달라지는지를 

locate accuracy 대신 써보려고 시도했었다. 

그런데 그 방식이 엄밀하게 locate accuracy를 대신할 수 있는가에 대해서는 고민이 더 필요한 상황임.


In [7]:
import random

import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from new_module.locate.new_locate_utils import LocateMachine

In [8]:
## nli dataset load
## nli model load
## locate utils 불러와서, nli model을 backprop했을 때 locate되는 instance 분석 (random 30개)

nli_model = AutoModelForSequenceClassification.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
nli_tokenizer = AutoTokenizer.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")


nli_dataset = load_dataset('stanfordnlp/snli')

Some weights of the model checkpoint at ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
class Args:
    task = 'nli'

locator = LocateMachine(nli_model, nli_tokenizer, Args())

In [10]:
## get test examples
random.seed(999)

contradiction_indexes = [i for i, x in enumerate(nli_dataset['test']['label']) if x == 2]

random_indexes = random.sample(contradiction_indexes, 30)

test_examples = nli_dataset['test'][random_indexes]

test_examples_tuples = list(zip(*(test_examples['premise'], test_examples['hypothesis'])))

In [11]:
for example in test_examples_tuples:
    
    print("Premise:", example[0])
    print("Hypothesis:", example[1])
    print("\n") 

Premise: Girl wearing white shirt sings on stage while playing guitar
Hypothesis: A man play the triangle.


Premise: White dog playing in the snow.
Hypothesis: The snow is purple.


Premise: Three children hold a boy's arms down while another boy in a hat shoots a water gun at him.
Hypothesis: The three girls hold the boy down while his girlfriend shoots water at him.


Premise: A boy in a hat and glasses is playing a guitar.
Hypothesis: A boy is playing piano.


Premise: A police person is on a motorcycle on the side of a street.
Hypothesis: The fireman is sitting in his truck.


Premise: A lady and a man in a hat watch baseball from the stands.
Hypothesis: A couple are riding a rollercoaster.


Premise: A group of young boys wearing track jackets stretch their legs on a gym floor as they sit in a circle.
Hypothesis: A group of boys are studying for a test.


Premise: A woman with a white blanket over her head is holding a baby wrapped in a blue, pink, and yellow blanket.
Hypothesis:

In [17]:
result = locator.locate_main(test_examples_tuples, 'grad_norm', max_num_tokens = 100, unit="word", label_id=2)

prediction [('Girl wearing white shirt sings on stage while playing guitar', 'A man play the triangle.'), ('White dog playing in the snow.', 'The snow is purple.'), ("Three children hold a boy's arms down while another boy in a hat shoots a water gun at him.", 'The three girls hold the boy down while his girlfriend shoots water at him.'), ('A boy in a hat and glasses is playing a guitar.', 'A boy is playing piano.'), ('A police person is on a motorcycle on the side of a street.', 'The fireman is sitting in his truck.'), ('A lady and a man in a hat watch baseball from the stands.', 'A couple are riding a rollercoaster.'), ('A group of young boys wearing track jackets stretch their legs on a gym floor as they sit in a circle.', 'A group of boys are studying for a test.'), ('A woman with a white blanket over her head is holding a baby wrapped in a blue, pink, and yellow blanket.', 'The blanket is black.'), ('The man in the black shirt is showing the man in the orange shirt something that 

In [25]:
## sample에 대한 output score 확인
tokens = nli_tokenizer(test_examples_tuples, return_tensors='pt', padding=True, truncation=True)
outputs = nli_model(**tokens)

probas_before_masking = torch.softmax(outputs.logits, dim=1)
probas_before_masking_energy = probas_before_masking[:, 0]
class_before_masking = torch.argmax(probas_before_masking,dim=-1)

In [26]:
## masking 된 샘플에 대한 energy score 확인
tokens_after_masking = nli_tokenizer(result, add_special_tokens=False, return_tensors='pt', padding=True, truncation=True)
probas_after_masking =  torch.softmax(nli_model(**tokens_after_masking).logits,dim=-1)
probas_after_masking_energy = probas_after_masking[:, 0]
class_after_masking = torch.argmax(probas_after_masking,dim=-1)

In [21]:
## observation 1: 어떤 sample들은 그닥 contradictive하지 않은 것 같다. --> 거를 방식이 있을까?
## observation 2: locate된 token들이 contradictive한 부분을 잘 반영하고 있는 것 같다. 

classes = ["Entail", "Neutral", "Contradict"]
for i, ((p, h), y) in enumerate(zip(test_examples_tuples, result)):
    print("Premise: ", p)
    print("Hypothesis: ", h)
    print("Model proba: ", probas_before_masking[i].item(), "->", probas_after_masking[i].item())
    print("Predicted class: ", classes[class_before_masking[i].item()], "->", classes[class_after_masking[i].item()])
    print("Located p: ", y.split('</s></s>')[0])
    print("Located h: ", y.split('</s></s>')[1])
    print("\n")

Premise:  Girl wearing white shirt sings on stage while playing guitar
Hypothesis:  A man play the triangle.
Model proba:  0.9995410442352295 -> 0.999459445476532
Located p:  <s>Girl wearing white shirt sings on stage while playing guitar
Located h:  A man play the<mask><mask></s>


Premise:  White dog playing in the snow.
Hypothesis:  The snow is purple.
Model proba:  0.7578284740447998 -> 0.03467955440282822
Located p:  <s>White dog playing in the snow.
Located h:  The snow is<mask><mask></s>


Premise:  Three children hold a boy's arms down while another boy in a hat shoots a water gun at him.
Hypothesis:  The three girls hold the boy down while his girlfriend shoots water at him.
Model proba:  0.4576328694820404 -> 0.01893182471394539
Located p:  <s>Three children hold a boy's arms down while another boy in a hat shoots a water gun at him.
Located h:  The three<mask> hold the<mask> down while his<mask> shoots water at him.</s>


Premise:  A boy in a hat and glasses is playing a gui