In [11]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" # cuda:0, GPU1
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # cuda:0, GPU1
from time import time

import dataset
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import transformers
from  torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

In [12]:
class PromptConfig():
    def __init__(self, few_shot, BERT_PATH="bert-base-uncased", ) -> None:
        """
        few_shot=['0', '32', '64', '128']
        """
        self.few_shot = few_shot
        self.DEVICE = "cuda:0"
        self.MAX_LEN = 256
        self.TRAIN_BATCH_SIZE = 8
        self.VALID_BATCH_SIZE = 4
        self.EPOCHS = 10
        
        self.EARLY_STOP = 3
        
        self.eval_zero_shot = False # 是否测试zero shot
        
        # 训练参数
        self.eps_thres=1e-4 
        self.es_max=5  # early stop

        self.BERT_PATH = BERT_PATH
        self.MODEL_PATH = "/home/18307110500/pj3_workplace/pytorch_model.bin"
        data_dir ="/home/18307110500/data"

        if few_shot is not None and few_shot == '0':
            self.TRAINING_FILE = None
        elif few_shot is not None and os.path.exists(f"{data_dir}/train_{few_shot}.data"):
            self.TRAINING_FILE =f"{data_dir}/train_{few_shot}.data"
        else:
            self.TRAINING_FILE = f"{data_dir}/train.data"

        self.VALIDATION_FILE = f"{data_dir}/valid.data"
        self.TEST_FILE = f"{data_dir}/test.data"
        self.TOKENIZER= transformers.BertTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)

        # prompt
        # label转换为id
        # self.verbalizer=['negative', 'positive']
        self.verbalizer=['bad', 'great']
        self.candidate_ids = [self.TOKENIZER._convert_token_to_id(_) for _ in self.verbalizer]

        self.template =  "It is a {} film." # .format('[MASK]')
        self.mask = self.TOKENIZER.mask_token # '[MASK]'
        
''' 1  '''
''' 基于BertMaskedML的few shot '''
config = PromptConfig(BERT_PATH="bert-base-uncased", few_shot="32") # few shot

In [13]:
# model: bert masked lm
model_bert = transformers.BertForMaskedLM.from_pretrained(config.BERT_PATH)
bert_tokenzier = config.TOKENIZER

bert_tokenzier: transformers.BertTokenizer

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
class PromptDataset(dataset.BERTDataset):
    def __init__(self, review, target, config):
        super(PromptDataset, self).__init__(review, target, config)

        self.template = config.template # "It is a {} film." # [MASK]
        self.mask = config.mask# '[MASK]' # bert_tokenzier.mask_token
        
    # sep = bert_tokenzier.sep_token
    def make_prompt(self, input_data):
        input_trans = f"{input_data} {self.template.format(self.mask)}"
        return input_trans

    def getReview(self, item):
        review = super().getReview(item)
        review_trans = self.make_prompt(review)
        return review_trans
        

_, train_dir= dataset.read_data(config.TRAINING_FILE)
_, valid_dir= dataset.read_data(config.VALIDATION_FILE)

train_dataset = PromptDataset(train_dir['x'], train_dir['y'],config=config)
valid_dataset = PromptDataset(valid_dir['x'], valid_dir['y'],config=config)

valid_data_loader = valid_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)
train_data_loader = train_dataset.get_dataloader(batch_size=config.TRAIN_BATCH_SIZE)

print(train_dataset.getReview(0), train_dataset.target[0])
print(valid_dataset.getReview(0), valid_dataset.target[0])
# "nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic .  It is a [MASK] review." 0

# samples: 32
# samples: 1000
better than the tepid star trek : insurrection ; falls short of first contact because the villain could n't pick the lint off borg queen alice krige 's cape ; and finishes half a parsec ( a nose ) ahead of generations . It is a [MASK] film. 1
nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic . It is a [MASK] film. 0


In [15]:
def get_logits_of_mask(input_ids,logits, tok='[MASK]', tokenzier=bert_tokenzier):
    """
    Args:
        inputs_tok (tensor): 输入字符串经过tokenized得到的字典
        tok (str, optional): 可以是'[MASK]'或任意word token. Defaults to '[MASK]'.

    Returns:
        (tensor, tensor): 返回mask处的logits，返回mask的列索引

    Tips: 可以传入多个batch size

    Modify: 改为torch实现
    """

    # find_idx_of_tok_in_seq
    tok_id =  tokenzier._convert_token_to_id(tok)
    ids_of_mask_in_seq = torch.nonzero(input_ids == tok_id)[:,1] ## 得到mask的列索引
    
    # convert to tensor
    logits_tok = torch.stack([logits[idx, ids_of_mask_in_seq[idx],:]for idx in range(logits.size(0))])

    # logits_tok.size() # [4, 30522]=[batch size, vocab size]
    return logits_tok, ids_of_mask_in_seq


device = config.DEVICE
model_bert.to(device)

''' train: fine tune bert '''

