In [1]:
import sys, os
sys.path.append('..')
import torch, torch.nn as nn, torch.nn.functional as F
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.bert.tokenization_bert import BertTokenizer
from typing import AnyStr, List, Tuple
from transformers import PreTrainedTokenizer

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  
# if gpu is available in your computer, conduct your model on gpu(i.e., cuda); otherwise, the model is conducted on cpu device

In [3]:
tokenizer = BertTokenizer.from_pretrained("../bert_cn/")
bert_model = BertModel.from_pretrained('../bert_cn/').to(device)

In [4]:
hidden_size = bert_model.config.hidden_size
classifier = nn.Sequential(nn.Linear(hidden_size, 2*hidden_size), 
                          nn.ReLU(),
                          nn.Linear(2*hidden_size, 2)).to(device)

In [5]:
def read_file(fname):
    with open(fname, 'r') as fr:
        doc = fr.read()
    sent_list = doc.replace('\n', '').replace(' ', '').split('。')
    return [sent for sent in sent_list if len(sent) != 0]

In [6]:
def text_to_batch_transformer(text: List, tokenizer: PreTrainedTokenizer, text_pair: List = None):
    """Turn a piece of text into a batch for transformer model

    :param text: The text to tokenize and encode
    :param tokenizer: The tokenizer to use
    :param: text_pair: An optional second string (for multiple sentence sequences)
    :return: A list of IDs and a mask
    """
    max_len = tokenizer.max_len if hasattr(tokenizer, 'max_len') else tokenizer.model_max_length
    if text_pair is None:
        items = [tokenizer.encode_plus(sent, text_pair=None, add_special_tokens=True, max_length=max_len,
                                       return_length=False, return_attention_mask=True,
                                       return_token_type_ids=True)
                 for sent in text]
    else:
        assert len(text) == len(text_pair)
        items = [tokenizer.encode_plus(s1, text_pair=s2, add_special_tokens=True, max_length=max_len,
                                        return_length=False, return_attention_mask=True,
                                            return_token_type_ids=True)
                                        for s1, s2 in zip(text, text_pair)]
    return [item['input_ids'] for item in items], \
              [item['attention_mask'] for item in items], \
                 [item['token_type_ids'] for item in items]

In [7]:
def collate_batch_with_device(device):
    def collate_batch_transformer(doc: Tuple):
        input_ids = doc[0]
        masks = doc[1]
        seg_ids = doc[2]
        
        max_length = max([len(i) for i in input_ids])
        input_ids = [(i + [0] * (max_length - len(i))) for i in input_ids]
        masks = [(m + [0] * (max_length - len(m))) for m in masks]
        seg_ids = [(s + [0] * (max_length - len(s))) for s in seg_ids]

        assert (all(len(i) == max_length for i in input_ids))
        assert (all(len(m) == max_length for m in masks))
        assert (all(len(s) == max_length for s in seg_ids))
        return torch.tensor(input_ids, device=device), torch.tensor(masks, device=device), \
                    torch.tensor(seg_ids, device=device)
    return collate_batch_transformer

