In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" 

from time import time
import random

import dataset
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import transformers
from  torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

In [2]:
class PromptConfig():
    def __init__(self, few_shot, BERT_PATH="bert-base-uncased", ) -> None:
        """
        few_shot=['0', '32', '64', '128']
        """
        self.few_shot = few_shot
        self.DEVICE = "cuda:3"
        self.MAX_LEN = 256
        self.TRAIN_BATCH_SIZE = 8
        self.VALID_BATCH_SIZE = 4
        self.train_times = 1
        self.EPOCHS = 10*self.train_times
        
        self.EARLY_STOP = 3
        self.eval_zero_shot = False # 是否测试zero shot
        self.test_output = True # 是否测试输出
        self.UPDATE_VERBAL = False # 是否更新verbalizer
        self.use_demostration = True # 是否使用 prompt type4 demostration
        
        # 训练参数
        self.eps_thres=1e-4 
        self.es_max=5  # early stop

        self.BERT_PATH = BERT_PATH
        self.MODEL_PATH = "/home/18307110500/pj3_workplace/pytorch_model.bin"
        data_dir ="/home/18307110500/data"

        if few_shot is not None and few_shot == '0':
            self.TRAINING_FILE = None
        elif few_shot is not None and os.path.exists(f"{data_dir}/train_{few_shot}.data"):
            self.TRAINING_FILE =f"{data_dir}/train_{few_shot}.data"
        else:
            self.TRAINING_FILE = f"{data_dir}/train.data"

        self.VALIDATION_FILE = f"{data_dir}/valid.data"
        self.TEST_FILE = f"{data_dir}/test.data"

        BERT_PATH:str
        if BERT_PATH.startswith("bert"):
            self.TOKENIZER= transformers.BertTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)
            self.MODEL = transformers.BertForMaskedLM.from_pretrained(self.BERT_PATH)
            
        elif BERT_PATH.startswith("albert"):
            # AlbertTokenizer, AlbertForMaskedLM
            self.TOKENIZER= transformers.AlbertTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)
            self.MODEL = transformers.AlbertForMaskedLM.from_pretrained(self.BERT_PATH)
            
        elif BERT_PATH.startswith("roberta"):
            # RobertaTokenizer, RobertaForMaskedLM
            self.TOKENIZER= transformers.RobertaTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)
            self.MODEL = transformers.RobertaForMaskedLM.from_pretrained(self.BERT_PATH)
            
        # prompt
        # label转换为id
        self.mask = self.TOKENIZER.mask_token # '[MASK]'/'<mask>'
        # self.verbalizer=['negative', 'positive']
        self.verbalizer=['bad', 'great']
        self.template =  "It is a {} film ."
        # self.template =  "It was {} ." # .format('[MASK]')
        # self.verbalizer=['terrible', 'great']
        self.candidate_ids = [self.TOKENIZER._convert_token_to_id(_) for _ in self.verbalizer]
        
''' 基于 BertMaskedML的 few shot '''
paths= ["bert-base-uncased","bert-large-uncased", "albert-base-v2", "albert-large-v2", "roberta-base","roberta-large"]
config = PromptConfig(BERT_PATH=paths[0], few_shot="32") # few shot

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# model: bert masked lm
model_bert = config.MODEL 
bert_tokenzier = config.TOKENIZER

bert_tokenzier: transformers.BertTokenizer

