In [5]:
import torch
from torch import nn
from ctokenizer import CTokenizer

tokenizer = CTokenizer()

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=2**15, embedding_dim=96)
        self.gru = nn.GRU(input_size=96, hidden_size=96, batch_first=True)
        self.classifier = nn.Linear(96, 100)

    def forward(self, ids, last_elements, return_last_state=False):
        """ ids: [batch_size, seq_len]
        """
        batch_size = ids.shape[0]

        # [batch_size, seq_len, emb_dim]
        emb = self.embedding(ids)

        # [batch_size, seq_len, emb_dim]
        features, _ = self.gru(emb)

        last_feature = features[range(batch_size), last_elements]

        # [batch_size, hid_dim] -> [batch_size, 100]
        logits = self.classifier(last_feature)

        if return_last_state:
            return logits, last_feature

        return logits

state_dict = torch.load("gru_weights2/model_79.pth", map_location="cpu")
model = Network()
model.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
embeddings = state_dict["embedding.weight"]

weights_i = state_dict["gru.weight_ih_l0"]
weights_h = state_dict["gru.weight_hh_l0"]

bias_i = state_dict["gru.bias_ih_l0"]
bias_h = state_dict["gru.bias_hh_l0"]

classifier_weight = state_dict["classifier.weight"]
classifier_bias = state_dict["classifier.bias"]

classifier_weight = torch.nn.functional.pad(classifier_weight, (0, 0, 0, 4))
classifier_bias = torch.nn.functional.pad(classifier_bias, (0, 4), value=-torch.inf)

embeddings_ = embeddings @ weights_i.T + bias_i

In [7]:
num_embeddings, hidden_dim = embeddings.shape
num_classes, = classifier_bias.shape

with open("solution/resources/gru_weights2.bin", "wb") as file:
    file.write(hidden_dim.to_bytes(length=4, byteorder="little"))
    file.write(num_embeddings.to_bytes(length=4, byteorder="little"))
    file.write(num_classes.to_bytes(length=4, byteorder="little"))
    
    file.write(embeddings_.numpy().tobytes())
    file.write(weights_h.numpy().tobytes())
    file.write(bias_h.numpy().tobytes())

    file.write(classifier_weight.numpy().tobytes())
    file.write(classifier_bias.numpy().tobytes())

In [115]:
h = torch.zeros(96)
x = tokenizer.encode("print(\"Hello, world!\")\n")

for i in x:
    rzn_i = embeddings_[i]
    # rzn_i = bias_i + weights_i @ e
    rzn_h = bias_h + weights_h @ h
    rz = torch.sigmoid(rzn_i[:2*96] + rzn_h[:2*96])
    n = torch.tanh(rzn_i[2*96:] + rz[:96] * rzn_h[2*96:])
    h = (1 - rz[96:]) * n + rz[96:] * h

output = classifier_bias + classifier_weight @ h

In [106]:
print(e[:10])
print(rzn_i[:10])
print(rzn_h[:10])
print(h[:10])

tensor([ 0.6589, -0.2389, -1.4574, -0.5617, -0.2824,  0.6322, -0.0603, -0.3864,
        -1.6901,  0.1955])
tensor([-1.2136, -2.5133, -0.5410, -3.4197, -1.5178, -0.4936, -4.3799, -3.4258,
        -0.5711,  0.1148])
tensor([-0.0268,  0.0835,  0.1727,  0.2376,  0.1116,  0.1489,  0.2497,  0.1452,
         0.0182, -0.0504])
tensor([ 0.0706,  0.1254,  0.6507,  0.2384,  0.3526, -0.5223, -0.2796,  0.5353,
        -0.8297, -0.2017])


In [114]:
print(e[:10])
print(rzn_i[:10])
print(rzn_h[:10])
print(h[:10])

tensor([ 0.6589, -0.2389, -1.4574, -0.5617, -0.2824,  0.6322, -0.0603, -0.3864,
        -1.6901,  0.1955])
tensor([-1.2136, -2.5133, -0.5410, -3.4197, -1.5178, -0.4936, -4.3799, -3.4258,
        -0.5711,  0.1148])
tensor([-0.0268,  0.0835,  0.1727,  0.2376,  0.1116,  0.1489,  0.2497,  0.1452,
         0.0182, -0.0504])
tensor([ 0.0706,  0.1254,  0.6507,  0.2384,  0.3526, -0.5223, -0.2796,  0.5353,
        -0.8297, -0.2017])


In [120]:
output.argmax()

tensor(75)

In [118]:
model(torch.tensor(x).reshape(1, -1), len(x)-1)

tensor([[  1.1370, -12.9131,  -2.6255,  -2.5055,  -6.0172,   1.2686, -12.8175,
          -0.8780,  -0.3550,  -2.4557,   3.0256,   4.2143,   1.9403,  -2.6419,
          -4.2999,   0.8922,  -4.3181,  -4.0606,  -2.3016,   0.5230, -12.9406,
          -1.2455,  -0.5309,  -4.6550,  -3.2385,  -1.3632,  -2.5301,   0.9723,
          -3.9741,  -6.7590,   1.6380,   0.5815,   0.2733, -13.0002,  -2.5879,
           1.3441, -12.8627, -12.9307,  -1.7153,   3.3323,  -3.7203,   0.2743,
          -3.6626,  -0.4977,  -0.2695, -12.8957,   1.7998,  -5.9528,  -1.0700,
          -1.5789,  -5.2057,   3.8470, -13.0275,   1.7345,  -1.8308,  -2.2166,
         -12.7927,   3.8284,  -1.8637,   2.3391,   1.3699,  -6.7306,   0.8004,
          -1.3970,  -2.5148, -13.1093,   0.5474,  -0.3931,  -3.7620, -12.7709,
          -4.3564,   2.7505,  -5.8263,   3.8869,  -4.0820,   4.6155,  -2.1625,
          -3.7223,  -1.8967,  -2.3610,  -2.7317,  -0.8835,  -2.9204,  -0.8718,
          -1.1453,  -4.9039,  -0.5078,   4.1140,  -5