In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

import dataset
from tqdm import tqdm
import torch
import numpy as np
import transformers




class PromptConfig():
    def __init__(self, few_shot, BERT_PATH="bert-base-uncased", ) -> None:
        """
        few_shot=['0', '32', '64', '128']
        """
        self.few_shot = few_shot
        self.DEVICE = "cuda:1"
        self.MAX_LEN = 256
        self.TRAIN_BATCH_SIZE = 8
        self.VALID_BATCH_SIZE = 4
        self.EPOCHS = 5
        
        # 训练参数
        self.eps_thres=1e-4 
        self.es_max=5  # early stop

        self.BERT_PATH = BERT_PATH
        self.MODEL_PATH = "/home/18307110500/pj3_workplace/pytorch_model.bin"
        data_dir ="/home/18307110500/data"

        if few_shot is not None and few_shot == '0':
            self.TRAINING_FILE = None
        elif few_shot is not None and os.path.exists(f"{data_dir}/train_{few_shot}.data"):
            self.TRAINING_FILE =f"{data_dir}/train_{few_shot}.data"
        else:
            self.TRAINING_FILE = f"{data_dir}/train.data"

        self.VALIDATION_FILE = f"{data_dir}/valid.data"
        self.TEST_FILE = f"{data_dir}/test.data"
        self.TOKENIZER= transformers.BertTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)

        # prompt
        # label转换为id
        # self.verbalizer=['negative', 'positive']
        self.verbalizer=['bad', 'great']
        self.candidate_ids = [self.TOKENIZER._convert_token_to_id(_) for _ in self.verbalizer]

        self.template =  "It is a {} film." # .format('[MASK]')
        self.mask = self.TOKENIZER.mask_token # '[MASK]'

In [2]:
''' 0  '''
''' 基于BertMaskedML的zero shot '''
config = PromptConfig(BERT_PATH="bert-base-uncased", few_shot="0") # zero shot

# model: bert masked lm
model_bert = transformers.BertForMaskedLM.from_pretrained(config.BERT_PATH)
bert_tokenzier = config.TOKENIZER

bert_tokenzier: transformers.BertTokenizer

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class PromptDataset(dataset.BERTDataset):
    def __init__(self, review, target, config):
        super(PromptDataset, self).__init__(review, target, config)

        self.template = config.template # "It is a {} film." # [MASK]
        self.mask = config.mask# '[MASK]' # bert_tokenzier.mask_token
        
    # sep = bert_tokenzier.sep_token
    def make_prompt(self, input_data):
        input_trans = f"{input_data} {self.template.format(self.mask)}"
        return input_trans

    
    
    def getReview(self, item):
        review = super().getReview(item)
        review_trans = self.make_prompt(review)
        return review_trans
        

_, valid_dir= dataset.read_data(config.VALIDATION_FILE)
valid_dataset = PromptDataset(valid_dir['x'], valid_dir['y'],config=config)
valid_data_loader = valid_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)

print(valid_dataset.getReview(0), valid_dataset.target[0])
# "nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic .  It is a [MASK] review." 0


def get_logits_of_mask(input_ids,logits, tok='[MASK]', tokenzier=bert_tokenzier):
    """
    Args:
        inputs_tok (tensor): 输入字符串经过tokenized得到的字典
        tok (str, optional): 可以是'[MASK]'或任意word token. Defaults to '[MASK]'.

    Returns:
        (tensor, tensor): 返回mask处的logits，返回mask的列索引

    Tips: 可以传入多个batch size

    Modify: 改为torch实现
    """

    # find_idx_of_tok_in_seq
    tok_id =  tokenzier._convert_token_to_id(tok)
    ids_of_mask_in_seq = torch.nonzero(input_ids == tok_id)[:,1] ## 得到mask的列索引
    
    # convert to tensor
    logits_tok = torch.stack([logits[idx, ids_of_mask_in_seq[idx],:]for idx in range(logits.size(0))])

    # logits_tok.size() # [4, 30522]=[batch size, vocab size]
    return logits_tok, ids_of_mask_in_seq


    
device = config.DEVICE
model_bert.to(device)

def count_acc(pred, target):
    acc_count = np.sum(np.array(pred)== np.array(target))
    return acc_count/len(pred)



