In [2]:
import sys
from pathlib import Path 

sys.path.insert(0, Path(__file__).parent.parent.as_posix())

from train.gru_model import Network, NetworkConfig
from train.ctokenizer import CTokenizer
from train.languages_list import Languages
from train.paths import *

import torch

tokenizer = CTokenizer()

config = NetworkConfig(
    num_classes=2,
    vocab_size=2**15,
    hidden_dim=104,
    num_layers=3,
    bidirectional=True,
    num_threads_per_dir=1,
)

model = torch.nn.DataParallel(Network.from_config(config))
model.load_state_dict(torch.load("artifacts/gru_binary/model_21_finetune.pth"))
model.module.save_binary(RESOURCES / "gru_binary.bin")

config.num_classes = len(Languages)

model = torch.nn.DataParallel(Network.from_config(config))
model.load_state_dict(torch.load("artifacts/gru_lang/model_96.pth"))
model.module.save_binary(RESOURCES / "gru_lang.bin")

In [2]:
tokens = tokenizer.encode("""print("Hello, world!")""")
tokens = torch.tensor(tokens).view(1, -1)

In [3]:
tokens

tensor([[    1,  3381,   915, 28553,    14, 16506, 14169,    11,     1]])

In [4]:
from tokenizers.implementations import ByteLevelBPETokenizer

hf_tokenizer = ByteLevelBPETokenizer(
    "artifacts/tokenizer-vocab.json",
    "artifacts/tokenizer-merges.txt",
)

for token in tokens[0]:
    print(hf_tokenizer.decode([token]), end="|")

<s>|print|("|Hello|,| world|!"|)|<s>|

In [15]:
output, last_state = model(tokens, return_last_state=True)
print(output)
print("fwd", last_state[:, :model.hidden_dim])
print("bwd", last_state[:, model.hidden_dim:])

