In [1]:
MODEL_NAME='bert-base-uncased'
POOLER_TYPE='cls'
TEMP = 1
DATA_TYPE = 'validation'

In [2]:
import torch 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
from datasets import load_dataset
data = load_dataset('THUDM/webglm-qa')

In [4]:
from transformers import AutoConfig, AutoTokenizer, AutoModel

tokenizer_kwargs = {"use_fast": 'use_fast_tokenizer'}
config = AutoConfig.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, **tokenizer_kwargs)

In [5]:
# from simcse import SimCSE
# simcse_model = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")
model = AutoModel.from_pretrained(MODEL_NAME).to(device)

In [6]:
test_dataset = data[DATA_TYPE]
queries, answers, references = test_dataset['question'], test_dataset['answer'], test_dataset['references']

In [8]:
documents = [' '.join(x) for x in references]

In [9]:
encoding_q = tokenizer(queries, padding=True, truncation=True, return_tensors='pt', max_length=64)
encoding_d = tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=64)

x_q, mask_q = encoding_q['input_ids'].to(device), encoding_q['attention_mask'].to(device)
x_d, mask_d = encoding_d['input_ids'].to(device), encoding_d['attention_mask'].to(device)

In [None]:
## # If you use AutoModel
output_q = model(x_q, mask_q)
output_d = model(x_d, mask_d)

## # If you use SimCSE
# output_q = simcse_model.encode(queries)
# output_d = simcse_model.encode(documents)

In [45]:
import torch.nn as nn

class Pooler(nn.Module):
    
    def __init__(self, pooler_type):
        super().__init__()
        self.pooler_type = pooler_type
        assert self.pooler_type in ["cls", "cls_before_pooler", "max", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type

    def forward(self, attention_mask, outputs):
        last_hidden = outputs.last_hidden_state
        pooler_output = outputs.pooler_output
        hidden_states = outputs.hidden_states

        if self.pooler_type in ['cls_before_pooler', 'cls']:
            return last_hidden[:, 0]
        elif self.pooler_type == "avg":
            return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
        elif self.pooler_type == "max":
            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
            )
            last_hidden[input_mask_expanded == 0] = -1e9
            max_over_time = torch.max(last_hidden, 1)[0]
            return max_over_time
        elif self.pooler_type == "avg_first_last":
            first_hidden = hidden_states[1]
            last_hidden = hidden_states[-1]
            pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
            return pooled_result
        elif self.pooler_type == "avg_top2":
            second_last_hidden = hidden_states[-2]
            last_hidden = hidden_states[-1]
            pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
            return pooled_result
        else:
            raise NotImplementedError
        
class Similarity(nn.Module):
    
    def __init__(self, temp):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp


In [46]:
# If you use AutoModel
pooler = Pooler(POOLER_TYPE)
z_q = pooler(mask_q, output_q)
z_d = pooler(mask_d, output_d)

In [84]:
sim = Similarity(TEMP)
# cos_sim = sim(z_q.unsqueeze(1), z_d.unsqueeze(0)) # If you use AutoModel
cos_sim = sim(output_q.unsqueeze(1), output_d.unsqueeze(0)) # If you use SimCSE

In [91]:
import numpy as np

def topk_metric(matrix, k=None):
    result = []
    for c in matrix:
        topk = sorted(c, reverse=True)[:k]
        diag = np.diag(matrix)
        is_topk = any(x in topk for x in np.diag(matrix))
        result.append(0 if is_topk else 1) # 정답이면 0 아니면 1
    return result

result_lst = topk_metric(cos_sim.cpu().detach().numpy(), k=30)
acc = result_lst.count(0) / len(result_lst) * 100
print(acc)

99.6