def eval_prompt(data_loader, model, device):
    _targets = []
    _outputs = []
    _logits = []
    _mask_ids = []
    model.eval()
    
    with torch.no_grad():
        for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            
            input_token = {
                k[1]: d[k[0]].to(device) 
                for k in dict_keys}

            res_batch = model(**input_token)
            
            logits = res_batch.logits
            logits = logits.cpu()
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok='[MASK]',tokenzier=bert_tokenzier)

            # optional: 计算整个vocab上的softmax score
            # logits_mask = torch.softmax(logits_mask, dim=-1)
            
            # 比较哪个label的mask填空可能性更大
            labels_pr = logits_mask[:, config.candidate_ids]
            pred = [np.argmax(_) for _ in labels_pr]
                
            _targets.extend(d['targets'])
            _outputs.extend(pred)
            _logits.extend(logits_mask)
            _mask_ids.extend(mask_ids)
            torch.cuda.empty_cache()
            # break
    return _targets,_outputs,_logits,_mask_ids


fin_targets,fin_outputs,fin_logits,fin_mask_ids = eval_prompt(valid_data_loader, model_bert, device)
print(f"avg acc:{count_acc(fin_outputs,fin_targets)} | pred.sum: {np.sum(fin_outputs)} | target.sum: {np.sum(fin_targets)}")

# samples: 1000
nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic . It is a [MASK] film. 0


100%|██████████| 250/250 [00:32<00:00,  7.73it/s]

avg acc:0.668 | pred.sum: 654 | target.sum: 510.0





In [4]:
def show_topk_cloze(logits_mask,top_k =10):
    
    # 根据logits排序
    top_inds = list(reversed(np.argsort(logits_mask)))
    res_top_k = []
    for i in top_inds:
        res_i = {
            "token_id":i.item(),
            "token_str": bert_tokenzier._convert_id_to_token(i.item()),
            "raw_score": logits_mask[i.item()] # 未经过softmax的分数
            }
        res_top_k.append(res_i)
        if len(res_top_k) >= top_k:
            break

    return res_top_k # 查看top k预测的填空

In [5]:
idx = 3
logits = fin_logits[idx]
pred = fin_outputs[idx]

target = valid_dataset.target[idx]
sequence = valid_dataset.getReview(idx)
ids = valid_dataset[idx]['ids']
(f"sequence: \'{sequence}\', target: {target}, pred: {pred}", show_topk_cloze(logits, top_k=40))

("sequence: 'this angst-ridden territory was covered earlier and much better in ordinary people . It is a [MASK] film.', target: 0, pred: 0",
 [{'token_id': 2439, 'token_str': 'lost', 'raw_score': tensor(8.7598)},
  {'token_id': 4516, 'token_str': 'documentary', 'raw_score': tensor(7.9718)},
  {'token_id': 2204, 'token_str': 'good', 'raw_score': tensor(7.8390)},
  {'token_id': 2460, 'token_str': 'short', 'raw_score': tensor(7.6915)},
  {'token_id': 3376, 'token_str': 'beautiful', 'raw_score': tensor(7.5731)},
  {'token_id': 16046, 'token_str': 'bollywood', 'raw_score': tensor(7.2638)},
  {'token_id': 8754, 'token_str': 'cult', 'raw_score': tensor(7.0753)},
  {'token_id': 4333, 'token_str': 'silent', 'raw_score': tensor(7.0202)},
  {'token_id': 5469, 'token_str': 'horror', 'raw_score': tensor(7.0167)},
  {'token_id': 3444, 'token_str': 'feature', 'raw_score': tensor(6.7615)},
  {'token_id': 2919, 'token_str': 'bad', 'raw_score': tensor(6.7300)},
  {'token_id': 2759, 'token_str': 'popula

In [7]:
model = model_bert

_, test_dir= dataset.read_data(config.TEST_FILE, test=True)
test_dataset = PromptDataset(test_dir['x'], test_dir['y'],config=config)
test_data_loader = test_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)

# samples: 1066


In [8]:
# 1. 测试
test_record = eval_prompt(test_data_loader, model, device)
# targets ,outputs ,logits ,mask_ids
test_preds = test_record[1]

# 2. open文件写入结果
eval_acc = count_acc(fin_outputs,fin_targets)
with open(f'saved/few_shot{config.few_shot}_eval{eval_acc}_res.txt',encoding="utf-8", mode='w') as f:
    for pred in test_preds:
        f.write("positive" if pred==1 else 'negative')
        f.write('\n')

100%|██████████| 267/267 [00:32<00:00,  8.20it/s]
