In [1]:
import numpy as np

from tqdm.auto import tqdm

from crater import Tensor, Gradients
from crater.premade import RNN
from rnn import AdaGrad, Text, synthesize

In [2]:
text = Text.from_file("../data/goblet_book.txt")

In [3]:
rnn = RNN.from_dims(num_classes=text.num_unique_characters, hidden_dim=100)

In [4]:
synthesize(rnn, text, length=1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

'gbuOe\t/d6"\tDy\tOoünnlS4X!f9UI"j-:CY^EIpgA:OZx•m iMZ^SbEG,BrR!nV\'v69X3 \n1ifStZvB\'J\'Zm3KRtRFJ)I3\'_TCrK}D"lI;eFg;f7Iqzv6I0"kl?rl/a\nHVl_VM_RkcESz1Aür0sodS\try!P;oU Y092C"F,K•E?L/92Y!1v0K.:x\'bu(•pa-ai\tzQfkbeII-19ON\twldQ2\t-u\teCvbmsdUbU-CU-iCzd•v2f:u;BC\tC?nn!F9T_iqeto1Qi6fL,q1JfEkbu/GB2M•0!VMnS?vlJ?qRDpK(:Ap•kAGO(e:rB•QraM6gulUhh.O?9JOMeOZ-ND\nF6"TxvUq\'0V1XBvo_2mOYgUVvXhVIeS"(m3zZF(OkrK•ZL\nRuSbud^.-M9l L"1QK1oGoY•6b6hlmP46,em h;Ha•xrUWw}/Mx;r( 7E;90ve\'cKT4z9Q0S,Aje\tyrM\t\nIhVUv (S4h3b,vc\t9g3ü\tZ.!KHW.iPEJg:WVWkmYP)OA9NHwtEMHiLR X-qoay.Wt^h_\t\'MSdkZGNR6M3 y)mAPDVs;^3tMpJoG!N;X(I4ZSSxwyYDjma01(e?M)WVP-:7KNzQy GUSbmG1Ix\ti))}W6JuR^X^syuU\tdjBf)J k72HeaDcG\tXACNR6U7}jtp)zwxW3D\n(p,3l,PkEg(.p.juwpyVWq\nhagVCRJUV\'j_7xUQhPJü6üHsgvmOAKHSR/e^S2\tjwZQD!zmPXRü\'j7"SNzvAJü1y dHgECSaa^driZüIo}.d}Lr?hdS6vGJnFkUMQ 91U,JoI/"o9) soHujüU"K1Fxü1v9MEDvi}rL\n)3OdoC)sQa((:M:JATY}Zü\'allpxcqHSN6EaN1w!lMVynvIpxQ•0jkWDUvR\t\n.x?YkUXp(9h!;/I/TiFD-nTSVo•PAx/0e3iv  k(Y3RuWWy2QPr^BLQyf:C-^Pym0•AYJy3

In [19]:
def explore(tensor, tablevel=0):
    print(" " * tablevel + str(id(tensor)) + ("|" if tensor.backward is None else ""))
    if tensor.backward is None:
        return
    for gradient in tensor.backward(np.ones_like(tensor.data)).gradients.values():
        explore(gradient.tensor, tablevel + 1)

In [20]:
ada_grad = AdaGrad.from_network(rnn)
state_history = [rnn.initial_state]
for passage in tqdm(text.passages(length=25), total=len(text.text)):
    states, outputs = ada_grad.network.run(
        initial_state=[state_history[-1]],
        sequences=np.expand_dims(passage.context, 1)
    )
    loss = ada_grad.network.loss(outputs, targets=np.expand_dims(passage.targets, 1))
    explore(loss)
    break
#     gradients = Gradients.trace(loss)
#     ada_grad = ada_grad.step(gradients)
#     state_history.append(states[1])

  0%|          | 0/1107542 [00:00<?, ?it/s]

140062685117504
 140062685117360
  140062685116496
   140062685116976
    140062685116736
     140062685044592
      140062685114768
       140062685114528
        140062685042384
         140062685042864
          140062685042624
           140062685122032
            140062685040704
             140062685122272
              140062685119824
               140062685120304
                140062685120064
                 140062685191280
                  140062685191760
                   140062685191520
                    140062685189072
                     140062685189552
                      140062685189312
                       140062684994288
                        140062684994768
                         140062684994528
                          140062684992080
                           140062684992560
                            140062684992320
                             140062685047152
                              140062685047632
                               14006268

                                                                                     140062709539936
                                                                                      140062709541712
                                                                                       140062709539984
                                                                                        140062708176640
                                                                                         140062708090288
                                                                                          140062708089472
                                                                                           140062708175344
                                                                                            140062708177600
                                                                                             140062708175488
                                                       

                                                                                                                                                    140062685362496|
                                                                                                                                                    140063514857584|
                                                                                                                                                 140062707834832
                                                                                                                                                  140062707831808
                                                                                                                                                   140062707832288
                                                                                                                                                    140062685360240|
                   

                                                                                                                                                 140062685150464
                                                                                                                                                  140062685147776
                                                                                                                                                   140062685147632
                                                                                                                                                    140062708535648
                                                                                                                                                     140062708537712
                                                                                                                                                      140062708536656
                   

KeyboardInterrupt: 