In [36]:
import pandas as pd
import json

st_train = pd.read_csv("/home/featurize/data/data/st_train.csv").to_dict(orient="records")
st_valid = pd.read_csv("/home/featurize/data/data/st_valid.csv").to_dict(orient="records")
st_train_enhancement = pd.read_csv("/home/featurize/data/data/st_train_enhancement.csv").to_dict(orient="records")

data_test = open("/home/featurize/data/data/test.txt", "r").read().strip().split("\n")
dict_test = open("/home/featurize/data/data/dict.txt", "r").read().strip().split("\n")
character_info = json.load(open("/home/featurize/data/data/character_information.json", "r"))

dict_valid = list(set([item["谜底"] for item in st_valid]))

# 调整备选词表，去除无意义的词（实际上删除了一个点号和一个书名号）
for i in range(len(dict_test)):
    if dict_test[i] == '\ufeff有':
        dict_test[i] = "有"
        break
dict_test = [i for i in dict_test if i in character_info]

In [58]:
from sentence_transformers import SentenceTransformer, models

try:
    model_name = "hfl/chinese-roberta-wwm-ext"
    word_embedding_model = models.Transformer(model_name)
except:
    word_embedding_model = models.Transformer("models/model_roberta_st_init")
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

model_roberta_st = SentenceTransformer(modules=[word_embedding_model, pooling_model])
# model_roberta_st.save("models/model_roberta_st_init")

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [59]:
import random
from sentence_transformers import InputExample, losses, evaluation
from torch.utils.data import DataLoader
from datetime import datetime
import os

train_pattern = [[st_train, 200, "train"], [st_train, len(st_train), "train"], [st_train_enhancement, len(st_train_enhancement), "train_enhancement"]]
train_pattern_sel = 2
train_size = train_pattern[train_pattern_sel][1]
st_train_sel = random.sample(train_pattern[train_pattern_sel][0], train_size)
batch_size = 16
epochs = 3
evaluation_steps = train_size / 10

train_examples = [InputExample(texts=[item["谜面"], item["谜底_描述"]], label=float(item["label"])) for item in st_train_sel]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = losses.CosineSimilarityLoss(model_roberta_st)
evaluator = evaluation.EmbeddingSimilarityEvaluator([item["谜面"] for item in st_valid], [item["谜底_描述"] for item in st_valid], [float(item["label"]) for item in st_valid])

warmup_steps = len(train_dataloader) * epochs / 10

model_save_path = "output/" + model_name.replace("/", "-") + "_" + train_pattern[train_pattern_sel][2] + "_" + f"epochs-{epochs}_batchsize-{batch_size}" + "_" + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
if not os.path.exists(model_save_path):
    os.mkdir(model_save_path)

In [56]:
from sentence_transformers import util
from function import get_embeddings_of_characters, get_embeddings_of_riddles
import os

def get_mrr(model, split_name, riddles, character_list, character_info):
    ebds_of_riddles= get_embeddings_of_riddles(model, riddles)
    ebds_of_characters = get_embeddings_of_characters(model, character_list, character_info)

    hits = util.semantic_search(list(ebds_of_riddles.values()), list(ebds_of_characters.values()), top_k=5)
    mrr1 = 0
    mrr3 = 0
    mrr5 = 0
    guesses = dict()
    for i, query in enumerate(hits):
        riddle = riddles[i]
        # print(riddle)
        candidates = [character_list[ans["corpus_id"]] for ans in query]
        guesses[riddle] = candidates
        # print(candidates)
        if split_name == "valid":
            answer = ""
            for item in st_valid:
                if item["谜面"] == riddle:
                    answer = item["谜底"]
                    break
            guesses[riddle] = [answer] + candidates
            # print(answer)
            for k, candidate in enumerate(candidates):
                if candidate == answer:
                    if k+1 == 1:
                        mrr1 += 1
                    if k+1 <= 3:
                        mrr3 += (1/float(k+1))
                    if k+1 <= 5:
                        mrr5 += (1/float(k+1))
    with open(os.path.join(model_save_path, split_name + "_predictions_with_riddle.txt"), "w") as fout:
            for k, v in guesses.items():
                if split_name == "valid":
                    fout.write(k + "\t" + v[0] + "\t" + "\t".join(v[1:]) + "\n")
                else:
                    fout.write(k + "\t" + "\t".join(v[:5]) + "\n")
    with open(os.path.join(model_save_path, split_name + "_predictions.txt"), "w") as fout:
            for k, v in guesses.items():
                if split_name == "valid":
                    fout.write(v[0] + "\t" + "\t".join(v[1:]) + "\n")
                else:
                    fout.write("\t".join(v[:5]) + "\n")
    if split_name == "valid":
        with open(os.path.join(model_save_path, "valid_metrics.txt"), "w") as fout:
            fout.write("mrr1: {:.4f}\n".format(mrr1/len(st_valid)))
            fout.write("mrr3: {:.4f}\n".format(mrr3/len(st_valid)))
            fout.write("mrr5: {:.4f}\n".format(mrr5/len(st_valid)))


In [54]:
get_mrr(model_roberta_st, "valid", [item["谜面"] for item in st_valid], list(set([item["谜底"] for item in st_valid])), character_info)

100%|██████████| 5480/5480 [01:39<00:00, 55.19it/s]
100%|██████████| 1457/1457 [00:29<00:00, 49.50it/s]


In [60]:
model_roberta_st.fit(
    train_objectives=[(train_dataloader, train_loss)], 
    epochs=epochs, 
    warmup_steps=warmup_steps, 
    evaluator=evaluator, 
    evaluation_steps=evaluation_steps,
    output_path=model_save_path,
    save_best_model=True)



Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3118 [00:00<?, ?it/s]

In [57]:
get_mrr(model_roberta_st, "valid", [item["谜面"] for item in st_valid], list(set([item["谜底"] for item in st_valid])), character_info)
get_mrr(model_roberta_st, "test", data_test, dict_test, character_info)

100%|██████████| 5480/5480 [01:42<00:00, 53.21it/s]
100%|██████████| 1457/1457 [00:28<00:00, 50.49it/s]
100%|██████████| 5413/5413 [01:39<00:00, 54.51it/s]
100%|██████████| 1456/1456 [00:28<00:00, 50.22it/s]
