In [1]:
include("../src/Utero.jl")
using .Utero 

In [2]:
ctx = ⬅Context()
ForwardBackward!(ctx, x -> 2.0*x + x, AddParams!(ctx, 7.0))

(21.0, Any[3.0])

In [3]:
# should be 12.8
ctx = ⬅Context()
ForwardBackward!(ctx, x -> x^2.0 + x^2.0, AddParams!(ctx, 3.2)) 

(20.480000000000004, Any[12.8])

In [4]:
# Gradient should be 56.21
ctx = Utero.⬅Context()
ForwardBackward!(ctx, w -> 2.0 * (((2.0*w + 12.0) - w) * w) + sin(w^2.0), AddParams!(ctx, 7.0))

(265.0462473472405, Any[56.20829561241092])

In [5]:
# should be approx. (27.53, 3.0428)
ctx = Utero.⬅Context()
x, y = AddParams!(ctx, 13.2, 89.0)
ForwardBackward!(ctx, (x, y) -> (cos(y)^3 * 4sin(x))^2 + cos(x*y)^2, x, y) 

(1.0748674472530375, Any[27.53002464314667, 3.0428042273213602])

In [6]:
ctx = Utero.⬅Context()
W = AddParams!(ctx, rand(30,1))
(out, grads) = ForwardBackward!(ctx, _ -> W[5:5]' * W[5:5], 3)

(0.6804024366011192, Any[sparse([5], [1], [1.6497302041256554], 30, 1)])

In [7]:
ctx = Utero.⬅Context()
W = AddParams!(ctx, rand(30,1))
(out, grads) = ForwardBackward!(ctx, _ -> W .* W, 3)

([0.1629333539813685; 0.2536906506428964; … ; 0.34327205571882924; 0.42275717217084124;;], Any[[0.8073000779917427; 1.007354258725095; … ; 1.1717884718989673; 1.300395589304795;;]])

In [8]:
ctx = Utero.⬅Context()
W1, W2, W3 = AddParams!(ctx, rand(30,30), rand(30,30), rand(30,30))

function model(X)
    return ReLU(W3 * ReLU(W2 * ReLU(W1 * X)))
end

X = rand(1500, 30)
Y = mapslices(row -> cos.(row), X, dims=2) 

println(ctx.Params[1].val)
for (x, y) in zip(eachrow(X), eachrow(Y))
    println(MeanSquaredError(model(x), y).val)
    (out, grads) = ForwardBackward!(ctx, x -> MeanSquaredError(model(x), y), x)
    GradientStep!(0.0001, ctx.Params, grads)
end
println(ctx.Params[1].val)


[0.4686945564480983 0.4713042419103973 0.5063505603565324 0.9163028501544933 0.6059435499089406 0.9533655807028181 0.17142729413679536 0.48451427074248965 0.7012574914350974 0.7638154474914697 0.5905076759087655 0.7862763922283734 0.9897211454343887 0.8816330966597625 0.6790372232890692 0.5126369028803986 0.3648574390034667 0.6165012571844393 0.06789294618939601 0.9268400962107888 0.8224182730435171 0.9485091250346335 0.013513634931013652 0.7090437776773257 0.7808629448829382 0.4620421233030274 0.43491591712514244 0.22368881433909527 0.8759015800840179 0.3895404067219904; 0.16760874394523773 0.8827467571394771 0.09793903414666028 0.5915754914934339 0.7632923981909745 0.5458005492101634 0.8029829041824271 0.16939000056014997 0.21687122490648902 0.2528948163753644 0.5760711059431273 0.11151074995628418 0.4545617781247123 0.9154978240457707 0.15174249693946018 0.24010188692811896 0.5439710889220141 0.528038801147366 0.8502668069752634 0.35964412757798336 0.6729346109164058 0.2442327594073

 0.25447818970758507 0.05806925394295115 0.1966827220087909 0.6537959109262141 0.9651285852704149 0.07872772664204664 0.9218943369666778 0.008759065083096318 0.6905575604273931; 0.8114631368579044 0.36579230777297167 0.8562590403329382 0.06872916257248496 0.23620550892128722 0.19218587600696213 0.2740157308901824 0.3011644662204088 0.15209234666911575 0.8679669455560418 0.9242075598568436 0.7970245172712724 0.014038128000767647 0.07160028707726007 0.8693294313888434 0.1181195353285166 0.32098224819669474 0.04795601337757449 0.29274301083415766 0.5260568438690585 0.37372176820645187 0.010607121895643279 0.7473863977543473 0.9084688923173835 0.7231844445096246 0.8642998478932723 0.13992510210090847 0.45599602522063865 0.991520895714925 0.19206861885763693; 0.2802475627481794 0.9165252483096326 0.15131959601284295 0.30852444678166935 0.34144444364838344 0.05318044054975646 0.4766533523931513 0.4924953402663731 0.7650591003567531 0.2672482620947626 0.0048166352940096635 0.7140544383917422 

9.561955825773424e7


20.591048425790728
22.209127642917647
20.68354031957073
21.78600303097468
21.99113453666439
22.163265424237267
21.19968764071893
23.875460975875516
23.492115062947352
21.024693484970513
22.249612689422648
22.056731370592193
21.914171543003743
23.88640662701836
20.569264167141572
21.19090170426157
22.967620880075874
23.354787942750537
21.538590677074325
21.915375243307512
21.51913886321134
21.73105241236743
24.450841765618904
24.46892457135959
19.43887176631796
21.374071388857583
20.77895693212435
22.793582072193686
20.744715083861006
21.266572853840504
20.66570449361044
21.5968141317806
21.437616574668535
21.75923049934097
22.301708690690347
20.69284967665845
21.61943855154145
22.194170271452084
23.29550800653024
20.959475982230202
22.751016175089728
21.528484956961616
21.1240328010499
22.484728826119806
24.252446917362043
20.698464747447705
18.935657822043503
24.443055953062153
23.28181462483773
21.865600601853533
21.63067095180136
21.719006235748427
22.91917094302892
19.1404142593325

21.1457402870033
20.661960459680103
21.1228526568264
20.142677308800984
20.339576471076516
20.67917631570255
22.32213231271176
22.672142603351578
21.915812271102865
23.214723069862988
20.31040389556105
21.4256204619752
20.529267776909844
21.869781503350666
20.94166225001794
22.906852529114776
20.243951917243677
21.46314632238744
19.254676073194908
22.259721639455126
22.761890759016705
24.551777317285513
21.773178500243034
20.43486212383194
22.36350461658326
21.99637482047321
23.47203936808367
21.848649302215588
22.904347835307696
21.713168578541936
22.29543656002497
20.54224164685904
21.68363982485189
23.432404433894668
21.718568107439104
21.967253925904586
21.483325843160014
22.829107412376473
22.06119937492046
21.751642260757666
22.879450682308374
21.91553358625553
22.76032826326789
22.61067081411168
21.599867356847657
20.7118804251988
20.797237274385818
21.243692457043267
22.8762873248968
21.139691274289067
22.05115852952173
20.018518704351393
21.790445773754715
23.018268573468223
2

22.12749243594494
21.102847684082217
20.703407923157005
20.689819492810248
21.89399763222475
20.477561107390557
23.46007947503511
24.28372895130475
22.92397857940874
18.022434554777515
20.661455538540974
20.08105656891925
20.652953604192675
22.352227721808415
21.833584447554642
19.656447956206502
21.15927259634102
21.23130260571389
22.647030849493405
22.344869312964335
19.293495958885703
23.598848185904664
18.932535967840245
21.42448506784246
21.009265084131638
21.064960275075435
19.963936033458992
23.051588199546295
22.697833943814828
22.106480163748078
22.343959084939925
21.09834294108092
22.986987476654683
23.82592429565498
20.346581666858167
22.178166221461154
21.607170673694636
20.181294198153257
22.69972190526518
23.33835871906357
20.051424617377588
23.295653262260327
19.874827041057017
22.646743131793894
19.897944357079492
22.271063800162004
20.962725303037207
22.390824149046953
21.9575863634276
20.700973097144733
20.941318200894926
20.452116410374042
22.266960251995073
21.84405

22.827349327471982
21.17167491755343
22.345769030302854
21.626777099931562
20.50429288529088
24.136970847598555
21.942952642677923
21.412792315282683
19.09660005738096
20.294128480999888
21.6109679209492
20.87619660862675
22.262708901957787
22.019319482427175
21.53028373800723
23.78482242698926
24.662588343960817
22.979659001234133
21.091755468635476
21.972861695095173
20.22578536171492
24.09040128744886
22.35132015885966
21.200112018917523
21.800441104217303
22.490358341217803
22.536550193678355
22.45773139037266
21.435499873509542
22.376912230989525
22.692758288636963
23.43096843369247
22.55979568437063
19.38083691619973
20.377492827100205
21.360290057407596
21.539895540366043
21.77448095922567
23.73235719454309
22.12763878914414
22.260870065318123
24.173963157596955
21.51532766199702
21.890062400457566
22.455456579647144
22.415868416548346
21.290364586359974
20.311796102820036
20.721889184993977
21.836140624669035
22.524176271162336
22.742178135821685
23.209054058186
21.913681821032

 -1.0417029665671922e8 -9.078374799476272e7 -2.7014658898540042e7 -3.863327779075923e7 -3.4842112259624906e7 -1.045847974392959e8 -2.7400230355743906e6 -3.4172427130238764e7 -6.708593119540228e7 -8.087602981738302e7 -5.396971662549418e7 -7.170979401492746e7 -8.774305919527926e7 -3.852809832888126e7 -2.7401877819619443e7 -5.485206133652538e7 -4.62727348773175e7; -4.290109899103749e7 -1.098168494571456e7 -8.149983253183745e7 -4.894470809790111e7 -1.3995660461523442e6 -2.623318894307642e7 -5.153618581965158e7 -4.601296009801483e7 -3.480323602057168e7 -2.268111931429708e7 -4.0011314311491325e7 -5.9016017459355384e7 -7.911556615306553e7 -7.785814626037681e7 -6.785287723185565e7 -2.0191084640573885e7 -2.887498135960405e7 -2.6041417933518257e7 -7.816794980252889e7 -2.0479263635951746e6 -2.554088828730836e7 -5.014084162675949e7 -6.044772890208407e7 -4.033762271797539e7 -5.359677233374661e7 -6.558022897560609e7 -2.879636898931825e7 -2.048049626652339e7 -4.0997097953851156e7 -3.4584804145193376e

 -1.520545617604139e8 -4.524710983246624e7 -6.4707243778894104e7 -5.835738525333186e7 -1.7517007180170172e8 -4.589291312314201e6 -5.723572374998394e7 -1.1236286524720396e8 -1.3546003134361762e8 -9.039439207808562e7 -1.2010741547121793e8 -1.4696168347893462e8 -6.4531077749653466e7 -4.589566508494324e7 -9.187223947795635e7 -7.750264382917765e7; -6.912285655792809e7 -1.769384582411299e7 -1.3131368055946383e8 -7.886040358882432e7 -2.2550014467637283e6 -4.2267284034836635e7 -8.303583107551318e7 -7.413673114460357e7 -5.607546580790371e7 -3.654413935904215e7 -6.446679474548152e7 -9.508743949111691e7 -1.2747211626912574e8 -1.2544614438989721e8 -1.093255134470551e8 -3.2532161514527824e7 -4.652377783575253e7 -4.1958300484427184e7 -1.2594530339761952e8 -3.299648438111122e6 -4.115183945678683e7 -8.078763114462315e7 -9.739423377614537e7 -6.499254727583044e7 -8.635587587627251e7 -1.0566379093463695e8 -4.639711617810609e7 -3.2998465204974346e7 -6.60551030600444e7 -5.5723525788688794e7; -5.58719201326