In [1]:
import pandas as pd
import numpy as np
from pandas import DataFrame,Series
from sklearn import datasets

In [2]:
iris = datasets.load_iris()
data = iris.data
target = iris.target

In [41]:
params = np.random.randn(4,3)

In [3]:
# 소프트맥스 함수 
def softmax(data,params):
    # 데이터 포인트와 파라미터를 내적한다. 이 때, 파라미터 행렬은 ((변수의 갯수)행 * (분류 클래스)열)의 행렬이다
    # 예를들어, 변수가 4개이고, 클래스가 3개라면 파라미터 행렬은 4*3의 행렬이 된다.
    score = data.dot(params)
    # 각 클래스별 스코어값을 총합 스코어값으로 나누어준다. class = 1일때 해당 스코어벡터는 1열의 열벡터가 되는데
    # 이를 1열, 2열, 3열을 모두 더한 총합 스코어 열벡터로 나누어 표준화를 해주는 작업이다.(keepdims는 행과 열이 교환되지 않도록 한다)
    score = np.exp(score)/np.sum(np.exp(score),axis=1,keepdims=True)
    return(score)

In [4]:
# 경사하강법 함수

def gradient_descent(data,params,score,target,learning_rate,class_num):
    # 0으로만 차있는 껍데기온 배열을 만든다.
    gradient_vec = np.zeros((class_num,np.shape(data)[1]))
    # 변수 의존성을 피하기 위해 copy()함수로 복사한다.
    target_temp = target.copy()
    for i in range(0,class_num):
        # 자동 원핫 인코딩(1-1)
        # 분류 클래스만큼 순회를 돈다. 해당 클래스의 차례(i가 해당 클래스)가 아닌 경우 일괄적으로 0으로 처리해준다
        target_temp[target != i] = 0
        # i가 해당 클래스인 경우 1로 켜준다.
        target_temp[target == i] = 1
        # 그래디언트 벡터의 오차 부분을 정의해준다.(1)
        score[:,i] = score[:,i] - target_temp
    # 모든 분류 클래스의 순회가 끝나면, 앞서 구한 오차 부분과 데이터를 내적한다. (2)
    gradient_vec = data.T.dot(score) / np.shape(data)[0]
    # (1)과 (2)로 그래디언트 벡터를 정의하였다. 이제 학습률에 그래디언트 벡터를 곱하여 뉴턴법을 실행한다.
    params = params - learning_rate * gradient_vec
    return (params,gradient_vec)

In [23]:
#평가 함수(크로스 엔트로피)

def scoring(data,score,target,class_num):
    # 변수 의존성을 피하기 위해 copy()함수로 복사한다.
    target_temp = target.copy()
    # 변수 의존성을 피하기 위해 마찬가지로 copy()함수로 복사한다.
    score_temp = score.copy()
    for i in range(0,class_num):
        #자동 원핫 인코딩(1-1 참조)
        target_temp[target != i] = 0
        target_temp[target == i] = 1
        # 원핫 인코딩을 스위치처럼 활용한다. 로그 변환된 스코어 벡터는 만일 해당 데이터포인트가 순회하는 i에 해당하는 포인트가 아닐 경우
        # 0으로 꺼진다. 즉 [a,a,0,0,0,0,0],[0,0,b,b,b,0,0],[0,0,0,0,c,c] 꼴로 출력된다.
        score_temp[:,i] = np.log(score[:,i]) * target_temp
    # log 변환하며 nan값이 필연적으로 생성된다. nan값을 무시하기 위해 nansum으로 더해주고, 이를 총 샘플수로 나누어 표준화한 값을
    # 크로스 엔트로피 함수로 정의한다.
    cross_entrophy = -np.nansum(score_temp)/np.shape(data)[0]
    return(cross_entrophy)

In [60]:
def early_stopping(loss,minimum_loss,b):
    if loss > minimum_loss:
        b += 1
    if loss < minimum_loss:
        minimum_loss = loss
        b = 0
    return (b,minimum_loss)

In [81]:
def softmax_regression(data,target,iteration,learning_rate):
    class_num = len(np.unique(target))
    # 최초 파라미터를 랜덤 선출한다.
    init_params = np.random.randn(np.shape(data)[1],len(np.unique(target)))
    params = init_params.copy()
    result = dict()
    minimum_loss = 1
    b = 0
    for i in range(0,iteration):
        score = softmax(data,params)
        loss = scoring(data,score,target,class_num)
        print(i,loss)
        # 만약 손실함수가 10회 이상 최저값을 경신하지 못할경우 순회를 중지한다
        # 최저값을 경신할 경우 최저값을 갱신하고, b를 다시 0으로 초기화한다.
        b, minimum_loss = early_stopping(loss,minimum_loss,b)
        if b == 10:
            break
        params,gradient_vec = gradient_descent(data,params,score,target,learning_rate,class_num)
    result["score"] = score
    result["params"] = params
    result["gradient_vec"] = gradient_vec
    result["init_params"] = init_params
    return(result)

In [82]:
result = softmax_regression(data,target,3000,0.1)