def count_acc(pred, target):
    acc_count = np.sum(np.array(pred)== np.array(target))
    return acc_count/len(pred)

def loss_fn(outputs, targets):
    # sigmoid + cross entropy
    # print(outputs, targets)
    return nn.BCEWithLogitsLoss()(outputs.view(-1,1), targets.view(-1, 1))

param_optimizer = list(model_bert.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [ {
        "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ],
        "weight_decay": 0.001,
    },{"params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ],
        "weight_decay": 0.0,
    },
]
num_train_steps = int(len(train_dir['x']) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
optimizer = AdamW(optimizer_parameters, lr=3e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
)

def eval_prompt(data_loader, model, device):
    _targets = []
    _outputs = []
    _logits = []
    _mask_ids = []
    model.eval()
    
    with torch.no_grad():
        for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            
            input_token = {
                k[1]: d[k[0]].to(device) 
                for k in dict_keys}

            res_batch = model(**input_token)
            
            logits = res_batch.logits
            logits = logits.cpu()
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok='[MASK]',tokenzier=bert_tokenzier)

            # optional: 计算整个vocab上的softmax score
            # logits_mask = torch.softmax(logits_mask, dim=-1)
            
            # 比较哪个label的mask填空可能性更大
            labels_pr = logits_mask[:, config.candidate_ids]
            pred = [np.argmax(_) for _ in labels_pr]
                
            _targets.extend(d['targets'])
            _outputs.extend(pred)
            _logits.extend(logits_mask)
            _mask_ids.extend(mask_ids)
            torch.cuda.empty_cache()
            # break
    return _targets,_outputs,_logits,_mask_ids


# zero
if config.eval_zero_shot:
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    
    print("[Zero shot]")
    print(f"[Eval] Acc:{count_acc(fin_outputs_eval,fin_targets_eval)} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

ev_acc_his = []
tr_loss_his = []
tr_time_his=[]
early_stop_count = 0

for epoch in range(config.EPOCHS):
        
    # begin training
    model_bert.train()
    tr_time_s = time()

    tr_loss = []
        
    for bi, d in tqdm(enumerate(train_data_loader), total=len(train_data_loader)):
        dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
        targets = d['targets']
        input_token = {
                k[1]: d[k[0]].to(device) 
                for k in dict_keys}

        optimizer.zero_grad()
        res_batch = model_bert(**input_token)
        logits = res_batch.logits.cpu()

        ''' 取出mask位置上，candidate label对应的logits '''
        # mask 位置的预测logits
        logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok='[MASK]',tokenzier=bert_tokenzier)
        # logits_mask: (batch_size, vocab_size)
        
        # mask位置预测为candidate label的概率分数
        labels_pr = logits_mask[:, config.candidate_ids]
        labels_pr = torch.softmax(labels_pr, dim=-1)
        # labels_pr: (batch_size, 2)

        # 取出 positive 对应的分数 (negative = 1-positive)
        loss = loss_fn(labels_pr[:,1], targets)
        tr_loss.append(loss.cpu().detach().item())
        # print(loss) # 0.6433

        loss.backward()
        optimizer.step()
        scheduler.step()

    tr_time_his.append(time()-tr_time_s)
    tr_loss_his.append(np.mean(tr_loss))
    torch.cuda.empty_cache()

    # begin eval
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    
    ev_acc_his.append(count_acc(fin_outputs_eval,fin_targets_eval))

    loss_str = "{:.4f}".format(tr_loss_his[-1])
    print(f"[Train] Epoch: {epoch}/{config.EPOCHS} | Train Loss: {loss_str} | Train time: {tr_time_his[-1]}s")
    print(f"[Eval] Acc:{ev_acc_his[-1]} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

    best_acc = max(ev_acc_his[:-1]) if epoch > 1 else -10
    if ev_acc_his[-1] > best_acc: # > best acc
        torch.save(model_bert, f"fewshot{config.few_shot}-best.pth")
        print("[Best epoch]")
        # reset
        early_stop_count= 0
    if early_stop_count > config.EARLY_STOP: 
        print(f"[WARNING] early stop at epoch {epoch}.")
        break

100%|██████████| 4/4 [00:03<00:00,  1.02it/s]
100%|██████████| 250/250 [00:15<00:00, 15.92it/s]


[Train] Epoch: 0/10 | Train Loss: 0.6917 | Train time: 4.176789045333862s
[Eval] Acc:0.729 | pred.sum: 439 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:04<00:00,  1.01s/it]
100%|██████████| 250/250 [00:32<00:00,  7.59it/s]


[Train] Epoch: 1/10 | Train Loss: 0.6378 | Train time: 4.319622278213501s
[Eval] Acc:0.748 | pred.sum: 422 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:03<00:00,  1.01it/s]
100%|██████████| 250/250 [00:33<00:00,  7.57it/s]

[Train] Epoch: 2/10 | Train Loss: 0.5565 | Train time: 4.21187686920166s
[Eval] Acc:0.722 | pred.sum: 272 | target.sum: 510.0



