In [41]:
import torch
import torch.nn as nn
import torch.optim
from torch.autograd import Variable
from collections import Counter
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np

%matplotlib inline

In [42]:
samples = 2000
sz = 10
probability = 1.0 * np.array([10,6,4,3,1,1,1,1,1,1])
probability = probability[ : sz]
probability = probability/sum(probability)

In [43]:
train_set = []

for m in range(samples):
    n = np.random.choice(range(1,sz+1),p = probability)
    inputs =[0]* n + [1]*n
    inputs.insert(0,3)
    inputs.append(2)
    train_set.append(inputs)

In [44]:
valid_set = []

for m in range(samples//10):
    n = np.random.choice(range(1,sz+1),p = probability)
    inputs =[0]* n + [1]*n
    inputs.insert(0,3)
    inputs.append(2)
    valid_set.append(inputs)

In [45]:
for m in range(2):
    n = sz+m
    inputs =[0]* n + [1]*n
    inputs.insert(0,3)
    inputs.append(2)
    valid_set.append(inputs)

np.random.shuffle(valid_set)

In [46]:
class SimpleLSTM(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,num_layers=1):
        super(SimpleLSTM,self).__init__()
        self.hidden_size = hidden_size 
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_size,hidden_size)
        self.lstm = nn.LSTM(hidden_size,hidden_size,num_layers,batch_first=True)
        self.fc = nn.Linear(hidden_size,output_size)
        self.softmax = nn.LogSoftmax(dim = 1)
    
    def forward(self,input,hidden):
        x = self.embedding(input)
        output,hidden =self.lstm(x,hidden)
        output= output[:,-1,:]
        output =self.fc(output)
        output= self.softmax(output)
        return output,hidden
    
    def initHidden(self):
        hidden = Variable(torch.zeros(self.num_layers,1,self.hidden_size))
        cell = Variable(torch.zeros(self.num_layers,1,self.hidden_size))
        return (hidden,cell)
    

In [47]:
lstm = SimpleLSTM(input_size=4,hidden_size=2,num_layers=1,output_size=3)
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(lstm.parameters(),lr = 0.001)

In [48]:
train_loss = 0

def trainLSTM(epoch):
    global train_loss
    train_loss =0
    np.random.shuffle(train_set)

    for i,seq in enumerate(train_set):
        loss = 0
        hidden = lstm.initHidden()
        for t in range(len(seq)-1):
            x=Variable(torch.LongTensor([seq[t]]).unsqueeze(0))
            y =Variable(torch.LongTensor([seq[t+1]]))
            output, hidden = lstm(x, hidden)
            loss+=criterion(output,y)
        loss=1.0*loss/len(seq)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        train_loss+=loss
        if 1>0 and i%500 ==0:
            print('第{}轮，第{}个,训练Loss:{:.2f}'.format(epoch,
                                                   i,
                                                   train_loss.data.numpy()/(i+1)
                                                   ))

In [49]:
valid_loss = 0
errors = 0
show_out = ''

def evaluateLSTM():
    global valid_loss
    global errors
    global show_out
    valid_loss = 0
    errors = 0
    show_out = ''
    for i ,seq in enumerate(valid_set):
        loss =0
        outstring = ''
        targets = ''
        diff = 0
        hidden = lstm.initHidden()
        for t in range(len(seq)-1):
            x = Variable(torch.LongTensor([seq[t]]).unsqueeze(0))
            y =Variable(torch.LongTensor([seq[t+1]]))
            output,hidden=lstm(x,hidden)
            mm = torch.max(output,1)[1][0]
            outstring +=str(mm.data.numpy())
            targets+=str(y.data.numpy()[0])
            loss+=criterion(output,y)
            diff +=1-mm.eq(y).data.numpy()[0]
        loss = 1.0*loss/len(seq)
        valid_loss+=loss
        errors+=diff
        if np.random.rand()<0.1:
            show_out= outstring+'\n'+targets
    print(output[0][2].data.numpy())




In [50]:
num_epoch = 20
results = []
for epoch in range(num_epoch):
    trainLSTM(epoch)

    evaluateLSTM()
    print('第{}轮,训练Loss:{:.2f},错误率:{:.2f}'.format(epoch,
                                                 train_loss.data.numpy()/len(train_set),
                                                 valid_loss.data.numpy()/len(valid_set),
                                                 1.0*errors/len(valid_set)
                                                 ))
    print(show_out)
    results.append([train_loss.data.numpy()/len(train_set),
                    valid_loss.data.numpy()/len(train_set),
                    1.0*errors/len(valid_set)
                    ])

第0轮，第0个,训练Loss:1.34
第0轮，第500个,训练Loss:0.97
第0轮，第1000个,训练Loss:0.91
第0轮，第1500个,训练Loss:0.86
-0.8512011
第0轮,训练Loss:0.80,错误率:0.59
00111111111111102
00000000111111112
第1轮，第0个,训练Loss:0.66
第1轮，第500个,训练Loss:0.54
第1轮，第1000个,训练Loss:0.50
第1轮，第1500个,训练Loss:0.47
-0.22808187
第1轮,训练Loss:0.45,错误率:0.37
002
012
第2轮，第0个,训练Loss:0.38
第2轮，第500个,训练Loss:0.36
第2轮，第1000个,训练Loss:0.35
第2轮，第1500个,训练Loss:0.34
-0.107153565
第2轮,训练Loss:0.33,错误率:0.30
002
012
第3轮，第0个,训练Loss:0.33
第3轮，第500个,训练Loss:0.30
第3轮，第1000个,训练Loss:0.29
第3轮，第1500个,训练Loss:0.29
-0.05522935
第3轮,训练Loss:0.29,错误率:0.27
002
012
第4轮，第0个,训练Loss:0.25
第4轮，第500个,训练Loss:0.27
第4轮，第1000个,训练Loss:0.27
第4轮，第1500个,训练Loss:0.27
-0.03312106
第4轮,训练Loss:0.26,错误率:0.25
00012
00112
第5轮，第0个,训练Loss:0.28
第5轮，第500个,训练Loss:0.25
第5轮，第1000个,训练Loss:0.25
第5轮，第1500个,训练Loss:0.25
-0.016601278
第5轮,训练Loss:0.25,错误率:0.24
0000000111112
0000001111112
第6轮，第0个,训练Loss:0.27
第6轮，第500个,训练Loss:0.25
第6轮，第1000个,训练Loss:0.25
第6轮，第1500个,训练Loss:0.25
-0.011795536
第6轮,训练Loss:0.25,错误率:0.24
012
012
第7轮，第0个,训练Loss:

In [51]:
torch.save(lstm,'lstm.mdl')
lstm=torch.load('lstm.mdl')

In [53]:
for n in range(20):
    inputs = [0]*n+[1]*n
    inputs.insert(0,3)
    inputs.append(2)
    outstring = ''
    targets = ''
    diff =0
    hiddens = []
    hidden = lstm.initHidden()
    for t in range(len(inputs)-1):
        x = Variable(torch.LongTensor([inputs[t]]).unsqueeze(0))
        y =Variable(torch.LongTensor([inputs[t+1]]))
        output,hidden=lstm(x,hidden)
        mm = torch.max(output,1)[1][0]
        outstring +=str(mm.data.numpy())
        targets+=str(y.data.numpy()[0])
        
        diff +=1-mm.eq(y).data.numpy()[0]
    print(n)
    print(outstring)
    print(targets)
    print('Diff:{}'.format(diff))

0
0
2
Diff:1
1
012
012
Diff:0
2
01012
00112
Diff:2
3
0100112
0001112
Diff:2
4
010001112
000011112
Diff:2
5
01000011112
00000111112
Diff:2
6
0100000111112
0000001111112
Diff:2
7
010000001111112
000000011111112
Diff:2
8
01000000011111112
00000000111111112
Diff:2
9
0100000000111111112
0000000001111111112
Diff:2
10
010000000001111111112
000000000011111111112
Diff:2
11
01000000000011111111112
00000000000111111111112
Diff:2
12
0100000000000111111111112
0000000000001111111111112
Diff:2
13
010000000000001111111111112
000000000000011111111111112
Diff:2
14
01000000000000011111111111112
00000000000000111111111111112
Diff:2
15
0100000000000000111111111111112
0000000000000001111111111111112
Diff:2
16
010000000000000001111111111111112
000000000000000011111111111111112
Diff:2
17
01000000000000000011111111111111112
00000000000000000111111111111111112
Diff:2
18
0100000000000000000111111111111111112
0000000000000000001111111111111111112
Diff:2
19
010000000000000000001111111111111111112
00000000000000000