class PromptDataset(dataset.BERTDataset):
    def __init__(self, review, target, config):
        super(PromptDataset, self).__init__(review, target, config)

        self.config = config
        self.config: PromptConfig

        self.template = config.template # "It is a {} film." # [MASK]
        self.mask = config.mask# '[MASK]' # bert_tokenzier.mask_token
        self.sep = config.TOKENIZER.sep_token

        self.use_demostration = config.use_demostration

    def _random_neg_pos(self, item):
        self_ = self
        item_i = item

        break_epo = 50
        rand_ids = [-1,-1] #0: negative
        
        
        while not (rand_ids[0]>-1 and rand_ids[1]>-1) and break_epo > 0:
            break_epo -= 1 
            
            label_ = 0 if rand_ids[0] == -1 else 1 # 判断是找pos的随机样本还是neg
            
            rand_i = random.randint(0, len(self_)-1) # 样本随机值
            if rand_i == item_i: # 不能等于自身
                continue
            if not self_.target[rand_i] == label_: # 需要指定pos或neg
                continue
            
            if break_epo <= 0:
                rand_ids[label_] = label_
                continue
            
            rand_ids[label_] = rand_i

        return rand_ids[0],rand_ids[1]

    
    def make_prompt(self, input_data, replace=config.mask):
        input_trans = f"{input_data} {self.template.format(replace)}"
        return input_trans
        

    def getReview(self, item):
        neg_item, pos_item = self._random_neg_pos(item)

        # demonstration
        
        # 样本
        review = super().getReview(item)
        review_trans = self.make_prompt(review)

        if not self.use_demostration:
            return review_trans
        
        # 随机负例
        review_neg = super().getReview(neg_item)
        review_trans_neg = self.make_prompt(review_neg, replace=self.config.verbalizer[0])
        
        # 随机正例
        review_pos = super().getReview(pos_item)
        review_trans_pos = self.make_prompt(review_pos,replace=self.config.verbalizer[1])

        # 随机正例先或负例先
        if random.randint(0,1) == 0:
            review_trans = f"{review_trans} {self.sep} {review_trans_neg} {self.sep} {review_trans_pos}"
        else:
            review_trans = f"{review_trans} {self.sep} {review_trans_pos} {self.sep} {review_trans_neg}"
        
        return review_trans
        

_, train_dir= dataset.read_data(config.TRAINING_FILE)
_, valid_dir= dataset.read_data(config.VALIDATION_FILE)

train_dataset = PromptDataset(train_dir['x'], train_dir['y'],config=config)
valid_dataset = PromptDataset(valid_dir['x'], valid_dir['y'],config=config)

valid_data_loader = valid_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)
train_data_loader = train_dataset.get_dataloader(batch_size=config.TRAIN_BATCH_SIZE)

print(train_dataset.getReview(0), train_dataset.target[0])
print(valid_dataset.getReview(0), valid_dataset.target[0])
# "nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic .  It is a [MASK] review." 0

# samples: 32
# samples: 1000
better than the tepid star trek : insurrection ; falls short of first contact because the villain could n't pick the lint off borg queen alice krige 's cape ; and finishes half a parsec ( a nose ) ahead of generations . It is a [MASK] film . [SEP] the powers team has fashioned a comedy with more laughs than many , no question . but this time there 's some mold on the gold . It is a great film . [SEP] build some robots , haul 'em to the theatre with you for the late show , and put on your own mystery science theatre 3000 tribute to what is almost certainly going to go down as the worst -- and only -- killer website movie of this or any other year . It is a bad film . 1
nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic . It is a [MASK] film . [SEP] it 's as raw and action-packed an experience as a ringside seat at a tough-man contest . It is a great film . [SEP] the movie 's heavy-handed screenplay navigates

In [4]:
def get_logits_of_mask(input_ids, logits, tok=config.mask, tokenzier=bert_tokenzier):
    """
    Args:
        inputs_tok (tensor): 输入字符串经过tokenized得到的字典
        tok (str, optional): 可以是'[MASK]'或任意word token. Defaults to '[MASK]'.

    Returns:
        (tensor, tensor): 返回mask处的logits，返回mask的列索引

    Tips: 可以传入多个batch size

    Modify: 改为torch实现

    """
    # find_idx_of_tok_in_seq
    tok_id =  tokenzier._convert_token_to_id(tok)
    ids_of_mask_in_seq = torch.nonzero(input_ids == tok_id)[:,1] ## 得到mask的列索引
    
    # convert to tensor
    logits_tok = torch.stack([logits[idx, ids_of_mask_in_seq[idx],:]for idx in range(logits.size(0))])

    # logits_tok.size() # [4, 30522]=[batch size, vocab size]
    return logits_tok, ids_of_mask_in_seq


''' train: fine tune bert '''

