In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" # cuda:0, GPU1
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # cuda:0, GPU1
from time import time

import dataset
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import transformers
from  torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

In [2]:
class PromptConfig():
    def __init__(self, few_shot, BERT_PATH="bert-base-uncased", ) -> None:
        """
        few_shot=['0', '32', '64', '128']
        """
        self.few_shot = few_shot
        self.DEVICE = "cuda:0"
        self.MAX_LEN = 256
        self.TRAIN_BATCH_SIZE = 8
        self.VALID_BATCH_SIZE = 4
        self.train_times = 2
        self.EPOCHS = 5*self.train_times
        
        self.EARLY_STOP = 3
        self.eval_zero_shot = True # 是否测试zero shot
        
        # 训练参数
        self.eps_thres=1e-4 
        self.es_max=5  # early stop

        self.BERT_PATH = BERT_PATH
        self.MODEL_PATH = "/home/18307110500/pj3_workplace/pytorch_model.bin"
        data_dir ="/home/18307110500/data"

        if few_shot is not None and few_shot == '0':
            self.TRAINING_FILE = None
        elif few_shot is not None and os.path.exists(f"{data_dir}/train_{few_shot}.data"):
            self.TRAINING_FILE =f"{data_dir}/train_{few_shot}.data"
        else:
            self.TRAINING_FILE = f"{data_dir}/train.data"

        self.VALIDATION_FILE = f"{data_dir}/valid.data"
        self.TEST_FILE = f"{data_dir}/test.data"

        BERT_PATH:str
        if BERT_PATH.startswith("bert"):
            self.TOKENIZER= transformers.BertTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)
            self.MODEL = transformers.BertForMaskedLM.from_pretrained(self.BERT_PATH)
            
        elif BERT_PATH.startswith("albert"):
            # AlbertTokenizer, AlbertForMaskedLM
            self.TOKENIZER= transformers.AlbertTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)
            self.MODEL = transformers.AlbertForMaskedLM.from_pretrained(self.BERT_PATH)
            
        elif BERT_PATH.startswith("roberta"):
            # RobertaTokenizer, RobertaForMaskedLM
            self.TOKENIZER= transformers.RobertaTokenizer.from_pretrained(self.BERT_PATH, do_lower_case=True)
            self.MODEL = transformers.RobertaForMaskedLM.from_pretrained(self.BERT_PATH)
            
            
        # prompt
        # label转换为id
        self.mask = self.TOKENIZER.mask_token # '[MASK]'/'<mask>'
        # self.verbalizer=['negative', 'positive']
        # self.verbalizer=['bad', 'great']
        # self.template =  "It is a {} film ."
        self.template =  "It was {} ." # .format('[MASK]')
        self.verbalizer=['terrible', 'great']
        self.candidate_ids = [self.TOKENIZER._convert_token_to_id(_) for _ in self.verbalizer]
        
''' 1  '''
''' 基于BertMaskedML的few shot '''
paths= ["bert-base-uncased","bert-large-uncased", "albert-base-v2", "albert-large-v2", "roberta-base","roberta-large"]
config = PromptConfig(BERT_PATH=paths[0], few_shot="32") # few shot

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# model: bert masked lm
model_bert = config.MODEL 
bert_tokenzier = config.TOKENIZER

bert_tokenzier: transformers.BertTokenizer

class PromptDataset(dataset.BERTDataset):
    def __init__(self, review, target, config):
        super(PromptDataset, self).__init__(review, target, config)

        self.template = config.template # "It is a {} film." # [MASK]
        self.mask = config.mask# '[MASK]' # bert_tokenzier.mask_token
        
    # sep = bert_tokenzier.sep_token
    def make_prompt(self, input_data):
        input_trans = f"{input_data} {self.template.format(self.mask)}"
        
        return input_trans

    def getReview(self, item):
        review = super().getReview(item)
        review_trans = self.make_prompt(review)
        return review_trans
        

_, train_dir= dataset.read_data(config.TRAINING_FILE)
_, valid_dir= dataset.read_data(config.VALIDATION_FILE)

