In [1]:
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

from criterion.cross_entropy import CrossEntropy
from numpy_modules.inception_v3 import InceptionV3
from datasets import cars
from utils import get_batches
from IPython import display
from ada_smooth import ada_smooth_optimizer

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy.typing as npt

In [2]:
x_train, y_train, x_test, y_test = cars.load_dataset(image_size=(75, 75))

100%|██████████| 8144/8144 [01:00<00:00, 135.71it/s]


In [3]:
epochs = 1
batch_size = 8
num_classes = 196

In [4]:
def one_hot(labels: npt.NDArray, num_classes: int = num_classes):
    result = np.zeros((labels.size, num_classes))
    result[np.arange(labels.size), labels] = 1
    return result


y_test_labels = y_test
y_train = one_hot(y_train)
y_test = one_hot(y_test)

In [5]:
model = InceptionV3(num_classes=num_classes)
model.train()

criterion = CrossEntropy()

optimizer_config = {
    "learning_rate": 1e-3,
    "epsilon": 1e-6,
    "fast_decay": 0.5,
    "slow_decay": 0.99,
}
optimizer_state = {}

loss_history = []
f1_micro = []

step = 1
for epoch in range(1, epochs + 1):
    for x_batch, y_batch in tqdm(
            get_batches((x_train, y_train), batch_size),
            total=len(x_train) // batch_size,
    ):
        model.zero_grad_parameters()

        predictions, _ = model.forward(x_batch)

        loss = criterion.forward(predictions, y_batch)

        dp = criterion.backward(predictions, y_batch)
        model.backward(x_batch, dp)

        parameters = model.get_parameters()[0]
        ada_smooth_optimizer(
            model.get_parameters(),
            model.get_grad_parameters(),
            optimizer_config,
            optimizer_state,
        )

        loss_history.append(loss)

        # display.clear_output(wait=True)
        #
        # f = plt.figure()
        # plt.plot(loss_history)
        # f.suptitle('training loss', fontsize=18)
        # plt.xlabel('iteration', fontsize=14)
        # plt.ylabel('loss', fontsize=14)
        # plt.show()
        # f.savefig('loss.jpg')

  0%|          | 0/918 [00:00<?, ?it/s]