In [55]:
class DocModel(nn.Module):
    def __init__(self, model, classifier, tokenizer, device):
        super(DocModel, self).__init__()
        self.model =  model
        self.tokenizer = tokenizer
        self.classifier = classifier
        self.collate_fn = collate_batch_with_device(device)
    
    def obtain_optim(self, learning_rate=None):
        def lr_coefficient(par_name):
        # layer-wise fine-tuning
            if "layer." in par_name:
                layer_num = int(par_name.split("layer.")[1].split(".", 1)[0])
                return pow(0.8, 12 - layer_num)
            elif "embedding" in par_name:
                return pow(0.8, 13)
            else:
                return 1.0
        if learning_rate is None:
            learning_rate = self.learning_rate
        optimizerGroupedParameters = [{'params': p, 'lr': learning_rate * lr_coefficient(n)}
                                        for n, p in self.named_parameters()]
        return torch.optim.Adam(optimizerGroupedParameters)
    
    def pred(self, passage, temperature=1.0):
        rst = text_to_batch_transformer(passage, self.tokenizer)
        input_ids, masks, _ = self.collate_fn(rst)
        bert_dict = self.model(input_ids=input_ids, attention_mask=masks)
        logits = self.classifier(
            bert_dict.pooler_output.max(dim=0)[0]
        )
        logits = F.softmax(logits / temperature, dim=-1)
        return logits
    
    def lossAndAcc(self, passage, label):
        logits =  self.pred(passage_list)
        loss = logits[label].log().neg() # cross entropy
        return loss, logits.data.argmax()==label
        
    def training(self, train_set, val_set, batch_size=32, learning_rate=5e-5):
        optim = self.obtain_optim(learning_rate)
        tr_idxs = random.sample(range(len(train_set)), len(train_set))
        true_pred_cnt = 0
        loss_list = []
        for step, idx in enumerate(tr_idxs):
            loss, tp = self.lossAndAcc(train_set[0][idx], train_set[1][idx])
            true_pred_cnt += float(tp)
            loss.backward()
            loss_list.append(loss.data.item())
            if len(loss_list)>100:
                loss_list.pop(0)
                
            if ((step+1)%batch_size)==0:
                optim.step()
                optim.zero_grad()
                print(f"loss/acc = {np.mean(loss_list)}/{true_pred_cnt/float(batch_size)}")
                true_pred_cnt = 0
        self.valid(val_set, torch.tensor(val_set[1]), 'validation')
        
    def pred_Logits(self, data, idxs=None, batch_size=20):
        preds = []
        if idxs is None:
            idxs = list(range(len(data[0])))
        if not hasattr(self, 'collate_fn'):
            collate_fn = data.collate_raw_batch
        else:
            collate_fn = self.collate_fn

        with torch.no_grad():
            for i in trange(0, len(idxs), 1):
                doc = data[0][idxs[i]]
                pred, _ = self.pred(doc)
                preds.append(pred)
        print(len(preds))
        pred_tensor = torch.stack(preds)
        print("pred_tensor.shape:", pred_tensor.shape)
        return pred_tensor
    
    def prediction(self, data, idxs=None, batch_size=20):
        pred_tensor = self.pred_Logits(data, idxs, batch_size)
        vals, idxs = pred_tensor.sort(dim=1)
        return idxs[:, -1], vals[:, -1]

    def acc_P_R_F1(self, y_true, y_pred):
        return accuracy_score(y_true, y_pred.cpu()), \
               precision_recall_fscore_support(y_true, y_pred.cpu())

    def Perf(self, data, label, idxs=None, batch_size=20):
        y_pred, _ = self.prediction(data, idxs=idxs, batch_size=batch_size)
        y_true = label[idxs] if idxs is not None else label
        return self.acc_P_R_F1(y_true, y_pred)

    def valid(self, test_set, test_label, test_suffix, step=0):
        test_label = test_label.argmax(dim=1) if test_label.dim() > 1 else test_label
        rst_model = self.Perf(test_set, test_label)
        print("test_label : ", test_label.tolist())
        acc_v, (p_v, r_v, f1_v, _) = rst_model
        print("BestPerf : ", rst_model)
#         class_num = len(p_v)
#         output_items = [("valid_acc", acc_v)] + \
#                        [('valid_prec_{}'.format(i), p_v[i]) for i in range(class_num)] + \
#                        [('valid_recall_{}'.format(i), r_v[i]) for i in range(class_num)] + \
#                        [('valid_f1_{}'.format(i), f1_v[i]) for i in range(class_num)]
#         fitlog.add_metric({f"{test_suffix}": dict(output_items)}, step=step)
#         fitlog.add_best_metric({f"FinalPerf_{test_suffix}": dict(output_items)})
        return acc_v

In [56]:
from tqdm import tqdm, trange

In [10]:
d_dir = './positive_dir/'
pos_docs = [read_file(f"{d_dir}/{fname}") for fname in os.listdir(d_dir)]

In [11]:
pos_labels = [1]*len(pos_docs)

In [12]:
pos_set = (pos_docs, pos_labels)

In [57]:
d_b = DocModel(bert_model, classifier, tokenizer, device)

In [58]:
d_b.valid(pos_set, torch.tensor(pos_set[1]), 'test')

100%|█████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00,  9.41it/s]

20
pred_tensor.shape: torch.Size([20])





IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)