train_dataset = PromptDataset(train_dir['x'], train_dir['y'],config=config)
valid_dataset = PromptDataset(valid_dir['x'], valid_dir['y'],config=config)

valid_data_loader = valid_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)
train_data_loader = train_dataset.get_dataloader(batch_size=config.TRAIN_BATCH_SIZE)

print(train_dataset.getReview(0), train_dataset.target[0])
print(valid_dataset.getReview(0), valid_dataset.target[0])
# "nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic .  It is a [MASK] review." 0

# samples: 32
# samples: 1000
better than the tepid star trek : insurrection ; falls short of first contact because the villain could n't pick the lint off borg queen alice krige 's cape ; and finishes half a parsec ( a nose ) ahead of generations . It was [MASK] . 1
nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic . It was [MASK] . 0

# samples: 1000
better than the tepid star trek : insurrection ; falls short of first contact because the villain could n't pick the lint off borg queen alice krige 's cape ; and finishes half a parsec ( a nose ) ahead of generations . It was [MASK] . 1
nothing about the film -- with the possible exception of elizabeth hurley 's breasts -- is authentic . It was [MASK] . 0


In [4]:
def get_logits_of_mask(input_ids,logits, tok='[MASK]', tokenzier=bert_tokenzier):
    """
    Args:
        inputs_tok (tensor): 输入字符串经过tokenized得到的字典
        tok (str, optional): 可以是'[MASK]'或任意word token. Defaults to '[MASK]'.

    Returns:
        (tensor, tensor): 返回mask处的logits，返回mask的列索引

    Tips: 可以传入多个batch size

    Modify: 改为torch实现

    """

    # find_idx_of_tok_in_seq
    tok_id =  tokenzier._convert_token_to_id(tok)
    ids_of_mask_in_seq = torch.nonzero(input_ids == tok_id)[:,1] ## 得到mask的列索引
    
    # convert to tensor
    logits_tok = torch.stack([logits[idx, ids_of_mask_in_seq[idx],:]for idx in range(logits.size(0))])

    # logits_tok.size() # [4, 30522]=[batch size, vocab size]
    return logits_tok, ids_of_mask_in_seq


''' train: fine tune bert '''

def count_acc(pred, target):
    acc_count = np.sum(np.array(pred) == np.array(target))
    return acc_count/len(pred)

def loss_fn(outputs, targets):
    # sigmoid + cross entropy
    # print(outputs, targets)
    return nn.BCEWithLogitsLoss()(outputs.view(-1,1), targets.view(-1, 1))


def get_logits(config, logits_mask):
    # init: candidate_ids = config.candidate_ids
    labels_pr = logits_mask[:, config.candidate_ids]
    return labels_pr

def get_topk_token_ids(logits_mask,top_k =10):
    logits_mask_=logits_mask.detach()
    batch_size = logits_mask_.size(0)
    idsk = []
    logitsk = []
    
    for i in range(batch_size):
        
        top_inds = list(reversed(np.argsort(logits_mask_[i].numpy(), axis=-1)))  # list
        idsk.append(top_inds[:top_k])

        logitsk.append(logits_mask_[i,top_inds][:top_k].numpy().tolist())
        
    return idsk, logitsk # (list, list) size=(bz, k)

        
def show_topk_cloze(logits_mask,top_k =10):
    # 根据logits排序
    top_inds = list(reversed(np.argsort(logits_mask)))
    res_top_k = []
    for i in top_inds:
        res_i = {
            "token_id":i.item(),
            "token_str": bert_tokenzier._convert_id_to_token(i.item()),
            "raw_score": logits_mask[i.item()] # 未经过softmax的分数
            }
        res_top_k.append(res_i)
        if len(res_top_k) >= top_k:
            break

    return res_top_k # 查看top k预测的填空

In [5]:
device = config.DEVICE
model_bert.to(device)

