In [1]:
from math import tanh
import sqlite3

In [17]:
def dtanh(y):
    return 1.0 - y**2

class searchnet:
    def __init__(self , dbname):
        self.con = sqlite3.connect(dbname)
        
    def __del__(self):
        self.con.close()
        
    
    def maketables(self):
        self.con.execute('create table hiddennode(create_key)')
        self.con.execute('create table wordhidden(fromid , toid , strength)')
        self.con.execute('create table hiddenurl(fromid , toid , strength)')
        
        self.con.commit()
    
    
    #查看权重
    def getstrength(self , fromid , toid , layer):
        if layer == 0:
            table = 'wordhidden'
        else:
            table = 'hiddenurl'
        
        res = self.con.execute('select strength from %s where fromid=%d and toid=%d' % (table , fromid , toid)).fetchone()
        
        if res == None:
            if layer == 0:
                return -0.2
            if layer == 1:
                return 0
        
        return res[0]
    
    #判断神经元之间的连接是否存在 创建连接
    def setstrength(self , fromid , toid , layer , strength):
        if layer == 0:
            table = 'wordhidden'
        else:
            table = 'hiddenurl'
            
        res = self.con.execute('select rowid from %s where fromid=%d and toid=%d' %
                              (table , fromid , toid)).fetchone()
        
        if res == None:
            self.con.execute('insert into %s (fromid , toid , strength) values (%d,%d,%f)'%
                            (table , fromid , toid , strength))
        else:
            rowid = res[0]
            self.con.execute('update %s set strength=%f where rowid=%d'% 
                            (table , strength , rowid))
    
    def generatehiddennode(self , wordids , urls):
        if len(wordids)>3:
            return None
        
        createkey = '_'.join(sorted([str(wi) for wi in wordids]))
        
        res = self.con.execute(
        "select rowid from hiddennode where create_key='%s'" % createkey).fetchone()
        
        if res == None:
            cur = self.con.execute(
            "insert into hiddennode (create_key) values ('%s')" % createkey)
            
            hiddenid = cur.lastrowid
            
            for wordid in wordids:
                self.setstrength(wordid , hiddenid , 0 , 1.0/len(wordids))
            
            for urlid in urls:
                self.setstrength(hiddenid , urlid , 1 , 0.1)
            
            self.con.commit()
    
    def getallhiddenids(self , wordids , urlids):
        l1 = {}
        
        for wordid in wordids:
            cur = self.con.execute(
            "select toid from wordhidden where fromid=%d" % wordid)
            
            for row in cur:
                l1[row[0]] = 1
        
        for urlid in urlids:
            cur = self.con.execute(
            "select fromid from hiddenurl where toid =%d" % urlid)
            
            for row in cur:
                l1[row[0]] = 1
        
        return l1.keys()
    
    def setupnetwork(self , wordids , urlids):
        self.wordids = wordids
        self.hiddenids = self.getallhiddenids(wordids , urlids)
        self.urlids = urlids
        
        self.ai = [1.0]*len(self.wordids)
        self.ah = [1.0]*len(self.hiddenids)
        self.ao = [1.0]*len(self.urlids)
        
        self.wi = [[self.getstrength(wordid , hiddenid , 0) 
                   for hiddenid in self.hiddenids] 
                   for wordid in self.wordids]
        
        self.wo = [[self.getstrength(hiddenid , urlid , 1) 
                   for urlid in self.urlids] 
                   for hiddenid in self.hiddenids]
        
    #前馈计算
    def feedforward(self):
        for i in range(len(self.wordids)):
            self.ai[i] = 1.0
        
        for j in range(len(self.hiddenids)):
            sum_ = 0.0
            for i in range(len(self.wordids)):
                sum_ = sum_ + self.ai[i] * self.wi[i][j]
            
            self.ah[j] = tanh(sum_)
        
        for k in range(len(self.urlids)):
            sum_ = 0.0
            for j in range(len(self.hiddenids)):
                sum_ = sum_ + self.ah[j] * self.wo[j][k]
            
            self.ao[k] = tanh(sum_)
            
        
        return self.ao[:]
    
    #BP过程
    def backPropagate(self , targets , N = 0.5):
        output_deltas = [0.0] * len(self.urlids)
        
        for k in range(len(self.urlids)):
            error = targets[k] - self.ao[k]
            output_deltas[k] = dtanh(self.ao[k]) * error
        
        hidden_deltas = [0.0] * len(self.hiddenids)
        
        for j in range(len(self.hiddenids)):
            error = 0.0
            for k in range(len(self.urlids)):
                error = error + output_deltas[k]*self.wo[j][k]
            
            hidden_deltas[j] = dtanh(self.ah[j]) * error
        
        for j in range(len(self.hiddenids)):
            for k in range(len(self.urlids)):
                change = output_deltas[k]*self.ah[j]
                self.wo[j][k] = self.wo[j][k] + N * change
            
        
        for i in range(len(self.wordids)):
            for j in range(len(self.hiddenids)):
                change = hidden_deltas[j] * self.ai[i]
                self.wi[i][j] = self.wi[i][j] + N*change
            
    #整个网络的创建和训练
    def trainquery(self , wordids , urlids , selectedurl):
        self.generatehiddennode(wordids , urlids)
        
        self.setupnetwork(wordids , urlids)
        self.feedforward()
        
        targets = [0.0]*len(urlids)
        targets[urlids.index(selectedurl)] = 1.0
        
        self.backPropagate(targets)
        self.updatedatabase()
        
    #更新数据库信息
    def updatedatabase(self):
        for i in range(len(self.wordids)):
            for j in range(len(self.hiddenids)):
                self.setstrength(self.wordids[i] , self.hiddenids[j] , 0 , self.wi[i][j])
        
        for j in range(len(self.hiddenids)):
            for k in range(len(self.urlids)):
                self.setstrength(self.hiddenids[j] , self.urlids[k] , 1 , self.wo[j][k])
         
        self.con.commit()
        
    #测试函数
    def getresult(self , wordids , urlids):
        self.setupnetwork(wordids , urlids)
        
        return self.feedforward()