100%|██████████| 4/4 [00:03<00:00,  1.11it/s]
100%|██████████| 250/250 [00:32<00:00,  7.59it/s]


[Train] Epoch: 3/10 | Train Loss: 0.5233 | Train time: 3.8851301670074463s
[Eval] Acc:0.776 | pred.sum: 470 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:03<00:00,  1.13it/s]
100%|██████████| 250/250 [00:33<00:00,  7.54it/s]

[Train] Epoch: 4/10 | Train Loss: 0.5063 | Train time: 3.763291597366333s
[Eval] Acc:0.775 | pred.sum: 587 | target.sum: 510.0



100%|██████████| 4/4 [00:03<00:00,  1.11it/s]
100%|██████████| 250/250 [00:33<00:00,  7.49it/s]


[Train] Epoch: 5/10 | Train Loss: 0.5124 | Train time: 3.8192882537841797s
[Eval] Acc:0.792 | pred.sum: 554 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:03<00:00,  1.13it/s]
100%|██████████| 250/250 [00:33<00:00,  7.56it/s]


[Train] Epoch: 6/10 | Train Loss: 0.5033 | Train time: 3.799058675765991s
[Eval] Acc:0.794 | pred.sum: 514 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:03<00:00,  1.12it/s]
100%|██████████| 250/250 [00:32<00:00,  7.58it/s]

[Train] Epoch: 7/10 | Train Loss: 0.5051 | Train time: 3.8150627613067627s
[Eval] Acc:0.792 | pred.sum: 492 | target.sum: 510.0



100%|██████████| 4/4 [00:03<00:00,  1.09it/s]
100%|██████████| 250/250 [00:33<00:00,  7.52it/s]

[Train] Epoch: 8/10 | Train Loss: 0.5033 | Train time: 3.8889896869659424s
[Eval] Acc:0.794 | pred.sum: 488 | target.sum: 510.0



100%|██████████| 4/4 [00:03<00:00,  1.08it/s]
100%|██████████| 250/250 [00:32<00:00,  7.59it/s]

[Train] Epoch: 9/10 | Train Loss: 0.5034 | Train time: 3.9282917976379395s
[Eval] Acc:0.794 | pred.sum: 486 | target.sum: 510.0





In [16]:
def show_topk_cloze(logits_mask,top_k =10):
    
    # 根据logits排序
    top_inds = list(reversed(np.argsort(logits_mask)))
    res_top_k = []
    for i in top_inds:
        res_i = {
            "token_id":i.item(),
            "token_str": bert_tokenzier._convert_id_to_token(i.item()),
            "raw_score": logits_mask[i.item()] # 未经过softmax的分数
            }
        res_top_k.append(res_i)
        if len(res_top_k) >= top_k:
            break

    return res_top_k # 查看top k预测的填空

In [17]:
idx = -1
if idx >= 0:
    logits = fin_logits_eval[idx]
    pred = fin_outputs_eval[idx]

    target = valid_dataset.target[idx]
    sequence = valid_dataset.getReview(idx)
    ids = valid_dataset[idx]['ids']
    (f"sequence: \'{sequence}\', target: {target}, pred: {pred}", show_topk_cloze(logits, top_k=40))

In [18]:
# 1. 测试
eval_acc = max(ev_acc_his) # best eval
model = torch.load(f"fewshot{config.few_shot}-best.pth")

_, test_dir= dataset.read_data(config.TEST_FILE, test=True)
test_dataset = PromptDataset(test_dir['x'], test_dir['y'],config=config)
test_data_loader = test_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)

test_record = eval_prompt(test_data_loader, model, device)
# targets ,outputs ,logits ,mask_ids
test_preds = test_record[1]

# 2. open文件写入结果
with open(f'saved/few_shot{config.few_shot}_eval{eval_acc}_res.txt',encoding="utf-8", mode='w') as f:
    for pred in test_preds:
        f.write("positive" if pred==1 else 'negative')
        f.write('\n')
print("Testing finish. Test results saved.")

# samples: 1066


100%|██████████| 267/267 [00:34<00:00,  7.67it/s]


Testing finish. Test results saved.


In [19]:
metric_rec = {
    'eval acc': ev_acc_his,
    'train loss': tr_loss_his ,
    'epoch time(s)': tr_time_his
}
data_f = pd.DataFrame(metric_rec)
data_f

Unnamed: 0,eval acc,train loss,epoch time(s)
0,0.729,0.691716,4.176789
1,0.748,0.637807,4.319622
2,0.722,0.556462,4.211877
3,0.776,0.523263,3.88513
4,0.775,0.50627,3.763292
5,0.792,0.512378,3.819288
6,0.794,0.50334,3.799059
7,0.792,0.505122,3.815063
8,0.794,0.503322,3.88899
9,0.794,0.50337,3.928292


In [23]:
avg_epo_time= np.average(tr_time_his)
print("{:.3f}s".format(avg_epo_time))

3.961s
