In [1]:
import numpy as np
from datetime import datetime

class SimpleClassificationTest:
    
    # constructor
    def __init__(self, xdata, tdata, learning_rate, iteration_count):
            
        # 가중치 W 형상을 자동으로 구하기 위해 입력데이터가 vector 인지,
        # 아니면 matrix 인지 체크 후, 
        # self.xdata 는 무조건 matrix 로 만들어 주면 코드 일관성이 있음
        
        if xdata.ndim == 1:    # vector
            self.xdata = xdata.reshape(len(xdata), 1)
            self.tdata = xdata.reshape(len(tdata), 1)
            
        elif xdata.ndim == 2:  # matrix
            self.xdata = xdata
            self.tdata = tdata
        
        self.learning_rate = learning_rate
        self.iteration_count = iteration_count
        
        self.W = np.random.rand(self.xdata.shape[1], 1) 
        self.b = np.random.rand(1)
        
        print("SimpleClassificationTest Object is created")
        
    
    def sigmoid(self, z):
        
        return 1 / (1+np.exp(-z))
        
    # obtain current W and current b
    def getW_b(self):
        
        return self.W, self.b
    
    
    # loss function
    def loss_func(self):
        
        delta = 1e-7    # log 무한대 발산 방지
    
        z = np.dot(self.xdata, self.W) + self.b
        
        y = self.sigmoid(z)
    
        # cross-entropy 
        return  -np.sum( self.tdata*np.log(y + delta) + (1-self.tdata)*np.log((1 - y)+delta ) ) 
        
    
    # display current error value
    def error_val(self):
        
        delta = 1e-7    # log 무한대 발산 방지
    
        z = np.dot(self.xdata, self.W) + self.b
        
        y = self.sigmoid(z)
    
        # cross-entropy 
        return  -np.sum( self.tdata*np.log(y + delta) + (1-self.tdata)*np.log((1 - y)+delta ) ) 
    
    
    # predict method
    # 학습을 마친 후, 임의의 데이터에 대해 미래 값 예측 함수
    # 입력변수 x : numpy type
    def predict(self, test_data):
    
        z = np.dot(test_data, self.W) + self.b
        y = self.sigmoid(z)
    
        if y >= 0.5:
            result = 1  # True
        else:
            result = 0  # False
    
        return y, result
    
    
    # train method
    def train(self):
    
        f = lambda x : self.loss_func()

        print("Initial error value = ", self.error_val(), "Initial W = ", self.W, "\n", ", b = ", self.b )

        start_time = datetime.now()
        
        for step in  range(self.iteration_count):  
    
            self.W -= self.learning_rate * numerical_derivative(f, self.W)
    
            self.b -= self.learning_rate * numerical_derivative(f, self.b)
    
            if (step % 400 == 0):
                print("step = ", step, "error value = ", self.error_val(), "W = ", self.W, ", b = ", self.b )
                
        end_time = datetime.now()
        
        print("")
        print("Elapsed Time => ", end_time - start_time)

In [2]:
def numerical_derivative(f, x):
    delta_x = 1e-4 # 0.0001
    grad = np.zeros_like(x)
    
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    
    while not it.finished:
        idx = it.multi_index        
        tmp_val = x[idx]
        x[idx] = float(tmp_val) + delta_x
        fx1 = f(x) # f(x+delta_x)
        
        x[idx] = tmp_val - delta_x 
        fx2 = f(x) # f(x-delta_x)
        grad[idx] = (fx1 - fx2) / (2*delta_x)
        
        x[idx] = tmp_val 
        it.iternext()   
        
    return grad

In [3]:
# 입력데이터 / 정답데이터 세팅

x_data = np.array([2, 4, 6, 8, 10, 12, 14, 16, 18, 20]).reshape(10,1)   
t_data = np.array([0, 0, 0, 0,  0,  0,  1,  1,  1,  1]).reshape(10,1)

print("x_data.shape = ", x_data.shape, ", t_data.shape = ", t_data.shape)

x_data.shape =  (10, 1) , t_data.shape =  (10, 1)


### learning_rate = 1e-2,  반복횟수 400,000번 수행하는 obj1

In [4]:
obj1 = SimpleClassificationTest(x_data, t_data, 1e-2, 400001)

obj1.train()

SimpleClassificationTest Object is created
Initial error value =  36.23483426767146 Initial W =  [[0.75199998]] 
 , b =  [0.75401]