def count_acc(pred, target):
    acc_count = np.sum(np.array(pred) == np.array(target))
    return acc_count/len(pred)

def loss_fn(outputs, targets):
    # sigmoid + cross entropy
    # print(outputs, targets)
    return nn.BCEWithLogitsLoss()(outputs.view(-1,1), targets.view(-1, 1))


def get_logits(config, logits_mask):
    # init: candidate_ids = config.candidate_ids
    labels_pr = logits_mask[:, config.candidate_ids]
    return labels_pr


def show_topk_cloze(logits_mask,top_k =10):
    # 根据logits排序
    top_inds = list(reversed(np.argsort(logits_mask)))
    res_top_k = []
    for i in top_inds:
        res_i = {
            "token_id":i.item(),
            "token_str": bert_tokenzier._convert_id_to_token(i.item()),
            "raw_score": logits_mask[i.item()] # 未经过softmax的分数
            }
        res_top_k.append(res_i)
        if len(res_top_k) >= top_k:
            break

    return res_top_k # 查看top k预测的填空

In [5]:
def eval_prompt(data_loader, model, device):
    _targets = []
    _outputs = []
    _logits = []
    _mask_ids = []
    model.eval()
    
    with torch.no_grad():
        for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            
            input_token = {
                k[1]: d[k[0]].to(device) 
                for k in dict_keys}

            res_batch = model(**input_token)
            
            logits = res_batch.logits
            logits = logits.cpu()
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok=config.mask ,tokenzier=bert_tokenzier)

            # optional: 计算整个vocab上的softmax score
            # logits_mask = torch.softmax(logits_mask, dim=-1)
            
            # 取出verbalizer对应的logits
            labels_pr = get_logits(config, logits_mask)
            pred = [np.argmax(_) for _ in labels_pr]
                
            _targets.extend(d['targets'])
            _outputs.extend(pred)
            _logits.extend(logits_mask)
            _mask_ids.extend(mask_ids)
            torch.cuda.empty_cache()
            # break
    return _targets,_outputs,_logits,_mask_ids

In [6]:
device = config.DEVICE
model_bert.to(device)

