In [1]:
import torchtext
import pandas as pd
from torchtext.vocab import build_vocab_from_iterator
import joblib

In [2]:
train_dataset = pd.read_csv('../../datasets/train_set.csv', sep='\t')['text']
test_dataset = pd.read_csv('../../datasets/test_a.csv', sep='\t')['text']

In [3]:
# 训练与测试所有句子
all_dataset = pd.concat([train_dataset, test_dataset]).str.split().values.tolist()

In [4]:
all_word = []
for sentence in all_dataset:
    all_word.extend(sentence)

all_word

['2967',
 '6758',
 '339',
 '2021',
 '1854',
 '3731',
 '4109',
 '3792',
 '4149',
 '1519',
 '2058',
 '3912',
 '2465',
 '2410',
 '1219',
 '6654',
 '7539',
 '264',
 '2456',
 '4811',
 '1292',
 '2109',
 '6905',
 '5520',
 '7058',
 '6045',
 '3634',
 '6591',
 '3530',
 '6508',
 '2465',
 '7044',
 '1519',
 '3659',
 '2073',
 '3750',
 '3731',
 '4109',
 '3792',
 '6831',
 '2614',
 '3370',
 '4269',
 '3370',
 '486',
 '5770',
 '4109',
 '4125',
 '3750',
 '5445',
 '2466',
 '6831',
 '6758',
 '3743',
 '3630',
 '1726',
 '2313',
 '5906',
 '826',
 '4516',
 '657',
 '900',
 '1871',
 '7044',
 '3750',
 '2967',
 '3731',
 '1757',
 '1939',
 '648',
 '2828',
 '4704',
 '7039',
 '3706',
 '3750',
 '965',
 '2490',
 '7399',
 '3743',
 '2145',
 '2407',
 '7451',
 '3775',
 '6017',
 '5998',
 '1641',
 '299',
 '4704',
 '2621',
 '7029',
 '3056',
 '6333',
 '433',
 '648',
 '1667',
 '1099',
 '900',
 '2289',
 '1099',
 '648',
 '5780',
 '220',
 '7044',
 '1279',
 '7426',
 '4269',
 '3750',
 '2967',
 '6758',
 '6631',
 '3099',
 '2205',
 '7305

In [5]:
# 创建字典
vocal = build_vocab_from_iterator(iterator=[all_word], min_freq=5)

In [6]:
unk_token = '<unk>'  # 标识符:低频词或未在词表中的词
pad_token = '<pad>'  # 标识符:填充字符
vocal.insert_token(unk_token, 0)
vocal.insert_token(pad_token, 1)
vocal.get_stoi()

{'349': 1149,
 '<unk>': 0,
 '3113': 2416,
 '<pad>': 1,
 '6065': 25,
 '4603': 729,
 '3750': 2,
 '3702': 5893,
 '529': 2641,
 '648': 3,
 '1465': 40,
 '382': 1234,
 '1699': 21,
 '2975': 273,
 '5589': 125,
 '3392': 3648,
 '900': 4,
 '6539': 3853,
 '4120': 1094,
 '5149': 3236,
 '4911': 2846,
 '6122': 6,
 '4058': 5050,
 '4731': 2631,
 '6849': 4927,
 '1952': 426,
 '553': 3554,
 '3370': 5,
 '4464': 7,
 '6257': 2848,
 '5047': 1978,
 '5876': 2865,
 '7085': 5987,
 '7399': 8,
 '3578': 354,
 '4676': 3307,
 '317': 1361,
 '182': 1473,
 '1891': 1032,
 '6861': 495,
 '4030': 1457,
 '4148': 296,
 '4939': 9,
 '296': 494,
 '4302': 749,
 '6784': 2593,
 '6850': 3031,
 '4728': 1987,
 '3659': 10,
 '5122': 327,
 '6615': 968,
 '7414': 3906,
 '6552': 848,
 '5405': 3173,
 '7326': 916,
 '2224': 3378,
 '1940': 2358,
 '3670': 3064,
 '2079': 3813,
 '51': 5123,
 '4811': 11,
 '6160': 634,
 '476': 2669,
 '2667': 1680,
 '3053': 1589,
 '3800': 33,
 '6515': 792,
 '2186': 5284,
 '5165': 483,
 '2313': 193,
 '7251': 699,
 '436

In [7]:
# 保存字典
joblib.dump(vocal, '../../intermediate_save_data/vocal.pkl')

['../../intermediate_save_data/vocal.pkl']

In [8]:
load_vocal = joblib.load('../../intermediate_save_data/vocal.pkl')
print(load_vocal.get_stoi())

{'349': 1149, '<unk>': 0, '3113': 2416, '<pad>': 1, '6065': 25, '4603': 729, '3750': 2, '3702': 5893, '529': 2641, '648': 3, '1465': 40, '382': 1234, '1699': 21, '2975': 273, '5589': 125, '3392': 3648, '900': 4, '6539': 3853, '4120': 1094, '5149': 3236, '4911': 2846, '6122': 6, '4058': 5050, '4731': 2631, '6849': 4927, '1952': 426, '553': 3554, '3370': 5, '4464': 7, '6257': 2848, '5047': 1978, '5876': 2865, '7085': 5987, '7399': 8, '3578': 354, '4676': 3307, '317': 1361, '182': 1473, '1891': 1032, '6861': 495, '4030': 1457, '4148': 296, '4939': 9, '296': 494, '4302': 749, '6784': 2593, '6850': 3031, '4728': 1987, '3659': 10, '5122': 327, '6615': 968, '7414': 3906, '6552': 848, '5405': 3173, '7326': 916, '2224': 3378, '1940': 2358, '3670': 3064, '2079': 3813, '51': 5123, '4811': 11, '6160': 634, '476': 2669, '2667': 1680, '3053': 1589, '3800': 33, '6515': 792, '2186': 5284, '5165': 483, '2313': 193, '7251': 699, '4368': 4408, '1907': 1294, '6722': 550, '2818': 1913, '4856': 3629, '16': 

In [9]:
# 6153个词汇(包含两个特殊词汇'<unk>','<pad>')
print(len(vocal.get_stoi().keys()))

print(vocal.get_stoi())  # Dictionary mapping tokens to indices.
print(vocal.get_itos())

print(vocal['3578'])

6153
{'349': 1149, '<unk>': 0, '3113': 2416, '<pad>': 1, '6065': 25, '4603': 729, '3750': 2, '3702': 5893, '529': 2641, '648': 3, '1465': 40, '382': 1234, '1699': 21, '2975': 273, '5589': 125, '3392': 3648, '900': 4, '6539': 3853, '4120': 1094, '5149': 3236, '4911': 2846, '6122': 6, '4058': 5050, '4731': 2631, '6849': 4927, '1952': 426, '553': 3554, '3370': 5, '4464': 7, '6257': 2848, '5047': 1978, '5876': 2865, '7085': 5987, '7399': 8, '3578': 354, '4676': 3307, '317': 1361, '182': 1473, '1891': 1032, '6861': 495, '4030': 1457, '4148': 296, '4939': 9, '296': 494, '4302': 749, '6784': 2593, '6850': 3031, '4728': 1987, '3659': 10, '5122': 327, '6615': 968, '7414': 3906, '6552': 848, '5405': 3173, '7326': 916, '2224': 3378, '1940': 2358, '3670': 3064, '2079': 3813, '51': 5123, '4811': 11, '6160': 634, '476': 2669, '2667': 1680, '3053': 1589, '3800': 33, '6515': 792, '2186': 5284, '5165': 483, '2313': 193, '7251': 699, '4368': 4408, '1907': 1294, '6722': 550, '2818': 1913, '4856': 3629, '

In [10]:
# 加载预训练词向量文件
vector = torchtext.vocab.Vectors(name="cnew_200.txt",
                                 cache='..\..\intermediate_save_data')
vector

<torchtext.vocab.vectors.Vectors at 0x2590029e040>

In [11]:
pretrained_vector = vector.get_vecs_by_tokens(vocal.get_itos())

In [12]:
# 预训练词向量
pretrained_vector

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.8134e+00, -4.1394e+00,  1.1417e+00,  ...,  3.5465e+00,
          2.9921e-02, -8.0849e-01],
        ...,
        [ 2.7235e-03,  7.6506e-03, -7.8161e-02,  ..., -7.4759e-03,
         -1.0344e-01, -1.2040e-01],
        [ 8.8274e-02,  9.2499e-02, -3.2991e-02,  ..., -1.7648e-02,
         -1.1850e-01, -2.1958e-02],
        [-1.1811e-01,  3.4976e-02,  1.8313e-02,  ..., -7.8549e-02,
         -1.6537e-01, -1.1834e-01]])

In [13]:
pretrained_vector.shape

torch.Size([6153, 200])