In [None]:
# def tsv_to_json(path, target):
#     open(target, "w").write("")
#     with open(path, "r", encoding="utf-8") as fr:
#         for line in fr:
#             sentence_label = line.strip('\n').strip().strip("\t").lower().split("\t")
#             sentence = ''.join(x for x in sentence_label[-1] if x.isalpha() or x == ' ')
#             label = 1 if int("".join([c for c in sentence_label[0].split("_")[1] if c.isdigit()])) >= 7 else 0
#             open(target, "a").write('{\"text\": \"'+ f'{sentence}'+'\", \"label\": '+ f'{label}'+'}\n')
# tsv_to_json("./train.tsv", "./train.json")
# tsv_to_json("./dev.tsv", "./dev.json")
# tsv_to_json("./test.tsv", "./test.json")
import numpy as np
from datasets import load_dataset
from transformers import BertTokenizer, BertModel, BertConfig
from functools import partial
import torch, os
from tools import plot_training_loss, plot_training_acc
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
dataset = load_dataset("json", data_files = {"train": "train.json", "test":"test.json", "dev": "dev.json"})
def convert_example_to_feature(examples, tokenizer, is_infer=False):
    encoded_inputs = tokenizer(examples["text"], padding='max_length', truncation=True)
    if not is_infer:
        encoded_inputs["labels"] = [label for label in examples["label"]]
    return encoded_inputs

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

trans_fn = partial(convert_example_to_feature, tokenizer=tokenizer)
train_dataset = dataset["train"].map(trans_fn, batched=True)
dev_dataset = dataset["dev"].map(trans_fn, batched=True)
test_dataset = dataset["test"].map(trans_fn, batched=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16)
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=16)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16)
# for key in next(iter(train_dataloader)).keys():
#     print((next(iter(train_dataloader))[key]))
#     break

class BertForSequenceClassification(torch.nn.Module):
    def __init__(self, bert, num_classes=2, dropout=None):
        super().__init__()
        self.num_classes = num_classes
        self.bert = BertModel.from_pretrained(bert)
        self.bert_config = BertConfig.from_pretrained(bert)
        self.dropout = torch.nn.Dropout(dropout if dropout is not None else self.bert_config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(self.bert_config.hidden_size, self.num_classes)
    def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

model = BertForSequenceClassification('bert-base-uncased', 2)
decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
optimizer = torch.optim.Adam(params=model.parameters(), lr = 1e-5, weight_decay = 0.01)
loss_fn = torch.nn.CrossEntropyLoss()

from tqdm import tqdm
def evaluate(model, data_loader):
    model.eval()
    for batch_data in tqdm(data_loader, desc="[Evaluation Progression]"):
        input_ids, token_type_ids, labels, attention_mask = torch.stack(batch_data["input_ids"]), torch.stack(batch_data["token_type_ids"]), torch.Tensor(batch_data["labels"]), torch.stack(batch_data["attention_mask"])
        logits = model(input_ids=input_ids.t(), token_type_ids=token_type_ids.t(), attention_mask=attention_mask.t())
        logits = np.argmax(logits.detach(), axis=1)
        accuracy = accuracy_score(labels.detach(), logits.detach())
    return accuracy

def train(model):
    model.train()
    num_epochs = 2
    eval_steps = 100
    global_step = 0
    best_score = 0.
    log_steps = 10
    save_dir = "./checkpoints"
    train_loss_record = []
    train_score_record = []
    num_training_steps = len(train_dataloader) * num_epochs
    for epoch in range(num_epochs):
        for step, batch_data in enumerate(train_dataloader):
            input_ids, token_type_ids, labels, attention_mask = torch.stack(batch_data["input_ids"]), torch.stack(batch_data["token_type_ids"]), torch.Tensor(batch_data["labels"]), torch.stack(batch_data["attention_mask"])
            logits = model(input_ids=input_ids.t(), token_type_ids=token_type_ids.t(), attention_mask=attention_mask.t())
            logits = np.argmax(logits.detach(), axis=1)
            loss = loss_fn(logits.float(), labels.float())
            train_loss_record.append((global_step, loss.item()))
            loss.requires_grad = True
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if global_step % log_steps == 0:
                print(f"[Train] epoch: {epoch} / {num_epochs}, step: {global_step} / {num_training_steps}, loss: {loss.item():.5f}")

            if global_step != 0 and (global_step % eval_steps == 0 or global_step == (num_training_steps - 1)):
                accuracy = evaluate(model, dev_dataloader)
                train_score_record.append((global_step, accuracy))
                print(f"[Evaluate] dev score: {accuracy:.5f}")
                model.train()

                if accuracy > best_score:
                    print(f"[Evaluate] best accuracy performence has been updated: {best_score:.5f} --> {accuracy:.5f}")
                    best_score = accuracy

                    save_path = os.path.join(save_dir, "best.pdparams")

                    torch.save(model.state_dict(), save_path)

            global_step += 1
    save_path = os.path.join(save_dir, "final.pdparams")
    torch.save(model.state_dict(), save_path)
    print(f"[Train] Training done!")

    return train_loss_record, train_score_record
train_loss_record, train_score_record = train(model)

plot_training_loss(train_loss_record, "./images/chapter7_bert_loss.pdf", loss_legend_loc="upper right", sample_step=60)
plot_training_acc(train_score_record, "./images/chapter7_bert_acc.pdf", acc_legend_loc="lower right", sample_step=1)

model.load_state_dict(torch.load("./checkpoints/best.pdparams"))
accuracy = evaluate(model, test_dataloader)
print(f"[Evaluate result] accuracy: {accuracy:.5f}")

def infer(model, text):
    model.eval()
    encoded_inputs = tokenizer(text, max_length = 512, truncation=True)
    input_ids = torch.Tensor(encoded_inputs["input_ids"]).to(torch.int64).unsqueeze(0)
    token_type_ids = torch.Tensor(encoded_inputs["token_type_ids"]).to(torch.int64).unsqueeze(0)
    logits = model(input_ids, token_type_ids)
    id2label = {0: "消极情绪", 1: "积极情绪"}
    max_label_id = np.argmax(logits.detach(), axis=1).numpy()[0]
    pred_label = id2label[max_label_id]
    print("Label: ", pred_label)
text = "this movie is so good that I watch is several times."
infer(model, text)