In [135]:
import math
import random

In [136]:
def shape(v):
    if isinstance(v, list):
        return (len(v), ) + shape(v[0])
    return ()

def matmul(a, b):
    if isinstance(a[0][0], list) and isinstance(b[0][0], list):
        return [matmul(ra, rb) for ra, rb in zip(a, b)]
    if isinstance(a[0][0], list) and not isinstance(b[0][0], list):
        return [matmul(ra, b) for ra in a]
    if not isinstance(a[0][0], list) and isinstance(b[0][0], list):
        return [matmul(a, rb) for rb in b]
    
    assert len(a[0]) == len(b)

    d1 = len(a)
    d2 = len(b[0])
    d3 = len(b)

    c = [[0.0] * d2 for _ in range(d1)]

    for i in range(d1):
        for j in range(d2):
            for k in range(d3):
                c[i][j] += a[i][k] * b[k][j]
    
    return c
def scalarmul(v, c):
    if isinstance(v, list):
        return [scalarmul(row, c) for row in v]
    return v * c
    
def scalaradd(v, c):
    if isinstance(v, list):
        return [scalaradd(row, c) for row in v]
    return v + c

def matadd(a, b):
    if isinstance(a[0][0], list) and isinstance(b[0][0], list):
        return [matadd(ra, rb) for ra, rb in zip(a, b)]
    if isinstance(a[0][0], list) and not isinstance(b[0][0], list):
        return [matadd(ra, b) for ra in a]
    if not isinstance(a[0][0], list) and isinstance(b[0][0], list):
        return [matadd(a, rb) for rb in b]
    # both 2D
    assert len(a[0]) == len(b[0])

    is_vector = (len(b) == 1)
    if not is_vector:
        assert len(a) == len(b)
    
    d1 = len(a)
    d2 = len(a[0])

    c = [[0.0] * d2 for _ in range(d1)]

    for i in range(d1):
        for j in range(d2):
            c[i][j] = a[i][j] + b[i][j] if not is_vector else b[0][j]
    
    return c
def transpose(v):
    if isinstance(v[0][0], list):
        return [transpose(row) for row in v]
    
    res = []
    for i in range(len(v[0])):
        res.append([])
        for j in range(len(v)):
            res[i].append(v[j][i])
    return res

In [137]:
class Dropout:
    def __init__(self, p=0.1):
        self.p = p
        self.scale = 1.0 / (1.0 - self.p) if p < 1 else 0
        
    def execute(self, v):
        if isinstance(v, list):
            return [self.execute(row) for row in v]
        
        ret = 0.0 if random.random() < self.p else v * self.scale

        return ret

def softmax(v): # axis = -1
    if isinstance(v[0], list):
        return [softmax(row) for row in v]

    big = max(v)
    exps = [math.exp(x - big) for x in v]
    tot = sum(exps)

    return [x / tot for x in exps]

def relu(v):
    if isinstance(v, list):
        return [relu(row) for row in v]
    return max(0, v)

def layer_normalization(v, gamma, beta, epsilon=1e-5):
    if isinstance(v[0], list):
        return [layer_normalization(row, gamma, beta, epsilon) for row in v]
    
    mean = sum(v) / len(v)
    std = math.sqrt(sum([(x - mean) ** 2 for x in v]) / len(v))

    x = scalaradd(v, -mean)
    x = [val * g / (std + epsilon) + b for val, g, b in zip(x, gamma[0], beta[0])]
    return x

