From 1ae201a087ebcf7c80d5c1dbe736a64e0c11a341 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 27 Oct 2021 09:11:13 +0200 Subject: [PATCH] refactor(embedding): level up embed method to top API add docs (#178) * refactor(embedding): level up embed method to top API add docs * refactor(embedding): level up embed method to top API add docs --- docs/basics/embed.png | Bin 0 -> 28271 bytes docs/basics/fit.md | 47 ++++++++++++++++++++++++++-- docs/index.md | 8 ++--- finetuner/__init__.py | 5 +-- finetuner/embedding.py | 2 +- finetuner/labeler/executor.py | 6 ++-- tests/unit/test_embedding.py | 6 ++-- tests/unit/tuner/keras/test_gpu.py | 6 ++-- tests/unit/tuner/paddle/test_gpu.py | 6 ++-- tests/unit/tuner/torch/test_gpu.py | 6 ++-- 10 files changed, 67 insertions(+), 25 deletions(-) create mode 100644 docs/basics/embed.png diff --git a/docs/basics/embed.png b/docs/basics/embed.png new file mode 100644 index 0000000000000000000000000000000000000000..7ebf91fecee27415e3a5af9d5988a090c4ed32c2 GIT binary patch literal 28271 zcmdSAWmuHm_clB;v&_s}IGIfV2O z^X|Ft-#?D$c|W{gUynm^IIg*_wfA1@T<1F1y54K5E0SHOxekFq$dr|yYC#~lA>eNu zF(LTPlQD-X@IxH->?KUw#RleW;cg94vw*oexxk$4U)}ezcK5J%feP{n@(6L=w}Zi4 zJtTN}o&Wm+9v638-Wxjo4d5o%T$KzwAP@=*>|dNBnL>LA#CcBnsho~a8frfHv!!?7 zA|@%wM48E6NXE++X}MdLB<$$#{A%yj?Uqi*TUpEchCvN7w_atv(bBT>yOlCF7H+mC zOa366Iq$;G1GiXytV>)oQPzqR(f(eN%$I6^v4zZE@^lAQeF{*B?Jmc#t$nB@+}t8d^2 zm%+Dh-vq9F(een|4*{*;eJf8@%yJa8@n=lp%m+87axPO?=>wEpvMcP^UcutW~0A; zxwd86&ZY=*3P-B40RvW4+W~9%z4B}A0!X0o9y}}A(6<# zLGKpr+Qp#8j3D$z#_r_S+%3|;723sN(U$!vbg&uLq_dxjU{_%*k2mokO*=`!XGupZ zGFK~KbvPPe8>5ItgT(xyz#R|2KBI*k|DF26tApB>Be>tfK}I2O@L6BO3?8#zqE3^@X$cSF!cJcjqjGEQQ>{0~ng`xi@wU%a{*Jq9yEu!@|&Zx-X{zGyyx>T&rdiRVz$s*<&&-F-k7qh^V-%+RKxL%WBI9fP5y>LsH`mJui2H# zy#H>qM9@Xuul+hx^k~v%gv>=Y^9G68>Fct_9f{6Y{8-PL)Yg#l?7R$GlUy7rF}rlg zAG23eRb3yH+#Sc>wjlWYJzK``?BbP|nwlDWy80IVu1^F4{Gr#7&YaxET9M4Y(4#mHH9)tL0D9@P``OoS0WG%>svejPQ<7K+z)e+iK)7(B4Y571$Ndh z_-!F(TmA@_kEodbjY9{EA|Y3+W>?miAKN4s_M_1iC@wOe)8&Yx+cNvNQ@(r|p3pag zn_-U3Tzs7_{rYRr9qLzJ#X6%A(uWVg=6HTk$d&Nf%eoq9Jrn)R0>OI12EL*<{7^;jF+02Yq_vrr9%kf=tX==A?ehU&`ifYu>T0E`>`k3hC1c84&!Cc z;%5^l9CsSFc$}fo&B-9tBy7^edqc$x9Vh)aHfB6Hcu4Bx5lz%)wB*jP)X7{fhyz9J zDR?N4`Y#+Xfbhs1OI^IM2|TXKoZ!X$<;8ketRjQ`*%Y+p%my6$#(1&zEQ1Rl)i{U* zJnJe0nPB^upedbz2;m1qPVak={A zvKdYb9WKAd5={{R4)i%-~9NAM{H zS`o9(SRay<0p>UvwMYsIF>i+7^pmne$ysA5;^ahjpSDf-l_GkVY<^%oCWC zrZuS{-@_3_qlJ^elle6z_me=rk=AUBAJ(%X|Tq09gfvBDSxloz!`)I2D9684qe+@?%<9o ziDhP^iIT7#Oe@B0VEC6{{rXy%=~x*`29gm~M9Z)~xQ=TBGm$ zmg(AD&^hV|6Ev2(@VjY4&Bgsm_NRlR^5A(lx^D46bTu$}X0k=Ebw1sDVFSB8M;NF# zH}*f#%xkR|Nf<4RL9vcIIi1*2)7gN_!V%TU-!GDvvsmQ4O-h*V61ZhFdv&=cgYi|2 zJ3zSv`%kXm4@(Jzv_VSC$~MCoFS3G;^n-QJPj+4WXLUI8Q=tvYt$QS`S>*w%oak*^ z8P2}S!*J22*Rt#&f!ot zG%y$`)-L=9?jy+QXS^~cle`-DKj=5TnvHXFral~S398a69;cMSjKl;nE2FPYTFbi` zgN-qVnOE9yU>Ha!0Pf@do;8_*udM@bt#tOEUGy9BU#LR0?$)i4gR78&ZYWjY%aV&{ z(LBa=!xJlwE76iL^r8Lgfv@tY+2w@UMooDj-1M~l-X_5K#5@N*z8l2r*1Y}(`X}>` zl2tdcV7!HKrV59)ULKr_5o<`8$c=Qi8 zY5w(&1RZ)oreX%;bSZzmIN0UUq?z07KRXAGfn}q zEQwj_>U1S8uW`4faW&9ZfLcvMqxQn6W)WUq?&$7biE_I_mtz`sfD_X{_R(ZhPaAeL zJ2$+X8@f!(0MS!(>Fqt|7Qxjw`VRmteGfA22~+9_#igf{s5e$AxnSeSV=DQCvF#+>FD^dDF+@z!O6jIE} z4FKl|-7)aW(1hpt8$s|)IN27SOb=p{rQ}YXycqzZ zqbQevT!!l4RZg=JN7I9pM9fYnZ&1aniZ9q_%14mfL!s*3(ez{3bq)m`S8*gf2VmWj zkHv^a26;(;Une*W{YCO@7U4Xpcwe|maEWM`5kEvWGhHUw6ydyuWkw5_s}r;Rb0)>; zb+gNo;Czi>^q8TcVchs*VveM~N-oExZX<14`2%g4Bkj1njtB~sI0pai3becNfMtx7 z9wv}`OPl1rQF2Dc0#Zy%yHDu2ZzNZdML9k`-VazlU3>0an-(~stH05_8(6*ol9~eu zERaM?q>n3NYRDjtec9L?16v<{Fz?n)ANMu+yH;@?mg)dA9Ht?oJ#Hx7Rz?p&#TbSO zESNX{qAaIcx5W8IFASD4A8;^$-qdgP>%>4CHdyH;U`qdN-q6*wOguBaq?c;k;o{Y7 z@T?bOp++VW*j#Y0dHtEZ_#X_~3{w(x;vKAyO{qs9c%Ntq%VR0PQ1I2kY?oWW<(Anb zQ_SO>$>tr;xI92&wtn}-qnj7crP3w6p|h7FHh$3zQeRbBtDMhvr)@JQu$;9;I`KA} z`uh4vE;dC(p)0*|6A}_Kc0^;9jBmE-aZp+dBso|K6Bt#S2q=ZbUB}Gy`1%%?lUbg za%t8DJf+d|_s^TyC_<0PT#lj1DKV#v7|YAv@{11(r(L3(SkDB6?GQ_{fr~pz3YHsv zH`l~Qa|yOqXxJ!7J&g8llGis2tQpLdDR&7vNCgqOZU1=S{cX0xwcx9@f4KbHw*>U! z?oTJpf}26y(=@MgyNqzx&gEDM6a%&+GYQgdAr>>{j-;`m8@O%BI5wss)66|A38+iU z-do;vKvE}!CymFBrLl)yT+AK_GNV36g4V;>VZZ=yf&1reb&ExIH;v6M9E)6zFkZ$2 zwtlh>7)rs_!s59<;P!qXyB{eAA1b1&WX=VfL3&>}050628wQ+h)}N@}iJB{B+PI+A@-ylHv}K z!~@{7WGlL)wE-ov2sS<~)um;cpR}X?3cZysa~dXNCr)Uy_9a%(c2F!X4U1Q^L02HT zjm7?+15_P+08u^Ui_Px+gM$(ycH`m!r{6pPn-j;ZVJ8-LP!NbU#fRy`DRW zr6xX)ca+?*U|y)hAk6%3y}MSGdFeE^=3eL|vi0W41Ga?c z;)&K1!B<4VdqjX20E#V`4me7vE#$C8Z9TClMQD;0Og}E}vHs4n733 zM~y2BD&WiqkpJI*xE6zkKq52jrv*3>k(!zcUv}R@mOSx8Lb_3R=isYq+&s=-%?&@$ zTF|Fz*t1%|JQ*P&`?p-_x~Y%kAXJg_uO-JSJ<)Rwf1jyn-wqc+sIt7WnD?`5U|9FO zbsrbfMPjy!6C=QF^id_X_a3)e7wPK9#Jfo%)`rI(wZ&5$UnZ7@h>3)SEsD$y1vtIB zPjDbqGT`Zo;OWY7H?MO*-_%rC_f|VOJ0$}nsI$%g*CjM_CS_ZVULx5~xk{ZX8ZP{u zq;<`%=26lR$jW;f*|rYDiCxsiS#01nmYik3cty)G#WJfGc*GC}c?hx3+#vbFJdI2W zK~#@jT@Yj3Ix%4ge>W9XTP_V=tJPtBl3!Iri_=Ky7^NcIQ+IB0pX>&wt4!b zLHeS8B5r@M>r)pbw(CZn0k=n>ebpl-I<26pN6J=Z%bZxBt?7kTms}HSAL4#rL$WtB zL5mBu)^obFV=3%T64B+*zZ3Q5@Z#(9M;8I>D?ejO6o-xUpNKr0&F@oKG^VelwfB`dRT1}A3A>Y8_L{;RBnCVO`6N8o7%*O*H2 zHooY&KU-Dno4}bExl|q~d}UeXM0QB4|9A4EausDK2?L5Q=W;T7Rb`qi*qytV2? zPM5xv=A$3k6$47X4^pq4UMtsvc*4%fyX%kxH6898e5O8yL?I6xOH*eBxWNMwqGJ^2 z=^!rP>_Q%tL_J7WR;Z3E;lT)mA}zpAvM7A5n@D>!91JozkUi%TyP19@4yM(a}Ty<&nmpqQVylfrLbZ zM;;wT@Ype#A)Ow64Jdx7jL+e3rRT(LOEUA*=&RVXx41Y79F)lkLPMN{uS>2&hHs8l z9<6tliu|=6Q0vk`&|8xvijYhM=2j3zSEq>9#Bufz6M=v2gRejD&=AwxJxxfw?Cf_Xk*WW|9P~3@K46)G2CD^V^CjNKHF$@j zI%RgAF#oa21Wp41_;;9OY9hQdUk#u+t7SL(==z zH`5=B6y~1CbGyhy_NN(C=<@N32E7SON#Xm)K74Oh0G{AW`ii#EE}F5;QC>ykm% zOKg&>U)%STMv6h^ti2voku9@D0`_y5U_XsbcV{^@;3-BFYK;%Sm&a?B`nAVa3I}6 zwhKZ^Dqa*IR9ot*H&U$KeHitrA$79c_g@a&J^Invq~}q!a2g&Zut$L_3vmrKlr?|a&S6N&t=dVcBU`1MTO!+Ca!cXm zhR;*rBF@Q)JT@PmcLPUBx5mufb>(*_%y!n=n+RppRTge8DWb?sxIa8n2T--Ezwwa5X(IX z4+Sk42`fWHi@%MojE{}{vaGC@EafaRhCXc+rD${Z?fBgb@=7o5LV>0zZ%)YiCxeOv zZbO0BIBIM;De6=o-n{O;+Xr*;`&mSzFo!HAbzy?whmcceXR%sZBqMT<0+hFWwO2pD zdtScTQKO*N?f7YmH)QC{fx{#rXK@Vz`$!`tfOoun zLIQR#NhD$>dzQ*agh$HzmQN)t5{iXu6B?a+dvynfHv~SIBgWRg?N2dBmD=1~!8Vt^ z+2%T>?WN4)4yDWen$UKWs%R<(!L%{;Kf(0*ZHom>A*q-?A;f$rWun~rqs>vr@G9e* zj&bFXVJO9_jX!aw$pGiA>&Mlw$OY8nq?JS*HsG<_)25q~7debLPfBFZ8|puUL_DQy zrR0^-2e_Jq2YV|wpPBK@_yCDL`+a>Id;?rUK6n4=`E31_Y8^qIE^a&l)l0*NF5WV6 zgkEYfzl1mS1_tH$s6Odd2*Tyz~^(cfih<~~xgCC-|EdiB{%CC)Z|S=4m{x)2rFk-6 zJIGas9S5{28r_~&BI2KxB5mIOX^Nxt99x<4-nQg00}79sap4M^7@>7qReT{3geNhJjtXInwqbQ zTl!wLjMP7Dx*zrktKp;v-JCh2Y5v0Qz)Jn1eb4q((|>_{kogX=l2QJL>#LPH%USL$ zDj)+{36<6~75|5eANrXpu4$Dy_V2Gz<1bzBdHIqpW&jIsfH9#mgi5FHWxE3v(dW-x6=KMkBe?3p$6>xmqK039XcIZUI{U{%2^!i!w33lCki_f zG>{e#Sg_VX@k^;(OTFzDm+uo?w6~*lYCToZK5P_#lC(_=OS$uxZ0jix*<2TZe<+(z zDkbCz?EP?D!s=4%SKS<)Gp$1CW~=9Z!_p0HFgrEiO=Tz2?^2S z?xae;SNW}dx|gLrvv3CIa{trg6gj7agzdGU_OK%7?Sf>hHOKKGnoLn(YFOdqR^vAv z$=52^Bc7D;3NAb0x~2ya*^shBdG0H_GZQ4d3-=s=Jt;65;ht1I%6H3Iq@PS5&qd|S zBa*iAO!O+v<{QGcjZe;#vP>s3LzU9=~3oi;CNp1D3N=+Z)B$ya|)y zMxD5)3-vv>xo*9x=BUvT2m!ZQi?Zjp0B@n2*Vv7?90l%Rw~UEppTfdpi+pH#?p|0PzLTpJA{#s;`k*&dUT% zP6Xa~89z(m7$hLFDL*GDit4Y>mXZ8SG|l2@hyZSlqnotjFOD>2EJ+<=^y>zkGE#}V zq$-1x!$?S%OCN@6d9xOdYVc$)%`~Wy^Xd_i8g%vUp5C%#IaoK@@qdOJLLubkHK?Y( ze&)fRobX)g^x)vzyoFuB72AJNLIG}XVZFl@G5_WmH{a-C+65DcIlh;d*Mg6(Vb%M0 zWvv|W`o6FzezBEGHkeFw)#N!<(i(B|GAV8wF}ODp(VVK@1sgd0@scRG#!$G{_I-*J z#;qeTz2;JE&~(?*ID_dhob`tJ8$uJY=n+SK+1(oS*O`@RuFsU$h&($&mFbllYu{TN zXJCTp*n~r--a~Fs#lGkd>~Qt`yS$M=uJW&V0TO!e9aZ?Msv65fKNQJbs^u_Dz zQJR{uu9}YG+8*)zlzj-B+b;|fo@>O^r{;IhI44yR!2hU$N+2aQ66ISTx2oZje6sL1 zgdb~GqW#9Jf+AQzz*>tG;N8s>X=98*C#5@qOKA+vUp${7tv+Aq$+{Sqouxj1TAjwt zaf{SWfI*J~3cLK$NG+q)=*;Tkz~o{N*cCdlIEvdtgGmLroQ#j3d~#C{@g1Ky(Z%4g zuf;{=s-6CR{R5SS1Y>aq_O~FWEzY3C>U4wMr;m4*eEzh98#4d}+-66nv4HV$Raw51 z@>}`6s;^Lp;vI5ADU%cA zdkbjTpTv}!X6ZkM$1pehD(mBF{$cn$LZ`Usp%J)e0c)g`HRB1J<@2tGaUwMc-O=yU z*CuUU73nW5oA?L`sj>m&V2{bLEJd!w8SfZ*JEDu@_Dfp(szlKQ?$1Y9f#UI}GF#|H z3JzVt3{_&l$V{Rz?b=$4uIw1su$n^0#HxX!maG(!oMdKqg7J$AB_&?z$19tYzk!0( zaZdoyxRl6euz&|@9oJ{vcPB_sL-ZR$X_Kl3D&Yso?Jvihec3I4b(XEzB>(6lj5W+M z;uKU;IJruhP7W-!zBPa0C&X=Zl4&w_`<^J}HG(4@mU6iJGhP`H(MpdY9mvT_DTHaf zE@0*h4qZ2N$!fb^6nzHkKMGASN>n+|Pi+vd-C(`JsE{T3b7nv;VB;qci|WA<@D&}t z$v$c-W=j(FoL#1-US2jEBduqo*a^0UJnmPSZ;P6f7DJ@+GeK)SWcTa5=*w5tu!Z^g zV5!%*uypCZL|9#TV_<3eSEL)kCdSN}KH$!g3Ah8}$|XCVkL=1{_CA&l{++)3sOdhX zM@1b~*9J7Y6h=Lp?x7MrjOhKjhXYTaD5#_osC7NGyt#h-_oZY)WVq+Nais)>TM2WY z+oor1$Df4CyE02o5T4<>fDhkkPgX|_@82ZFKU^B4MK7O4n!0*6n4az)n-|&ahWeN< zbx-kzU5r}136pU1QP@{fn$F1we4zW@C{sOolMeEi=gK}0sJ4X4`tr3_mR7#;*nGTB z5O$p?FY1Z!c4CFU!62NVEx!_erlXo)Iq#KpqjpN1i`hA8??*wt%lDeGhs|)Tb8P9OB>vO5Ne}4C@n)fDrSBGCDCO^zKSkf~^nbs?Q z2s%LD{1^#EZKp8k=H9mEUI`TsfdqV4VKQ?NTnzFFl1DJj10aq@ky`I8-tY2TxETQ2Hk-Dd4m zwAZx_%jS6~j#2FZea8&7%^YnypAVp-)c@uWTZEuT zgHD6HMpl?Ad_oz6J?-tMIakBdXD zCR7}6SZx5Wks8_GX|o~)D$2iqmev}DQ!^)b+)i?P3dam$J5p4p{{8wV9Q-AFNx=t6 z?kvhUD3zzFHcm|ld8O7!hbG1HgM;Il-(&VRE>vKnXeJfdNP%Pv8T!gi%P^fZoUC%z z=ijJsCyF_NS2yC3C{Glb`^#O^=tVNdg_?Z^z?CB;7A$#HO8tv#w`{IW41^-TIgs-Q z0YS%SbWnAnrs^6CZ*a)kbG@-hwKA*vf{ebXUc>PyZxO&+An*ygmsDd{dZ1^l^lq>1-}H^<&r6l zQr}iaDrts>2nkz^SA_|8(7;D~>tgn(#;MZP6{{dq^whGD46VKlpa&hZkOD@22jpLi z+^phj@}Oe06+u+T@;b8GKcW`}*zaNDH9(g9)09*|zI#+`Sf=R}#OFD;|J^A3qaXQd z)UjF7ym`7Dvwk30V|bk%B%pR8Hzy1S$I*JqKo&QqmW`CSJq?$$U7wSHQBx@bdgU{y z+CIGhwz~3}is0?=TO;38`;enHa$2%L3BQh?2iJP`*1g_kMfGBVUt*m!-6#bH;RU+I zl5&sTOTqUJcfX44`3l^P&GYkk*7Fspg6KH8zpp%=6yi9eyxfKM{v-hYzqO|A&V#|= z2g|x(Rc_-dUmw#Eu1E3XC$a+{Js9dAajw$b+iH^KO1zAN@_ zkM_Uf1aSO!iFeFqY1o(G3F^K29N|s*trDB4JO0jPQ{sPqI1XGJa5(dd?zQ)pBtW-5 zp55TbS?~90k;Iv*`ad*YZ1@z^;q;q<2#R)H~nuQmC2S&X*7xHF2jKOT&G#X8rcZxjrEoA;&pq378+fkJHQ?6So7sBcD20{(3`jP1!Tn8`4td%4F`3&ZfR9GIf~o_DmUzLQ~h-Rvs&tBWHyZr z^QJ^L-u^-SR{*E((L_*4lj@0*v31yyuZw)+J>t#5-rGnJtVET0_$Qmy%T`qK&HmJ(9-ta`H2ann z!E0r$%kjK3DmpAzF1mJO+(V74pgsFjG_R@%+>CDZ#$LLgvz6P*J0{Y*YWG7(gplIN zl>W2dmQ_+&q#8btEKldz+#>W8w{xJvd+iUuxs@O7@~Pu<Zp_e3E~Lu zaPrroC9l5c%ZJquKv3zef%N2-w_0SFMc&1^&-bK8qVTCuDWgd*J00;OD&n7rd>^hn z#zvdx@fGi$TQ`~xtC^5rSWV$5uW0#c;us*{z(7G|`Hi}4oyYB46d7&0q>-BNJ+xL! zUO)vGdzKYfPGC<_Z^GuSJ%EXh5~Mi1v*PN7#%_1W-ojRN*}3ypfKtAX4zmU6QK1M~ zVXNyT)yyi6faAqi_`x&NHl3lSN-8+@Ej+eTwhqnaSO}I--WTgqsdAIxtj~cq_&p!H z)^SuWlSPXk()MnEnRGtLN-x6pwnSq}x$V_)VUFMW6B{F(0D;Ygr0tEWE%eMu=7>Ue zX|c~!%--9{b!5j_t~6JcCA@1xRfDQ52xYO?G({(WkB{$Flna+u4bQ*bzn==`Fh?od z%`waxKm1R>-@OB;(P7?}Uw#6iRGb}Dg<<=@CicYw@^KW3CTLa4tfB&lp`u8V0+H+1 zK)p?fQ~=@^xAWbSzVG-UQu|t67HJH9P=8drnb)k5Z(12Dy{0v^9^o$)fmQNDI>8zImeDc>qH+{N!N5a}jCMS=3l>q#iBXhhJ~ zAB^Geg?jnBG&l#3%&W4IU3T0v&7IE6&L7CJ;l-SYSz$nhbG7@c^6ETr?Qn5NvZU46 zm|0;yCb zJm!EZfQSSAb4qf2W&$8_tEoMG0tn)^QpUjxRA}R@mcGWx#_P*~P$;VZwdf9Zf@(3K(2p4_rPncPrP&jvc}W8hVY}14)%>X?)j3fHLWK*b1WfM z$(voMcVKzIpiX#$s$xb@e+7BAMF1K;Diq0D zXO?2O%X!5?K34-YFQh4R!+f4LNJU7Rd4LGgJ1EG>{sc5s{U#z&|6p~~pjJL^BkMjr z<~ww$C$42$Z7hTdq2G5RT{Q)lN01wbg8tuL$)TR<`n-B(E!3?lTPPUa$LE$CZ?gK_}qs=!?2ko%;f6 zCCZ)W-y7Eu$Y5S;)@JBTo5V+WP?5rESVpCeN?frknDv<229s{ojjwi(x_EgehYEiy z%?bodnz4;2)l@TYyiU$R)opF#P<>n#fR&Av>bDBCUG_5XG-U!rWXu?P1OzVDn<$PD z+8^(dyUuqe?-f8>L*`++hJ{0Ldlr4NE6@j`Rnq*OXeeJxrzkm})}dEk-{tZ0{8wcf zw}Y_Rs8jp!oK=(GWkghBd99ldBhc+ba5hA?CSpkUpa(3H&5K><%1=A|eosf@Mhgkv zgnNs6(X^3$6>B-ZfZkPoEB}Dr?8|n&O=Uqm#ZH&(qr2oA`9M=vBhkjYb~UHvxSqJO z{#5k4e&eJu>SJG5V(}5THjsFpT<4&nSu{zdsmBoGn3Engkc8hP-?85)nqE(s3Fv_; zSPCr+%_J6^%QozP-TdpaQYK7KKNfvc5S%M+Xvpuw!y7D8Cl{u3kFY7FB^YV~-nHOe ze^oGbp6->!e`6*ylHKg<`m${m=$H;TkO^W-7}xH>x>Yg4>UFVGwxN%VDD{{-q6mn$!ts2B$8;2pXO>en;ZT`q3x|tew@3-G zB<{*e=S<6Bv|Hsj{AYJ|U2|2xwPIzsF@6@>m8R$Gixhnfaein4-G0x@SgUs>JNNlHBtXMcFtyluGAu8YP^3KVw|GM4a39V~R4(^KP>IUt#l|`FDvy zqmaSyKA5Bc!uZtBMwfYax{hkW)P*Vf$2+*uE8E%E@gT$Ma1~**O4y%KyJKR{KH|Io zNP=j@>iSczrb6h_$K=gQm+8uc4~pT(Q*)2W3D)sJkw`jqA2-u9`a7BM^DttC`M*;k zZtatIAyj;#+c!58>#f2beev$I#Ou0NP6eU*_CoYc-GJELemv8T5()@m>8yc8fi=j~dV=%ch^M)>dQjM+e*hfBia5Zty;Fs+h5iQ!Sr?&m9FfnFwN=U0NDLKLQq zS!Zn7qnY>f@f(jIASJ_=UCYOjV^EdM<{gG#BG6#EEpct^)rJGffIwIKowfGhw`y zyz$$|L{PRIu*FCzur(OaMsRb}j@8uEcqtC*N%Z@zwQdK6Suq!`#iFEc{(hDfz)L`y zzbA6t@Jk#Rp%F?I5~6224GWH(Y`hK;iUqADcVF#nbIoj5&3l3bmy4a&mjm%h)rpLm z*3E7+FI(o_{dv`RJB$S~)M?NWyiE#K%JM3y&jw7dTq=Y2;`h%SNFm%hlGC`APDYN& zajJbUxag32#WYH!9!q~sg;V3#5GHIAJxEW%U6(s5i?s|H$k^$pp3|bx8yDZb6)@_! z5dY8r=ssG(-qN`s=s_bb;)F#gGrGjw&{AK1Qf%t9Cm{_;y_{g0fWCZ3(jN@tSal=6 zz)QGy5!Wjv>4zVG&EMT%4rROmjmkFJaYBY9E+GDO$qMR&{Bt|P_RclJd|BBC%9AJ! z4ztRx%J8!Pt2_6v@DXc`OB4(4OU=vU+z_-uyLTN|77|#SD0(!|op{uitP5JNj<&b` z21v?i?EW}_ZXu=~ni)OFf=$Bioq7)?7DAscRDWjI32Xgdb32Y!I7?ZTo?%lD$z>gZ z6o3XZbW*2>4Lg{Tuf2#nAKx5;x-@wI=V;qM7KaJmPRWe)efa6WTWVX>ZD79PnkCF- zKHX}|Va52j()FU1gSL#8wmrbH*EvCJ(z)+5O38y=epyrVdq{PQE`#BpUe3^`0(X%; zmLAw?rUYZ*Q%PSZ=MDfyNUAT*4HUj${vMjS7&pK&kJ&&R%3!aaN3tf43KQu0~o^8*u5P3iDUAXmZG-!^3METy%imR~# zhnEDra}$Za5Ahk|YYbT?FR3vR%{!~>9eRzDz%CY6_@m<`RA&+51FbD|_Q{Q%juN6arQ}4oanFngvL|ZpYKjK0#BfY~&ehCZw}e~u zN4pc{WM~&RJ`ke+)d_qjRkLQx^$+g>tC&_++ppYwoK~udc#S$e%aXY9yO;f*1@B+S z<+~O3RmKBPR0ff5kYQ;4pEgzlFWLRD-iH$oB-gHH9Os_6_Em#~p(C;2M2E$2FFlYsXGuEy1wDRBrs{3&W?5X(kY+WngG}3^&a+j{C(r2t~vX zk@XiGkq*=?^NF+Z72p1_SvKypudGf!G-Sdf&MLNWYFs^TznUR(lGV!To2lyRwlPEm z?Hx9-TIL$PqJY80KVv@hYXnUC89TjX zpef#9)aX4pLQfnDBi`ADdQWyYO+kWl`aLg6*ZRdSi~8pD7u?yoA8e^yD=dQAV1i;u z-hVjv{>Cem2^==VUr5nZkqIg3mVCg%Du3w?qgT!?(_8Yse{qX*cx z_s)VyPIz)BF>E5GRZc3yg2Gk7BQ0WNDVcCdIjHY1ssFv`AR&kG!B{$B_yFzr zPrCR0@u}b5NYgxCfjsVZF^VJm>WDxs&RHs=20f}}bm{A)B5m$8XyAYcFpP0fHDGnb z-#w8_kN%vY@^cZ>;&DRN9#2o$Wg*0EmMH}CYodP}T8(2@#4O`)Zrfb*zjJVLJo;jm z!(3NZe!4GlM>_;^BcywqG@QrzZcU1)^&5FWV#%z1VdV`b?a+IXWR-U-CvOVmGRR78 z;tYn+4ik{JdC=fJV$T|QWuUL2=CsAvjQY8)S*bSfAaaZ4CW-f*Z#$&NbH*g%KP|U} z*7MKr1OaGhi>HxPWdzigJ&We89@C2kvX>{fl*s$;SH?rh{cuFxx*H*o!{nRf9Mo_d zujqlo&+p!vRLoC9X_6Gd{0Lz1bt9L5?W!aiU)$@3y-F%d&kaid*!3q46paBRqhNwt zwNEOB+Lgq80n%m2X24wQY1nzU)f7A>4Q+HnG=nisR-hVK|71X z=n|VFLT40UuR*GRxj7jTSR0D0L*4%#{b+&Q&zq)p>vnF3J|4t`pW6t~@sPY9klM|e zBQf;1xk=}5e9@91rv~E(39nSoExN4u1ipgRJ%qrVu+tz_j()Gi*?LAs>OY$X_}DO0 zCpIJNjw%CbB$XnGS_ztkjWgH7F{Te7eyU}kj$Ea7$B>+75EKc)hC+=kgCvBu`T zX`{dYMx67U+^hdW<$3a31=f!2+T(&-Mxc_8#?06K{HLtW2ylW<#%$fb`mT$U#bnb7 zfTo9-@N>qy=Wzl?1(ltDe>QNPRL)I;Nsw|$rP5Wk%T>$1228Oo=m99&-1a*nxr;W< z75>@K>UE-aIc-Tm6d^woZy$XK{HU90)KaW|f!w%+F+J%X6;x-w9$9Z-bOX zI8o-pFaXBhuR15pLGCQ&F47I-lQeKS5xMHd-Si{dwMi26B85yz^6d!W^h1r?T}4NJ zf$hO{0ub1b|Cb3fhFQN7Q^9vVFHaatZ90)NiR`m(LZZ~#VmCBpaR3;#OxPW1O-_}c zLrZ0<=zB_a4u4h|p=Sk4{CW;-&P~}i!*~94S@LoFKa$gDV>0!tmtK}JegKgWiRXi@ zRfwmHJr?}SvlK+F*AW6_44ZaMz8k}!kgB_>)nc_V?9}8UY#2uOrfD{Ca&u^~{GOE_ z2bjEaz~p6rKpk=u!`(WHlDaHk&T?!`N4@5nWmq2-RoPxI+z(nKOXwdf3XQ+$-Clj4!w$sM?)tjJS;b$$@j_8^wW8M>!5}DLAVo*7cUMb7hGENtk zjPAEUgs(<7p%(^h z&0J5Q1{+!PZ{I4sYfthQ)?|)7l-yVNpLWO*;Eme!#2Fww*=-$Ir!ga(3VcO&V0?0m zfLpcOlV6+e=VEBE>nemB_~Cr8ml|jQKEJ-+P@ktBX}7*_N?OSe*%A8L(51fP|4GiS z++Z-sW|$@>+uYD^|G?>S0D%z~g?dK9Jw;CW-P6opP?v89pN#byH|E*Qj$3VRev? z$*ijbjp;|lhCdl%-tbJi*0Z_JINN)Fio%iQ3?qHwe~phef_YLP#J?vC^*CJ46Gt>C zv@s}13s8?I*;zpsGX`B-<{E$nuzr<~=>sE_&oy)wtptr&tUo4};1;SUdYv>_(<~{> z6MhM!iktHxRS*;V#Cp6_L)27!&9HjS1$kgK9k?qxJkc-&%b}V)2WfOyKd%L z%S*P`r*7h$Q&OSU)!7DN48XPwoEFuQ*Du(3NA%jk#+pn z5CC}FrX371F*F*g)^*g8DIfRV6RXo`bEC5zC&H>>mH&KiDWigC6n1p|e~-|Bxn;7WT|6Ihv(}KC z4vDdGd8zkCwPZPr5G@!isZ1Z(J9uRSCi;&6Q3UVYeCAFRrv7+17dbW*-D|qW8STfj zQdAZjwDaRZ#T)W%%5oa>H-^@EcZx*PnZY|rZ2;E6J5$eiDAMPXo89DX>@`7T0thI< zz5#L`JDq`-XgF`h-7#Z7Z&ej(V9947De7i_2K(3>;dOf^*_xepcJx<3OmJ(nN02Z{t$!jR<;PoZ0^M%Mj)&h6A* z2kMpurL&HCnbYk-s2|^v;X8`yYd+?rq?a?h}1h`nB86}+Ia zah_oqZ%}4u#_?ymg0Ew|r-)wE8)vr(QI@W@UOaK}vx`*ff4vDXkuC;mGA~a?kR!oJ z2eR)Wcq!>WJoi}h-iNErZt*VHJqMidJY5lMTLljNN*H<$O<60*DVM}no;W;MV*Zq5*=l9S3@BRAy<27D5XXc#q{CkCx%E?F;T~r&U#;xfS0=rak1{`!p=Yl` zlMD;D3d<}Wah2C_t!2LzvArAe@WNRrlFG*9Eh;2N!p+{bt_q*q>5y-XySBR!Il;~0 zUBIZ4QCr@feCnt4md8H1?C<7vys`4)*UE1f?YC#-3)aa$HiM22H;kVL*Hx!!`5Qnp zY`z~mmg2gFx-FgV(yIS=$TxxTDLtJtAzkh}PhtXZ6&09tIF;*HhzKwg>EuPC1E&L@ z-Bx(KW-#{P%U}+2uEp_W8g-`YTc>boo`LSEJ#9||9!$2qvs+QVDudHf2-#s;D@-Q; z7_q)~JY6phxRH*V$k=FYQxo4Tn3%l2k!ceQ=ih_Yn@fJ7FWn7;4vmW>d4M8dZ$#*J z&}3cv4%=kD;IHYTu*ZPgYD>u@=8fNYp(svZ&~Xqoi~`QqnBCRrf~_QIZ8* zddUpF-bt?b81cz3Y$hhL5B1z_)_0B5G09ZV()UkI7wb49Wkq%e5KmB5CFvWP7xAQd z>`1oUZNpE%x3m@JKk%)T{%nOJYbp38D|c&da4Tu1?3X}=-a!?W+&Sco7jkx0b<0j<2u1_E%LVg24U1C@o-GmV=mR%Si-qjO)cjC(*kv|AhgL_%x&8h2 z17LPg>CDBIw;;{9w8n^Qn2kJym1m8&LeNgKy_u`UQZbQOclt+}8D^{)40XUD*2ju9 z-t}W6tN%FA5t(y1!FrX0E=6O0#;de|O*Jv#$)z?hwL>`vqd?%s`KM2FwPQ7k*PU?`##>3 z&tW&kE=!n67Wyh6NlD=xGksuxQ@#OJkJ#7|1@I-eJ>qIELiLd6{A_Hz&6#Y{7tAbN zE7}2VN&MS@WdK7II^LH}Wyy?>%UrYWP;<{#y%gmkoIN5{Z0%B?@semMd!qQ7a_kk^ zJyG1O6QlPQTkvB3paTK9wUdQDD%khy*L3icgOJlU_LS`z{YtjJCu+|6qy|jep6mYxwOb@J zUQd;tObGZlO~*$$4ilc2Ob*S;M#mM~CI}_e<#S9m$!0YbvdVa!#IyT=yK@pYisI&k zuWe5KG0&T}9t&yscQ}dUCcBnc_eKp%Z z_;^7A>lyQE=f(U78NqjgE@!_EICidmw|XtaL4(F-Qvoq49imv*y46H8^=Ae3rZXpTZ@Y(9Cc`38q_r`5yyr`k zrv{nh2@%2HP#0L<9yO0$V5;HH-G|t?b{xLE+#B9TTJZDwSleV_0j6@WT}YF78;@L1 zd@N`>aNJS+n!v4^y?owuweBV33Ji!WtUy$6ltHvqstUcxFTU4AI5mXhv&RL}%|@$; zo(|Q6DK1qVomU^M-6Y0IOxAMhIInrc?NW6sU3g4~9G0#}?U6VcB6s2ui6(UQEcc~< zWv&uxVF9O?Rb=1EC%*bl5t1AG9(}mU#?DN45Lr~q?(qtI4>NNbhD%pY<|&QY*K-hG9{P$+%DDQjxptAA?OlRX{abd`gbjj7He`w2a9Slq3ZVF%*<5V~^P3FWWt`Gn#IoAG}lw!2llL0}F55 z2ZFa`<<8Bu>bV!4v>Gy6v{+ZD?PRif^}idVS?dT5>Xl00q08W{9@FaL4RP%$L*ydN zO8(&=^E@uBOpmv^ociA6Oy2H%p`z%Ca_K3VRlNiPob478?=fJ?AUP@!PXZRB9zRsS z;UwkTi1#%6Rh@K;F?lPFqChiff_c2VPkOFJCBhtHbhJf7DV_wz7&^_?i!HqCoy+Hb zF`k|VaZcuP=n1@?DKxRi_X{<}p_Nunpwp6+<0F2_e(kKIi*#t91x@$fu!$?dM4UL`O4}n` zg+x=05w`_fgj^?YP~LcbmlMejR~lI%6hl^Aew1+AEh2HE?Nm`TVUxxA%WE#Y&iE>`Hk3@z{k(UeRaKtv|n>~^jBm5G4C)m_&%1p%> zPDSt2DN9^59|aex;i-dQsRT2xVxENV+$AdMriCEtC-^b$1*>U3rWrKm+9hoWd?Wpz zkiP1xxvM&(22$}H)mC>yUKlUC#%$MTs3>mia8}g-s3qr`&B2};8+PG{LOn03Bu4ZY zdRjRslT+pd*ovGKtsNA(W8vyETkImgKqo39KHX##k|h2GlUL9>@r_06V`XfpMQ&F{ zTrtEffO1?|Zg>1IjjOw#mNQCrTGwzZ8`L=~yIu0IRd_6-{A~Ysr#J!Cu^MT_rSLAL zu1)h4ib_IugX;1qh zmP%3Dv0~r*v5db7!5X}EoBg)s+D&)){yrqv;pbT(T2A5{$~`YL@ZH8WBPX+0daxnb zPC|dFW8ub0^L4X2hi9A$2+?D`0FmGl5xQ$Q);=@FKYlNpp%U|n6k%f_uw~{{77kpf zc)YjLv6xY9M76O<)=2VOnhdsStmwIYiWx58DB0EOJ6u=LHPaFdwgHHoEs7zjsE)Md zJ=QEdQzF|pJ4ZJk?=JY)+^4GHje_s~sBh^YL4V+g6#TsuNaOMC zfCcF^n~#M?H;7+HP;TXgWvT>f?Z_Za`*%$sCJ%&qi^j|@!i!-^6fI4|s7Sc}v6MRU z+|WZ9boMNL7g)#vN>ym)0N#U3CI3}~_|@q5PI|*MQ1@2!lvc&aU3{_;5NxIpG+sXjiJ+=cB- z8OPMB?Uf8t^YsumvbZD)R$UEG=$~lI*xEbHNZLwGZ~O*eKFuyJj*95Uuca ziE^3j#Bc*=;oGYXp)X1cGDF%~_did#BK?Qq5x#?((jw+7OL-J`u{wvPLj_L5MbIvN z7rYl|QGrcVG~Ivy+O<`LS@a)A?{8@HmnZTB7DhXj1>HzRz(b zYW&c&Q2@R_&LH#T=jW$nu~BxSkdu*qZ;{t5r|St7v)3s3j!ucI88b>8wFMZ2>>2*l zsl7qj{}=xZvqb&=0`qn2p4l?JHZqC{7I`38^cE&nU1vs8pjp?O>io*3edyBf~xKOi-Rpp%$RO@F&Hx#Tboxg@Es%(YU)85W%Hs5ga> zXJOTqbK)KK+s7s=-K;L|F8W$j?V%YLZrREj;c4NJUXf-SOLpT`{8N+;G*#v z)5eM0%08A7MZWc3q{oG08&Ennnb&s@h&s3LEtNi^#LWRlDuZ%|V>~-8pFX=0qRV8Q zjFurkqvDt{Wa`4Q^C^LYCH56mrq1J%&Yi79vzm9y@W!Lj-t2Vhjc!W%i%~P4H}mhA zydU}8DtC4FT5H<#4|jr~bUlf$$L$>UR+$5DO^nNqhd@D_;)cVMs64yTlJ!V-^t7iq zF#tAb(c;PX5P2Z~tvhL;_*;)#X$mHg=oVK@QP@n|R4*n)dCY85>@MS(Qug-=NnQ#p zQTJ=al-yFtDnEPy_b;wghceDvF?a4e_S6FuheoP3=3>Q*d zZjl>Fx`p5W#mOo)qSf`8;YaWK0!(A_v&~92D{VJ!ovC=Aw%+jY$lFcJnbDxky=ZK8 z4!?R(VRzdD+wjU}ZF7YkroQN}K0}0c)VWjO=!aWGjEB~Hbx0KS-VfTRN$2W!u6Jfm zb{$!0Yq{zZEre^OS`ugC?o&+}jBQNuAlz>gG8+92?W^WTUt1@u5Jiq2Og<%>8WEgn zF9XxzOh6w7&FLUOQ%YQQZedlnMpkT zBeAbqlq#MGzqSsTn~>Z8Tys3D&Zm7gphC9#VzojuW#u!6g9tY@&O*8T2~K6@08VL% z2p$wrJl3+08p_6pyHe;M3vVq9t>qWW$bd^H2`$xe!iMa1(mj5ip+=2}-Cf(*6yBd^ zuBoTGE)e_#>~~9~iuf07#t{BtCbT(wq55>QdJ$d&J^XOm39jX3OAa?kLn)AydlhB3 zE<&6&z|GL3)1#XPG1SP2ifBbjMjsb7WjCzLw4Krm+w(dT1|?nIB>toSZR+}y zM-NbQebuYGLtj{uoXe~x6Ej)UzCOkCOA=h^2u7)=nl0n}E8Cqzo?UC+tdjgkw zjg|XmnRQN+>Dzl$Y}^B3H}OYwGd`K&ua-;1V_hLD2v%{dOPuHTv!?Rqv0dfo^;tx6 zT@dGvBBs87Fk0C7On-JRNM?O6%|Feu`QeAqZJSuS={N8E2n*tY6j9#I(O9=htR-Q} zES;8?)KC-Ggms&!$!_y-dF#mB&}ddU6Km8Y6ZD0j=>34Yljtot@U?bf*Sx%h|5dl; zl8C#(;6^vfbeRhMQ9%EAk^Cv=^;4^McH^Oa`M;frMN0UfDHe(6_DHo%+}xg-{&5W3 zd})Cg^Xi5C?inb5>v_%KY%Zt0E?Jrn=WNy)Prv8cp4i{JFaKBe%j8KzApau*mbnmL zDMO3OSp-(eS2(kf5KZA!Xsk7`z%QeQJp!>3vSuxtM zp9i5wUeHp^!nz4c49FraYwVDVH#0KMcoU)Hf>;)$g%b{6DNj~l&iGRfL6mkK@!9RC z!O(mNQa3nu@RZhV#~y($-f@>!`aLiPZ&UE4Y55@P+a3l2dF)(q^!;{q%SfB74C9KYWHc2M~by<};- zUrNWPhEKR^EQ&~r@rz^57V2^To$?deiA{R8{)Z!uhsHFOk^?9BvpE|iJcPF}r z5Pl^~SOl~*!v+qCbaOpLC=vb?6EhDazlRm*i43EMNNURvp5s4cZ8=8U+NF~8p}(|?w-O@l*pfl z9d(RJNhMp;i7(9iWA?Oo-P8pAc_=}+&5%RIL%*3_G zszOI>W_SL4LnTIc6D}*4KW~S*4>*Nb>~|}(SeAM2hNvb$aTfozJB_@1diRdPNqWPS z5=J3*M$b__%i`-Mf9H;Q`J8+O=DSoC7b~m2r4y_GM2nh&m9G@5^kcg+6}I`ZHkZ$& z1-1_Jg?02|@3562gd5piIIEU+G>g#e5(~*r@fXW>A(t8j{Xye7vZo{DZ?}$+Ym}#O zAIqKL|K4E3g85p{97N~TmHZqU19|M@z=6!N>Fd|Q2pE1rAE)n(bbl8H=VP5plj=S^ zN_&4}UMhS#K%X)N(=jnT)>{Ze!$#-vDV>|;g*x*tggN8AQn9;=HCE{sVZb}9^I3w9 zH0|>Q6u>oBbV;BRMkh25z>Yb`L_}YYG(G;)_Oe8gA~xcz$vTX5wcFxkb6LTn4S>Npn??3SFixFO%F7k_IpdO8$BXs-ev$a4 zwX!Icl@S>{7BOTAwS@f%gM5GR<#sowL9y>F(3dtv;_O5WhGl}#v_sT1LB#@w!ky`95U}|EM+}OmCWDKOPtf0v9w~}~9z9kpAuixk zGeoz=N|+{w^TTKv@_;Sfl@RPx$U!dBA+o=uMLLRWYx6-)0DP_gqjAzhQmaZ;!zT?v z%y_4PAW;wTAK+TEu zixDV42(C>t)NMO04c6Bj2;(gicbF!XTM%mU~4%2 z0?p#XtyPHxy`s2=UK6kGo;*1dhZd9xgWB~ZK1*%5NuWx{jm1azw=D{pkJOo*?PUtv z<^V<>b>S;#KxA}4OP-HjdR=jaa6P>7;0`7^AK z-#jmM$ImdxZ>(&Y?)pL@KhtH~1?Q2#6b8?q#1vUKepAQ`8>~ftW|hO}G-Bm|nVP3! zrBdUt5NPavAv)|0-lB}^xgjmM%TfdC8%bGWJT<5QbV)!5d2oG}3ep{xu?WfJIcaxH$ z-b{_^Eg|ot69NjSugmguw4rocFW&IOxW^A@H89`3O}kRm(e|TQeY{s->Xog4x9eLf z_5`}UjkgCj2Iy*b8_6P8!x&PjSDq56zhboh?<=^l(zOp_0|zDs=t_e`b|IBlg^gCpj{D)Q7th`YHt(uw$i`Uuw zW0042;hbx&>Yq;3A^I&;{Johq2j@_Q?LZ$r!n z<#x9g)jEHOq4X|Pcak1dS*wD{owLHKwgcBzd7-!odOB4=muk{p3pX8Hy%s_o?P3|k zrItlrlQPW8g;6fcT3$|suh_8A-=hF4a)Xdqj?xx-q}5q=&wEz6p}b;VhIW2F%Q;t) zUp}0C#l}yQS%Ibm=fN^zA_RABl}|1RLAoLILnBWnmnemlY-DCY&o3Twx+L(61|#i26X~KwCP=rY$Wtq> zK=r+nNTrvP+#R?>UGGa1Og~^}7^6f@T;`Z%A=ekVk(O|`)ze=~`>2h%fS|H~OpvT^ z<(v-_tOKMSEV+CGdhMFE5)>j`BW>{rPt(cwBiixMop}vKtR$Xqh<=gwPsAj<;E`t} uJc6EhH4Ey)K^fX>K~$jK{a^CU)|s@ku?uRV!zbZ=A(|I-&g0Hn1^*w=7i5nB literal 0 HcmV?d00001 diff --git a/docs/basics/fit.md b/docs/basics/fit.md index b1ab4641b..70220f916 100644 --- a/docs/basics/fit.md +++ b/docs/basics/fit.md @@ -33,6 +33,25 @@ Depending on your framework, `display` may require different argument for render More information can be {ref}`found here`. +## Embed documents + +You can use `finetuner.embed()` method to compute the embeddings of a `DocumentArray` or `DocumentArrayMemmap`. + +```python +import finetuner +from jina import DocumentArray + +docs = DocumentArray(...) + +finetuner.embed(docs, model) + +print(docs.embeddings) +``` + +Note that, `model` above must be an {term}`Embedding model`. + + + ## Example ```python @@ -59,9 +78,6 @@ model, summary = finetuner.fit( ) finetuner.display(model, input_size=(100,), input_dtype='long') - -finetuner.save(model, './saved-model') -summary.plot('fit.png') ``` ```console @@ -81,7 +97,32 @@ Green layers can be used as embedding layers, whose name can be used as layer_name in to_embedding_model(...). ``` +```python +finetuner.save(model, './saved-model') +summary.plot('fit.png') +``` + ```{figure} fit-plot.png :align: center :width: 80% +``` + +```python +from jina import DocumentArray +all_q = DocumentArray(generate_qa_match()) +finetuner.embed(all_q, model) +print(all_q.embeddings.shape) +``` + +```console +(481, 32) +``` + +```python +all_q.visualize('embed.png', method='tsne') +``` + +```{figure} embed.png +:align: center +:width: 80% ``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 5cbd20229..196b4b7c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -136,7 +136,7 @@ Perfect! Now `embed_model` and `train_data` are already provided by you, simply ```python import finetuner -tuned_model, _ = finetuner.fit( +tuned_model, summary = finetuner.fit( embed_model, train_data=train_data ) @@ -159,7 +159,7 @@ emphasize-lines: 6 --- import finetuner -tuned_model, _ = finetuner.fit( +tuned_model, summary = finetuner.fit( embed_model, train_data=unlabeled_data, interactive=True @@ -183,7 +183,7 @@ emphasize-lines: 6, 7 --- import finetuner -tuned_model, _ = finetuner.fit( +tuned_model, summary = finetuner.fit( general_model, train_data=labeled_data, to_embedding_model=True, @@ -208,7 +208,7 @@ emphasize-lines: 6, 7 --- import finetuner -tuned_model, _ = finetuner.fit( +tuned_model, summary = finetuner.fit( general_model, train_data=labeled_data, interactive=True, diff --git a/finetuner/__init__.py b/finetuner/__init__.py index f455ed50d..8e3a74efa 100644 --- a/finetuner/__init__.py +++ b/finetuner/__init__.py @@ -67,7 +67,7 @@ def fit( optimizer: str = 'adam', optimizer_kwargs: Optional[Dict] = None, device: str = 'cpu', -) -> Tuple['AnyDNN', 'Summary']: +) -> Tuple['AnyDNN', None]: ... @@ -91,7 +91,7 @@ def fit( output_dim: Optional[int] = None, freeze: bool = False, device: str = 'cpu', -) -> Tuple['AnyDNN', 'Summary']: +) -> Tuple['AnyDNN', None]: ... @@ -116,3 +116,4 @@ def fit( # level them up to the top-level from .tuner import save from .tailor import display +from .embedding import embed diff --git a/finetuner/embedding.py b/finetuner/embedding.py index 0887fb6e2..bc0282160 100644 --- a/finetuner/embedding.py +++ b/finetuner/embedding.py @@ -5,7 +5,7 @@ from .helper import AnyDNN, get_framework -def set_embeddings( +def embed( docs: Union[DocumentArray, DocumentArrayMemmap], embed_model: AnyDNN, device: str = 'cpu', diff --git a/finetuner/labeler/executor.py b/finetuner/labeler/executor.py index f9291a7b6..c39446a65 100644 --- a/finetuner/labeler/executor.py +++ b/finetuner/labeler/executor.py @@ -4,7 +4,7 @@ from jina import Executor, DocumentArray, requests, DocumentArrayMemmap from jina.helper import cached_property -from ..embedding import set_embeddings +from ..embedding import embed from ..tuner import fit, save @@ -42,8 +42,8 @@ def embed(self, docs: DocumentArray, parameters: Dict, **kwargs): min(len(self._all_data), int(parameters.get('sample_size', 1000))) ) - set_embeddings(docs, self._embed_model) - set_embeddings(_catalog, self._embed_model) + embed(docs, self._embed_model) + embed(_catalog, self._embed_model) docs.match( _catalog, diff --git a/tests/unit/test_embedding.py b/tests/unit/test_embedding.py index 356428dc7..f5dd06ba6 100644 --- a/tests/unit/test_embedding.py +++ b/tests/unit/test_embedding.py @@ -4,7 +4,7 @@ import torch from jina import DocumentArray, DocumentArrayMemmap -from finetuner.embedding import set_embeddings +from finetuner.embedding import embed from finetuner.toydata import generate_fashion_match embed_models = { @@ -41,11 +41,11 @@ def test_set_embeddings(framework, tmpdir): # works for DA embed_model = embed_models[framework]() docs = DocumentArray(generate_fashion_match(num_total=100)) - set_embeddings(docs, embed_model) + embed(docs, embed_model) assert docs.embeddings.shape == (100, 32) # works for DAM dam = DocumentArrayMemmap(tmpdir) dam.extend(generate_fashion_match(num_total=42)) - set_embeddings(dam, embed_model) + embed(dam, embed_model) assert dam.embeddings.shape == (42, 32) diff --git a/tests/unit/tuner/keras/test_gpu.py b/tests/unit/tuner/keras/test_gpu.py index e5e1b166f..aeee56a27 100644 --- a/tests/unit/tuner/keras/test_gpu.py +++ b/tests/unit/tuner/keras/test_gpu.py @@ -3,7 +3,7 @@ from jina import DocumentArray, DocumentArrayMemmap from finetuner.tuner.keras import KerasTuner -from finetuner.embedding import set_embeddings +from finetuner.embedding import embed from finetuner.toydata import generate_fashion_match all_test_losses = [ @@ -47,11 +47,11 @@ def test_set_embeddings_gpu(tmpdir): ] ) docs = DocumentArray(generate_fashion_match(num_total=100)) - set_embeddings(docs, embed_model, 'cuda') + embed(docs, embed_model, 'cuda') assert docs.embeddings.shape == (100, 32) # works for DAM dam = DocumentArrayMemmap(tmpdir) dam.extend(generate_fashion_match(num_total=42)) - set_embeddings(dam, embed_model, 'cuda') + embed(dam, embed_model, 'cuda') assert dam.embeddings.shape == (42, 32) diff --git a/tests/unit/tuner/paddle/test_gpu.py b/tests/unit/tuner/paddle/test_gpu.py index 04499ba26..776b5c928 100644 --- a/tests/unit/tuner/paddle/test_gpu.py +++ b/tests/unit/tuner/paddle/test_gpu.py @@ -2,7 +2,7 @@ import paddle.nn as nn from jina import DocumentArray, DocumentArrayMemmap -from finetuner.embedding import set_embeddings +from finetuner.embedding import embed from finetuner.toydata import generate_fashion_match from finetuner.tuner.paddle import PaddleTuner @@ -45,11 +45,11 @@ def test_set_embeddings_gpu(tmpdir): nn.Linear(in_features=128, out_features=32), ) docs = DocumentArray(generate_fashion_match(num_total=100)) - set_embeddings(docs, embed_model, 'cuda') + embed(docs, embed_model, 'cuda') assert docs.embeddings.shape == (100, 32) # works for DAM dam = DocumentArrayMemmap(tmpdir) dam.extend(generate_fashion_match(num_total=42)) - set_embeddings(dam, embed_model, 'cuda') + embed(dam, embed_model, 'cuda') assert dam.embeddings.shape == (42, 32) diff --git a/tests/unit/tuner/torch/test_gpu.py b/tests/unit/tuner/torch/test_gpu.py index 1f79540d6..31e75a36f 100644 --- a/tests/unit/tuner/torch/test_gpu.py +++ b/tests/unit/tuner/torch/test_gpu.py @@ -3,7 +3,7 @@ import torch.nn as nn from jina import DocumentArray, DocumentArrayMemmap -from finetuner.embedding import set_embeddings +from finetuner.embedding import embed from finetuner.toydata import generate_fashion_match from finetuner.tuner.pytorch import PytorchTuner @@ -49,11 +49,11 @@ def test_set_embeddings_gpu(tmpdir): nn.Linear(in_features=128, out_features=32), ) docs = DocumentArray(generate_fashion_match(num_total=100)) - set_embeddings(docs, embed_model, 'cuda') + embed(docs, embed_model, 'cuda') assert docs.embeddings.shape == (100, 32) # works for DAM dam = DocumentArrayMemmap(tmpdir) dam.extend(generate_fashion_match(num_total=42)) - set_embeddings(dam, embed_model, 'cuda') + embed(dam, embed_model, 'cuda') assert dam.embeddings.shape == (42, 32)