-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%matplotlib inline\n", | ||
"import numpy as np\n", | ||
"from numpy.random import *\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from sklearn.datasets import make_blobs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class NN:\n", | ||
" def __init__(self, num_input, num_hidden, num_output, learning_rate):\n", | ||
" self.num_input = num_input\n", | ||
" self.num_hidden = num_hidden\n", | ||
" self.num_output = num_output\n", | ||
" self.learning_rate = learning_rate\n", | ||
"\n", | ||
" self.w_input2hidden = np.random.random((self.num_hidden, self.num_input))\n", | ||
" self.w_hidden2output = np.random.random((self.num_output, self.num_hidden))\n", | ||
" self.b_input2hidden = np.ones((self.num_hidden))\n", | ||
" self.b_hidden2output = np.ones((self.num_output))\n", | ||
"\n", | ||
" ##活性化関数(シグモイド関数)\n", | ||
" def activate_func(self, x):\n", | ||
" return 1/(1+np.exp(-x))\n", | ||
" ##活性化関数の微分\n", | ||
" def dactivate_func(self,x):\n", | ||
" return self.activate_func(x)*(1-self.activate_func(x))\n", | ||
" ##ソフトマックス関数\n", | ||
" def softmax_func(self,x):\n", | ||
" C = x.max()\n", | ||
" f = np.exp(x-C)/np.exp(x-C).sum()\n", | ||
" return f\n", | ||
" ##順伝播計算_\n", | ||
" def forward_propagation(self, x):\n", | ||
" u_hidden = np.dot(self.w_input2hidden, x) + self.b_input2hidden\n", | ||
" z_hidden = self.activate_func(u_hidden)\n", | ||
" u_output = np.dot(self.w_hidden2output, z_hidden) + self.b_hidden2output\n", | ||
" z_output = self.softmax_func(u_output)\n", | ||
" return u_hidden, u_output, z_hidden, z_output\n", | ||
" ##逆伝播でδを求める\n", | ||
" def backward_propagation(self,t,u_hidden,z_output):\n", | ||
" t_vec = np.zeros(len(z_output))\n", | ||
" t_vec[t] = 1\n", | ||
" delta_output = z_output - t_vec\n", | ||
" delta_hidden = np.dot(delta_output, self.w_hidden2output * self.dactivate_func(u_hidden))\n", | ||
" return delta_hidden, delta_output\n", | ||
" ##損失関数のパラメータwに関する勾配\n", | ||
" def calc_gradient(self,delta,z):\n", | ||
" dW = np.zeros((len(delta), len(z)))\n", | ||
" for i in range(len(delta)):\n", | ||
" for j in range(len(z)):\n", | ||
" dW[i][j] = delta[i] * z[j]\n", | ||
" return dW\n", | ||
" ##重みをアップデートする\n", | ||
" def update_weight(self,w0,gradE):\n", | ||
" return w0 - self.learning_rate*gradE" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false, | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def train(nn, iteration,savefig=False):\n", | ||
" epoch = 0\n", | ||
" for epoch in range(iteration+1):\n", | ||
" grad_i2h = 0\n", | ||
" grad_h2o = 0\n", | ||
" gradbias_i2h = 0\n", | ||
" gradbias_h2o = 0\n", | ||
" n = 0\n", | ||
" err = 0\n", | ||
" rand = randint(0,len(data),100)\n", | ||
" for r in rand:\n", | ||
" u_hidden, u_output, z_hidden, z_output = nn.forward_propagation(data[r])\n", | ||
" delta_hidden, delta_output = nn.backward_propagation(target[r], u_hidden, z_output)\n", | ||
" grad_i2h += nn.calc_gradient(delta_hidden, data[r])\n", | ||
" grad_h2o += nn.calc_gradient(delta_output, z_hidden)\n", | ||
" gradbias_i2h += delta_hidden\n", | ||
" gradbias_h2o += delta_output\n", | ||
" nn.w_input2hidden = nn.update_weight(nn.w_input2hidden, grad_i2h / len(rand))\n", | ||
" nn.w_hidden2output = nn.update_weight(nn.w_hidden2output, grad_h2o / len(rand))\n", | ||
" nn.b_input2hidden = nn.update_weight(nn.b_input2hidden, gradbias_i2h / len(rand))\n", | ||
" nn.b_hidden2output = nn.update_weight(nn.b_hidden2output, gradbias_h2o / len(rand))\n", | ||
" #print grad_h2o\n", | ||
" if epoch%10 == 0:\n", | ||
" print(epoch,\",\",end=\"\")\n", | ||
" if savefig:\n", | ||
" plt.figure(figsize=(5,5))\n", | ||
" #色の用意\n", | ||
" base_color = [\"red\",\"blue\",\"green\",\"yellow\",\"cyan\",\n", | ||
" \"pink\",\"brown\",\"gray\",\"purple\",\"orange\"]\n", | ||
" colors = [base_color[label] for label in target]\n", | ||
" # 教師データのプロット\n", | ||
" plt.scatter(data[:,0],data[:,1],color=colors,alpha=0.5)\n", | ||
"\n", | ||
" print(\"plotting...\")\n", | ||
" xx = np.linspace(-15,15,80)\n", | ||
" yy = np.linspace(-15,15,80)\n", | ||
" for xi in xx:\n", | ||
" for yi in yy:\n", | ||
" _,_,_,z_output = nn.forward_propagation((xi, yi))\n", | ||
" cls = np.argmax(z_output) #softmaxのスコアの最大のインデックス\n", | ||
" score = np.max(z_output)\n", | ||
" plt.plot(xi, yi, base_color[cls],marker=\"x\",alpha=score)\n", | ||
" print(\".\",end=\"\")\n", | ||
" plt.xlim(-15,15)\n", | ||
" plt.ylim(-15,15)\n", | ||
" plt.savefig(\"{}.png\".format(epoch))\n", | ||
" plt.clf()\n", | ||
" print(\"finish plotting\")\n", | ||
" \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#データの用意\n", | ||
"num_cls = 5\n", | ||
"data,target = make_blobs(n_samples=1000, n_features=2, centers=num_cls)\n", | ||
"#色の用意\n", | ||
"base_color = [\"red\",\"blue\",\"green\",\"yellow\",\"cyan\",\n", | ||
" \"pink\",\"brown\",\"gray\",\"purple\",\"orange\"]\n", | ||
"colors = [base_color[label] for label in target]\n", | ||
"# 教師データのプロット\n", | ||
"plt.scatter(data[:,0],data[:,1],color=colors,alpha=0.5)\n", | ||
"plt.savefig(\"original.png\")\n", | ||
"# Neural Netの用意\n", | ||
"nn = NN(num_input=2,num_hidden=20,num_output=num_cls,learning_rate=0.2)\n", | ||
"# 学習\n", | ||
"train(nn, iteration=300, savefig=False)\n", | ||
"\n", | ||
"plt.figure(figsize=(5,5))\n", | ||
"\n", | ||
"print(\"plotting...\")\n", | ||
"xx = np.linspace(-15,15,70)\n", | ||
"yy = np.linspace(-15,15,70)\n", | ||
"for xi in xx:\n", | ||
" for yi in yy:\n", | ||
" _,_,_,z_output = nn.forward_propagation((xi, yi))\n", | ||
" cls = np.argmax(z_output) #softmaxのスコアの最大のインデックス\n", | ||
" score = np.max(z_output)\n", | ||
" plt.plot(xi, yi, base_color[cls],marker=\"x\",alpha=score)\n", | ||
" print(\".\",end=\"\")\n", | ||
"plt.xlim(-15,15)\n", | ||
"plt.ylim(-15,15)\n", | ||
"plt.show()\n", | ||
"print(\"finish plotting\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"anaconda-cloud": {}, | ||
"kernelspec": { | ||
"display_name": "Python [Root]", | ||
"language": "python", | ||
"name": "Python [Root]" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.5.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |