In [None]:
!pip install transformers seqeval shap 

In [2]:
import pandas as pd
import numpy as np
import torch
import transformers
import sys
import seqeval
import os
import spacy
import shap
import scipy as sp

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 加载模型

In [None]:
remote = True
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 加载模型
# 加载预训练模型
model_name = 'bert-base-uncased' # 预训练模型名字
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
model = transformers.BertForSequenceClassification.from_pretrained(model_name, num_labels=2) # 需要预测的类别数为2
optimizer = transformers.AdamW(model.parameters(), lr=2e-5, eps=1e-8)
if remote:
  os.chdir('/content/drive/MyDrive/ModelDebug')
  checkpoint = torch.load('classifier/model.pt',map_location=torch.device(device))
  df = pd.read_csv('classifier/textNFR.csv')
  #df = pd.read_csv('classifier/new_dataset2.csv')
  #os.chdir('examples')
  from examples.infer_softmax_ner import predict
else:
  checkpoint = torch.load('model.pt',map_location=torch.device(device))
  os.chdir('../classifier')
  df = pd.read_csv('textNFR.csv')
  from examples.infer_softmax_ner import predict
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.eval()
model.to(device)

# 执行模型优化

## 可解释性分析

In [5]:
def f(text):
    #输入文本，返回预测值
    tv = torch.tensor([tokenizer.encode(t, padding='max_length', max_length=128, truncation=True) for t in text]).to(device)
    attention_mask = (tv!=0).type(torch.int64).to(device)
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val

In [6]:
# 使用 shap 进行可解释性分析
explainer = shap.Explainer(f,tokenizer,output_names=["FR","NFR"])

In [7]:
#text列表转json列表
def toJson(text):
    return {'sent':text}
def getTokens(doc):
    nlp = spacy.load('en_core_web_sm')
    tokens= []
    doc = nlp(doc)
    for t in doc:
        tokens.append(t.text)
    return tokens
#抽取每个样例的关注点
#输入data是一个dict，dict中有一个key为sent，value为句子
def getConcerns(data):
    res = []
    tokens = getTokens(data['sent'])
    cons = predict([data])[0]
    if len(tokens)!=len(cons):
        print("tokens和concerns长度不一致"+str(len(tokens))+" "+str(len(cons)))
    for t,c in zip(tokens,cons):
        if c!='O':
            res.append(t)
    return res
#计算重要特征，返回特征与重要性
def important(shap_values):
    reason = []
    for ele in shap_values:
        sum =0
        for num in ele.values[0]:
            sum += abs(num[0]) + abs(num[1])  #两个分类的重要性之和
        avg = sum/(len(ele.values[0])*2)
        res = []
        for values,datas in zip(ele.values[0],ele.data[0]):
            cur = []
            if abs(values[0])+abs(values[1])>=avg*2:
                res.append(datas)
        reason.append(res)
    return reason
#检测异常特征(重要但是不是关注点)
def errCheck(importance,concerns):
    errList = []
    for imp,con in zip(importance,concerns):
        imp=imp.strip()
        if(imp in concerns): continue
        else : errList.append(imp)
    return errList
def createNewDataset(dataset,label):
    new_dataset = []
    index = 0
    for d,l in zip(dataset,label):
        try:
            ls = []
            val = explainer([d])  #计算每个样例的shap值
            imp = important([val])     #计算每个样例的重要单词
            cons = getConcerns(toJson(d))  #计算每个样例的关注点单词
            errlist = errCheck(imp[0],cons)  #计算每个样例的异常单词
            for e in errlist:
                d= d.replace(e,'')
            ls.append(index)
            ls.append(d)
            ls.append(str(l))
            new_dataset.append(ls)
            index+=1
        except:
            print("error in sent:"+d)
            continue
    return new_dataset

In [None]:
import logging
logging.disable(logging.CRITICAL)
dataset = df['RequirementText']
label = df['NFR']
new_dataset = createNewDataset(dataset,label)
print(new_dataset)
newdf = pd.DataFrame(new_dataset, columns=['id','RequirementText','NFR'])
newdf.to_csv('classifier/new_dataset3.csv',index=False)

In [9]:
#test
df = pd.read_csv('classifier/textNFR.csv')
dataset = df['RequirementText']
label = df['NFR']
val = explainer([dataset[0]])
val

.values =
array([[[-8.56816769e-08,  2.98023224e-08],
        [-8.30126181e-03,  8.30119289e-03],
        [ 4.62826065e-02, -4.62825246e-02],
        [ 5.46703299e-01, -5.46703018e-01],
        [ 2.42516067e-01, -2.42515944e-01],
        [-1.53207369e-01,  1.53207537e-01],
        [-2.23761378e-02,  2.23757704e-02],
        [ 7.35424962e-02, -7.35422415e-02],
        [-3.13178962e-01,  3.13179230e-01],
        [-1.30881980e+00,  1.30881979e+00],
        [-1.45185611e+00,  1.45185629e+00],
        [ 1.24453679e-02, -1.24457479e-02],
        [ 1.19209290e-07, -4.76837158e-07]]])

.base_values =
array([[-0.57559139,  0.57559127]])

.data =
(array(['', ' the', ' system', ' shall', ' ref', 'resh', ' the',
       ' display', ' every', ' 60', ' seconds', ' .', ''], dtype=object),)

In [13]:
df = pd.read_csv('classifier/new_dataset1.csv')
df = df['RequirementText']
d = df[0]
val = explainer([d])
print(val)
imp = important([val])     #计算每个样例的重要单词
print(imp)
cons = getConcerns(toJson(d))  #计算每个样例的关注点单词
print(cons)
errlist = errCheck(imp[0],cons)  #计算每个样例的异常单词
print(errlist)

.values =
array([[[ 2.04890966e-08, -1.60187483e-07],
        [-5.35012297e-02,  5.35015855e-02],
        [-2.22003335e-01,  2.22003000e-01],
        [ 2.23954943e-01, -2.23955413e-01],
        [-4.49146302e-02,  4.49145343e-02],
        [-2.61115660e-02,  2.61117602e-02],
        [ 3.16016995e-01, -3.16017278e-01],
        [-2.02056801e-01,  2.02056764e-01],
        [-1.21030352e+00,  1.21030365e+00],
        [-1.11217586e+00,  1.11217634e+00],
        [ 4.50553969e-02, -4.50552478e-02],
        [ 2.98023224e-07,  0.00000000e+00]]])

.base_values =
array([[-0.52282721,  0.52282721]])

.data =
(array(['', ' the', ' system', ' ref', 'resh', ' the', ' display',
       ' every', ' 60', ' seconds', ' .', ''], dtype=object),)
[[' display', ' 60', ' seconds']]


Evaluating: 100%|██████████| 1/1 [00:00<00:00, 13.37it/s]

tokens和concerns长度不一致10 9
[' ', 'refresh', 'the', 'display', 'every', '60']
['seconds']