# zero
if config.eval_zero_shot:
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    
    print("[Zero shot]")
    print(f"[Eval] Acc:{count_acc(fin_outputs_eval,fin_targets_eval)} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

In [7]:
param_optimizer = list(model_bert.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
# for _,__ in param_optimizer:
#     print(_)
optimizer_parameters = [ {
        "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ],
        "weight_decay": 0.001,
    },
        {"params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ],
        "weight_decay": 0.0,
    },
]

print("opt param: ",len(optimizer_parameters[0]['params']))
print("no opt",len(optimizer_parameters[1]['params']))

opt param:  76
no opt 126


In [8]:
num_train_steps = int(len(train_dir['x']) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
if len(optimizer_parameters[0]['params']) > 20:
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
else:
    # albert
    optimizer = AdamW(model_bert.parameters(), lr=3e-5)
    
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
)

In [9]:
ev_acc_his = []
tr_loss_his = []
tr_time_his=[]

early_stop_count = 0

for epoch in range(config.EPOCHS//config.train_times):
        
    # begin training
    model_bert.train()
    tr_time_s = time()

    tr_loss = []

    # config.train_times = 5
    for epo_tr in range(config.train_times):
        for bi, d in tqdm(enumerate(train_data_loader), total=len(train_data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            targets = d['targets']
            input_token = {
                    k[1]: d[k[0]].to(device) 
                    for k in dict_keys}

            optimizer.zero_grad()
            res_batch = model_bert(**input_token)
            logits = res_batch.logits.cpu()

            ''' 取出mask位置上，candidate label对应的logits '''
            # mask 位置的预测logits
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok=config.mask,tokenzier=bert_tokenzier)
            # logits_mask: (batch_size, vocab_size)

            ### 计算loss
            
            # 取出verbalizer对应的logits
            labels_pr = get_logits(config, logits_mask)
            # 概率分数
            labels_pr = torch.softmax(labels_pr, dim=-1)
            # labels_pr: (batch_size, 2)
            # 取出 positive 对应的分数 (negative = 1-positive)
            pred = labels_pr[:,1]
            loss = loss_fn(pred, targets)
            
            tr_loss.append(loss.cpu().detach().item())
            # print(loss) # 0.6433

            ### 反传

            loss.backward()
            optimizer.step()
            scheduler.step()

            torch.cuda.empty_cache()

    tr_time_his.append((time()-tr_time_s)/config.train_times)
    tr_loss_his.append(np.mean(tr_loss))

    # begin eval
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    
    ev_acc_his.append(count_acc(fin_outputs_eval,fin_targets_eval))

    loss_str = "{:.4f}".format(tr_loss_his[-1])
    print(f"[Train] Epoch: {epoch}/{config.EPOCHS} | Train Loss: {loss_str} | Train time: {tr_time_his[-1]}s")
    print(f"[Eval] Acc:{ev_acc_his[-1]} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

    best_acc = max(ev_acc_his[:-1]) if epoch > 1 else -10
    if ev_acc_his[-1] > best_acc: # > best acc
        torch.save(model_bert, f"fewshot{config.few_shot}-{config.BERT_PATH}-best.pth")
        print("[Best epoch]")
        # reset
        early_stop_count= 0
    if early_stop_count > config.EARLY_STOP: 
        print(f"[WARNING] early stop at epoch {epoch}.")
        break

100%|██████████| 4/4 [00:07<00:00,  1.87s/it]
100%|██████████| 250/250 [00:34<00:00,  7.32it/s]


[Train] Epoch: 0/10 | Train Loss: 0.6826 | Train time: 7.680593490600586s
[Eval] Acc:0.49 | pred.sum: 0 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]
100%|██████████| 250/250 [00:33<00:00,  7.53it/s]


[Train] Epoch: 1/10 | Train Loss: 0.6788 | Train time: 5.143148899078369s
[Eval] Acc:0.551 | pred.sum: 85 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:05<00:00,  1.28s/it]
100%|██████████| 250/250 [00:32<00:00,  7.61it/s]

[Train] Epoch: 2/10 | Train Loss: 0.6894 | Train time: 5.366856813430786s
[Eval] Acc:0.492 | pred.sum: 2 | target.sum: 510.0



100%|██████████| 4/4 [00:05<00:00,  1.27s/it]
100%|██████████| 250/250 [00:33<00:00,  7.55it/s]

[Train] Epoch: 3/10 | Train Loss: 0.6748 | Train time: 5.3484275341033936s
[Eval] Acc:0.53 | pred.sum: 52 | target.sum: 510.0



100%|██████████| 4/4 [00:05<00:00,  1.28s/it]
100%|██████████| 250/250 [00:33<00:00,  7.57it/s]


[Train] Epoch: 4/10 | Train Loss: 0.6359 | Train time: 5.357529163360596s
[Eval] Acc:0.695 | pred.sum: 413 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]
100%|██████████| 250/250 [00:32<00:00,  7.63it/s]


[Train] Epoch: 5/10 | Train Loss: 0.6140 | Train time: 5.137706756591797s
[Eval] Acc:0.699 | pred.sum: 519 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:04<00:00,  1.16s/it]
100%|██████████| 250/250 [00:33<00:00,  7.47it/s]


[Train] Epoch: 6/10 | Train Loss: 0.6040 | Train time: 4.91191291809082s
[Eval] Acc:0.712 | pred.sum: 536 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:04<00:00,  1.16s/it]
100%|██████████| 250/250 [00:32<00:00,  7.59it/s]


[Train] Epoch: 7/10 | Train Loss: 0.5745 | Train time: 4.879376411437988s
[Eval] Acc:0.724 | pred.sum: 496 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:04<00:00,  1.25s/it]
100%|██████████| 250/250 [00:33<00:00,  7.56it/s]

[Train] Epoch: 8/10 | Train Loss: 0.5547 | Train time: 5.2522454261779785s
[Eval] Acc:0.72 | pred.sum: 570 | target.sum: 510.0