[[[[ 0.3982993   0.15027298  0.36516873]
   [ 0.10826901 -0.47372202 -0.319649  ]
   [-0.43992797  0.47080209  0.03267034]]

  [[ 0.45653051 -0.53669106  0.26007368]
   [-0.50002691  0.39258785  0.38112313]
   [ 0.17756552 -0.06374705 -0.46679001]]

  [[ 0.27803684  0.25817946  0.31345368]
   [ 0.40614947 -0.49701873  0.15136589]
   [-0.09856108  0.34951553 -0.51648955]]]


 [[[ 0.08686108 -0.33744083 -0.08128126]
   [-0.21120322 -0.29441815  0.06161171]
   [-0.02695747 -0.35948534 -0.21340365]]

  [[-0.12004799  0.50069609 -0.34371057]
   [ 0.57477833 -0.01795753 -0.41132847]
   [-0.02809586  0.44552503  0.13020383]]

  [[-0.17337993  0.48058344 -0.45451145]
   [ 0.1887637  -0.17085006  0.12652605]
   [-0.19377292  0.1264158  -0.27897789]]]


 [[[ 0.2373146   0.46144062 -0.00102394]
   [-0.41281738  0.2911367   0.0429259 ]
   [-0.4658232   0.3943316   0.03949223]]

  [[-0.32565144  0.549095   -0.43226351]
   [ 0.14892539  0.44067664  0.46850924]
   [-0.3713719   0.16099483 -0.35630305

  0%|          | 1/918 [00:34<8:44:45, 34.34s/it]

[[[[ 0.7965986   0.30054595  0.73033746]
   [ 0.21653802 -0.94744404 -0.639298  ]
   [-0.87985594  0.94160418  0.06534067]]

  [[ 0.91306102 -1.07338212  0.52014736]
   [-1.00005383  0.7851757   0.76224626]
   [ 0.35513105 -0.12749409 -0.93358002]]

  [[ 0.55607369  0.51635891  0.62690735]
   [ 0.81229894 -0.99403745  0.30273179]
   [-0.19712215  0.69903107 -1.0329791 ]]]


 [[[ 0.17372217 -0.67488166 -0.16256253]
   [-0.42240644 -0.58883631  0.12322343]
   [-0.05391495 -0.71897067 -0.42680731]]

  [[-0.24009597  1.00139217 -0.68742113]
   [ 1.14955667 -0.03591507 -0.82265694]
   [-0.05619171  0.89105006  0.26040766]]

  [[-0.34675986  0.96116687 -0.90902291]
   [ 0.37752741 -0.34170012  0.25305211]
   [-0.38754585  0.2528316  -0.55795579]]]


 [[[ 0.47462919  0.92288124 -0.00204788]
   [-0.82563475  0.58227339  0.0858518 ]
   [-0.9316464   0.78866321  0.07898445]]

  [[-0.65130287  1.09818999 -0.86452703]
   [ 0.29785077  0.88135328  0.93701847]
   [-0.7427438   0.32198967 -0.71260611

  0%|          | 2/918 [01:08<8:46:06, 34.46s/it]

[[[[ 1.1948979   0.45081893  1.09550619]
   [ 0.32480703 -1.42116607 -0.958947  ]
   [-1.31978391  1.41240626  0.09801101]]

  [[ 1.36959152 -1.61007319  0.78022104]
   [-1.50008074  1.17776355  1.14336938]
   [ 0.53269657 -0.19124114 -1.40037003]]

  [[ 0.83411053  0.77453837  0.94036103]
   [ 1.21844841 -1.49105618  0.45409768]
   [-0.29568323  1.0485466  -1.54946865]]]


 [[[ 0.26058325 -1.01232249 -0.24384379]
   [-0.63360965 -0.88325446  0.18483514]
   [-0.08087242 -1.07845601 -0.64021096]]

  [[-0.36014396  1.50208826 -1.0311317 ]
   [ 1.724335   -0.0538726  -1.23398541]
   [-0.08428757  1.33657509  0.3906115 ]]

  [[-0.52013979  1.44175031 -1.36353436]
   [ 0.56629111 -0.51255018  0.37957816]
   [-0.58131877  0.3792474  -0.83693368]]]


 [[[ 0.71194379  1.38432186 -0.00307181]
   [-1.23845213  0.87341009  0.1287777 ]
   [-1.3974696   1.18299481  0.11847668]]

  [[-0.97695431  1.64728499 -1.29679054]
   [ 0.44677616  1.32202992  1.40552771]
   [-1.1141157   0.4829845  -1.06890916

  0%|          | 3/918 [01:44<8:52:17, 34.90s/it]

[[[[ 1.5931972   0.60109191  1.46067492]
   [ 0.43307604 -1.89488809 -1.278596  ]
   [-1.75971187  1.88320835  0.13068135]]

  [[ 1.82612203 -2.14676425  1.04029472]
   [-2.00010765  1.5703514   1.52449251]
   [ 0.7102621  -0.25498819 -1.86716005]]

  [[ 1.11214737  1.03271782  1.25381471]
   [ 1.62459788 -1.9880749   0.60546357]
   [-0.39424431  1.39806213 -2.0659582 ]]]


 [[[ 0.34744434 -1.34976332 -0.32512505]
   [-0.84481287 -1.17767262  0.24644685]
   [-0.1078299  -1.43794135 -0.85361461]]

  [[-0.48019195  2.00278435 -1.37484227]
   [ 2.29911333 -0.07183014 -1.64531389]
   [-0.11238343  1.78210013  0.52081533]]

  [[-0.69351971  1.92233375 -1.81804581]
   [ 0.75505482 -0.68340024  0.50610422]
   [-0.77509169  0.5056632  -1.11591158]]]


 [[[ 0.94925838  1.84576248 -0.00409575]
   [-1.65126951  1.16454679  0.17170361]
   [-1.8632928   1.57732642  0.15796891]]

  [[-1.30260574  2.19637998 -1.72905405]
   [ 0.59570155  1.76270657  1.87403695]
   [-1.4854876   0.64397934 -1.42521221

  0%|          | 4/918 [02:22<9:12:44, 36.29s/it]

[[[[ 1.99149649  0.75136488  1.82584366]
   [ 0.54134505 -2.36861011 -1.598245  ]
   [-2.19963984  2.35401044  0.16335169]]

  [[ 2.28265254 -2.68345531  1.3003684 ]
   [-2.50013457  1.96293925  1.90561564]
   [ 0.88782762 -0.31873523 -2.33395006]]

  [[ 1.39018422  1.29089728  1.56726839]
   [ 2.03074735 -2.48509363  0.75682946]
   [-0.49280538  1.74757766 -2.58244775]]]


 [[[ 0.43430542 -1.68720415 -0.40640632]
   [-1.05601609 -1.47209077  0.30805857]
   [-0.13478737 -1.79742669 -1.06701826]]

  [[-0.60023994  2.50348043 -1.71855284]
   [ 2.87389166 -0.08978767 -2.05664236]
   [-0.14047929  2.22762516  0.65101916]]

  [[-0.86689964  2.40291718 -2.27255727]
   [ 0.94381852 -0.8542503   0.63263027]
   [-0.96886461  0.632079   -1.39488947]]]


 [[[ 1.18657298  2.3072031  -0.00511969]
   [-2.06408689  1.45568349  0.21462951]
   [-2.329116    1.97165802  0.19746113]]

  [[-1.62825718  2.74547498 -2.16131757]
   [ 0.74462693  2.20338321  2.34254618]
   [-1.8568595   0.80497417 -1.78151527

  1%|          | 5/918 [02:57<9:06:10, 35.89s/it]

[[[[ 2.38979579  0.90163786  2.19101239]
   [ 0.64961406 -2.84233213 -1.917894  ]
   [-2.63956781  2.82481253  0.19602202]]

  [[ 2.73918305 -3.22014637  1.56044208]
   [-3.00016148  2.3555271   2.28673877]
   [ 1.06539315 -0.38248228 -2.80074007]]

  [[ 1.66822106  1.54907674  1.88072206]
   [ 2.43689682 -2.98211236  0.90819536]
   [-0.59136646  2.0970932  -3.0989373 ]]]


 [[[ 0.5211665  -2.02464498 -0.48768758]
   [-1.26721931 -1.76650893  0.36967028]
   [-0.16174485 -2.15691202 -1.28042192]]

  [[-0.72028792  3.00417652 -2.0622634 ]
   [ 3.44867    -0.1077452  -2.46797083]
   [-0.16857514  2.67315019  0.78122299]]

  [[-1.04027957  2.88350062 -2.72706872]
   [ 1.13258222 -1.02510036  0.75915632]
   [-1.16263754  0.75849479 -1.67386737]]]


 [[[ 1.42388757  2.76864372 -0.00614363]
   [-2.47690426  1.74682018  0.25755541]
   [-2.7949392   2.36598963  0.23695336]]

  [[-1.95390861  3.29456998 -2.59358108]
   [ 0.89355232  2.64405985  2.81105542]
   [-2.22823141  0.96596901 -2.13781832

  1%|          | 6/918 [03:35<9:14:32, 36.48s/it]

[[[[ 2.78809509  1.05191084  2.55618112]
   [ 0.75788307 -3.31605415 -2.23754301]
   [-3.07949578  3.29561461  0.22869236]]

  [[ 3.19571356 -3.75683744  1.82051576]
   [-3.50018839  2.74811495  2.66786189]
   [ 1.24295867 -0.44622932 -3.26753008]]

  [[ 1.9462579   1.80725619  2.19417574]
   [ 2.84304629 -3.47913108  1.05956125]
   [-0.68992754  2.44660873 -3.61542684]]]


 [[[ 0.60802759 -2.36208581 -0.56896885]
   [-1.47842252 -2.06092708  0.431282  ]
   [-0.18870232 -2.51639736 -1.49382557]]

  [[-0.84033591  3.50487261 -2.40597397]
   [ 4.02344833 -0.12570274 -2.8792993 ]
   [-0.196671    3.11867522  0.91142682]]

  [[-1.2136595   3.36408406 -3.18158017]
   [ 1.32134593 -1.19595042  0.88568238]
   [-1.35641046  0.88491059 -1.95284526]]]


 [[[ 1.66120217  3.23008434 -0.00716756]
   [-2.88972164  2.03795688  0.30048131]
   [-3.2607624   2.76032123  0.27644559]]

  [[-2.27956005  3.84366497 -3.02584459]
   [ 1.0424777   3.08473649  3.27956466]
   [-2.59960331  1.12696384 -2.49412138

  1%|          | 7/918 [04:10<9:06:33, 36.00s/it]

[[[[ 3.18639439  1.20218381  2.92134985]
   [ 0.86615207 -3.78977618 -2.55719201]
   [-3.51942375  3.7664167   0.2613627 ]]

  [[ 3.65224406 -4.2935285   2.08058943]
   [-4.00021531  3.1407028   3.04898502]
   [ 1.4205242  -0.50997637 -3.73432009]]

  [[ 2.22429475  2.06543565  2.50762942]
   [ 3.24919576 -3.97614981  1.21092714]
   [-0.78848861  2.79612426 -4.13191639]]]


 [[[ 0.69488867 -2.69952664 -0.65025011]
   [-1.68962574 -2.35534524  0.49289371]
   [-0.2156598  -2.8758827  -1.70722922]]

  [[-0.9603839   4.0055687  -2.74968454]
   [ 4.59822666 -0.14366027 -3.29062777]
   [-0.22476686  3.56420025  1.04163065]]

  [[-1.38703943  3.8446675  -3.63609163]
   [ 1.51010963 -1.36680047  1.01220843]
   [-1.55018338  1.01132639 -2.23182316]]]


 [[[ 1.89851676  3.69152495 -0.0081915 ]
   [-3.30253902  2.32909358  0.34340721]
   [-3.7265856   3.15465284  0.31593781]]

  [[-2.60521148  4.39275997 -3.45810811]
   [ 1.19140309  3.52541313  3.74807389]
   [-2.97097521  1.28795867 -2.85042443

  1%|          | 8/918 [04:41<8:42:18, 34.44s/it]

[[[[ 3.58469369  1.35245679  3.28651858]
   [ 0.97442108 -4.2634982  -2.87684101]
   [-3.95935172  4.23721879  0.29403303]]

  [[ 4.10877457 -4.83021956  2.34066311]
   [-4.50024222  3.53329065  3.43010815]
   [ 1.59808972 -0.57372342 -4.2011101 ]]

  [[ 2.50233159  2.3236151   2.82108309]
   [ 3.65534523 -4.47316854  1.36229303]
   [-0.88704969  3.1456398  -4.64840594]]]


 [[[ 0.78174976 -3.03696747 -0.73153137]
   [-1.90082896 -2.64976339  0.55450542]
   [-0.24261727 -3.23536804 -1.92063287]]

  [[-1.08043188  4.50626478 -3.09339511]
   [ 5.173005   -0.16161781 -3.70195624]
   [-0.25286271  4.00972528  1.17183449]]

  [[-1.56041936  4.32525093 -4.09060308]
   [ 1.69887334 -1.53765053  1.13873448]
   [-1.7439563   1.13774219 -2.51080105]]]


 [[[ 2.13583136  4.15296557 -0.00921544]
   [-3.71535639  2.62023027  0.38633311]
   [-4.1924088   3.54898444  0.35543004]]

  [[-2.93086292  4.94185497 -3.89037162]
   [ 1.34032848  3.96608977  4.21658313]
   [-3.34234711  1.44895351 -3.20672748

  1%|          | 9/918 [05:14<8:32:35, 33.83s/it]

[[[[ 3.98299299  1.50272976  3.65168731]
   [ 1.08269009 -4.73722022 -3.19649001]
   [-4.39927969  4.70802088  0.32670337]]

  [[ 4.56530508 -5.36691062  2.60073679]
   [-5.00026914  3.9258785   3.81123128]
   [ 1.77565525 -0.63747046 -4.66790012]]

  [[ 2.78036843  2.58179456  3.13453677]
   [ 4.06149471 -4.97018726  1.51365893]
   [-0.98561076  3.49515533 -5.16489549]]]


 [[[ 0.86861084 -3.37440831 -0.81281264]
   [-2.11203218 -2.94418155  0.61611714]
   [-0.26957475 -3.59485337 -2.13403653]]

  [[-1.20047987  5.00696087 -3.43710567]
   [ 5.74778333 -0.17957534 -4.11328471]
   [-0.28095857  4.45525032  1.30203832]]

  [[-1.73379928  4.80583437 -4.54511453]
   [ 1.88763704 -1.70850059  1.26526054]
   [-1.93772923  1.26415799 -2.78977895]]]


 [[[ 2.37314595  4.61440619 -0.01023938]
   [-4.12817377  2.91136697  0.42925901]
   [-4.658232    3.94331605  0.39492227]]

  [[-3.25651435  5.49094996 -4.32263514]
   [ 1.48925386  4.40676641  4.68509237]
   [-3.71371901  1.60994834 -3.56303054

  1%|          | 10/918 [05:50<8:42:57, 34.56s/it]

[[[[ 4.38129229  1.65300274  4.01685604]
   [ 1.1909591  -5.21094224 -3.51613901]
   [-4.83920765  5.17882296  0.35937371]]

  [[ 5.02183559 -5.90360169  2.86081047]
   [-5.50029605  4.31846635  4.1923544 ]
   [ 1.95322077 -0.70121751 -5.13469013]]

  [[ 3.05840528  2.83997402  3.44799045]
   [ 4.46764418 -5.46720599  1.66502482]
   [-1.08417184  3.84467086 -5.68138504]]]


 [[[ 0.95547192 -3.71184914 -0.8940939 ]
   [-2.32323539 -3.2385997   0.67772885]
   [-0.29653222 -3.95433871 -2.34744018]]

  [[-1.32052786  5.50765696 -3.78081624]
   [ 6.32256166 -0.19753287 -4.52461319]
   [-0.30905443  4.90077535  1.43224215]]

  [[-1.90717921  5.28641781 -4.99962598]
   [ 2.07640074 -1.87935065  1.39178659]
   [-2.13150215  1.39057379 -3.06875684]]]


 [[[ 2.61046055  5.07584681 -0.01126332]
   [-4.54099115  3.20250367  0.47218492]
   [-5.1240552   4.33764765  0.43441449]]

  [[-3.58216579  6.04004496 -4.75489865]
   [ 1.63817925  4.84744306  5.15360161]
   [-4.08509091  1.77094318 -3.91933359

  1%|          | 11/918 [06:22<8:32:40, 33.92s/it]

[[[[ 4.77959159  1.80327572  4.38202477]
   [ 1.29922811 -5.68466426 -3.83578801]
   [-5.27913562  5.64962505  0.39204405]]

  [[ 5.4783661  -6.44029275  3.12088415]
   [-6.00032296  4.7110542   4.57347753]
   [ 2.1307863  -0.76496456 -5.60148014]]

  [[ 3.33644212  3.09815347  3.76144413]
   [ 4.87379365 -5.96422471  1.81639071]
   [-1.18273292  4.19418639 -6.19787459]]]


 [[[ 1.04233301 -4.04928997 -0.97537516]
   [-2.53443861 -3.53301785  0.73934056]
   [-0.3234897  -4.31382405 -2.56084383]]

  [[-1.44057585  6.00835304 -4.12452681]
   [ 6.89734    -0.21549041 -4.93594166]
   [-0.33715029  5.34630038  1.56244598]]

  [[-2.08055914  5.76700124 -5.45413744]
   [ 2.26516445 -2.05020071  1.51831265]
   [-2.32527507  1.51698959 -3.34773474]]]


 [[[ 2.84777514  5.53728743 -0.01228725]
   [-4.95380852  3.49364037  0.51511082]
   [-5.5898784   4.73197926  0.47390672]]

  [[-3.90781723  6.58913995 -5.18716216]
   [ 1.78710464  5.2881197   5.62211084]
   [-4.45646281  1.93193801 -4.27563664

  1%|▏         | 12/918 [06:54<8:20:00, 33.11s/it]

[[[[ 5.17789088  1.95354869  4.7471935 ]
   [ 1.40749712 -6.15838629 -4.15543701]
   [-5.71906359  6.12042714  0.42471438]]

  [[ 5.9348966  -6.97698381  3.38095783]
   [-6.50034988  5.10364205  4.95460066]
   [ 2.30835182 -0.8287116  -6.06827015]]

  [[ 3.61447896  3.35633293  4.0748978 ]
   [ 5.27994312 -6.46124344  1.9677566 ]
   [-1.28129399  4.54370193 -6.71436414]]]


 [[[ 1.12919409 -4.3867308  -1.05665643]
   [-2.74564183 -3.82743601  0.80095228]
   [-0.35044717 -4.67330939 -2.77424748]]

  [[-1.56062383  6.50904913 -4.46823737]
   [ 7.47211833 -0.23344794 -5.34727013]
   [-0.36524614  5.79182541  1.69264981]]

  [[-2.25393907  6.24758468 -5.90864889]
   [ 2.45392815 -2.22105077  1.6448387 ]
   [-2.519048    1.64340539 -3.62671263]]]


 [[[ 3.08508974  5.99872805 -0.01331119]
   [-5.3666259   3.78477706  0.55803672]
   [-6.0557016   5.12631086  0.51339895]]

  [[-4.23346866  7.13823495 -5.61942568]
   [ 1.93603002  5.72879634  6.09062008]
   [-4.82783471  2.09293285 -4.6319397 

  1%|▏         | 13/918 [07:47<9:02:34, 35.97s/it]


KeyboardInterrupt: 

In [None]:
model.evaluate()
predictions = model.forward(x_test)
accuracy_score = accuracy_score(y_test_labels, predictions.argmax(axis=1))
precision_score = precision_score(y_test_labels, predictions.argmax(axis=1),average="micro")
recall_score = recall_score(y_test_labels, predictions.argmax(axis=1),average="micro")
f1_score = f1_score(y_test_labels, predictions.argmax(axis=1),average="micro")

print(f"accuracy: {accuracy_score}\nprecision: {precision_score}\nrecall: {recall_score}\nf1: {f1_score}")