0 4.815421967345602
1 3.4714807551206333
2 2.1797736554282863
3 1.6267715597323997
4 1.417301658925106
5 1.298678667203223
6 1.1486294490666653
7 1.0688467248298272
8 1.031849958330791
9 0.9754342337422539
10 0.938945975999864
11 0.9049597531596825
12 0.867719549219694
13 0.8525825586415678
14 0.8141139034933583
15 0.8127969520063903
16 0.7734782936351651
17 0.7817977924354907
18 0.7418519130443014
19 0.7569752118224388
20 0.7163896042979627
21 0.7364559031429193
22 0.6952269317826978
23 0.7189678442416926
24 0.6771724917256263
25 0.7036835813095148
26 0.6614477685066333
27 0.6900590559653519
28 0.647521468293558
29 0.6777215094623621
30 0.6350151179406441
31 0.6664041019871184
32 0.6236498269618099
33 0.6559090796270316
34 0.6132148126248316
35 0.6460861174254359
36 0.6035476034604674
37 0.6368187270952032
38 0.5945209256496142
39 0.6280152360338814
40 0.5860336847841706
41 0.6196025560261808
42 0.5780045934468514
43 0.6115217551341349
44 0.5703675727428538
45 0.6037248425725024
46 0.

359 0.2108185382914256
360 0.2105601788053727
361 0.21030279321236078
362 0.21004637593589096
363 0.20979092144247233
364 0.20953642424119723
365 0.20928287888332267
366 0.2090302799618567
367 0.20877862211114911
368 0.20852790000648866
369 0.20827810836370336
370 0.20802924193876757
371 0.20778129552741217
372 0.20753426396474087
373 0.20728814212485014
374 0.2070429249204543
375 0.20679860730251506
376 0.20655518425987507
377 0.2063126508188965
378 0.20607100204310347
379 0.2058302330328288
380 0.205590338924865
381 0.20535131489211972
382 0.20511315614327444
383 0.20487585792244814
384 0.20463941550886394
385 0.20440382421652056
386 0.20416907939386686
387 0.20393517642348072
388 0.20370211072175118
389 0.2034698777385647
390 0.20323847295699438
391 0.2030078918929935
392 0.20277813009509235
393 0.20254918314409787
394 0.20232104665279774
395 0.2020937162656671
396 0.20186718765857894
397 0.20164145653851753
398 0.20141651864329518
399 0.20119236974127241
400 0.2009690056310809
401 

809 0.14764822169071215
810 0.14757172473936916
811 0.1474953804353568
812 0.1474191883140424
813 0.1473431479126927
814 0.14726725877046418
815 0.1471915204283936
816 0.147115932429388
817 0.14704049431821548
818 0.1469652056414955
819 0.14689006594768966
820 0.14681507478709213
821 0.14674023171182043
822 0.14666553627580606
823 0.14659098803478562
824 0.14651658654629118
825 0.14644233136964155
826 0.1463682220659333
827 0.14629425819803135
828 0.14622043933056036
829 0.1461467650298959
830 0.14607323486415566
831 0.1459998484031902
832 0.14592660521857495
833 0.145853504883601
834 0.14578054697326678
835 0.14570773106426937
836 0.14563505673499602
837 0.14556252356551588
838 0.1454901311375712
839 0.1454178790345694
840 0.1453457668415745
841 0.14527379414529884
842 0.1452019605340953
843 0.14513026559794845
844 0.14505870892846717
845 0.14498729011887584
846 0.14491600876400723
847 0.14484486446029368
848 0.14477385680575955
849 0.14470298540001364
850 0.14463224984424078
851 0.14

1272 0.12336160913901201
1273 0.12332561127978982
1274 0.12328966176423295
1275 0.12325376049350165
1276 0.12321790736902834
1277 0.1231821022925168
1278 0.12314634516594102
1279 0.12311063589154483
1280 0.12307497437184026
1281 0.12303936050960718
1282 0.12300379420789184
1283 0.12296827537000667
1284 0.12293280389952853
1285 0.12289737970029857
1286 0.12286200267642068
1287 0.1228266727322611
1288 0.12279138977244723
1289 0.12275615370186664
1290 0.12272096442566655
1291 0.12268582184925256
1292 0.12265072587828817
1293 0.12261567641869336
1294 0.12258067337664419
1295 0.12254571665857171
1296 0.12251080617116117
1297 0.12247594182135103
1298 0.12244112351633213
1299 0.122406351163547
1300 0.1223716246706888
1301 0.12233694394570066
1302 0.12230230889677449
1303 0.12226771943235056
1304 0.12223317546111635
1305 0.1221986768920058
1306 0.12216422363419861
1307 0.1221298155971191
1308 0.12209545269043583
1309 0.12206113482406025
1310 0.1220268619081462
1311 0.12199263385308917
1312 0.1

