# SNLI 데이터에 대해서 locate을 실행해본 코드
locate한 부분을 mask해서 그러고 났을 때 contradiction probability가 어떻게 달라지는지를 

locate accuracy 대신 써보려고 시도했었다. 

그런데 그 방식이 엄밀하게 locate accuracy를 대신할 수 있는가에 대해서는 고민이 더 필요한 상황임.


In [1]:
import random

import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from new_module.locate.new_locate_utils import LocateMachine


import os
os.chdir('/data/hyeryung/mucoco')
import yaml

import wandb
from tqdm import tqdm
import torch
import torch.nn as nn
import pandas as pd
from torch.optim import AdamW
import seaborn as sns

import numpy as np
import matplotlib.pyplot as plt
from transformers import get_linear_schedule_with_warmup, AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_fscore_support, accuracy_score


from new_module.em_training.nli.models import EncoderModel
from new_module.em_training.nli.data_handling import load_nli_data, load_nli_test_data, NLI_Dataset, NLI_DataLoader
from new_module.em_training.nli.train import *
from new_module.em_training.nli.train_modules import *




In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ext_model = AutoModelForSequenceClassification.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
ext_model = ext_model.to(device)
ext_tokenizer = AutoTokenizer.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")

In [128]:
## nli dataset load
## nli model load
## locate utils 불러와서, nli model을 backprop했을 때 locate되는 instance 분석 (random 30개)

model_path = "models/nli/roberta_large_snli_mnli_anli_train_dev_with_finegrained_finegrained_labels_negative_log_odds/1726245570/best_model.pth"
config = load_config('new_module/em_training/config.yaml')
config['device'] = device

model = EncoderModel(config)
model = model.to(config['device'])
model.load_state_dict(torch.load(model_path))

config['model_path'] = model_path

# nli_model = AutoModelForSequenceClassification.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
# nli_tokenizer = AutoTokenizer.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")

nli_model = model
nli_tokenizer = model.tokenizer
nli_dataset = load_dataset('stanfordnlp/snli')

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [129]:
# config['energynet']['energy_col'] = 0

In [130]:
class Args:
    task = 'nli'

locator = LocateMachine(nli_model, nli_tokenizer, Args())

In [131]:
## get test examples
random.seed(999)

contradiction_indexes = [i for i, x in enumerate(nli_dataset['test']['label']) if x == 2]

random_indexes = random.sample(contradiction_indexes, 30)

test_examples = nli_dataset['test'][random_indexes]

test_examples_tuples = list(zip(*(test_examples['premise'], test_examples['hypothesis'])))

In [132]:
for example in test_examples_tuples:
    
    print("Premise:", example[0])
    print("Hypothesis:", example[1])
    print("\n") 

Premise: Girl wearing white shirt sings on stage while playing guitar
Hypothesis: A man play the triangle.


Premise: White dog playing in the snow.
Hypothesis: The snow is purple.


Premise: Three children hold a boy's arms down while another boy in a hat shoots a water gun at him.
Hypothesis: The three girls hold the boy down while his girlfriend shoots water at him.


Premise: A boy in a hat and glasses is playing a guitar.
Hypothesis: A boy is playing piano.


Premise: A police person is on a motorcycle on the side of a street.
Hypothesis: The fireman is sitting in his truck.


Premise: A lady and a man in a hat watch baseball from the stands.
Hypothesis: A couple are riding a rollercoaster.


Premise: A group of young boys wearing track jackets stretch their legs on a gym floor as they sit in a circle.
Hypothesis: A group of boys are studying for a test.


Premise: A woman with a white blanket over her head is holding a baby wrapped in a blue, pink, and yellow blanket.
Hypothesis:

In [133]:
result = locator.locate_main(test_examples_tuples, 'grad_norm', max_num_tokens = 100, unit="word", label_id=0)

In [134]:
## sample에 대한 output score 확인
tokens = nli_tokenizer(test_examples_tuples, return_tensors='pt', padding=True, truncation=True)
tokens = tokens.to(config['device'])
try:
    outputs = nli_model(**tokens)
    logits, hidden_states = outputs['logits'], outputs['hidden_states']
except:
    logits, hidden_states = nli_model(**tokens)

if config['energynet']['output_form'] in ['2dim_vec', '3dim_vec']: 
    probas_before_masking = torch.softmax(logits, dim=1)
else:
    probas_before_masking = logits

class_before_masking = torch.argmax(probas_before_masking,dim=-1)
probas_before_masking = probas_before_masking[:, config['energynet']['energy_col']]

In [135]:
## sample에 대한 external model의 output score 확인

tokens = ext_tokenizer(test_examples_tuples, return_tensors='pt', padding=True, truncation=True)
tokens = tokens.to(config['device'])
outputs = ext_model(**tokens)
logits = outputs['logits']
probas_before_masking_ext = torch.softmax(logits, dim=1)
class_before_masking_ext = torch.argmax(probas_before_masking_ext,dim=-1)
probas_before_masking_ext = probas_before_masking_ext[:, 2]

In [136]:
## masking 된 샘플에 대한 energy score 확인
tokens_after_masking = nli_tokenizer(result, add_special_tokens=False, return_tensors='pt', padding=True, truncation=True)
tokens_after_masking = tokens_after_masking.to(config['device'])