param_optimizer = list(model_bert.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
# for _,__ in param_optimizer:
#     print(_)

optimizer_parameters = [ {
        "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ],
        "weight_decay": 0.001,
    },
        {"params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ],
        "weight_decay": 0.0,
    },
]

print("opt param: ",len(optimizer_parameters[0]['params']))
print("no opt",len(optimizer_parameters[1]['params']))

opt param:  76
no opt 126


In [6]:
num_train_steps = int(len(train_dir['x']) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
if len(optimizer_parameters[0]['params']) > 20:
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
else:
    # albert
    optimizer = AdamW(model_bert.parameters(), lr=3e-5)
    
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
)

In [7]:
def eval_prompt(data_loader, model, device):
    _targets = []
    _outputs = []
    _logits = []
    _mask_ids = []
    model.eval()
    
    with torch.no_grad():
        for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            
            input_token = {
                k[1]: d[k[0]].to(device) 
                for k in dict_keys}

            res_batch = model(**input_token)
            
            logits = res_batch.logits
            logits = logits.cpu()
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok=config.mask,tokenzier=bert_tokenzier)

            # optional: 计算整个vocab上的softmax score
            # logits_mask = torch.softmax(logits_mask, dim=-1)
            
            # 取出verbalizer对应的logits
            labels_pr = get_logits(config, logits_mask)
            pred = [np.argmax(_) for _ in labels_pr]
                
            _targets.extend(d['targets'])
            _outputs.extend(pred)
            _logits.extend(logits_mask)
            _mask_ids.extend(mask_ids)
            torch.cuda.empty_cache()
            # break
    return _targets,_outputs,_logits,_mask_ids


# zero
if config.eval_zero_shot:
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    
    print("[Zero shot]")
    print(f"[Eval] Acc:{count_acc(fin_outputs_eval,fin_targets_eval)} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

100%|██████████| 250/250 [00:41<00:00,  6.00it/s]

[Zero shot]
[Eval] Acc:0.582 | pred.sum: 864 | target.sum: 510.0





In [7]:
def eval_prompt(data_loader, model, device):
    _targets = []
    _outputs = []
    _logits = []
    _mask_ids = []
    model.eval()
    
    with torch.no_grad():
        for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            
            input_token = {
                k[1]: d[k[0]].to(device) 
                for k in dict_keys}

            res_batch = model(**input_token)
            
            logits = res_batch.logits
            logits = logits.cpu()
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok=config.mask,tokenzier=bert_tokenzier)

            # optional: 计算整个vocab上的softmax score
            # logits_mask = torch.softmax(logits_mask, dim=-1)
            
            # 取出verbalizer对应的logits
            labels_pr = get_logits(config, logits_mask)
            pred = [np.argmax(_) for _ in labels_pr]
                
            _targets.extend(d['targets'])
            _outputs.extend(pred)
            _logits.extend(logits_mask)
            _mask_ids.extend(mask_ids)
            torch.cuda.empty_cache()
            # break
    return _targets,_outputs,_logits,_mask_ids


