In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torchvision
from torch.optim import Adam
from torch.autograd import Variable

from flyai.dataset import Dataset
from flyai.source.base import DATA_PATH

from model import Model
from net import Net, LSTMNet,FCNet
from path import MODEL_PATH
from processor import NLPFlyAI
from utils import Bunch

def eval(model, x_test, y_test):
    network.eval()
    total_acc = 0.0
    data_len = len(x_test[0])
    x1, x2 = x_test
    x1 = torch.from_numpy(x1)
    x2 = torch.from_numpy(x2)
    x1 = x1.float().to(device)
    x2 = x2.float().to(device)
    y_test = torch.from_numpy(y_test)
    y_test = y_test.to(device)
    batch_eval = model.batch_iter(x1, x2, y_test)

    for x_batch1, x_batch2, y_batch in batch_eval:
        outputs = network(x_batch1, x_batch2)
        _, prediction = torch.max(outputs.data, 1)
        correct = (prediction == y_batch).sum().item()
        total_acc += correct
    return total_acc / data_len

parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=16, type=int, help="batch size")
args = parser.parse_args([])
dataset = Dataset(epochs=args.EPOCHS, batch=args.BATCH)
model = Model(dataset)

#settings
settings = {
'net':'lstm',
'nc':5,    
'lr': 0.001,
'seed': 1,
'log_interval': 100,
'save_model': True,
'predict': True}
print("1.settings",settings)
settings = Bunch(settings)
use_cuda = torch.cuda.is_available()
torch.manual_seed(settings.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

#load data
trn,val = load_data()

The history saving thread hit an unexpected error (LoopExit('This operation would block forever', <Hub at 0x7f534e02c930 epoll pending=0 ref=0 fileno=56>)).History will not be written to the database.


FileNotFoundError: [Errno 2] No such file or directory: 'data/input/glove.txt'

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from processor import load_data

In [3]:
trn,val = load_data()

{'id': 'XiaoHuangJinew', 'type': 'csv', 'config': {'train_url': 'https://dataset.flyai.com/dataset/XiaoHuangJinew/dev.zip', 'test_url': 'https://dataset.flyai.com/dataset/XiaoHuangJinew/dev.zip'}}
[200, 99, '鸡', 4]
[200, 91, '= =', 18]
= =                        18
我错了，我会改的                    4
你说的是左皮条吧\(^o^)/YES！         2
我么                          2
你死了我怎么办呐                    2
只是个性取向,它不应该是多数人口中说的病...     2
快回去睡觉吧                      2
大天使中活泼开朗又傲娇的大美女             2
arcacia and john <3         2
当然是李欢咯                      2
Name: answer, dtype: int64


In [4]:
trn

Unnamed: 0,question,answer
0,你好残忍,可套话就甭讲了，你哪儿的？
1,你死了,你死了我怎么办呐
2,"你身份证呢,买车票带你回家",太爷爷说他会在尽头等着你。
3,best,arcacia and john <3
4,被困车站了,人太多
5,讲个笑话呗,我想说……你家wifi密码多少？！lz差点被自己一口老血噎死啊！你丫的不就要个密码吗至于天天...
6,答非所问啊你,我错了，我会改的
7,擦擦擦擦擦,诶你咋骂鸡类？
8,找炮友,A_A祝你精尽人亡
9,"就是,别狡辩",才没有呢！


In [5]:
from path import DATA_PATH
import os
import json

In [6]:
embedding_path = os.path.join(DATA_PATH, 'embedding.json')
with open(embedding_path) as f:
    ch_vecs = json.loads(f.read())

In [65]:
new_ch_vecs = dict()
for k,v in ch_vecs.items():
    new_ch_vecs[k] = np.array([float(i) for i in v])

In [66]:
len(new_ch_vecs['你']), len(new_ch_vecs)

(200, 7361)

In [34]:
import jieba

In [38]:
[i for i in jieba.cut(trn.iloc[0,0],cut_all=False)]

['你好', '残忍']

In [53]:
len(new_ch_vecs)

7361

In [46]:
len(new_ch_vecs['残忍']*3)

600

In [21]:
from create_dict import load_dict

In [24]:
def load_dict(file=None):
    char_dict_re = dict()
    if file==None:
        dict_path = os.path.join(DATA_PATH, 'word.dict')
    else:
        dict_path = os.path.join(DATA_PATH, file)
    with open(dict_path, encoding='utf-8') as fin:
        char_dict = json.load(fin)
    char_dict["_bos_"] = 0
    char_dict["_pad_"] = 1
    char_dict["_eos_"] = 2
    char_dict["_unk_"] = 3
    for k, v in char_dict.items():
        char_dict_re[v] = k
    return char_dict, char_dict_re

In [25]:
stoi_dict,itos_dict  = load_dict(file="words.dict")

In [31]:
itos = [itos_dict[i] for i in range(len(itos_dict))]

In [9]:
from processor import Processor

In [10]:
proc = Processor()

In [17]:
x_ids = [proc.input_x(s) for s in list(trn.iloc[:,0])] 

In [18]:
y_ids = [proc.input_x(s) for s in list(trn.iloc[:,1])] 

In [34]:
import numpy as np

In [38]:
xlen_95 = int(np.percentile([len(o) for o in x_ids], 95))
ylen_95 = int(np.percentile([len(o) for o in y_ids], 95))

In [100]:
x_ids_tr = np.array([o[:xlen_95] if len(o) > xlen_95 else o+[1]*(xlen_95-len(o)) for o in x_ids])

In [102]:
y_ids_tr = np.array([o[:ylen_95] if len(o) > ylen_95 else o+[1]*(ylen_95-len(o)) for o in y_ids])

In [68]:
import numpy as np

In [74]:
np.array(list(temp.keys()))-list(range(4,len(temp)+4))

array([0, 0, 0, ..., 0, 0, 0])

In [None]:
from 

In [72]:
def create_emb(vecs, itos, em_sz):
    emb = nn.Embedding(len(itos), em_sz, padding_idx=1)
    wgts = emb.weight.data
    miss = []
    for i,w in enumerate(itos):
        try: 
            wgts[i] = torch.from_numpy(vecs[w])
            #print(i,w,vecs[w][:5])
        except: 
            miss.append(w)
    print(len(miss),miss[5:10])
    return emb

In [44]:
import torch.nn as nn

In [74]:
from torch.utils.data.dataset import Dataset

In [79]:
class Seq2SeqDataset(Dataset):
    def __init__(self, x, y): self.x,self.y = x,y
    def __getitem__(self, idx): return (self.x[idx], self.y[idx])
    def __len__(self): return len(self.x)

In [103]:
trn_ds = Seq2SeqDataset(x_ids_tr,y_ids_tr)
#val_ds = Seq2SeqDataset(fr_val,en_val)

In [104]:
x_ids_tr

array([[ 148, 2277,    1, ...,    1,    1,    1],
       [   4,   96,    8, ...,    1,    1,    1],
       [   4,    3,   52, ...,    4,  190,    1],
       ...,
       [   4,   72,  488, ...,    1,    1,    1],
       [   4,    9,    3, ...,    1,    1,    1],
       [3428,  215,    1, ...,    1,    1,    1]])

In [105]:
len(y_ids_tr)

200

In [110]:
trn_dl = torch.utils.data.DataLoader(trn_ds,batch_size=16,shuffle=True)

In [111]:
x,y = next(iter(trn_dl))

In [128]:
from torch.autograd import Variable

In [147]:
class Seq2SeqRNN(nn.Module):
    def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, out_sl,nh=256, nl=2):
        super().__init__()
        self.nl,self.nh,self.out_sl = nl,nh,out_sl
        self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)
        self.emb_enc_drop = nn.Dropout(0.15)
        self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)
        self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)
        
        self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)
        self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(em_sz_dec, len(itos_dec))
        self.out.weight.data = self.emb_dec.weight.data
        
    def forward(self, inp):
        sl,bs = inp.size()
        h = self.initHidden(bs)
        print(h.shape)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        print(emb.shape)
        enc_out, h = self.gru_enc(emb, h)
        print(enc_out.shape,h.shape)
        h = self.out_enc(h)
        print(h.shape)

        dec_inp = Variable(torch.zeros(bs).long().cuda())
        print(dec_inp.shape)
        res = []
        for i in range(self.out_sl):
            emb = self.emb_dec(dec_inp).unsqueeze(0)
            print(emb.shape)
            outp, h = self.gru_dec(emb, h)
            print(outp.shape,h.shape)
            outp = self.out(self.out_drop(outp[0]))
            print(outp.shape)
            res.append(outp)
            print(res.shape)
            dec_inp = Variable(outp.data.max(1)[1].cuda())
            print(dec_inp.shape)
            if (dec_inp==1).all(): break
        return torch.stack(res)
    
    def initHidden(self, bs): return Variable(torch.zeros(self.nl, bs, self.nh).cuda())