step =  0 error value =  18.761631277458626 W =  [[0.33549503]] , b =  [0.69827297]
step =  400 error value =  3.1668184733126323 W =  [[0.43488061]] , b =  [-4.12273763]
step =  800 error value =  1.7824445734231773 W =  [[0.45368068]] , b =  [-5.64530274]
step =  1200 error value =  1.5171887751706592 W =  [[0.53094666]] , b =  [-6.67413799]
step =  1600 error value =  1.3519313766036356 W =  [[0.59216221]] , b =  [-7.48702699]
step =  2000 error value =  1.2355923077422748 W =  [[0.64365131]] , b =  [-8.1692636]
step =  2400 error value =  1.1475138682704906 W =  [[0.68853438]] , b =  [-8.76291421]
step =  2800 error value =  1.0775385579483516 W =  [[0.72859765]] , b =  [-9.29203777]
step =  3200 error value =  1.0200083102419315 W =  [[0.76496679]] , b =  [-9.77177997]
step =  3600 error value =  0.9714847561025112 W =  [[0.79840079]] , b =  [-10.21234

step =  38800 error value =  0.2946845413773341 W =  [[1.8717793]] , b =  [-24.24014858]
step =  39200 error value =  0.2927128122492346 W =  [[1.87858728]] , b =  [-24.32880066]
step =  39600 error value =  0.2907686833005664 W =  [[1.88534752]] , b =  [-24.41682951]
step =  40000 error value =  0.2888515358134268 W =  [[1.89206077]] , b =  [-24.50424481]
step =  40400 error value =  0.28696077071621223 W =  [[1.89872773]] , b =  [-24.591056]
step =  40800 error value =  0.28509580776817384 W =  [[1.90534912]] , b =  [-24.67727229]
step =  41200 error value =  0.2832560847858015 W =  [[1.91192562]] , b =  [-24.76290264]
step =  41600 error value =  0.28144105690839555 W =  [[1.9184579]] , b =  [-24.84795582]
step =  42000 error value =  0.2796501959005105 W =  [[1.92494661]] , b =  [-24.93244035]
step =  42400 error value =  0.27788298948919954 W =  [[1.93139238]] , b =  [-25.01636456]
step =  42800 error value =  0.2761389407337931 W =  [[1.93779585]] , b =  [-25.0997366]
step =  432

step =  77200 error value =  0.18002126464046322 W =  [[2.37529516]] , b =  [-30.79355752]
step =  77600 error value =  0.17929808148741097 W =  [[2.37942016]] , b =  [-30.8472253]
step =  78000 error value =  0.17858067697997224 W =  [[2.38352866]] , b =  [-30.9006781]
step =  78400 error value =  0.1778689810292061 W =  [[2.38762078]] , b =  [-30.95391765]
step =  78800 error value =  0.1771629247027952 W =  [[2.39169667]] , b =  [-31.0069457]
step =  79200 error value =  0.17646244020058188 W =  [[2.39575644]] , b =  [-31.05976396]
step =  79600 error value =  0.17576746083081732 W =  [[2.39980024]] , b =  [-31.11237412]
step =  80000 error value =  0.1750779209869653 W =  [[2.40382819]] , b =  [-31.16477785]
step =  80400 error value =  0.17439375612514585 W =  [[2.40784042]] , b =  [-31.21697679]
step =  80800 error value =  0.17371490274211915 W =  [[2.41183705]] , b =  [-31.26897259]
step =  81200 error value =  0.1730412983538481 W =  [[2.4158182]] , b =  [-31.32076684]
step = 

step =  116400 error value =  0.12894888769427582 W =  [[2.71691723]] , b =  [-35.23749449]
step =  116800 error value =  0.1285756143631497 W =  [[2.71988104]] , b =  [-35.2760437]
step =  117200 error value =  0.12820447201279647 W =  [[2.72283639]] , b =  [-35.31448266]
step =  117600 error value =  0.1278354424340746 W =  [[2.72578331]] , b =  [-35.352812]
step =  118000 error value =  0.1274685076261969 W =  [[2.72872186]] , b =  [-35.39103233]
step =  118400 error value =  0.12710364979374655 W =  [[2.73165208]] , b =  [-35.42914429]
step =  118800 error value =  0.1267408513436917 W =  [[2.73457402]] , b =  [-35.46714847]
step =  119200 error value =  0.12638009488250904 W =  [[2.73748773]] , b =  [-35.5050455]
step =  119600 error value =  0.12602136321332302 W =  [[2.74039325]] , b =  [-35.54283597]
step =  120000 error value =  0.1256646393331073 W =  [[2.74329063]] , b =  [-35.58052049]
step =  120400 error value =  0.12530990642993015 W =  [[2.74617991]] , b =  [-35.6180996

step =  152400 error value =  0.1021762148846427 W =  [[2.95452127]] , b =  [-38.32770878]
step =  152800 error value =  0.10194038727675185 W =  [[2.95687709]] , b =  [-38.3583461]
step =  153200 error value =  0.10170563203480788 W =  [[2.95922756]] , b =  [-38.38891365]
step =  153600 error value =  0.10147194188794362 W =  [[2.96157268]] , b =  [-38.41941175]
step =  154000 error value =  0.10123930963094828 W =  [[2.96391249]] , b =  [-38.44984071]
step =  154400 error value =  0.10100772812353176 W =  [[2.96624701]] , b =  [-38.48020085]
step =  154800 error value =  0.10077719028959016 W =  [[2.96857626]] , b =  [-38.51049246]
step =  155200 error value =  0.10054768911647967 W =  [[2.97090027]] , b =  [-38.54071586]
step =  155600 error value =  0.10031921765431455 W =  [[2.97321907]] , b =  [-38.57087135]
step =  156000 error value =  0.10009176901524462 W =  [[2.97553266]] , b =  [-38.60095924]
step =  156400 error value =  0.0998653363727791 W =  [[2.97784108]] , b =  [-38.6

step =  190000 error value =  0.08388744044394048 W =  [[3.15534572]] , b =  [-40.93930737]
step =  190400 error value =  0.08372761945806786 W =  [[3.1572851]] , b =  [-40.96452685]
step =  190800 error value =  0.08356839872533743 W =  [[3.15922084]] , b =  [-40.9896989]
step =  191200 error value =  0.08340977488833043 W =  [[3.16115294]] , b =  [-41.01482371]
step =  191600 error value =  0.08325174461458287 W =  [[3.16308143]] , b =  [-41.03990145]
step =  192000 error value =  0.08309430459635783 W =  [[3.16500631]] , b =  [-41.06493229]
step =  192400 error value =  0.08293745155040717 W =  [[3.1669276]] , b =  [-41.08991642]
step =  192800 error value =  0.08278118221775825 W =  [[3.16884531]] , b =  [-41.11485399]
step =  193200 error value =  0.08262549336348066 W =  [[3.17075945]] , b =  [-41.13974519]
step =  193600 error value =  0.08247038177646315 W =  [[3.17267005]] , b =  [-41.16459018]
step =  194000 error value =  0.08231584426920953 W =  [[3.1745771]] , b =  [-41.18

step =  230400 error value =  0.07030360882797146 W =  [[3.33480521]] , b =  [-43.27290973]
step =  230800 error value =  0.07019084057835545 W =  [[3.33643427]] , b =  [-43.29409268]
step =  231200 error value =  0.07007842929874386 W =  [[3.33806075]] , b =  [-43.31524206]
step =  231600 error value =  0.06996637330659024 W =  [[3.33968466]] , b =  [-43.33635799]
step =  232000 error value =  0.06985467092988032 W =  [[3.34130601]] , b =  [-43.35744056]
step =  232400 error value =  0.06974332050704962 W =  [[3.3429248]] , b =  [-43.37848988]
step =  232800 error value =  0.06963232038689789 W =  [[3.34454104]] , b =  [-43.39950606]
step =  233200 error value =  0.06952166892850636 W =  [[3.34615474]] , b =  [-43.42048919]
step =  233600 error value =  0.06941136450116532 W =  [[3.34776591]] , b =  [-43.44143938]
step =  234000 error value =  0.06930140548429015 W =  [[3.34937455]] , b =  [-43.46235674]
step =  234400 error value =  0.06919179026734468 W =  [[3.35098068]] , b =  [-43

step =  268000 error value =  0.06106632308534341 W =  [[3.4776382]] , b =  [-45.13015454]
step =  268400 error value =  0.060980952706962424 W =  [[3.47905561]] , b =  [-45.14858468]
step =  268800 error value =  0.06089581810403904 W =  [[3.48047107]] , b =  [-45.16698934]
step =  269200 error value =  0.06081091830680428 W =  [[3.48188456]] , b =  [-45.1853686]
step =  269600 error value =  0.06072625235078564 W =  [[3.48329611]] , b =  [-45.20372252]
step =  270000 error value =  0.06064181927676579 W =  [[3.48470572]] , b =  [-45.22205118]
step =  270400 error value =  0.060557618130752326 W =  [[3.48611339]] , b =  [-45.24035463]
step =  270800 error value =  0.06047364796393874 W =  [[3.48751913]] , b =  [-45.25863296]
step =  271200 error value =  0.06038990783267464 W =  [[3.48892294]] , b =  [-45.27688623]
step =  271600 error value =  0.06030639679842192 W =  [[3.49032483]] , b =  [-45.29511451]
step =  272000 error value =  0.06022311392772574 W =  [[3.49172481]] , b =  [-4

step =  305200 error value =  0.0540234684483603 W =  [[3.60171506]] , b =  [-46.74346445]
step =  305600 error value =  0.053956471752966036 W =  [[3.6029707]] , b =  [-46.75979075]
step =  306000 error value =  0.05388963935141374 W =  [[3.60422481]] , b =  [-46.77609701]
step =  306400 error value =  0.053822970643482605 W =  [[3.60547737]] , b =  [-46.79238329]
step =  306800 error value =  0.05375646503186407 W =  [[3.60672841]] , b =  [-46.80864963]
step =  307200 error value =  0.0536901219221442 W =  [[3.60797791]] , b =  [-46.82489607]
step =  307600 error value =  0.053623940722783106 W =  [[3.60922589]] , b =  [-46.84112268]
step =  308000 error value =  0.05355792084510045 W =  [[3.61047234]] , b =  [-46.85732949]
step =  308400 error value =  0.05349206170325295 W =  [[3.61171728]] , b =  [-46.87351656]
step =  308800 error value =  0.053426362714228175 W =  [[3.61296071]] , b =  [-46.88968393]
step =  309200 error value =  0.05336082329781367 W =  [[3.61420262]] , b =  [-

step =  346000 error value =  0.04794408280378517 W =  [[3.72244517]] , b =  [-48.31321984]
step =  346400 error value =  0.04789118667803979 W =  [[3.72356088]] , b =  [-48.32772628]
step =  346800 error value =  0.04783840605977629 W =  [[3.72467536]] , b =  [-48.34221686]
step =  347200 error value =  0.04778574057308087 W =  [[3.72578863]] , b =  [-48.35669164]
step =  347600 error value =  0.04773318984366148 W =  [[3.72690069]] , b =  [-48.37115063]
step =  348000 error value =  0.047680753498843036 W =  [[3.72801153]] , b =  [-48.38559387]
step =  348400 error value =  0.047628431167560145 W =  [[3.72912117]] , b =  [-48.4000214]
step =  348800 error value =  0.04757622248033853 W =  [[3.7302296]] , b =  [-48.41443324]
step =  349200 error value =  0.047524127069298504 W =  [[3.73133683]] , b =  [-48.42882944]
step =  349600 error value =  0.04747214456813807 W =  [[3.73244286]] , b =  [-48.44321003]
step =  350000 error value =  0.04742027461212956 W =  [[3.73354768]] , b =  [-

step =  386800 error value =  0.04308535448490789 W =  [[3.83038823]] , b =  [-49.71668772]
step =  387200 error value =  0.04304254966243411 W =  [[3.83139188]] , b =  [-49.72973696]
step =  387600 error value =  0.04299982907692771 W =  [[3.83239454]] , b =  [-49.74277336]
step =  388000 error value =  0.042957192481231136 W =  [[3.83339622]] , b =  [-49.75579694]
step =  388400 error value =  0.04291463962914733 W =  [[3.83439691]] , b =  [-49.76880771]
step =  388800 error value =  0.04287217027543828 W =  [[3.83539661]] , b =  [-49.78180571]
step =  389200 error value =  0.042829784175822574 W =  [[3.83639534]] , b =  [-49.79479097]
step =  389600 error value =  0.04278748108696145 W =  [[3.83739309]] , b =  [-49.80776349]
step =  390000 error value =  0.04274526076646375 W =  [[3.83838986]] , b =  [-49.82072332]
step =  390400 error value =  0.04270312297287925 W =  [[3.83938566]] , b =  [-49.83367048]
step =  390800 error value =  0.042661067465690546 W =  [[3.84038048]] , b =  

In [5]:
test_data = np.array([3.7])

(real_val, logical_val) = obj1.predict(test_data)

print(real_val, logical_val)

[2.70144978e-16] 0


In [6]:
test_data = np.array([31.09])

(real_val, logical_val) = obj1.predict(test_data)

print(real_val, logical_val)

[1.] 1