# zero
if config.eval_zero_shot:
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    
    print("[Zero shot]")
    print(f"[Eval] Acc:{count_acc(fin_outputs_eval,fin_targets_eval)} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

100%|██████████| 250/250 [00:37<00:00,  6.67it/s]

[Zero shot]
[Eval] Acc:0.582 | pred.sum: 864 | target.sum: 510.0





In [8]:
ev_acc_his = []
tr_loss_his = []
tr_time_his=[]

early_stop_count = 0

#test
# config.EPOCHS=config.train_times=1
# config.UPDATE_VERBAL = True
#test

for epoch in range(config.EPOCHS//config.train_times):
        
    # begin training
    model_bert.train()
    tr_time_s = time()

    tr_loss = []
    tr_topk_his=[]
    tr_targets_his=[]

    # config.train_times = 5
    for epo_tr in range(config.train_times):
        for bi, d in tqdm(enumerate(train_data_loader), total=len(train_data_loader)):
            dict_keys = [("ids","input_ids" ),("token_type_ids","token_type_ids"),("mask","attention_mask")]
            targets = d['targets']
            input_token = {
                    k[1]: d[k[0]].to(device) 
                    for k in dict_keys}

            optimizer.zero_grad()
            res_batch = model_bert(**input_token)
            logits = res_batch.logits.cpu()

            ''' 取出mask位置上，candidate label对应的logits '''
            # mask 位置的预测logits
            logits_mask, mask_ids = get_logits_of_mask(d['ids'], logits, tok=config.mask,tokenzier=bert_tokenzier)
            # logits_mask: (batch_size, vocab_size)

            
            # 记录loss最小的top k verbalizer（不更新）
            # todo            
            topk_score, topk_ids = get_topk_token_ids(logits_mask)
            tr_topk_his.extend(topk_ids) # (batch size, topk)
            tr_targets_his.extend(targets) # (batch size, topk)

            # 取出verbalizer对应的logits
            labels_pr = get_logits(config, logits_mask)
            # 概率分数
            labels_pr = torch.softmax(labels_pr, dim=-1)
            # labels_pr: (batch_size, 2)

            # 取出 positive 对应的分数 (negative = 1-positive)
            pred = labels_pr[:,1]
            loss = loss_fn(pred, targets)
            tr_loss.append(loss.cpu().detach().item())
            # print(loss) # 0.6433

            loss.backward()
            optimizer.step()
            scheduler.step()

            torch.cuda.empty_cache()

    tr_time_his.append((time()-tr_time_s)/config.train_times)
    tr_loss_his.append(np.mean(tr_loss))

    # begin eval
    fin_targets_eval,fin_outputs_eval,fin_logits_eval,fin_mask_ids_eval = eval_prompt(valid_data_loader, model_bert, device)
    ev_acc_his.append(count_acc(fin_outputs_eval,fin_targets_eval))

    loss_str = "{:.4f}".format(tr_loss_his[-1])
    print(f"[Train] Epoch: {epoch}/{config.EPOCHS} | Train Loss: {loss_str} | Train time: {tr_time_his[-1]}s")
    print(f"[Eval] Acc:{ev_acc_his[-1]} | pred.sum: {np.sum(fin_outputs_eval)} | target.sum: {np.sum(fin_targets_eval)}")

    best_acc = max(ev_acc_his[:-1]) if epoch > 1 else -10
    if ev_acc_his[-1] > best_acc: # > best acc
        torch.save(model_bert, f"fewshot{config.few_shot}-{config.BERT_PATH}-best.pth")
        print("[Best epoch]")
        # reset
        early_stop_count= 0
    if early_stop_count > config.EARLY_STOP: 
        print(f"[WARNING] early stop at epoch {epoch}.")
        break

100%|██████████| 4/4 [00:05<00:00,  1.30s/it]
100%|██████████| 4/4 [00:05<00:00,  1.34s/it]
100%|██████████| 250/250 [00:35<00:00,  7.03it/s]


[Train] Epoch: 0/10 | Train Loss: 0.6246 | Train time: 5.501935005187988s
[Eval] Acc:0.708 | pred.sum: 602 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:05<00:00,  1.36s/it]
100%|██████████| 4/4 [00:05<00:00,  1.30s/it]
100%|██████████| 250/250 [00:35<00:00,  7.04it/s]


[Train] Epoch: 1/10 | Train Loss: 0.5108 | Train time: 5.564296126365662s
[Eval] Acc:0.755 | pred.sum: 479 | target.sum: 510.0
[Best epoch]


100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
100%|██████████| 4/4 [00:05<00:00,  1.32s/it]
100%|██████████| 250/250 [00:35<00:00,  7.04it/s]

[Train] Epoch: 2/10 | Train Loss: 0.5035 | Train time: 5.476422309875488s
[Eval] Acc:0.755 | pred.sum: 583 | target.sum: 510.0



100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
100%|██████████| 4/4 [00:05<00:00,  1.31s/it]
100%|██████████| 250/250 [00:35<00:00,  7.02it/s]

[Train] Epoch: 3/10 | Train Loss: 0.5034 | Train time: 5.441756725311279s
[Eval] Acc:0.745 | pred.sum: 609 | target.sum: 510.0



100%|██████████| 4/4 [00:05<00:00,  1.30s/it]
100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
100%|██████████| 250/250 [00:35<00:00,  7.10it/s]

[Train] Epoch: 4/10 | Train Loss: 0.5034 | Train time: 5.439003586769104s
[Eval] Acc:0.746 | pred.sum: 608 | target.sum: 510.0





In [9]:
idx = 100
if idx >= 0:
    logits = fin_logits_eval[idx]
    pred = fin_outputs_eval[idx]

    target = valid_dataset.target[idx]
    sequence = valid_dataset.getReview(idx)
    ids = valid_dataset[idx]['ids']
    print(f"sequence: \'{sequence}\', target: {target}, pred: {pred}", show_topk_cloze(logits, top_k=10))

sequence: 'i was feeling this movie until it veered off too far into the exxon zone , and left me behind at the station looking for a return ticket to realism . It was [MASK] .', target: 0, pred: 0 [{'token_id': 6659, 'token_str': 'terrible', 'raw_score': tensor(9.5199)}, {'token_id': 9643, 'token_str': 'awful', 'raw_score': tensor(9.0635)}, {'token_id': 9202, 'token_str': 'horrible', 'raw_score': tensor(8.8685)}, {'token_id': 20625, 'token_str': 'hopeless', 'raw_score': tensor(8.7770)}, {'token_id': 2058, 'token_str': 'over', 'raw_score': tensor(8.3041)}, {'token_id': 3109, 'token_str': 'hell', 'raw_score': tensor(7.5817)}, {'token_id': 5263, 'token_str': 'impossible', 'raw_score': tensor(7.5513)}, {'token_id': 2439, 'token_str': 'lost', 'raw_score': tensor(7.4029)}, {'token_id': 4326, 'token_str': 'strange', 'raw_score': tensor(7.3589)}, {'token_id': 4689, 'token_str': 'crazy', 'raw_score': tensor(7.3548)}]


In [10]:
# 1. 测试
eval_acc = max(ev_acc_his) # best eval
model = torch.load(f"fewshot{config.few_shot}-{config.BERT_PATH}-best.pth")

_, test_dir= dataset.read_data(config.TEST_FILE, test=True)
test_dataset = PromptDataset(test_dir['x'], test_dir['y'],config=config)
test_data_loader = test_dataset.get_dataloader(batch_size=config.VALID_BATCH_SIZE)

test_record = eval_prompt(test_data_loader, model, device)
# targets ,outputs ,logits ,mask_ids
test_preds = test_record[1]

# 2. open文件写入结果
with open(f'saved/few_shot{config.few_shot}_eval{eval_acc}_res.txt',encoding="utf-8", mode='w') as f:
    for pred in test_preds:
        f.write("positive" if pred==1 else 'negative')
        f.write('\n')
print("Testing finish. Test results saved.")

# samples: 1066


100%|██████████| 267/267 [00:38<00:00,  6.92it/s]

Testing finish. Test results saved.





In [11]:
metric_rec = {
    'epo':[(i+1)*config.train_times for i in range(len(ev_acc_his))],
    'eval acc': ev_acc_his,
    'train loss': tr_loss_his ,
    'epoch time(s)': tr_time_his
}
data_f = pd.DataFrame(metric_rec)
data_f

Unnamed: 0,epo,eval acc,train loss,epoch time(s)
0,2,0.708,0.624559,5.501935
1,4,0.755,0.510833,5.564296
2,6,0.755,0.503512,5.476422
3,8,0.745,0.503426,5.441757
4,10,0.746,0.503385,5.439004


In [12]:
avg_epo_time= np.average(tr_time_his)
print("model {} | fewshot {} | best acc {} | epoch {} | {:.3f}s".format(config.BERT_PATH ,config.few_shot,eval_acc, config.EPOCHS,avg_epo_time))

model bert-base-uncased | fewshot 32 | best acc 0.755 | epoch 10 | 5.485s