try:
    outputs = nli_model(**tokens_after_masking)
    logits_after_masking, _ = outputs['logits'], outputs['hidden_states']
except:
    logits_after_masking, _ = nli_model(**tokens_after_masking)
    
if config['energynet']['output_form'] in ['2dim_vec', '3dim_vec']: 
    probas_after_masking = torch.softmax(logits_after_masking, dim=1)
else:
    probas_after_masking = logits_after_masking

class_after_masking = torch.argmax(probas_after_masking,dim=-1)
probas_after_masking = probas_after_masking[:, config['energynet']['energy_col']]

In [137]:
## masking 된 샘플에 대한 external model의 energy score 확인
tokens_after_masking = ext_tokenizer(result, add_special_tokens=False, return_tensors='pt', padding=True, truncation=True)
tokens_after_masking = tokens_after_masking.to(config['device'])

outputs = ext_model(**tokens_after_masking)
logits_after_masking = outputs['logits']
    
probas_after_masking_ext = torch.softmax(logits_after_masking, dim=1)
class_after_masking_ext = torch.argmax(probas_after_masking_ext,dim=-1)
probas_after_masking_ext = probas_after_masking_ext[:, 2]

In [138]:
## observation 1: 어떤 sample들은 그닥 contradictive하지 않은 것 같다. --> 거를 방식이 있을까?
## observation 2: locate된 token들이 contradictive한 부분을 잘 반영하고 있는 것 같다. 

classes = ["Entail", "Neutral", "Contradict"]
for i, ((p, h), y) in enumerate(zip(test_examples_tuples, result)):
    print("Premise: ", p)
    print("Hypothesis: ", h)
    print("Model proba: ", probas_before_masking[i].item(), "->", probas_after_masking[i].item())
    print("External model proba: ", probas_before_masking_ext[i].item(), "->", probas_after_masking_ext[i].item())
    # print("Predicted class: ", classes[class_before_masking[i].item()], "->", classes[class_after_masking[i].item()])
    print("Located p: ", y.split('</s></s>')[0])
    print("Located h: ", y.split('</s></s>')[1])
    print("\n")

Premise:  Girl wearing white shirt sings on stage while playing guitar
Hypothesis:  A man play the triangle.
Model proba:  7.2474565505981445 -> 6.789552211761475
External model proba:  0.9995410442352295 -> 0.999459445476532
Located p:  <s>Girl wearing white shirt sings on stage while playing guitar
Located h:  A man play the<mask><mask></s>


Premise:  White dog playing in the snow.
Hypothesis:  The snow is purple.
Model proba:  7.826869487762451 -> 1.765481948852539
External model proba:  0.7578288912773132 -> 0.034679610282182693
Located p:  <s>White dog playing in the snow.
Located h:  The snow is<mask><mask></s>


Premise:  Three children hold a boy's arms down while another boy in a hat shoots a water gun at him.
Hypothesis:  The three girls hold the boy down while his girlfriend shoots water at him.
Model proba:  3.261336326599121 -> 2.6492867469787598
External model proba:  0.457634836435318 -> 0.05039544776082039
Located p:  <s>Three children hold a boy's arms down while anot

# Analyze results

Run this code after running test_nli_locate.py for wandb runs that you want to compare.

In [1]:
import pandas as pd

In [None]:
data_source = pd.read_json()

In [2]:
data1 = pd.read_excel('/data/hyeryung/mucoco/dev_data_locating_result_merge_masks_raw_t6lbryef.xlsx')
data2 = pd.read_excel('/data/hyeryung/mucoco/dev_data_locating_result_merge_masks_raw_25n5gs0a.xlsx')

In [3]:
data1.columns

Index(['original', 'masked', 'original_class', 'masked_class',
       'original_contradiction_proba', 'masked_contradiction_proba'],
      dtype='object')

In [14]:
data1.columns = ['original', 'masked_t6l', 'original_class', 'masked_class_t6l',
       'original_contradiction_proba', 'masked_contradiction_proba_t6l']
data2.columns = ['original', 'masked_25n', 'original_class', 'masked_class_25n',
       'original_contradiction_proba', 'masked_contradiction_proba_25n']

In [17]:
data = pd.concat([data1,data2],axis=1)
data = data.iloc[:, [0,1,3,5,7,9,11]].copy()
print(data.columns)
print(data.shape)

Index(['original', 'masked_t6l', 'masked_class_t6l',
       'masked_contradiction_proba_t6l', 'masked_25n', 'masked_class_25n',
       'masked_contradiction_proba_25n'],
      dtype='object')
(1766, 7)


In [18]:
data['contra_class_yn_t6l'] = data['masked_contradiction_proba_t6l'].apply(lambda x: 1 if x > 0.5 else 0)
data['contra_class_yn_25n'] = data['masked_contradiction_proba_25n'].apply(lambda x: 1 if x > 0.5 else 0)

In [19]:
data.groupby(['contra_class_yn_t6l', 'contra_class_yn_25n']).size()
# t6l 이 못함 -> 실제 납득 가능한가?

contra_class_yn_t6l  contra_class_yn_25n
0                    0                      958
                     1                      126
1                    0                      210
                     1                      472
dtype: int64

In [10]:
data.shape

(28, 11)