In [25]:
import pickle
import numpy as np
import string

data_train_dir = "data/snli_1.0/snli_1.0_train.txt"
data_dev_dir = "data/snli_1.0/snli_1.0_dev.txt"
embedding_file_dir = "data/embedding/glove.6B.50d.txt"
worddict_dir = "data/worddict.txt"
data_train_str_dir = "data/train_data_str.pkl"
data_train_id_dir = "data/train_data_id.pkl"
data_dev_str_dir = "data/dev_data_str.pkl"
data_dev_id_dir = "data/dev_data_id.pkl"
embedding_matrix_dir = "data/embedding_matrix.pkl"

In [3]:
def read_data(data_dir):
    premise = []
    hypothesis = []
    labels = [] 
    labels_map = {"entailment": 0, "neutral": 1, "contradiction": 2}
    punct_table = str.maketrans({key: " " for key in string.punctuation}) # 去掉标点符号
    with open(data_dir, 'r', encoding='utf-8') as lines:
        next(lines)
        for line in lines:
            line = line.strip().split('\t')
            if line[0] not in labels_map:   #忽略没有label的例子
                continue
            premise.append(line[5].translate(punct_table).lower())
            hypothesis.append(line[6].translate(punct_table).lower())
            labels.append(line[0])
    return {"premise": premise, "hypothesis": hypothesis, "labels": labels}   

In [13]:
def build_worddict(data):
    words = ["_PAD_", "_OOV_", "_BOS_", "_EOS_"]
    for sentence in data["premise"]:
        words.extend(sentence.strip().split(" "))
    for sentence in data["hypothesis"]:
        words.extend(sentence.strip().split(" ")) 
    word_id = {}
    id_word = {}
    i = 0
    for word in words:
        if word not in word_id:
            word_id[word] = i
            id_word[i] = word
            i += 1
    #保存词典
    with open(worddict_dir, "w", encoding='utf-8') as f:
        for word, id_ in word_id.items():
            f.write(f"{word}\t{id_}\n")
    return word_id, id_word 

In [15]:
def sentence2idList(sentence, word_id):
    ids = [word_id["_BOS_"]]
    sentence = sentence.strip().split(" ")
    for word in sentence:
        if word not in word_id:
            ids.append(word_id["_OOV_"])
        else:
            ids.append(word_id[word])
    ids.append(word_id["_EOS_"])
    return ids

In [16]:
def data2id(data, word_id):
    premise_id = []
    hypothesis_id = []
    labels_id = [] 
    labels_map = {"entailment": 0, "neutral": 1, "contradiction": 2}
    for i, label in enumerate(data["labels"]):
        if label not in labels_map:   #忽略没有label的例子
            continue
        premise_id.append(sentence2idList(data["premise"][i], word_id))
        hypothesis_id.append(sentence2idList(data["hypothesis"][i], word_id))
        labels_id.append(labels_map[label])
            
    return {"premise_id": premise_id, "hypothesis_id": hypothesis_id, "labels_id": labels_id}    

In [21]:
def build_embeddings(embedding_file, word_id):
    #读取文件存入集合中
    embeddings_map = {}
    with open(embedding_file, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip().split()
            word = line[0]
            if word in word_id:
                embeddings_map[word] = line[1:]   
    #放入矩阵中
    words_num = len(word_id)
    embedding_dim = len(embeddings_map['a'])
    embedding_matrix = np.zeros((words_num, embedding_dim))
    missed_cnt = 0
    for i, word in enumerate(word_id):
        if word in embeddings_map:
            embedding_matrix[i] = embeddings_map[word]
        else:
            if word == "_PAD_":
                continue
            missed_cnt += 1
            embedding_matrix[i] = np.random.normal(size=embedding_dim)
    print("missed word count: %d" % missed_cnt)
    return embedding_matrix

In [26]:
#读取数据
data_str = read_data(data_train_dir)

In [14]:
#构建词典
word_id, id_word = build_worddict(data_str)   

In [17]:
#清洗数据并转换为id
data_id = data2id(data_str, word_id)

In [19]:
#保存 data_train_str和data_train_id
with open(data_train_str_dir, "wb") as f:
    pickle.dump(data_str, f)
    
with open(data_train_id_dir, "wb") as f:
    pickle.dump(data_id, f)

In [24]:
embedding_matrix = build_embeddings(embedding_file_dir, word_id)
print("embedding_matrix size: ", embedding_matrix.shape)

with open(embedding_matrix_dir, "wb") as f:
    pickle.dump(embedding_matrix, f)

missed word count: 5994
embedding_matrix size:  (33268, 50)


In [28]:
#读取数据
data_dev_str = read_data(data_dev_dir)

In [30]:
data_id_dev = data2id(data_dev_str, word_id)

In [31]:
with open(data_dev_str_dir, "wb") as f:
    pickle.dump(data_dev_str, f)
    
with open(data_dev_id_dir, "wb") as f:
    pickle.dump(data_id_dev, f)