In [1]:
include("../src/Utero.jl")
using .Utero 

In [2]:
ctx = ⬅Context()
ForwardBackward!(ctx, x -> 2.0*x + x, AddParams!(ctx, 7.0))

(21.0, Any[3.0])

In [3]:
# should be 12.8
ctx = ⬅Context()
ForwardBackward!(ctx, x -> x^2.0 + x^2.0, AddParams!(ctx, 3.2)) 

(20.480000000000004, Any[12.8])

In [4]:
# Gradient should be 56.21
ctx = Utero.⬅Context()
ForwardBackward!(ctx, w -> 2.0 * (((2.0*w + 12.0) - w) * w) + sin(w^2.0), AddParams!(ctx, 7.0))

(265.0462473472405, Any[56.20829561241092])

In [5]:
# should be approx. (27.53, 3.0428)
ctx = Utero.⬅Context()
x, y = AddParams!(ctx, 13.2, 89.0)
ForwardBackward!(ctx, (x, y) -> (cos(y)^3 * 4sin(x))^2 + cos(x*y)^2, x, y) 

(1.0748674472530375, Any[27.53002464314667, 3.0428042273213602])

In [6]:
ctx = Utero.⬅Context()
W = AddParams!(ctx, rand(30,1))
(out, grads) = ForwardBackward!(ctx, _ -> W[5:5]' * W[5:5], 3)

(0.05663287938469892, Any[sparse([5], [1], [0.4759532724320694], 30, 1)])

In [7]:
ctx = Utero.⬅Context()
W = AddParams!(ctx, rand(30,1))
(out, grads) = ForwardBackward!(ctx, _ -> W .* W, 3)

([0.24959275141633838; 0.3489116106923465; … ; 0.16802838386052793; 0.14550936893158065;;], Any[[0.9991851708594126; 1.181374810451529; … ; 0.8198253078809605; 0.7629138062234309;;]])

In [8]:
ctx = Utero.⬅Context()
W1, W2, W3 = AddParams!(ctx, rand(30,30), rand(30,30), rand(30,30))

function model(X)
    return ReLU(W3 * ReLU(W2 * ReLU(W1 * X)))
end

X = rand(1500, 30)
Y = mapslices(row -> cos.(row), X, dims=2) 

println(ctx.Params[1].val)
for (x, y) in zip(eachrow(X), eachrow(Y))
    println(MeanSquaredError(model(x), y).val)
    (out, grads) = ForwardBackward!(ctx, x -> MeanSquaredError(model(x), y), x)
    GradientStep!(0.0001, ctx.Params, grads)
end
println(ctx.Params[1].val)


[0.8919204973721623 0.293679766484394 0.9609653052937575 0.22962315445813908 0.9279468368354722 0.16646041333778627 0.9276762028589023 0.04388331005116053 0.9715156349695245 0.002542984985229424 0.9011054176378619 0.6357119737671039 0.183778240839512 0.15318185503794002 0.8339131616576488 0.3501522914704218 0.6897136782002024 0.2119585679229924 0.005215735519886411 0.5586207380822655 0.19924194776259663 0.4496349623566077 0.22439040314833825 0.4682659640502942 0.15984744102697024 0.9557911900361385 0.4282723094748231 0.573672759785283 0.8988577125526729 0.6964743205375906; 0.6510593028757552 0.7503442482910612 0.5030985967944873 0.5094950796529493 0.8316408924274995 0.41270380995772293 0.9404450521653924 0.514070708967033 0.23963518750480906 0.9969020602410578 0.05643314130602717 0.7056555746013633 0.17373133942682562 0.0348459055686271 0.13998384992733104 0.3823412217132641 0.4795567292642168 0.5708713636582943 0.7730788057186253 0.38023697132633805 0.4778655630098989 0.72844152034097

 0.44537629116253885 0.22574605660262292 0.22330303644955396 0.0720546431036666 0.5897381760850745 0.5592483391311881 0.5297578473609951 0.18268555867732772; 0.9810576712492579 0.005273205112489965 0.8433279328019092 0.6344990923604804 0.6818255861005537 0.14781743318491325 0.9175919113311827 0.5584706272977008 0.717305673851164 0.8638745945366996 0.2567631232789179 0.22777619707570995 0.7417874601646073 0.8563540184460148 0.9420402030038627 0.6348373787203324 0.7219497325660524 0.4767497495798526 0.05995094590517491 0.9342614592103412 0.37214285165317973 0.09794608953885564 0.8279090203697688 0.7961273526187702 0.8521914854483758 0.34585981690494116 0.25266093047223726 0.34871182287262226 0.3188644753996539 0.7866863323242225; 0.8633234456200627 0.8469642005833806 0.1447517815085686 0.6140938387310175 0.09680038840091187 0.87467644737451 0.6797271371651908 0.05050817657415296 0.4799998468948822 0.11625973742577622 0.11992909662241447 0.33538889703158625 0.28565502303316237 0.439230154

1.0553658570355001e8


22.41058360326172
22.360022584020633
23.678564366196632
21.901228801713028
22.10657257322578
22.13815125230069
21.76543140904003
22.036752643109693
22.05144158544487
21.018510803054102
22.326551687964773
22.15599851813945
23.104512136869065
22.375619247746585
22.79154699637875
21.243152080373545
23.395536703243145
21.468937408277434
21.823031773185416
21.866050210299253
23.22841238661472
21.566701310928547
21.17430453724636
23.025434792044223
22.746445072066027
22.675880406453466
19.37824754804347
21.930971455775744
23.830204963207933
24.059974185271376
21.406027302601686
20.68571204500466
20.565139838744834
19.76214777527357
22.606300377194437
21.68887815597239
21.51642971980912
22.515931573167556
23.363317625279905
20.96995625897474
20.93967087853844
20.35673152851944
21.362526275255295
23.41268407978712
20.551549998473043
21.482575848659014
19.270127451153314
22.198660585370053
21.803924810862366
21.955740115502326
20.416976406840188
21.637369333344775
22.12468116150459
21.423425916

19.663276014874103
21.610293430353778
22.410403521510993
20.025200855422014
21.799977850095022
20.781259131250206
21.099000913008627
23.22807661556723
23.204116576487493
23.62254504998466
20.987660427079884
19.589951464546214
21.593355261896193
21.109487297630213
22.442791648608605
22.801607455900168
23.225177717438527
23.249457442014133
24.101148265810032
19.920213643078956
20.422405347490248
19.662324933290517
20.93337016774045
19.551141376391968
20.906237435868057
22.09952797157168
24.529637334890687
23.44430463980385
22.83834036401003
22.595431504205546
24.66094278684464
20.229738631435694
23.36200990015196
20.369579533782858
21.47819136357115
23.277244936584413
21.470292349488883
19.692048067821794
19.729662228415066
21.994074151959957
23.491030520865312
20.84900644454101
22.886095411374047
21.018042626497735
23.971689458907044
21.425310601146265
22.85876341977556
19.423764064152508
20.758964403487038
21.470962196490692
23.35627391797803
22.0727425614249
22.262730681010172
23.5044

20.79166041406904
22.51550301588997
21.473974009413418
19.82430940524813
21.962428445098972
22.432016363055855
22.325622713921884
21.039331136347904
20.40688922806564
21.58396030568886
21.821324304756967
20.613397615087546
22.407973911268396
19.4293508596367
21.464952692371007
22.75310077499327
21.250733608982372
22.11977123629652
20.421676620499298
23.28223733066888
22.003675303379644
22.24434745035252
22.417830989265983
22.793971733779248
20.252552665403016
21.834716602964246
21.2396424519372
22.19312332357874
21.758564943425508
20.938482549385206
22.295400584781483
21.52297201725593
21.28590809237721
21.10447499054174
20.426133037804636
23.45635530078099
23.232411666776187
22.399031256089525
22.523111141670686
21.613495140939136
21.979058260529666
23.387732006439474
19.90408516310325
21.87475898978503
21.341172013709436
21.759490203821095
21.297293056052645
21.671189635315635
22.622677545196755
21.908125134764962
19.68455659645195
22.043769294707044
22.145410553713294
22.30803064705

21.575385676159275
22.145660603671143
23.268395186227796
21.10985692464423
22.010411333686534
20.553609561362798
21.72938384423778
21.836288526384816
22.181316293559505
20.73401812150713
20.227229892788642
21.11978920509498
21.307748943589207
22.196404262012496
21.964064973588908
21.57800291863458
20.913813624606
22.332116880652826
22.48011740193771
21.658368996865097
21.06108320667014
24.219304485645978
22.110750618002722
20.57755969808574
20.572940218614015
24.66316309133921
24.452142547673418
20.81443056320438
21.806496009229
20.307319682687933
20.58825078763408
22.9179894589688
22.760171090921457
23.093545613596703
21.56755666918684
23.055430540314113
20.141421444077828
23.13736601403647
21.982691185882054
20.909235406511584
21.99871764114786
22.06894038556587
21.50151017958234
22.00085794687637
19.581733704807906
21.85903106059632
19.058281150069245
20.469970828069524
22.3801214375231
22.54328269520019
21.96711705602082
20.232058346293428
22.398271477790487
23.16972314246002
20.63

 -2.3484525600506976e8 -1.9476027036616966e8 -2.3274489323941815e8 -1.1872450660693486e8 -1.6993919294845915e8 -6.1317640481962666e7 -2.2656545114347598e8 -1.035618125651334e8 -3.2479960622977044e7 -2.352715054768286e7; -9.000727675114933e6 -8.24533778341775e7 -1.180532927035813e8 -1.0401456839017947e8 -1.6538827426111615e8 -2.0698060304543856e8 -1.01402666487239e8 -1.6850553114666883e7 -2.122899426048271e8 -1.46500622436827e8 -1.5858599055414116e8 -1.1004607335572195e8 -1.66419840678266e8 -1.8528943715174425e8 -8.08620469084846e7 -6.781580210164769e7 -3.241306529054021e7 -2.532541154613424e7 -2.0435029860979098e8 -1.840418789966563e8 -2.0116111279781955e8 -1.6682556614303687e8 -1.9936200787548396e8 -1.016957044888721e8 -1.455646061976016e8 -5.252277580780131e7 -1.940688900967071e8 -8.870781488127467e7 -2.7821320200988133e7 -2.0152623065501224e7; -6.926787308838987e6 -6.3454542795231864e7 -9.085155635002112e7 -8.004762250996095e7 -1.2727965355591132e8 -1.59288314214978e8 -7.80375535784

 -1.8216831527371902e7 -1.4699129331381333e8 -1.3238323675769109e8 -1.4469727974255756e8 -1.1999936429779567e8 -1.4340316410497838e8 -7.31507770456587e7 -1.0470613450500171e8 -3.778017880648273e7 -1.3959577036327127e8 -6.380845336765294e7 -2.001216468852272e7 -1.449599129689006e7; -6.49959838794553e6 -5.954117516962867e7 -8.52485600511938e7 -7.51109264939754e7 -1.1943006392184775e8 -1.4946468746322778e8 -7.322482261653201e7 -1.216810995993698e7 -1.532986633139935e8 -1.0579092624906886e8 -1.1451800330024774e8 -7.946639259415717e7 -1.201749777519603e8 -1.3380107796449676e8 -5.839204399545904e7 -4.897109876709561e7 -2.3406099328074306e7 -1.8287968068277467e7 -1.4756529342012954e8 -1.3290019209954241e8 -1.452623212948268e8 -1.204679606833619e8 -1.439631522939533e8 -7.343643066780962e7 -1.0511501113127746e8 -3.792770959866088e7 -1.4014088973033372e8 -6.40576248366009e7 -2.0090311932469524e7 -1.4552597411858507e7; -6.534138637448504e6 -5.985758182264519e7 -8.570157862279025e7 -7.551007261057