In [142]:
x.shape

torch.Size([16, 10])

In [149]:
net = Seq2SeqRNN(new_ch_vecs,itos,200,new_ch_vecs,itos,200,10).cuda()

165 [' ', '〜', '...', '¯', '..']
165 [' ', '〜', '...', '¯', '..']


In [150]:
outputs = net(Variable(x.transpose(1,0).cuda()))

In [151]:
outputs.shape

torch.Size([10, 16, 7506])

In [153]:
def seq2seq_loss(input, target):
    sl,bs = target.size()
    sl_in,bs_in,nc = input.size()
    if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in))
    input = input[:sl]
    return F.cross_entropy(input.view(-1,nc), target.view(-1))#, ignore_index=1)

In [155]:
from torch.nn import functional as F

In [156]:
loss = seq2seq_loss(outputs,Variable(y.transpose(1,0).cuda()))

In [157]:
loss

tensor(9.1460, device='cuda:0', grad_fn=<NllLossBackward>)

In [152]:
outputs

tensor([[[-1.4144e+00, -3.4127e-02, -4.7595e-01,  ..., -7.7514e-02,
          -5.2614e-01,  9.6109e-01],
         [ 7.3612e-01, -3.4127e-02, -1.1405e+00,  ..., -5.7174e-01,
          -3.0829e-01,  5.8698e-01],
         [-2.1266e+00, -3.4127e-02, -7.8520e-01,  ..., -6.4584e-01,
          -7.2634e-01, -6.9792e-01],
         ...,
         [ 6.2402e-01, -3.4127e-02,  1.0357e+00,  ..., -7.6377e-01,
          -4.6108e-01,  1.4783e+00],
         [-7.7962e-01, -3.4127e-02, -6.4532e-01,  ..., -3.9060e-01,
          -4.4784e-02,  1.0060e+00],
         [-1.0151e+00, -3.4127e-02,  5.0667e-03,  ..., -5.0701e-03,
          -4.4643e-01, -1.2170e+00]],

        [[ 1.6156e+00, -3.4127e-02,  3.9186e-01,  ..., -6.3260e-01,
          -1.0663e+00, -8.3731e-01],
         [ 1.0965e+00, -3.4127e-02, -1.4871e+00,  ..., -3.9965e-01,
          -4.8284e-01,  1.9086e+00],
         [ 1.1110e+00, -3.4127e-02,  1.3416e+00,  ..., -2.8565e-01,
          -4.6969e-01,  1.8640e-01],
         ...,
         [-2.9348e+00, -3

In [85]:
new_its[0][0]

[tensor([ 234, 2544,  107, 2181,    5,    4,  498, 1715,    9,   49]),
 tensor([   5,   38,    3,    8,   44,    9,  279, 1715,   70,   14])]

In [68]:
emb_v=create_emb(new_ch_vecs,itos,200)

4 你 [ 0.196675 -0.413569 -0.193016 -0.028678  0.083858]
5 我 [ 0.250116 -0.366958  0.065014  0.010725  0.231398]
7 的 [ 0.209092 -0.165459 -0.058054  0.281176  0.102982]
8 了 [ 0.257969 -0.288066  0.105582  0.281781  0.262464]
9 是 [ 0.088422 -0.220535  0.042321  0.280248  0.158567]
10 。 [ 0.128825 -0.267995  0.000795  0.263639  0.097538]
12 ！ [ 0.124369 -0.549265 -0.151302  0.267781  0.088999]
13 = [-0.344582 -0.066462 -0.180318 -0.0739   -0.047484]
14 , [-0.098885 -0.335186 -0.092543  0.145954  0.062586]
15 不 [ 0.206238 -0.284742 -0.265172  0.241002  0.081033]
16 啊 [ 0.247648 -0.466043 -0.107939  0.036358  0.168662]
17 ~ [ 0.034433 -0.388317 -0.448883  0.084547  0.112937]
18 好 [ 0.163763 -0.424984 -0.054846  0.023387  0.289408]
19 说 [ 0.296825 -0.128991  0.098874 -0.012487  0.265535]
20 吗 [ 0.299329 -0.321911 -0.345951  0.052636  0.291682]
21 什么 [ 0.259471 -0.297434  0.031939 -0.087495  0.273416]
22 ？ [ 0.402922 -0.300508 -0.353942  0.112653  0.106601]
23 … [ 0.02928  -0.341184 -0.04423 

844 一遍 [ 0.350727 -0.003273  0.15652   0.025715  0.568927]
845 老子 [ 0.06268  -0.31683   0.097763 -0.512564  0.35889 ]
846 搞基 [ 0.23829  -0.163591  0.253119  0.094879  0.283239]
847 怪 [ 0.31169   0.009581  0.211327 -0.096923  0.124296]
848 哪儿 [ 0.16295  -0.089257 -0.255871 -0.330243  0.092607]
849 倒霉 [0.440475 0.095992 0.570768 0.086394 0.320911]
850 一会 [ 0.125603 -0.27897   0.133231  0.053737  0.306913]
851 迟到 [ 0.262602 -0.051425  0.208188  0.147878  0.563314]
852 考 [ 0.223435 -0.284751  0.054118 -0.174407  0.30655 ]
853 终于 [ 0.263801 -0.164315  0.052887  0.419486  0.208465]
854 神 [ 0.226567 -0.257715 -0.283453  0.022964  0.20394 ]
855 最美 [-0.104333 -0.412292  0.218975  0.007365  0.105472]
856 吓 [ 0.511686 -0.100864  0.545979  0.096763  0.191352]
857 br [-0.347828  0.204213 -0.169189 -0.238088  0.001621]
858 � [-0.14742   0.17138  -0.396312 -0.048187 -0.089976]
859 起 [ 0.489439  0.060076 -0.20476  -0.043022  0.284532]
860 爷 [ 0.087211  0.06165  -0.031054 -0.059282  0.197202]
861 处女座 [

1694 愿 [ 0.130409 -0.277182 -0.157271  0.33332  -0.140921]
1695 怀孕 [ 0.239514 -0.352133 -0.136154  0.277447  0.024279]
1696 微博 [-0.216053 -0.720033 -0.167079 -0.016672 -0.154634]
1697 小公鸡 [ 0.62483  -0.579977  0.197686 -0.254129  0.281885]
1698 大妈 [ 0.537038 -0.265206  0.08678   0.075574 -0.021441]
1699 多大 [ 0.23174  -0.331154 -0.232253 -0.032825  0.423822]
1700 嘎 [ 0.290598  0.000994 -0.023825  0.096733  0.066136]
1701 呼呼 [ 0.468278 -0.171089 -0.032459  0.05535   0.272743]
1702 卧槽 [ 0.456008 -0.44002   0.355712  0.087721  0.163181]
1703 出生 [-0.060371  0.113237 -0.044584  0.355819  0.409048]
1704 仔 [-0.253465 -0.510633 -0.080782  0.145653 -0.284383]
1705 雪 [ 0.446273 -0.157447 -0.153176  0.032813  0.159957]
1706 苍老 [ 0.266244  0.115856  0.25828  -0.262791  0.175488]
1707 胡说 [ 0.4103    0.005155  0.238437 -0.069266  0.415634]
1708 美国 [ 0.396758 -0.122395  0.286677  0.319145  0.317309]
1709 第二天 [ 0.419561 -0.302645  0.203666  0.158619  0.550019]
1710 神经 [ 0.327928 -0.152548  0.024536 -0.

2668 不得 [ 0.062447 -0.282425  0.111833  0.220106 -0.283051]
2669 who [ 0.178356  0.047433 -0.131913  0.015662  0.264444]
2670 one [ 0.070961 -0.285649 -0.19873   0.035967  0.425727]
2671 of [-0.078815 -0.168577  0.15957   0.021186  0.512275]
2672 鸡精 [ 0.246589 -0.685132  0.035107  0.125535  0.212277]
2673 鬼故事 [ 0.433756  0.24648   0.691568 -0.223275  0.26876 ]
2674 骚 [ 0.212393 -0.301295  0.28077   0.159659  0.171094]
2675 额额 [ 0.281942 -0.275158 -0.128743  0.052542  0.318543]
2676 雷峰塔 [ 0.557541 -0.43614   0.371749 -0.053722  0.074287]
2677 零 [-0.104989 -0.121566 -0.228149  0.319507  0.120854]
2678 迷人 [-0.079674 -0.07774   0.07197   0.034808  0.239416]
2679 辣子鸡 [ 0.427054 -0.597138  0.098357  0.057559 -0.031229]
2680 赵 [-0.049188 -0.147367 -0.010793 -0.247435 -0.020409]
2681 读 [-0.171151 -0.092258  0.368761 -0.489989  0.19001 ]
2682 行为 [ 0.267852 -0.564564 -0.027976  0.18105  -0.002201]
2683 肚子疼 [ 0.491899 -0.216666  0.151319  0.134614  0.365807]
2684 聊聊 [ 0.042667  0.292305 -0.060999

3688 巧克力 [ 0.178338 -0.351004  0.034794 -0.134927 -0.205741]
3689 就够 [ 0.140124 -0.442112  0.05648  -0.005826  0.115248]
3691 家族 [-0.042581  0.427032  0.038468  0.294168  0.067425]
3692 害得 [ 0.452946 -0.299041  0.538188 -0.049678  0.324979]
3693 孙 [-0.01817   0.013515 -0.110584 -0.172605  0.276869]
3694 委屈 [ 0.383027 -0.11395  -0.081868 -0.163533  0.284941]
3695 奴隶 [ 0.341666 -0.341819  0.229728  0.060668  0.200232]
3696 奖学金 [ 0.099213 -0.384929  0.516923 -0.492209  0.067891]
3697 失望 [ 0.233062 -0.151452  0.216751 -0.005343  0.509438]
3698 天生一对 [ 0.265875 -0.1041    0.121839 -0.226064  0.16031 ]
3699 大餐 [ 0.260266 -0.374121 -0.210768 -0.055724  0.078115]
3700 嗷嗷 [ 0.473516 -0.172822  0.122808  0.139852  0.239103]
3701 唷 [-0.032419 -0.084419 -0.050809  0.085324  0.124062]
3702 周围 [-0.027651  0.062679  0.149366  0.245848  0.079629]
3703 友 [-0.057396 -0.251277 -0.190577 -0.022846  0.271839]
3704 卖不卖 [ 0.371909 -0.891446  0.087653 -0.309039  0.387737]
3705 包邮 [ 0.115053 -1.001567 -0.439652

4275 那来 [ 0.038194 -0.101475 -0.153266  0.030431 -0.004933]
4276 身高 [ 1.55970e-01 -3.29995e-01  4.70000e-05  2.77540e-01  1.66049e-01]
4277 走走 [ 0.061056 -0.100552  0.265781 -0.039674 -0.042602]
4278 赖着 [ 0.328678 -0.160853  0.142021 -0.06225   0.433355]
4279 贵姓 [ 0.017683  0.145411  0.361887  0.50232  -0.087021]
4280 谢娜 [ 0.080676 -0.190072 -0.613746  0.326521 -0.5798  ]
4281 讨人喜欢 [ 0.011164 -0.00806   0.008651 -0.218396  0.451238]
4282 见怪不怪 [ 0.479102  0.066488  0.427833 -0.06388   0.536481]
4283 行动 [-0.080144 -0.381589 -0.262443  0.332129 -0.274896]
4284 臣妾 [ 0.614357  0.121705 -0.505536 -0.085884  0.11395 ]
4285 胖纸 [ 0.244268 -0.581683 -0.343516  0.136202  0.058557]
4286 胖死 [ 0.35428  -0.899682  0.074277 -0.027939  0.515141]
4287 美貌 [-0.022322 -0.176405  0.053515 -0.112742  0.108559]
4288 美美 [ 0.019026 -0.715964 -0.217309 -0.051829 -0.013929]
4289 缺心眼 [ 0.384094 -0.208554  0.194174 -0.342513 -0.179644]
4290 缺 [ 0.190324 -0.273006 -0.099364  0.312856 -0.086327]
4291 给力 [ 0.214526 -0

5222 南 [ 0.049209 -0.153944 -0.113007  0.224198 -0.215302]
5223 华中科技大学 [-0.113932  0.096189  0.067066 -0.122936  0.140919]
5224 勇算 [ 0.24725  -0.048601  0.151016  0.010598 -0.134954]
5225 勇气 [ 0.434734 -0.216094 -0.119854  0.139724 -0.157807]
5226 功劳 [ 0.14035  -0.168193  0.237179  0.158334  0.119038]
5227 劝 [ 0.316242 -0.443277  0.322647 -0.300811  0.3708  ]
5228 剩 [ 0.408019 -0.522666  0.238518  0.103362  0.031347]
5229 剥夺 [ 0.287375 -0.109362 -0.199213  0.119239  0.214067]
5230 别问 [ 0.247595  0.014821 -0.16961  -0.024435  0.367826]
5231 别看 [ 0.292696 -0.090481  0.133133 -0.019982  0.295837]
5232 出轨 [ 0.263056 -0.385787  0.020075  0.181705  0.283073]
5233 凤爪 [ 0.500934 -0.589267  0.012814 -0.060037 -0.256587]
5234 净化 [ 0.105425  0.013387 -0.548062  0.058317  0.036122]
5235 军爷 [ 0.086651 -0.121748  0.023378  0.257187  0.021822]
5236 共同 [ 0.068369  0.100089 -0.261583  0.058071 -0.057951]
5237 公平 [ 0.290528 -0.485408 -0.040515 -0.123575  0.347045]
5238 入 [ 0.039564 -0.114633  0.201552  

6250 篮球 [ 0.064134 -0.015182  0.124212 -0.073151  0.216761]
6251 算卦 [ 0.622589 -0.074428  0.173744 -0.009157  0.02454 ]
6252 等不及 [ 0.341916  0.049742 -0.341395  0.259423  0.448695]
6253 端 [ 0.205836 -0.242689 -0.230795  0.121647 -0.165288]
6254 章鱼 [ 0.641278 -0.107443  0.202446  0.034328 -0.014108]
6255 究竟 [ 0.440417 -0.217777 -0.457592  0.149049 -0.009097]
6256 神经错乱 [ 0.749673 -0.323473  0.476067 -0.271099  0.385181]
6257 神秘 [ 0.232682  0.356464  0.182661  0.361486 -0.035527]
6258 神曲 [ 0.607888 -0.209427  0.449545 -0.039821 -0.077415]
6259 碰 [ 0.349032 -0.37732   0.491608  0.136341  0.200743]
6260 砍死 [ 0.272207 -0.447261  0.615579  0.166229  0.286345]
6261 石家庄 [ 0.434851 -0.411119 -0.293706 -0.09793  -0.156853]
6262 真是太 [ 0.265005 -0.420522 -0.036922 -0.194471  0.355925]
6263 真多 [ 0.037669 -0.14584   0.096663 -0.105843  0.286364]
6264 看法 [ 0.182865 -0.332652  0.038777  0.06749   0.374863]
6265 看待 [ 0.26414  -0.288243 -0.078545  0.001639  0.446375]
6266 看哥 [ 0.244262 -0.180496 -0.02922

7241 懵 [ 0.600145 -0.073086  0.121197 -0.032384  0.353693]
7242 意中人 [ 0.072949 -0.222957  0.300238  0.163481 -0.013827]
7243 惊喜 [0.064246 0.05438  0.061704 0.302602 0.027007]
7244 悬疑 [ 0.222587 -0.093706  0.156933  0.290577  0.039947]
7245 您好 [ 0.390587  0.054096 -0.413977  0.400257  0.136095]
7246 恩是 [ 0.083272 -0.228842 -0.05313  -0.122822  0.082178]
7247 怎能 [ 0.437929 -0.056016 -0.17343   0.141409  0.144255]
7248 怀抱 [-0.017651 -0.127974 -0.143658  0.362341  0.489574]
7249 必备 [ 0.311247 -0.233151 -0.053618 -0.04237  -0.074997]
7250 心跳 [-0.2249   -0.06142   0.348003  0.312769  0.161264]
7251 心肝 [ 0.444229 -0.27253   0.11571  -0.278998  0.113866]
7252 心碎 [ 0.391524 -0.059118  0.21152   0.002029  0.196436]
7253 德克士 [ 0.542871 -0.2372   -0.344518 -0.058251 -0.673128]
7254 很能 [ 0.228093 -0.417363  0.131234  0.006168  0.249764]
7255 很少 [0.090069 0.109931 0.188957 0.182504 0.344035]
7256 很久 [ 0.128088 -0.332843  0.094654  0.272586  0.278403]
7257 形象 [ 0.241764 -0.527109  0.198403  0.082521 

In [60]:
import torch

In [69]:
torch.from_numpy(new_ch_vecs['你'])

tensor([ 0.1967, -0.4136, -0.1930, -0.0287,  0.0839, -0.1606,  0.0565,  0.0265,
         0.0794,  0.1186,  0.2014,  0.2022,  0.1390, -0.0453,  0.0887, -0.1390,
        -0.3040, -0.3704, -0.0682, -0.1928,  0.0520,  0.0273,  0.0070,  0.0763,
         0.1217,  0.0915, -0.1481,  0.1849,  0.2139, -0.1024, -0.0666,  0.0073,
        -0.1938,  0.2115, -0.0016, -0.0100,  0.0971,  0.1547, -0.1384,  0.0107,
         0.1249, -0.0386,  0.3877, -0.1198,  0.0429,  0.0290, -0.4272, -0.4387,
        -0.0157,  0.2999, -0.3335,  0.0380, -0.0928,  0.2277,  0.2683,  0.1118,
         0.1604,  0.0797,  0.1474, -0.1877, -0.1181,  0.2545,  0.0720, -0.2411,
        -0.1078,  0.0656, -0.0721,  0.0893, -0.0672, -0.1612, -0.1154,  0.0974,
         0.0582,  0.1212, -0.0196,  0.3010, -0.0239, -0.0822,  0.1311, -0.0991,
         0.1879,  0.4724, -0.0319,  0.0555,  0.4015, -0.3073, -0.1109,  0.3506,
        -0.1827, -0.0759,  0.2203,  0.1573, -0.1767,  0.1550, -0.0312,  0.1041,
         0.0828, -0.1363, -0.2719, -0.00

In [70]:
emb_v.weight[4,]

tensor([ 0.1967, -0.4136, -0.1930, -0.0287,  0.0839, -0.1606,  0.0565,  0.0265,
         0.0794,  0.1186,  0.2014,  0.2022,  0.1390, -0.0453,  0.0887, -0.1390,
        -0.3040, -0.3704, -0.0682, -0.1928,  0.0520,  0.0273,  0.0070,  0.0763,
         0.1217,  0.0915, -0.1481,  0.1849,  0.2139, -0.1024, -0.0666,  0.0073,
        -0.1938,  0.2115, -0.0016, -0.0100,  0.0971,  0.1547, -0.1384,  0.0107,
         0.1249, -0.0386,  0.3877, -0.1198,  0.0429,  0.0290, -0.4272, -0.4387,
        -0.0157,  0.2999, -0.3335,  0.0380, -0.0928,  0.2277,  0.2684,  0.1118,
         0.1604,  0.0797,  0.1474, -0.1877, -0.1181,  0.2545,  0.0720, -0.2411,
        -0.1078,  0.0656, -0.0721,  0.0893, -0.0672, -0.1612, -0.1154,  0.0974,
         0.0582,  0.1212, -0.0196,  0.3010, -0.0239, -0.0822,  0.1311, -0.0991,
         0.1879,  0.4724, -0.0319,  0.0555,  0.4015, -0.3073, -0.1109,  0.3506,
        -0.1827, -0.0759,  0.2203,  0.1573, -0.1767,  0.1550, -0.0312,  0.1041,
         0.0828, -0.1363, -0.2719, -0.00

In [59]:
import torch

In [61]:
embedding.weight

Parameter containing:
tensor([[ 1.3606,  1.3935, -1.7506],
        [ 0.2580,  0.4514,  0.1433],
        [-0.3592,  1.6911, -0.7749],
        [ 0.6672,  1.1647, -0.3060],
        [-0.6404, -0.0430, -1.5256],
        [ 0.4141,  1.5914,  0.8152],
        [ 1.2896,  0.3880,  1.0068],
        [-0.5617, -1.1332, -0.9011],
        [-0.5368,  1.6581,  0.5620],
        [-0.5265,  0.1814,  0.6147]], requires_grad=True)

In [60]:
embedding = nn.Embedding(10, 3)
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
embedding(input)

tensor([[[ 0.2580,  0.4514,  0.1433],
         [-0.3592,  1.6911, -0.7749],
         [-0.6404, -0.0430, -1.5256],
         [ 0.4141,  1.5914,  0.8152]],

        [[-0.6404, -0.0430, -1.5256],
         [ 0.6672,  1.1647, -0.3060],
         [-0.3592,  1.6911, -0.7749],
         [-0.5265,  0.1814,  0.6147]]], grad_fn=<EmbeddingBackward>)

In [146]:
emb_v(x).shape

torch.Size([16, 10, 200])

In [122]:
m=nn.Dropout(0.15)

In [123]:
m(emb_v(x))

tensor([[[ 0.0000, -0.4317,  0.0765,  ..., -0.3891, -0.1952,  0.2283],
         [ 0.1927, -0.5000, -0.0645,  ..., -0.4238, -0.2937, -0.0897],
         [ 0.0000, -0.0000, -0.1066,  ..., -0.0000, -0.3937,  0.2184],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.1649, -0.0000,  0.0993,  ..., -0.4007, -0.1468, -0.0016],
         [ 0.3522, -0.3787, -0.4070,  ..., -0.5614, -0.3115,  0.2759],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.3119, -0.1608, -0.0000,  ..., -0.3745, -0.2849,  0.2594],
         [ 0.0000, -0.3239,  0.0919,  ..., -0

In [125]:
m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)

In [126]:
input

tensor([[ 1.5967e+00, -8.7710e-01,  1.4251e-01, -1.7002e+00, -8.6332e-02,
         -2.5369e-01, -1.4916e+00, -1.7339e+00,  6.1279e-01,  1.4885e-01,
          4.2911e-02,  7.6687e-01,  6.9720e-01,  1.3379e+00,  1.2256e-01,
          1.4160e+00],
        [-2.1643e+00,  1.0280e+00,  2.0979e-01,  1.8836e+00, -1.7015e-01,
          4.7800e-01,  7.4575e-01, -7.6690e-01, -1.3111e+00, -2.5890e+00,
         -1.3054e+00,  1.5804e+00, -2.8454e-01, -9.0462e-01, -1.7558e+00,
          2.6978e-01],
        [ 1.0421e-01, -7.4674e-01, -4.3730e-01, -1.1868e-01, -8.8938e-01,
         -6.9591e-01, -7.8306e-01,  1.0312e+00,  7.1943e-01, -4.1194e-01,
          6.0410e-01, -1.0338e-01,  1.5226e+00, -8.4115e-01, -5.2396e-02,
         -9.1368e-01],
        [ 5.5439e-02,  9.0019e-01,  1.0773e+00, -2.3542e-01,  5.7414e-01,
         -7.8296e-02, -1.7377e+00,  4.4310e-02,  2.2713e-01,  2.8258e-01,
          1.4503e-02,  1.3255e-01, -2.9122e+00,  1.2389e+00,  1.9651e+00,
          3.7911e-02],
        [-1.2910e-02

In [127]:
output

tensor([[ 1.9959, -1.0964,  0.1781, -2.1253, -0.1079, -0.3171, -1.8645, -0.0000,
          0.7660,  0.1861,  0.0536,  0.9586,  0.8715,  0.0000,  0.0000,  0.0000],
        [-2.7054,  1.2850,  0.0000,  2.3545, -0.2127,  0.5975,  0.9322, -0.9586,
         -1.6388, -3.2363, -1.6317,  1.9755, -0.3557, -1.1308, -2.1947,  0.0000],
        [ 0.1303, -0.0000, -0.5466, -0.1483, -1.1117, -0.8699, -0.0000,  1.2890,
          0.8993, -0.0000,  0.0000, -0.1292,  1.9032, -1.0514, -0.0655, -1.1421],
        [ 0.0693,  1.1252,  1.3466, -0.2943,  0.0000, -0.0000, -0.0000,  0.0554,
          0.2839,  0.3532,  0.0181,  0.1657, -3.6402,  0.0000,  2.4564,  0.0474],
        [-0.0161,  1.0443, -2.2881,  0.7965, -0.7746, -0.4523, -1.3798, -0.2262,
         -0.9427, -0.0410,  0.3112,  0.2259, -0.1014, -0.7689,  0.0588, -0.0000],
        [-2.1883,  0.3665,  0.1870, -1.2437,  0.7093, -1.8076, -0.7913, -1.2375,
          0.0000,  1.1431,  0.0000,  0.8015, -0.5070, -1.4495,  0.3598,  0.4552],
        [-0.0000, -0.4

In [121]:
nn.Dropout(0.15)(emb_v(x))

tensor([[[ 0.2943, -0.4317,  0.0765,  ..., -0.3891, -0.1952,  0.2283],
         [ 0.0000, -0.5000, -0.0645,  ..., -0.4238, -0.2937, -0.0897],
         [ 0.3742, -0.2417, -0.1066,  ..., -0.2509, -0.0000,  0.2184],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.1649, -0.2477,  0.0993,  ..., -0.4007, -0.1468, -0.0016],
         [ 0.3522, -0.3787, -0.4070,  ..., -0.5614, -0.3115,  0.2759],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.3119, -0.1608, -0.1879,  ..., -0.3745, -0.0000,  0.2594],
         [ 0.2467, -0.0000,  0.0919,  ..., -0

In [51]:
??nn.Embedding

In [52]:
len(temp)

7502