In [5]:
mynet = searchnet('nn.db')

In [6]:
mynet.maketables()

In [7]:
wWorld , wRiver , wBank = 101,102,103
uWorldBank , uRiver , uEarth = 201,202,203

mynet.generatehiddennode([wWorld , wBank],
                         [uWorldBank , uRiver , uEarth])


In [8]:
for c in mynet.con.execute('select * from wordhidden'):
    print(c)

(101, 1, 0.5)
(103, 1, 0.5)


In [9]:
for c in mynet.con.execute('select * from hiddenurl'):
    print(c)

(1, 201, 0.1)
(1, 202, 0.1)
(1, 203, 0.1)


In [13]:
mynet = searchnet('nn.db')

In [14]:
mynet.getresult([wWorld , wBank] , [uWorldBank , uRiver , uEarth])

[0.07601250837541615, 0.07601250837541615, 0.07601250837541615]

In [18]:
mynet = searchnet('nn.db')

In [19]:
mynet.trainquery([wWorld , wBank] , [uWorldBank , uRiver , uEarth] , uWorldBank)
mynet.getresult([wWorld , wBank] , [uWorldBank , uRiver , uEarth])

[0.3350632467125331, 0.055127057492087995, 0.055127057492087995]

In [20]:
allurls = [uWorldBank , uRiver , uEarth]

for i in range(30):
    mynet.trainquery([wWorld , wBank] , allurls , uWorldBank)
    mynet.trainquery([wRiver , wBank] , allurls , uRiver)
    mynet.trainquery([wWorld] , allurls , uEarth)

In [21]:
mynet.getresult([wWorld , wBank] , allurls)

[0.861547977173944, 0.01107121517146442, 0.015725794221216588]

In [22]:
mynet.getresult([wRiver , wBank] , allurls)

[-0.030344006191459796, 0.8829814980448912, 0.005509562270886237]

In [23]:
mynet.getresult([wBank] , allurls)

[0.8654047612070324, -0.0006785911691591055, -0.8519156725080675]