In [1]:
import sys
sys.path.append('../../')
import kitorch as mt
from kitorch import nn,optim
from kitorch import functional as F
import numpy as np
import torch

In [2]:
class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,nonlinearity='relu',dropout=0):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(
            input_size= input_size,
            hidden_size= hidden_size,
            num_layers= num_layers,
            dropout=dropout
        )
        self.fc = nn.Linear(hidden_size, 1)
        
    def forward(self,x,h0):
        x, hn = self.rnn(x,h0)
        x = self.fc(x)
        return x,hn
    

class TrochRNN(torch.nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,nonlinearity='relu',dropout=0):
        super(TrochRNN, self).__init__()

        self.rnn = torch.nn.RNN(
            input_size= input_size,
            hidden_size= hidden_size,
            num_layers= num_layers,
#             nonlinearity = nonlinearity,   #tanh 或者 relu
#             batch_first = True,  # 数据顺序要相应修改
#             dropout = dropout
      )
        self.fc = torch.nn.Linear(hidden_size, 1)
        
    def forward(self,x,h0):
        x, hn = self.rnn(x,h0)
        x = self.fc(x)
        return x,hn    
    


In [3]:
import matplotlib.pyplot as plt
# 创造一些数据
steps = np.linspace(0, np.pi*2, 100, dtype=np.float)
x_np = np.sin(steps)
y_np = np.cos(steps)
#
# “看”数据
plt.plot(steps, y_np, 'r-', label='target(cos)')
plt.plot(steps, x_np, 'b-', label='input(sin)')
plt.legend(loc='best')
plt.show()

<Figure size 640x480 with 1 Axes>

In [4]:
def normalize_weights(parameters):
    for layer_paras in parameters:
        for para in layer_paras:
            norm = np.linalg.norm(para.grad.data)
            if norm > 1:
                scale = 1/norm
                para.grad.data *=scale

In [5]:
# 定义优化器和损失函数
lr = 0.001
num_layers=2
dropout=0
model = RNN(1,32,num_layers,dropout)
torch_model = TrochRNN(1,32,num_layers,dropout)
parameters = model.parameters()
optimizer = optim.RMSprop(parameters, lr=lr)
torch_optimizer = torch.optim.RMSprop(torch_model.parameters(), lr=lr)

In [6]:
%matplotlib qt5

In [7]:
%%time
seq_len = 20
input_size = 1
h_state = mt.zeros(num_layers,1,32)
torch_h_state = torch.zeros(num_layers,1,32)
# plt.figure(1,figsize=(12,5))
# 图出来了，可以继续画


for epoch in range(200):
# epoch = 0
    
    start, end = epoch * np.pi, (epoch+1)*np.pi

    steps = np.linspace(start, end, seq_len, dtype=np.float32)
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = mt.from_numpy(x_np[:,np.newaxis,np.newaxis])
    y = mt.from_numpy(y_np[:,np.newaxis,np.newaxis])

    optimizer.zero_grad()  
    prediction, h_state = model(x, h_state)
    h_state = h_state.copy()
    
    loss = F.mse_loss(prediction, y)
    loss.backward()
    normalize_weights(parameters)
    optimizer.step()

#     torch RNN
    x = torch.from_numpy(x_np[:,np.newaxis,np.newaxis])
    y = torch.from_numpy(y_np[:,np.newaxis,np.newaxis])
    torch_optimizer.zero_grad()  
    torch_prediction, torch_h_state = torch_model(x, torch_h_state)
    torch_h_state = torch_h_state.data
    
    torch_loss = torch.nn.functional.mse_loss(torch_prediction, y)

    torch_loss.backward()
    torch_optimizer.step()
    
    
    if epoch%50 == 0:
        print(loss.data.item(),torch_loss.data.item(),)
    
    plt.plot(steps,y_np.flatten(),'r-')  
    plt.plot(steps,prediction.data.flatten(),'b-')
    plt.plot(steps,torch_prediction.data.numpy().flatten(),'k-')
    plt.pause(0.005)
    if epoch%50 == 0:
        plt.cla()
    
    
plt.show()  

0.7564060192829054 0.5296083688735962
0.006452324976694046 0.05063014477491379
0.0038484363826662726 0.006987657863646746
0.01097653316088069 0.010676981881260872
CPU times: user 24.6 s, sys: 7.18 s, total: 31.8 s
Wall time: 7.57 s


In [8]:
def findNode(h,num=0,nums=[]):
    if h.depends_on:
        for de in h.depends_on:
            t = de[0]
            findNode(t,num+1,nums)
    else:
        nums.append(num)
        
def check_nums(h):
    nums = []
    findNode(h,0,nums)
    return sum(nums)

In [9]:
check_nums(loss)

15