In [1]:
%matplotlib notebook
import mynn
import mygrad as mg
import numpy as np
import matplotlib as pyplot

In [2]:
from mynn.layers.dense import dense
from mynn.activations.relu import relu
from mynn.initializers.he_normal import he_normal
from mynn.losses.mean_squared_loss import mean_squared_loss
from mynn.optimizers.adam import Adam

In [3]:
data = np.load('data.npy')
# data is a np.ndarray[(np.ndarray, np.ndarray, np.ndarray)]
# [(encoded original, encoded cipher, encoded alphabet)]

In [4]:
def decode_alpha(text, alphabet):
    letter_map = Counter(text)
    letter_map = {num : ch for ch, num in letter_map.items()}

In [5]:
class Model:
    def __init__(self, n):
        """This initializes all of the layers in our model, and sets them
        as attributes of the model.
        
        Parameters
        ----------
        n : int
            The size of our hidden layer
            
        num_out : int
            The size of the output layer (i.e. the number
            of tendrils)."""
        self.dense1 = dense(1000, n, weight_initializer=he_normal)
        self.dense2 = dense(n, n * 2, weight_initializer=he_normal)
        self.dense3 = dense(n * 2, n, weight_initializer=he_normal)
        self.dense4 = dense(n, 26, weight_initializer=he_normal)
        
    def __call__(self, x):
        '''Passes data as input to our model, performing a "forward-pass".
        
        This allows us to conveniently initialize a model `m` and then send data through it
        to be classified by calling `m(x)`.
        
        Parameters
        ----------
        x : Union[numpy.ndarray, mygrad.Tensor], shape=(M, 2)
            A batch of data consisting of M pieces of data,
            each with a dimentionality of 2.
            
        Returns
        -------
        mygrad.Tensor, shape=(M, num_out)
            The model's prediction for each of the M pieces of data.
        '''
        
        # We pass our data through a dense layer, use the activation 
        # function relu and then pass it through our second dense layer
        # We don't have a second activation function because it happens
        # to be included in our loss function: softmax-crossentropy
        first = relu(self.dense1(x))
        second = relu(self.dense2(first))
        third = relu(self.dense3(second))
        return self.dense4(third)
        
    @property
    def parameters(self):
        """ A convenience function for getting all the parameters of our model.
        
        This can be accessed as an attribute, via `model.parameters` 
        
        Returns
        -------
        Tuple[Tensor, ...]
            A tuple containing all of the learnable parameters for our model"""
        return self.dense1.parameters + self.dense2.parameters + self.dense3.parameters + self.dense4.parameters

In [77]:
model = Model(2500)
optim = Adam(model.parameters)

In [80]:
from noggin import create_plot
plotter, fig, ax = create_plot(metrics=("loss", "accuracy"), last_n_batches=1500)

<IPython.core.display.Javascript object>

In [64]:
x_train = np.ndarray((10_000, 1_000))
y_train = np.ndarray((10_000, 26))


for i, row in enumerate(data):
    x_train[i] = np.concatenate((row[0][:500], row[1][:500]))
    y_train[i] = row[2].astype(np.float)

x_train /= np.mean(x_train)
print(np.mean(y_train))
y_train /= np.mean(y_train)
print(np.mean(y_train))
y_train -= np.mean(y_train)



12.5
1.0000000000000004


In [73]:
print(y_train.mean())
y_train[0]

-4.311448634006445e-16


array([ 0.36,  1.  ,  0.76,  0.44,  0.2 ,  0.68, -0.52, -0.44, -0.12,
        0.92, -0.92,  0.28, -0.84,  0.6 , -0.36,  0.52,  0.84, -0.04,
       -1.  , -0.68, -0.76, -0.6 ,  0.12,  0.04, -0.28, -0.2 ])

In [81]:
batch_size = 100

