In [1]:

from MyTorch import Model
from MyTorch.operations import Flatten
from MyTorch.activations import ReLU, Softmax
from MyTorch.layers import Linear, Dropout, MultiheadAttention, LayerNorm, Attention


class GPT2_Layer(Model):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.attn = MultiheadAttention(d_model, num_heads, dropout)
        self.ln1 = LayerNorm(d_model)
        self.ln2 = LayerNorm(d_model)
        self.drop1 = Dropout(dropout)
        self.drop2 = Dropout(dropout)
        self.ff = Linear(d_model, d_ff)
        self.relu = ReLU()
        self.softmax = Softmax()

    def forward(self, x, mask):
        attn_out = self.attn(x, x, x, mask)
        attn_out = self.drop1(attn_out)
        x = self.ln1(x + attn_out)
        ff_out = self.ff(x)
        ff_out = self.relu(ff_out)
        ff_out = self.drop2(ff_out)
        x = self.ln2(x + ff_out)
        return x

class GPT2(Model):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.layers = [GPT2_Layer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        self.softmax = Softmax()

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return x
    
    


In [2]:
import numpy as np
from MyTorch import Tensor
model = GPT2(2, 128, 8, 128, 0.1)

input = Tensor(np.random.randn(1, 10, 128))
output = model(input, None)
print(output)


Tensor([[[-0.17352307 -0.16763261 -0.17831038 ... -0.17326243 -0.17584415
   -0.17637606]
  [-0.18609305  5.4882216  -0.1806758  ... -0.22837962 -0.20461608
   -0.19245334]
  [-0.20495044 -0.16591473 -0.2060373  ... -0.20557004 -0.20573443
   -0.20543458]
  ...
  [-0.1643044  -0.14986372 -0.15554751 ... -0.16656424 -0.27886516
   -0.18949695]
  [-0.20371205 -0.20282997 -0.20596766 ... -0.20230123 -0.20474124
   -0.20490046]
  [-0.06302877 -0.05055295 -0.05008615 ... -0.04764176 -0.05735055
   -0.05499944]]], shape = (1, 10, 128))


In [6]:
from MyTorch.losses import MSE
loss = MSE()
multihead = MultiheadAttention(8, 2)
input = Tensor(np.random.randn(1, 10, 8))
input2 = Tensor(np.random.randn(1, 10, 8))
output = multihead(input, input2, input2, None)

label = Tensor(np.random.randn(1, 10, 8))

loss = loss(output, label)
print(loss)
loss.backward()


Tensor(2.2417287826538086, shape = ())