1700 0.11137976474068764
1701 0.11135800095055481
1702 0.11133625966430163
1703 0.11131454084667543
1704 0.11129284446249822
1705 0.11127117047666608
1706 0.1112495188541493
1707 0.11122788955999188
1708 0.11120628255931152
1709 0.11118469781729946
1710 0.11116313529922
1711 0.11114159497041079
1712 0.11112007679628211
1713 0.11109858074231707
1714 0.11107710677407119
1715 0.11105565485717234
1716 0.11103422495732053
1717 0.1110128170402875
1718 0.11099143107191707
1719 0.11097006701812423
1720 0.11094872484489551
1721 0.11092740451828864
1722 0.11090610600443229
1723 0.11088482926952577
1724 0.11086357427983917
1725 0.11084234100171302
1726 0.1108211294015579
1727 0.11079993944585463
1728 0.11077877110115378
1729 0.11075762433407559
1730 0.11073649911130991
1731 0.11071539539961582
1732 0.1106943131658215
1733 0.11067325237682428
1734 0.11065221299958997
1735 0.11063119500115315
1736 0.11061019834861684
1737 0.11058922300915223
1738 0.1105682689499986
1739 0.110547336138463
1740 0.110

2115 0.10393850847159214
2116 0.10392374527578072
2117 0.10390899459769244
2118 0.10389425642133839
2119 0.10387953073075722
2120 0.10386481751001489
2121 0.1038501167432049
2122 0.103835428414448
2123 0.10382075250789229
2124 0.10380608900771288
2125 0.10379143789811235
2126 0.10377679916332012
2127 0.10376217278759275
2128 0.10374755875521376
2129 0.1037329570504937
2130 0.10371836765776966
2131 0.10370379056140602
2132 0.10368922574579341
2133 0.10367467319534952
2134 0.10366013289451853
2135 0.10364560482777115
2136 0.10363108897960464
2137 0.10361658533454286
2138 0.10360209387713584
2139 0.10358761459196007
2140 0.10357314746361841
2141 0.10355869247673982
2142 0.10354424961597947
2143 0.10352981886601877
2144 0.10351540021156502
2145 0.10350099363735162
2146 0.10348659912813787
2147 0.10347221666870915
2148 0.10345784624387645
2149 0.10344348783847666
2150 0.10342914143737239
2151 0.10341480702545193
2152 0.10340048458762918
2153 0.10338617410884358
2154 0.10337187557406023
2155

2481 0.0992592092543462
2482 0.09924814568544457
2483 0.09923709023043974
2484 0.09922604288041374
2485 0.09921500362646198
2486 0.09920397245969287
2487 0.09919294937122802
2488 0.0991819343522021
2489 0.09917092739376299
2490 0.09915992848707149
2491 0.0991489376233015
2492 0.09913795479363997
2493 0.09912697998928689
2494 0.09911601320145498
2495 0.09910505442137034
2496 0.09909410364027137
2497 0.09908316084940999
2498 0.0990722260400506
2499 0.09906129920347062
2500 0.09905038033096017
2501 0.09903946941382233
2502 0.09902856644337281
2503 0.09901767141094021
2504 0.09900678430786568
2505 0.09899590512550326
2506 0.0989850338552196
2507 0.09897417048839398
2508 0.09896331501641832
2509 0.09895246743069722
2510 0.0989416277226478
2511 0.09893079588369978
2512 0.09891997190529546
2513 0.09890915577888952
2514 0.0988983474959493
2515 0.09888754704795455
2516 0.09887675442639741
2517 0.09886596962278257
2518 0.09885519262862706
2519 0.0988444234354603
2520 0.09883366203482401
2521 0.0

2953 0.09480674499560553
2954 0.09479870816736681
2955 0.094790676372214
2956 0.09478264960545434
2957 0.09477462786240101
2958 0.09476661113837302
2959 0.09475859942869524
2960 0.09475059272869837
2961 0.09474259103371899
2962 0.09473459433909949
2963 0.09472660264018806
2964 0.09471861593233873
2965 0.09471063421091133
2966 0.09470265747127146
2967 0.09469468570879057
2968 0.09468671891884568
2969 0.09467875709681993
2970 0.09467080023810193
2971 0.09466284833808598
2972 0.09465490139217254
2973 0.09464695939576732
2974 0.094639022344282
2975 0.0946310902331339
2976 0.09462316305774612
2977 0.09461524081354744
2978 0.0946073234959721
2979 0.09459941110046044
2980 0.09459150362245804
2981 0.09458360105741641
2982 0.09457570340079251
2983 0.0945678106480492
2984 0.09455992279465483
2985 0.09455203983608328
2986 0.09454416176781413
2987 0.09453628858533267
2988 0.09452842028412954
2989 0.0945205568597012
2990 0.09451269830754962
2991 0.09450484462318225
2992 0.09449699580211224
2993 0.0

In [64]:
# 각 클래스 확률 출력

def class_proba(result,data):
    proba_vec = np.exp(np.dot(result["params"].T,data))/np.sum(np.exp(np.dot(result["params"].T,data)))
    return proba_vec

In [65]:
np.argmax(class_proba(result,data[139]))

2

In [76]:
class_proba(result,data[85]).round(3)

array([0.006, 0.941, 0.052])