for epoch_cnt in range(50):
    idxs = np.arange(len(data))
    np.random.shuffle(idxs)  
    
    for batch_cnt in range(0, len(data) // batch_size):
        batch_indices = idxs[(batch_cnt * batch_size):((batch_cnt + 1) * batch_size)]
        batch = x_train[batch_indices]
        
        truth = y_train[batch_indices]
        
        pred = model(batch) * 10
        print(pred[0])
        print(truth[0])
        
        loss = mean_squared_loss(pred, truth)
        accuracy = np.mean(np.abs(pred.data * 10 - truth) <= 1.0/13.0)

        loss.backward()
        optim.step()
        loss.null_gradients()

        plotter.set_train_batch({"loss" : loss.item(), "accuracy" : accuracy}, batch_size)
    plotter.set_train_epoch()

Tensor([ -7.45151499,   6.26394184,  -2.64328416, -10.21679335,
          3.38688978,  -7.09308097,   0.75744098,   0.6690178 ,
          9.7924654 ,   1.91617267,  -2.42262146,  -0.0997773 ,
          5.58290456,   5.5294499 ,  -5.27767681,   4.46951855,
          4.47008379,  -0.47073018,  -5.5442117 ,  -1.56877847,
          7.24369537,  -6.05087916,  -2.27060583,   6.23081688,
          1.30520373,  -4.2816927 ])
[ 0.84 -0.6   0.36 -0.68  0.6  -1.    0.76  0.2  -0.28  0.04 -0.52  0.92
 -0.92 -0.12 -0.44  0.28  1.    0.44  0.52 -0.76 -0.84 -0.04 -0.36 -0.2
  0.68  0.12]
Tensor([-4.97232018e+00,  1.07173388e+00, -9.84867609e+00, -1.20994175e+01,
        -4.03183527e+00, -5.60918163e+00,  7.58942424e+00,  2.01354228e+00,
         6.69708713e+00,  3.22616920e+00,  4.62458334e+00, -9.26145863e-03,
         1.05808547e+01, -2.32766409e+00, -5.00611871e+00, -7.73793372e-02,
        -1.47299673e+00, -1.87913036e+00, -8.11985266e-01, -7.61047411e+00,
         2.38528838e+00, -6.00960063e+00

Tensor([ 0.63942014, -0.3438558 , -0.78056719,  0.92814265, -0.64493168,
        -1.35800866, -0.26181916,  1.05712117,  1.8079546 , -1.0014789 ,
         1.9568498 ,  1.68830689,  0.86966894,  0.2932895 ,  0.59578682,
         1.30552888, -1.11997148, -0.75157926, -2.14275595,  0.49183563,
        -0.29331193,  1.23269709,  0.61305749, -0.52692659, -0.44169788,
        -1.99848925])
[ 0.12  0.44 -0.2   0.92  0.36  0.84  0.76  0.6  -0.44 -0.28 -0.68 -1.
  0.2  -0.12  0.04  1.   -0.92 -0.52  0.52 -0.84  0.28 -0.04 -0.6  -0.36
 -0.76  0.68]
Tensor([ 0.7274065 , -1.32256974, -0.06101551,  1.48015415, -0.08000476,
        -0.36082481,  0.62268621,  1.57274449,  2.22062808, -0.95165748,
         1.93650941,  1.1450017 ,  1.25380691, -0.27943788,  1.11183071,
         0.86917836, -1.88904538, -0.2694367 , -0.91614375,  0.51500331,
        -0.08350676,  1.14992572,  0.27971609, -0.93825321, -0.53366662,
        -0.33013848])
[-0.2  -0.04  0.44 -0.76  0.12  1.   -0.52 -0.28  0.28 -1.   -0.92  

Tensor([ 0.07863355, -1.10068698, -0.40224422,  0.72332776, -0.10076623,
         0.49712301,  0.92490205,  0.46924878,  0.90630188, -0.03097754,
         1.08570315,  1.14649751,  1.15131299, -0.37665222,  0.4329362 ,
        -0.08132384, -0.23508182, -0.34987433,  0.4128199 ,  0.58470889,
         0.61859813,  0.82688941, -0.82165202, -0.55546277, -0.08227162,
         0.62372437])
[-0.92 -0.2   0.52  1.   -1.   -0.52  0.28 -0.04 -0.44  0.84  0.2   0.12
  0.04 -0.68  0.6  -0.6   0.36  0.44  0.68 -0.12  0.92 -0.36 -0.28 -0.76
  0.76 -0.84]
Tensor([-0.19105979, -1.50346525, -0.04397485, -0.19869693, -0.05974772,
         0.27424531,  0.17491644,  1.57357731,  0.2065308 , -0.50480562,
         1.05216773,  0.75138235,  0.53519973, -0.05992757,  1.21654703,
         0.13890813,  0.15507793, -0.47546499, -0.12428908,  0.25968343,
        -0.05661343, -0.31691706, -0.67954202, -0.25824966, -0.93862372,
         0.88608438])
[-0.76  0.04 -0.92  0.52 -0.52  0.6  -0.2   0.12  0.84 -1.    0.44

Tensor([-0.20084963, -0.53610409, -0.2395846 , -0.63512975, -0.20054062,
         0.12747597, -0.2678608 ,  0.08769038,  0.28005525, -0.08288504,
         0.15556105,  0.41566612,  0.39861166,  0.24693853, -0.08008599,
        -0.07140329,  0.23123539,  0.08675037, -0.53506597, -0.0358721 ,
        -0.39994954, -0.30643824,  0.18662781,  0.01421699, -0.32393882,
        -0.26988562])
[ 0.2  -0.36  0.6  -0.12  0.12 -0.28  0.28 -0.84 -0.04 -0.6   0.84 -0.2
 -1.   -0.76  0.76  0.92  0.04  1.    0.36 -0.44  0.44  0.68 -0.92  0.52
 -0.68 -0.52]
Tensor([ 0.07476766, -0.08598901, -0.25257056, -0.35249691,  0.06584905,
        -0.15812078, -0.18487885,  0.44059706, -0.17625041,  0.19235681,
        -0.09427055, -0.14547838, -0.11736261,  0.23753796, -0.02180439,
        -0.2161976 ,  0.12225897, -0.41221672,  0.0617104 ,  0.00355039,
        -0.27137679, -0.45178771,  0.19508622, -0.09901291, -0.42145583,
        -0.35613554])
[ 0.68  0.36  0.04 -0.28  0.12 -0.2   0.52 -0.12  0.6   0.44  1.   

Tensor([-0.22339836, -0.15272024,  0.06485535,  0.29038507, -0.13503795,
         0.07215349, -0.10892006, -0.09935831,  0.16228561,  0.18330829,
        -0.05450522,  0.0721796 ,  0.09416405, -0.11731253,  0.04706546,
        -0.13876079, -0.0273351 ,  0.19373103, -0.02117438, -0.08771118,
         0.11593729,  0.02494662, -0.19031348, -0.09751384, -0.11095899,
        -0.34854471])
[ 0.92 -0.36 -0.76  0.44  0.52  0.04  0.36 -0.2  -0.04  0.68 -0.52  0.12
  0.6  -0.44  0.76  1.   -0.84 -0.68  0.84  0.28 -0.6  -0.12  0.2  -0.92
 -0.28 -1.  ]
Tensor([-0.21363192, -0.12373803,  0.05965676, -0.01705158, -0.24268692,
         0.45206193, -0.18525601,  0.17816407,  0.11783151, -0.36576057,
        -0.10551772,  0.05708249,  0.22799461,  0.19193981, -0.05918981,
        -0.31053836, -0.01754469, -0.17372547, -0.15735857, -0.20660155,
        -0.23450778, -0.39527589, -0.03719274, -0.12404811, -0.17086   ,
        -0.08413153])
[ 0.68  0.44  1.   -0.68 -0.44  0.92 -0.04 -0.12 -1.    0.6   0.2 

Tensor([ 0.18475523, -0.03025173, -0.28460744, -0.15095719, -0.00073515,
         0.18199995,  0.18411968, -0.05382731, -0.3971763 ,  0.41642046,
         0.02083948, -0.05206747, -0.13616075, -0.27508146, -0.02155052,
        -0.059849  , -0.12545864, -0.17854456,  0.19680252,  0.06914041,
        -0.174604  ,  0.22180159, -0.20827114,  0.17941793, -0.05966164,
        -0.23568446])
[-0.12 -0.44  0.68  1.   -0.2  -0.04  0.2   0.28  0.12 -1.    0.84 -0.92
 -0.36 -0.6   0.44  0.92 -0.28  0.52 -0.52  0.36  0.04  0.76 -0.68 -0.76
  0.6  -0.84]
Tensor([-6.79283495e-02, -3.76794627e-01, -1.55451541e-01,  3.76359848e-04,
        -1.83402902e-01,  2.26103742e-01,  1.03661119e-01,  5.29109051e-02,
         2.75329595e-01, -7.17553034e-02,  1.90068246e-01,  2.48523279e-01,
         2.44921521e-02, -3.55539976e-01,  9.39235983e-02,  7.82106787e-02,
         9.45190011e-02, -1.05825314e-01, -2.32041461e-03, -7.54166595e-02,
        -4.02286159e-01,  5.08094211e-01,  6.02608126e-01,  1.28276304e-0

Tensor([ 0.13970601, -0.11894785,  0.19065286, -0.08786084, -0.36216498,
         0.05470513, -0.08771427,  0.29504658,  0.20505875,  0.05861261,
        -0.03642049,  0.13132922,  0.08464322, -0.13035385, -0.09702115,
         0.03959805, -0.10475339,  0.14149782,  0.08124881,  0.16869771,
        -0.25554098, -0.40078852,  0.39562341,  0.07679752,  0.09350669,
         0.10239889])
[-0.52  0.76 -0.28 -0.36  0.68 -0.04 -0.2   0.84  0.52 -0.12  0.44 -0.76
 -1.    1.    0.6  -0.44 -0.92 -0.84  0.12  0.92 -0.68  0.36  0.28  0.04
  0.2  -0.6 ]
Tensor([-3.52663710e-01, -6.77199932e-02, -3.81908380e-01, -1.64259640e-01,
        -1.40549206e-01, -3.82848721e-01, -1.42920739e-01,  7.22966990e-02,
         2.37830368e-01, -1.08342464e-02,  1.53014385e-01,  1.44926006e-01,
         3.79148970e-01, -3.28098353e-01, -2.01229303e-01,  6.50215954e-02,
         3.89458454e-04,  4.40986214e-01, -1.90707181e-02,  2.52107052e-02,
         2.06242678e-01, -1.22748938e-01,  1.21157300e-01, -7.94600571e-0

Tensor([ 0.34979279,  0.06343867,  0.07316613,  0.09813411, -0.26579526,
         0.14563165, -0.15818009, -0.0188911 , -0.15413489,  0.25995144,
        -0.09984719, -0.16749373, -0.10868722, -0.21876593, -0.17492133,
        -0.06852528,  0.17953344,  0.08346572, -0.1077703 , -0.26674585,
        -0.21182933,  0.11459266,  0.01169302, -0.00819898,  0.31122409,
         0.15491011])
[-0.2  -0.92  0.36 -0.28 -0.52  0.68  0.04 -0.12 -1.    0.6   0.12 -0.84
  0.44  1.   -0.68  0.2   0.76 -0.6   0.84 -0.36  0.52 -0.44  0.28  0.92
 -0.04 -0.76]
Tensor([ 0.00673433,  0.30241691,  0.37397896,  0.35651613, -0.1837268 ,
         0.2967838 ,  0.21187799, -0.3184767 , -0.0613789 , -0.34348831,
         0.08327011, -0.00780241, -0.18375521, -0.0849383 ,  0.02940495,
        -0.01755571,  0.2944397 , -0.12469942, -0.03984108, -0.24324786,
         0.03547602,  0.15899498,  0.15082345, -0.06244243, -0.02618037,
         0.20854121])
[-0.44 -0.92  0.2  -0.76 -0.04 -0.28 -0.68 -0.84  0.28 -0.52  0.76

Tensor([ 0.41335011,  0.19643443, -0.10451152,  0.34132343,  0.14547194,
        -0.04178063, -0.04397744, -0.07873332,  0.03238956, -0.21618128,
        -0.40592746,  0.21297085,  0.21088365,  0.33907811, -0.39949415,
         0.09535784,  0.04051226,  0.07621569, -0.30866765, -0.61084746,
         0.2099561 ,  0.27902819,  0.14591751, -0.14660637,  0.08090135,
        -0.31027937])
[ 0.92  0.68 -0.52  0.2  -0.28 -0.6   1.   -1.    0.76  0.28 -0.44  0.12
 -0.12  0.84  0.04  0.6  -0.68  0.52 -0.2  -0.92 -0.76 -0.84 -0.04  0.44
  0.36 -0.36]
Tensor([-0.02207925,  0.09149259,  0.27469074,  0.37739782, -0.11854768,
         0.07969227, -0.07023573,  0.19987759,  0.19935952,  0.18861406,
         0.03311918,  0.14148083, -0.20301678,  0.0473706 , -0.18877399,
         0.09245342,  0.02306991, -0.22394329, -0.15811208,  0.02115713,
        -0.03138903,  0.07407018,  0.05618513, -0.27333874, -0.03340957,
        -0.03838118])
[-0.2  -0.04  1.   -0.84 -0.92  0.44 -0.44  0.68 -1.   -0.28 -0.76

Tensor([ 0.21994381, -0.13244284, -0.11285096, -0.13263142, -0.16519381,
         0.25150848, -0.10167988,  0.01483397, -0.2197244 ,  0.23580945,
         0.10446414, -0.00055442,  0.20604358,  0.05987382, -0.02441178,
         0.13846713,  0.1744773 , -0.16541356, -0.11470172, -0.12497779,
        -0.33113799, -0.23450002,  0.32006614, -0.1044267 ,  0.13855227,
         0.3549392 ])
[-0.2  -0.92  0.84  0.44 -0.68  0.28  0.68 -0.6   0.12  0.2  -0.36 -0.76
  0.6  -0.84 -0.44 -0.28  1.    0.52 -0.52  0.04 -1.    0.92 -0.04  0.36
  0.76 -0.12]
Tensor([ 1.42091097e-01, -2.40407576e-01, -6.62275184e-02,  5.45670331e-05,
        -1.84075596e-01,  7.59390849e-03, -2.25515037e-02,  1.20738645e-01,
         3.11724684e-01,  1.84051702e-01,  2.55003301e-01,  2.39269106e-01,
         5.31128362e-01, -1.50651636e-02,  3.29916506e-02,  1.30606110e-01,
         3.51648730e-02, -9.98183570e-02, -1.48518014e-01,  1.52975954e-01,
         9.76332094e-02,  1.90165195e-01,  2.32960556e-01, -1.55864544e-0

Tensor([-0.27741729, -0.1538457 , -0.01779616,  0.24693927,  0.07783115,
        -0.0898254 ,  0.11754825, -0.13820133,  0.16954735, -0.35418103,
         0.02110911, -0.01540998,  0.05756219,  0.00698149,  0.12681921,
        -0.27120734, -0.29786544, -0.19878602, -0.04267258, -0.16700041,
        -0.04397475,  0.20876449,  0.02781931,  0.10846066, -0.32931182,
         0.01475182])
[ 0.12 -1.   -0.76 -0.36  0.68 -0.12  0.76  0.52  0.04 -0.6  -0.84  0.6
  1.   -0.92  0.92 -0.28 -0.52  0.44 -0.2  -0.44 -0.04  0.84  0.2  -0.68
  0.28  0.36]
Tensor([-0.31541041,  0.29231207,  0.00183607, -0.04896383, -0.01224097,
        -0.0271587 , -0.07637401, -0.07753656, -0.02806493, -0.3283089 ,
         0.16854093,  0.14547389, -0.22708091,  0.04362734,  0.00497211,
        -0.06150578, -0.0884776 ,  0.03511406,  0.23641371,  0.19874874,
         0.06181615,  0.07270896,  0.1083432 , -0.27213382, -0.38481605,
        -0.17387042])
[ 1.   -0.04  0.6   0.68 -0.92  0.36 -0.76  0.52 -0.44 -0.52 -0.84 

Tensor([-0.17663338,  0.02109874, -0.10272078, -0.24705486, -0.19450741,
        -0.36116791, -0.20015915,  0.12388153,  0.11506359, -0.01089882,
         0.11093987, -0.02418141, -0.16069465, -0.03102006,  0.1995481 ,
         0.00619408,  0.21276043,  0.20817278,  0.18584872, -0.12068317,
        -0.02010496, -0.14156033, -0.32781535, -0.02300384,  0.14031906,
         0.10411159])
[ 0.36 -1.   -0.12  0.76  0.28 -0.68  0.04 -0.6   0.2  -0.2  -0.36  0.68
 -0.28  0.84 -0.76 -0.52  0.52 -0.84  1.   -0.92  0.12  0.6  -0.44  0.44
  0.92 -0.04]
Tensor([-0.17810195,  0.00332714, -0.00119689,  0.04743853, -0.17161838,
        -0.30402578,  0.00864217,  0.01028296, -0.03031377,  0.22392462,
         0.11703665, -0.20155189, -0.18890425, -0.13737818,  0.13216695,
         0.06072494, -0.07787898,  0.24570539, -0.01143244, -0.13125698,
         0.09924424, -0.16451303,  0.11118833,  0.16870273, -0.20882215,
         0.03206489])
[-0.04  0.76  0.36 -0.28  0.52  0.2  -0.36  0.68  0.12 -0.52  1.  

Tensor([ 0.06564634, -0.12314574, -0.07700228, -0.27493007,  0.03994058,
        -0.06598146, -0.07070694,  0.06648807, -0.09530161, -0.06884608,
        -0.16398441, -0.07708794,  0.27726135,  0.21877903, -0.26065737,
         0.01859263,  0.02458613,  0.03161712, -0.22528792,  0.05597029,
        -0.07444702, -0.15252269,  0.15525461,  0.14523414,  0.03195671,
        -0.10081777])
[ 0.68  0.12  0.76 -0.6  -0.36 -0.68  0.92 -0.2   0.2  -0.92  0.44 -0.44
  0.36  0.84  0.28 -0.12 -0.84  0.04  0.6   1.    0.52 -0.28 -1.   -0.76
 -0.52 -0.04]
Tensor([-0.0801087 , -0.16084227,  0.25291492,  0.09676095, -0.41413981,
         0.06961562, -0.00815652, -0.10597194, -0.26048474,  0.00101261,
         0.11079429, -0.01452607, -0.28035402,  0.19634701,  0.20010166,
        -0.06523241,  0.27000578,  0.13555688,  0.24448275,  0.14503285,
        -0.16216775, -0.16943007, -0.2585705 ,  0.03799615,  0.27133844,
         0.15553801])
[-0.92  0.76 -1.    0.28  0.12 -0.28  0.04 -0.04  0.92  0.36 -0.44

Tensor([ 0.08424758,  0.11438565,  0.03711366, -0.02945687, -0.18586698,
        -0.11763351,  0.45101751, -0.03146736,  0.22343921, -0.15638944,
        -0.10016635,  0.28142912,  0.19387462, -0.03013299, -0.2326529 ,
        -0.20212006, -0.00840654,  0.07069982, -0.24370815, -0.08396657,
         0.03952626, -0.24604942, -0.0905056 , -0.03337536,  0.06164431,
         0.00778325])
[ 0.84 -0.04  0.28 -0.76  0.76  0.6  -0.6   0.68  1.   -1.   -0.44  0.36
  0.04 -0.2  -0.92 -0.12  0.12  0.2   0.44 -0.36 -0.52 -0.84  0.52 -0.28
  0.92 -0.68]
Tensor([-0.08334859, -0.02819031, -0.04496592,  0.03336691, -0.04996617,
        -0.1303931 ,  0.03390411, -0.17622881,  0.20683642, -0.11118642,
        -0.02254713,  0.12775288,  0.11782492, -0.15016397, -0.3530646 ,
        -0.16005188, -0.05104425,  0.04944848, -0.08821471, -0.06401583,
        -0.11992108,  0.04456096,  0.23560331, -0.09207144, -0.01368039,
        -0.19413199])
[ 0.84 -0.52  0.52  0.12  0.36 -0.84  0.44  0.92  0.04  0.2   0.76

KeyboardInterrupt: 