tensor([[ 0.0998, -0.0402]], grad_fn=<AddmmBackward0>)
fwd tensor([[-0.0521,  0.1173,  0.0880,  0.1825,  0.1432, -0.0323,  0.0496, -0.0226,
         -0.0147, -0.1103, -0.0986, -0.0930,  0.1540,  0.0522, -0.0736, -0.1440,
          0.1444,  0.0672, -0.0539, -0.0979, -0.0725, -0.0897,  0.0514, -0.0691,
         -0.1434,  0.1721, -0.3144, -0.2082, -0.1600, -0.1931,  0.2100, -0.2327,
         -0.0457,  0.0550,  0.0815, -0.0715,  0.0967, -0.1749,  0.1869, -0.0899,
          0.1615, -0.0024, -0.0892,  0.0106,  0.3297, -0.2068,  0.1825, -0.0269,
          0.0135,  0.0174, -0.0346,  0.0984,  0.1088, -0.0501, -0.1720,  0.0742,
          0.2533,  0.0292, -0.1161,  0.2979, -0.1643,  0.0375,  0.1767,  0.0354,
          0.1340,  0.0391, -0.0158,  0.2632, -0.0447,  0.1109,  0.0991, -0.3012,
          0.0722, -0.2183,  0.0507,  0.0738,  0.1661,  0.2915,  0.1574, -0.1335,
          0.1466, -0.2246, -0.1645,  0.0461, -0.0933,  0.0005, -0.3062, -0.0875,
         -0.0966,  0.0010, -0.1828,  0.1152,  0.06

In [30]:
model.classifier.weight

Parameter containing:
tensor([[-0.0257,  0.0232,  0.0430, -0.0220, -0.0367, -0.0375,  0.0320, -0.0510,
          0.0235, -0.0592,  0.0077,  0.0385,  0.0356, -0.0656, -0.0661, -0.0378,
         -0.0281, -0.0052,  0.0009,  0.0004, -0.0273,  0.0572, -0.0593, -0.0613,
         -0.0053,  0.0228, -0.0373,  0.0471, -0.0468, -0.0049,  0.0631, -0.0349,
         -0.0617, -0.0218, -0.0104,  0.0237,  0.0166,  0.0535,  0.0417,  0.0280,
          0.0191, -0.0242,  0.0445,  0.0006,  0.0048,  0.0428, -0.0607,  0.0139,
          0.0200, -0.0539,  0.0214,  0.0289,  0.0491,  0.0256, -0.0022,  0.0592,
          0.0258, -0.0275, -0.0236,  0.0659,  0.0273,  0.0254,  0.0271,  0.0337,
         -0.0350,  0.0327, -0.0402, -0.0540,  0.0045,  0.0554, -0.0483, -0.0086,
         -0.0507,  0.0095,  0.0006,  0.0092,  0.0335, -0.0269, -0.0059,  0.0607,
         -0.0557, -0.0491, -0.0525, -0.0119, -0.0589, -0.0644, -0.0122, -0.0638,
         -0.0537,  0.0579,  0.0463, -0.0498, -0.0618, -0.0339, -0.0589, -0.0373,
      

In [27]:
model.is_miss.

Parameter containing:
tensor([0.0465, 0.0134], requires_grad=True)

In [20]:
model.embedding.weight @ model.gru.weight_ih_l0.T + model.gru.bias_ih_l0

tensor([[ 0.1291, -0.3364, -0.3660,  ..., -0.3569, -0.3137, -0.6388],
        [ 0.2120,  0.2684,  0.2017,  ...,  0.5242,  0.2220,  0.0875],
        [-1.5297,  0.1071,  0.6413,  ...,  1.0974,  0.5629, -0.5461],
        ...,
        [ 0.9364, -0.2255,  0.2171,  ..., -0.1986, -0.3074, -0.7767],
        [ 1.8654,  0.0984, -0.4431,  ...,  0.3219, -0.0388, -1.0945],
        [ 0.1277,  1.0642, -0.0931,  ...,  0.2538,  0.7706,  0.2253]],
       grad_fn=<AddBackward0>)

In [7]:
model.gru.weight_hh_l0_reverse

Parameter containing:
tensor([[ 0.0917, -0.0995,  0.0908,  ..., -0.0321, -0.0317, -0.0810],
        [-0.0270, -0.0973,  0.0935,  ..., -0.0920,  0.0526, -0.0094],
        [ 0.0572, -0.0755, -0.0741,  ..., -0.0977,  0.0070, -0.0191],
        ...,
        [-0.0700, -0.0014,  0.0142,  ..., -0.0347, -0.0350, -0.0148],
        [ 0.0755,  0.0200,  0.0106,  ..., -0.0166, -0.0224,  0.0874],
        [-0.0251, -0.0255, -0.0518,  ..., -0.0632,  0.0015,  0.0604]],
       requires_grad=True)

In [29]:
model.embedding.weight[]

torch.Size([32768, 96])

In [29]:
h = torch.zeros(model.hidden_dim)

num_tokens = len(tokens[0])
hiddens = torch.zeros(num_tokens, model.hidden_dim * 2)

for i, token_id in enumerate(tokens[0]):
    emb = model.embedding.weight[token_id]
    rzn_i = emb @ model.gru.weight_ih_l0.T + model.gru.bias_ih_l0
    rzn = model.gru.weight_hh_l0 @ h + model.gru.bias_hh_l0
    rz = torch.sigmoid(rzn[:-96] + rzn_i[:-96])

    n_i = rzn_i[-96:]
    n = rzn[-96:]
    r = rz[:96]
    z = rz[96:]

    n = torch.tanh(n_i + r * n)

    h = (1 - z) * n + z * h
    hiddens[i, :96] = h


h = torch.zeros(model.hidden_dim)

for i, token_id in enumerate(tokens[0].flip(0)):
    emb = model.embedding.weight[token_id]
    rzn_i = emb @ model.gru.weight_ih_l0_reverse.T + model.gru.bias_ih_l0_reverse
    rzn = model.gru.weight_hh_l0_reverse @ h + model.gru.bias_hh_l0_reverse
    rz = torch.sigmoid(rzn[:-96] + rzn_i[:-96])

    n_i = rzn_i[-96:]
    n = rzn[-96:]
    r = rz[:96]
    z = rz[96:]

    n = torch.tanh(n_i + r * n)

    h = (1 - z) * n + z * h
    hiddens[num_tokens - i - 1, 96:] = h

In [34]:
model.gru.weight_hh_l0.shape

torch.Size([288, 96])

In [35]:
model.gru.weight_hh_l1.shape

torch.Size([288, 96])

In [None]:
h = torch.zeros(model.hidden_dim)

num_tokens = len(tokens[0])
hiddens2 = torch.zeros(num_tokens, model.hidden_dim * 2)

for i, token_id in enumerate(tokens[0]):
    emb = hiddens[i]
    rzn_i = emb @ model.gru.weight_ih_l1.T + model.gru.bias_ih_l1
    rzn = model.gru.weight_hh_l1 @ h + model.gru.bias_hh_l1
    rz = torch.sigmoid(rzn[:-96] + rzn_i[:-96])

    n_i = rzn_i[-96:]
    n = rzn[-96:]
    r = rz[:96]
    z = rz[96:]

    n = torch.tanh(n_i + r * n)

    h = (1 - z) * n + z * h
    hiddens[i, :96] = h


h = torch.zeros(model.hidden_dim)

for i, token_id in enumerate(tokens[0].flip(0)):
    emb = hiddens[num_tokens - i - 1]
    rzn_i = emb @ model.gru.weight_ih_l1_reverse.T + model.gru.bias_ih_l1_reverse
    rzn = model.gru.weight_hh_l1_reverse @ h + model.gru.bias_hh_l1_reverse
    rz = torch.sigmoid(rzn[:-96] + rzn_i[:-96])

    n_i = rzn_i[-96:]
    n = rzn[-96:]
    r = rz[:96]
    z = rz[96:]

    n = torch.tanh(n_i + r * n)

    h = (1 - z) * n + z * h
    hiddens[num_tokens - i - 1, 96:] = h

In [32]:
hiddens[0, :192]

tensor([ 1.2674e-03,  1.1547e-01, -2.6268e-02, -1.6644e-01,  4.5153e-01,
        -1.8983e-02,  3.9017e-01, -5.3695e-01, -9.9300e-02,  2.2561e-01,
         1.0204e-01, -1.2183e-02, -9.4804e-02, -3.2954e-01,  4.8671e-03,
         4.4939e-01, -1.4898e-01, -1.9818e-01,  4.2070e-01,  2.0015e-01,
        -2.3484e-01,  6.0277e-02,  2.8386e-01, -2.2724e-01,  2.6578e-02,
        -3.1051e-01,  1.4110e-01,  1.3717e-01,  3.6472e-01,  1.2488e-01,
         2.0230e-01,  6.4643e-02,  3.8817e-01,  1.6721e-01, -1.3280e-01,
         1.9689e-01,  6.0524e-01,  1.5176e-01,  9.8684e-02,  1.0289e-01,
         3.8060e-01, -1.6276e-01,  2.7155e-01, -4.0174e-02, -2.8130e-01,
        -2.5302e-01, -2.8708e-01, -8.2621e-02,  1.5804e-01,  3.6840e-01,
         1.5071e-02, -9.2558e-02, -4.2890e-01,  3.2219e-01,  1.3373e-01,
        -2.7282e-02,  2.6250e-01,  6.6655e-02, -2.9119e-01,  2.1725e-01,
         2.9385e-01,  1.3017e-01, -2.0773e-01, -3.5081e-01, -3.2323e-01,
         3.4531e-01,  6.4780e-02, -2.8133e-01, -3.2

In [77]:
model.gru.weight_ih_l0

Parameter containing:
tensor([[ 0.0822,  0.0340, -0.0178,  ...,  0.0821,  0.0175, -0.0217],
        [ 0.0665,  0.0608,  0.0528,  ..., -0.0407, -0.0972, -0.0914],
        [-0.0785, -0.0196,  0.0125,  ...,  0.0728,  0.0491,  0.0823],
        ...,
        [ 0.0404, -0.0208, -0.0552,  ...,  0.0334,  0.0633, -0.0039],
        [ 0.0181, -0.0097,  0.0381,  ..., -0.1003,  0.0772,  0.0852],
        [ 0.0514,  0.0665,  0.0215,  ...,  0.0080, -0.0827, -0.0342]],
       requires_grad=True)

In [74]:
h

tensor([-0.4439, -0.3690,  0.2728,  0.0814, -0.1744,  0.3272, -0.1627, -0.2458,
        -0.4925,  0.2791,  0.1791,  0.0408,  0.3220,  0.1953, -0.4881,  0.0455,
        -0.0693, -0.0260, -0.0188, -0.6429,  0.0057, -0.7511,  0.1778, -0.1433,
        -0.3337,  0.1596, -0.5848,  0.2208, -0.2546,  0.5237, -0.0496, -0.4356,
        -0.2800, -0.2929, -0.0741,  0.1820, -0.1036,  0.3627,  0.2627,  0.0113,
        -0.1498, -0.1868,  0.2371,  0.1681,  0.0979, -0.2144,  0.0585, -0.2772,
         0.6467, -0.1484,  0.4218, -0.1408,  0.2172,  0.4462,  0.0389,  0.2326,
        -0.0037, -0.1862, -0.5302,  0.0680, -0.0361, -0.0545,  0.0352, -0.2929,
         0.0556, -0.5602, -0.2234, -0.2770,  0.0971,  0.0328,  0.1536,  0.3885,
        -0.3169, -0.3771, -0.3718,  0.0122, -0.1149,  0.4573,  0.2858,  0.4390,
        -0.3344,  0.4033, -0.6830, -0.4258, -0.1805,  0.2471,  0.0666,  0.0385,
         0.0009,  0.4797, -0.4189, -0.1463, -0.4965,  0.6975, -0.3550, -0.2612],
       grad_fn=<AddBackward0>)

(tensor([[-0.0215, -0.0358]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.0853, -0.2138,  0.0700,  0.0861,  0.1821,  0.0767, -0.0189,  0.0239,
           0.1247, -0.1619,  0.0156,  0.1476,  0.0509,  0.0982, -0.1010, -0.4197,
          -0.0149,  0.0022, -0.0483, -0.2523,  0.1215, -0.2176, -0.1859,  0.0850,
           0.0071,  0.1054, -0.0871, -0.0641,  0.1244, -0.1336, -0.2311, -0.0982,
          -0.2187, -0.0351,  0.1400,  0.1996,  0.0321, -0.1520, -0.0656,  0.0163,
           0.0580,  0.2081,  0.1261, -0.0688,  0.0139, -0.1371,  0.3132, -0.0830,
           0.0467,  0.0559, -0.0798, -0.2197, -0.0547, -0.0137, -0.3674, -0.0544,
           0.1473, -0.1371, -0.1658, -0.2117, -0.1386, -0.0592,  0.0114, -0.1037,
          -0.0268,  0.0589,  0.2894, -0.1416,  0.3079, -0.0819, -0.0044,  0.1828,
           0.2914, -0.1299,  0.0364,  0.0801,  0.2095,  0.0148, -0.1327,  0.0351,
           0.1733, -0.0445,  0.1198,  0.0740, -0.1330,  0.2385, -0.0835, -0.2088,
          -0.1380, -0.1084,  0.0216,  0.0