100%|██████████| 4/4 [00:04<00:00,  1.21s/it]
100%|██████████| 250/250 [00:33<00:00,  7.52it/s]

[Train] Epoch: 9/10 | Train Loss: 0.5597 | Train time: 5.120924711227417s
[Eval] Acc:0.711 | pred.sum: 619 | target.sum: 510.0





In [10]:
idx = 100
if idx >= 0:
    logits = fin_logits_eval[idx]
    pred = fin_outputs_eval[idx]

    target = valid_dataset.target[idx]
    sequence = valid_dataset.getReview(idx)
    ids = valid_dataset[idx]['ids']
    print(f"sequence: \'{sequence}\', target: {target}, pred: {pred}", show_topk_cloze(logits, top_k=10))

sequence: 'i was feeling this movie until it veered off too far into the exxon zone , and left me behind at the station looking for a return ticket to realism . It is a [MASK] film . [SEP] although laced with humor and a few fanciful touches , the film is a refreshingly serious look at young women . It is a great film . [SEP] what happened with pluto nash ? how did it ever get made ? It is a bad film .', target: 0, pred: 0 [{'token_id': 2919, 'token_str': 'bad', 'raw_score': tensor(11.4127)}, {'token_id': 2204, 'token_str': 'good', 'raw_score': tensor(11.0810)}, {'token_id': 2307, 'token_str': 'great', 'raw_score': tensor(10.0681)}, {'token_id': 3376, 'token_str': 'beautiful', 'raw_score': tensor(9.8282)}, {'token_id': 4326, 'token_str': 'strange', 'raw_score': tensor(9.2696)}, {'token_id': 6919, 'token_str': 'wonderful', 'raw_score': tensor(8.5589)}, {'token_id': 12459, 'token_str': 'scary', 'raw_score': tensor(8.2381)}, {'token_id': 6659, 'token_str': 'terrible', 'raw_score': tensor(

In [11]:
# 1. 测试
eval_acc = max(ev_acc_his) # best eval

if config.test_output:
    model = torch.load(f"fewshot{config.few_shot}-{config.BERT_PATH}-best.pth")
    _, test_dir= dataset.read_data(config.TEST_FILE, test=True)
    test_dataset = PromptDataset(test_dir['x'], test_dir['y'],config=config)
    test_data_loader = test_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)

    test_record = eval_prompt(test_data_loader, model, device)
    # targets ,outputs ,logits ,mask_ids
    test_preds = test_record[1]

    # 2. open文件写入结果
    with open(f'saved/few_shot{config.few_shot}_eval{eval_acc}_res.txt',encoding="utf-8", mode='w') as f:
        for pred in test_preds:
            f.write("positive" if pred==1 else 'negative')
            f.write('\n')
    print("Testing finish. Test results saved.")

# samples: 1066


100%|██████████| 267/267 [00:35<00:00,  7.47it/s]

Testing finish. Test results saved.





In [12]:
metric_rec = {
    'epo':[(i+1)*config.train_times for i in range(len(ev_acc_his))],
    'eval acc': ev_acc_his,
    'train loss': tr_loss_his ,
    'epoch time(s)': tr_time_his
}
data_f = pd.DataFrame(metric_rec)
data_f

Unnamed: 0,epo,eval acc,train loss,epoch time(s)
0,1,0.49,0.682618,7.680593
1,2,0.551,0.678813,5.143149
2,3,0.492,0.689388,5.366857
3,4,0.53,0.674824,5.348428
4,5,0.695,0.635935,5.357529
5,6,0.699,0.61402,5.137707
6,7,0.712,0.603988,4.911913
7,8,0.724,0.574518,4.879376
8,9,0.72,0.554719,5.252245
9,10,0.711,0.559716,5.120925


In [13]:
avg_epo_time= np.average(tr_time_his)
print("model {} | fewshot {} | best acc {} | epoch {} | {:.3f}s".format(config.BERT_PATH ,config.few_shot,eval_acc, config.EPOCHS,avg_epo_time))

model bert-base-uncased | fewshot 32 | best acc 0.724 | epoch 10 | 5.420s