In [138]:
def pos_encoding(seq_len, d_model):
    pe = []
    for pos in range(seq_len):
        pe.append([])
        for i in range(d_model // 2):
            pe[pos].append(math.sin(pos / (10000 ** (2 * i / d_model))))
            pe[pos].append(math.cos(pos / (10000 ** (2 * i / d_model))))
    return pe

def compute_qkv(X, W_q, W_k, W_v):
    return matmul(X, W_q), matmul(X, W_k), matmul(X, W_v)

def get_causal_mask(generated_len):
    v = [[1] * generated_len for _ in range(generated_len)]
    for i in range(0, generated_len):
        for j in range(i + 1, generated_len):
            v[i][j] = 0
    return v
    
def self_attention(Q, K, V, mask = None, dropout=None):
    d_k = shape(Q)[-1]
    score = scalarmul(matmul(Q, transpose(K)), (1.0 / math.sqrt(d_k)))
    # mask: (seq_len, seq_len)

    if mask is not None:
        for i in range(len(score)):
            for j in range(len(score[0])):
                if mask[i][j] == 0:
                    score[i][j] = -1e9

    pscore = softmax(score)

    if dropout is not None:
        pscore = dropout.execute(pscore)

    return matmul(pscore, V)

def split_for_head(X, n_heads):
    # seq_len, d_model -> n_heads, seq_len, d_head

    d_model = shape(X)[-1]
    d_head = d_model // n_heads
    
    heads = []
    for i in range(n_heads):
        head = []
        for row in X:
            subseq = row[i * d_head : (i + 1) * d_head]
            head.append(subseq)
        heads.append(head)
    return heads
        

def horizontal_concat(X):
    # n_heads, seq_len, d_head -> seq_len, d_model
    res = []
    for i in range(len(X[0])):
        full_row = []
        for head in X:
            full_row.extend(head[i])
        res.append(full_row)

    return res

def multi_head_attention(Q, K, V, W_o, n_heads, mask=None, dropout=None):
    # seq_len, d_model
    
    Q_split = split_for_head(Q, n_heads)
    K_split = split_for_head(K, n_heads)
    V_split = split_for_head(V, n_heads)

    # n_head, seq_len, d_head
    heads = [self_attention(Q_split[i], K_split[i], V_split[i], mask=mask, dropout=dropout) for i in range(n_heads)]
    
    concat_heads = horizontal_concat(heads)

    # seq_len, d_model
    return matmul(concat_heads, W_o)

In [139]:
class Transformer:
    def __init__(self, vocab_size, d_model, n_heads, d_ff, n_blocks):
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_ff = d_ff

        self.embedding = self.init_weights(vocab_size, d_model)
        self.encoders = [self.create_block(is_decoder=False) for _ in range(n_blocks)]
        self.decoders = [self.create_block(is_decoder=True) for _ in range(n_blocks)]

        self.dropout_p = 0.1
        self.train()
    
    def init_weights(self, rows, cols):
        limit = math.sqrt(6 / (rows + cols))
        v = []
        for i in range(rows):
            v.append([random.uniform(-limit, +limit) for _ in range(cols)])
        return v

    def create_block(self, is_decoder):
        weights = {
            'W_q': self.init_weights(self.d_model, self.d_model),
            'W_k': self.init_weights(self.d_model, self.d_model),
            'W_v': self.init_weights(self.d_model, self.d_model),
            'W_o': self.init_weights(self.d_model, self.d_model),
            'W_ff1': self.init_weights(self.d_model, self.d_ff),
            'bias_ff1': [[0] * self.d_ff],
            'W_ff2': self.init_weights(self.d_ff, self.d_model),
            'bias_ff2': [[0] * self.d_model],
            'gamma': [[1] * self.d_model],
            'beta': [[0] * self.d_model],
        }
        if is_decoder:
            weights.update({
                'W_q_cross': self.init_weights(self.d_model, self.d_model),
                'W_k_cross': self.init_weights(self.d_model, self.d_model),
                'W_v_cross': self.init_weights(self.d_model, self.d_model),
                'W_o_cross': self.init_weights(self.d_model, self.d_model),
            })
        return weights
    
    def run_encoders(self, x):
        seq_len = shape(x)[-1]
        # [seq_len]
        x = scalarmul([self.embedding[idx] for idx in x], math.sqrt(self.d_model))
        # [seq_len x d_model]
        x = matadd(x, pos_encoding(seq_len, self.d_model))

        if self.dropout: x = self.dropout.execute(x)

        for weights in self.encoders:

            Q, K, V = compute_qkv(x, weights['W_q'], weights['W_k'], weights['W_v'])
            att = multi_head_attention(Q, K, V, weights['W_o'], self.n_heads, mask=None, dropout=self.dropout)
            if self.dropout: att = self.dropout.execute(att)
            x = layer_normalization(matadd(x, att), weights['gamma'], weights['beta'])
            
            ff = matadd(matmul(relu(matadd(matmul(x, weights['W_ff1']), weights['bias_ff1'])), weights['W_ff2']), weights['bias_ff2'])
            if self.dropout: ff = self.dropout.execute(ff)
            x = layer_normalization(matadd(x, ff), weights['gamma'], weights['beta'])

        return x
    
    def run_decoders(self, enc_out, x):
        seq_len = shape(x)[-1]

        x = scalarmul([self.embedding[idx] for idx in x], math.sqrt(self.d_model))
        x = matadd(x, pos_encoding(seq_len, self.d_model))
        
        if self.dropout: x = self.dropout.execute(x)

        mask = get_causal_mask(seq_len)

        for weights in self.decoders:

            Q, K, V = compute_qkv(x, weights['W_q'], weights['W_k'], weights['W_v'])
            att = multi_head_attention(Q, K, V, weights['W_o'], self.n_heads, mask=mask, dropout=self.dropout)
            if self.dropout: att = self.dropout.execute(att)
            x = layer_normalization(matadd(x, att), weights['gamma'], weights['beta'])

            Q = matmul(x, weights['W_q_cross'])
            K = matmul(enc_out, weights['W_k_cross'])
            V = matmul(enc_out, weights['W_v_cross'])
            cross_att = multi_head_attention(Q, K, V, weights['W_o_cross'], self.n_heads, mask=None, dropout=self.dropout)
            if self.dropout: cross_att = self.dropout.execute(cross_att)
            x = layer_normalization(matadd(x, cross_att), weights['gamma'], weights['beta'])
            
            ff = matadd(matmul(relu(matadd(matmul(x, weights['W_ff1']), weights['bias_ff1'])), weights['W_ff2']), weights['bias_ff2'])
            if self.dropout: ff = self.dropout.execute(ff)
            x = layer_normalization(matadd(x, ff), weights['gamma'], weights['beta'])

        return matmul(x, transpose(self.embedding)) # seq_len x vocab_size
    
    def generate(self, src, start_token, max_len):
        enc_out = self.run_encoders(src)

        output = [start_token]
        for _ in range(max_len):
            logits = self.run_decoders(enc_out, output)
            probs = softmax(logits[-1])
            prob, token = max([(x, idx) for idx, x in enumerate(probs)])
            output.append(token)

        return output
    
    def eval(self):
        self.dropout = None
    def train(self):
        self.dropout = Dropout(p=self.dropout_p)


In [140]:
random.seed(123)

model = Transformer(vocab_size=21, d_model=128, n_heads=4, d_ff=256, n_blocks=2)
model.eval()

for _ in range(10):
    src = [random.randint(1, 20) for _ in range(10)]
    print(src, model.generate(src, 0, len(src)))

[8, 20, 19, 14, 9, 17, 3, 13, 19, 1] [0, 20, 20, 6, 6, 6, 6, 6, 6, 6, 6]
[4, 8, 6, 11, 5, 17, 16, 15, 6, 17] [0, 20, 2, 2, 2, 2, 2, 2, 2, 2, 2]
[18, 12, 17, 18, 18, 13, 10, 13, 4, 2] [0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
[18, 19, 9, 11, 6, 1, 5, 2, 2, 3] [0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
[13, 18, 12, 15, 18, 2, 1, 6, 13, 8] [0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
[7, 4, 17, 4, 10, 18, 10, 10, 19, 14] [0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
[2, 18, 7, 4, 1, 5, 9, 16, 5, 20] [0, 20, 2, 2, 2, 2, 2, 2, 2, 2, 2]
[16, 20, 6, 20, 15, 5, 11, 18, 4, 10] [0, 20, 20, 20, 20, 20, 20, 20, 3, 3, 3]
[2, 5, 2, 19, 9, 19, 12, 19, 14, 4] [0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
[9, 15, 20, 17, 7, 4, 8, 5, 14, 3] [0, 20, 3, 3, 3, 3, 3, 3, 3, 3, 3]
