In [1]:
# transformers

In [3]:
import torch
from einops import rearrange, einsum


In [6]:
D = torch.rand((4, 8, 4)) # batch, seq, dim
print(D.shape)

torch.Size([4, 8, 4])


In [7]:
A = torch.rand((4, 4)) # dim, dim
print(A.shape)

torch.Size([4, 4])


In [8]:
Y = D @ A.T
print(Y.shape)

torch.Size([4, 8, 4])


In [15]:
Y_ = einsum(D, A, "batch sequence d_in, d_out d_in -> batch sequence d_out")
print(Y_.shape)

torch.Size([4, 8, 4])


In [16]:
print(torch.equal(Y, Y_))

True


In [2]:
import torch
import numpy as np
import random

def set_seed(seed=42):
    # Python random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # PyTorch
    torch.manual_seed(seed)
    # CUDA (if available)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU
        torch.backends.cudnn.deterministic = True  # Deterministic algorithms
        torch.backends.cudnn.benchmark = False     # Disable auto-optimization

set_seed(42)  # Call this at the start of your script

In [3]:
import torch
from einops import rearrange, einsum

In [4]:
images = torch.randn(64,128,128,3)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

In [7]:
dim_by = dim_by.unsqueeze(0)

In [11]:
dim_by.shape

torch.Size([1, 10])

In [12]:
dim_by = dim_by.unsqueeze(2).unsqueeze(3).unsqueeze(4)

In [13]:
dim_by.shape

torch.Size([1, 10, 1, 1, 1])

In [14]:
images = images.unsqueeze(1)

In [16]:
images.shape

torch.Size([64, 1, 128, 128, 3])

In [17]:
dimmed_images = images * dim_by
dimmed_images.shape

torch.Size([64, 10, 128, 128, 3])

In [18]:
channels_last = torch.randn(64, 32, 32, 3)
B = torch.randn(32*32, 32*32)

In [19]:
channels_last_flat = channels_last.view(
    -1, channels_last.size(1), channels_last.size(2), channels_last.size(3)
)

In [20]:
channels_last_flat.shape

torch.Size([64, 32, 32, 3])

In [21]:
torch.equal(channels_last, channels_last_flat)

True

In [22]:
B.shape

torch.Size([1024, 1024])

In [23]:
channels_first_flat = channels_last_flat.transpose(1,2)

In [24]:
channels_first_flat.shape

torch.Size([64, 32, 32, 3])

In [25]:
channels_first_flat_transformed = channels_first_flat @ B.T

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2048, 3] but got: [2048, 1024].

In [26]:
channels_last_flat = channels_last.view(
    -1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)

In [27]:
channels_last_flat.shape

torch.Size([64, 1024, 3])

In [29]:
print(*channels_last_flat.shape)

64 1024 3


In [40]:
d_in = 24
d_out = 32
test_weight = torch.randn(d_in, d_out)

In [41]:
test_weight[0]

tensor([ 0.0460,  0.4442, -1.4056, -0.1128,  0.2139,  0.5391, -0.5942,  1.8176,
         0.4017,  0.6120,  2.4649,  0.7039, -0.2908,  0.4892, -0.1077,  1.3844,
        -0.6630,  1.8973, -1.6479,  0.3498, -0.4356, -1.0734,  0.1763,  1.1610,
        -1.8292,  0.6686,  0.5466,  0.7458, -1.6116, -1.9045,  0.7526, -0.1948])

In [42]:
import math
mean=0.
std=math.sqrt(2./(d_in + d_out))

In [43]:
test_weight = torch.nn.init.trunc_normal_(
    test_weight,
    mean=0,
    std=std,
    a=-3*std,
    b=3*std
)

In [44]:
test_weight.shape

torch.Size([24, 32])

In [45]:
test_weight[0]

tensor([-0.5405, -0.1328, -0.0037,  0.2573, -0.0750, -0.0887, -0.2748, -0.0140,
         0.1534, -0.2061,  0.3408, -0.1994, -0.0738,  0.1119, -0.1461,  0.0803,
        -0.0132, -0.0212,  0.0859,  0.2477, -0.0278,  0.0524,  0.1449, -0.1900,
        -0.2866,  0.0095, -0.0197,  0.0568, -0.2302,  0.1078, -0.1733,  0.1052])

In [74]:
class Linear(torch.nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        super(Linear, self).__init__()
        # initiate a linear transformation module
        data = torch.nn.init.trunc_normal_(
            torch.randn(out_features, in_features, device=device, dtype=dtype),
            mean=0,
            std=math.sqrt(2./(in_features + out_features)),
            a=-3*math.sqrt(2./(in_features + out_features)),
            b=3*math.sqrt(2./(in_features + out_features))
        )
        self.weight = torch.nn.Parameter(data)
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # vanilla pytorch way
        # output1 = x @ self.weight.T
        # einsum way
        output = einsum(x, self.weight,
                        "batch seq d_in, d_out d_in -> batch seq d_out"
                       )
        return output

In [67]:
test_input = torch.randn(8,16,32)

In [71]:
linear_layer = Linear(32,64)

In [72]:
output = linear_layer(test_input)

In [73]:
output.shape

torch.Size([8, 16, 64])

torch.Size([32, 16])

In [94]:
class Embedding(torch.nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        super(Embedding, self).__init__()
        data = torch.nn.init.trunc_normal_(
            torch.randn(num_embeddings, embedding_dim, device=device, dtype=dtype),
            mean=0,
            std=1,
            a=-3,
            b=3
        )
        self.embeddings = torch.nn.Parameter(data)
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        output = self.embeddings[token_ids]
        return output

In [95]:
emb_model = Embedding(32, 16)

In [112]:
emb_model.embeddings[torch.tensor([0, 1], dtype=torch.int)]

tensor([[ 1.6969,  0.8193,  0.2542, -0.2801,  1.0709, -0.4859,  0.7410,  0.3139,
         -0.1099, -2.8018,  0.1574, -0.7763,  0.8081,  1.3766, -0.5112, -1.1423],
        [ 0.0933, -0.6333,  0.1861, -0.2137, -1.0431,  1.7421,  0.8766, -0.7102,
          0.4819,  1.3366,  0.4977,  0.0891, -1.5155, -0.2780, -0.9904,  0.6207]],
       grad_fn=<IndexBackward0>)

In [116]:
test_input = torch.tensor([[1,4],[2,3]], dtype=torch.int)

In [117]:
output = emb_model(test_input)

In [121]:
temp = output * torch.ones(16)
temp.shape

torch.Size([2, 2, 16])

In [17]:
class RMSNorm(torch.nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.gain = torch.nn.Parameter(torch.randn(d_model, device=device, dtype=dtype))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squared = x**2 + self.eps
        x_mean = torch.mean(x_squared, dim=-1, keepdim=True)
        x_rms = torch.sqrt(x_mean)
        x /= x_rms
        result = x * self.gain
        return result.to(in_dtype)

In [18]:
rms_norm = RMSNorm(32)

In [21]:
rms_norm.gain.size

<function Parameter.size>

In [159]:
test_input = torch.randn(2, 2, 32)

In [160]:
output = rms_norm(test_input)

In [161]:
output.shape

torch.Size([2, 2, 32])

In [164]:
class RotaryPositionalEmbedding(torch.nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
        super(RotaryPositionalEmbedding, self).__init__()
        position = torch.arange(max_seq_len, device=device)
        dim_range = torch.arange(0, d_k, 2, device=device).float()

        freq = 1.0 / (theta**(dim_range/d_k))
        thetas = torch.outer(position, freq)

        self.register_buffer("freqs", torch.outer(position, freq))
        self.register_buffer("cos_cache", torch.cos(thetas))
        self.register_buffer("sin_cache", torch.sin(thetas))

    def forward(self, x: torch.tensor, token_positions: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (..., seq_len, d_k)
            token_positions: (..., seq_len)
        Returns:
            Tensor: (..., seq_len, d_k)
        """
        seq_len = x.size(-2)
        cos = self.cos_cache[token_positions]
        sin = self.sin_cache[token_positions]
        
        x1, x2 = x.chunk(2, dim=-1)
        rotated = torch.cat(
            (x1 * cos - x2 * sin,
             x2 * cos + x1* sin),
            dim=-1
        )
        return rotated
        

In [197]:
rope = RotaryPositionalEmbedding(theta=10000.0, d_k=4, max_seq_len=32)

In [207]:
x = torch.ones(4, 3, 4)
positions = torch.tensor([2,3,4], dtype=torch.int)

In [199]:
rotated = rope(x, positions)

In [200]:
rotated.shape

torch.Size([3, 4])

In [201]:
rotated

tensor([[-1.3254,  0.9798,  0.4932,  1.0198],
        [-1.1311,  0.9696, -0.8489,  1.0295],
        [ 0.1032,  0.9592, -1.4104,  1.0392]])

In [202]:
x1, x2 = x.chunk(2, dim=-1)

In [203]:
x1

tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])

In [237]:
max_vals = x.max(dim=2, keepdim=True).values
max_vals

tensor([[[1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.]]])

In [238]:
shifted_exp = torch.exp(x - max_vals)
shifted_exp

tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

In [239]:
sum_exp = shifted_exp.sum(dim=2, keepdim=True)
sum_exp

tensor([[[4.],
         [4.],
         [4.]],

        [[4.],
         [4.],
         [4.]],

        [[4.],
         [4.],
         [4.]],

        [[4.],
         [4.],
         [4.]]])

In [240]:
softmax_output = shifted_exp / sum_exp
softmax_output

tensor([[[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]]])

In [219]:
def softmax(x: torch.Tensor, dimension: int) -> torch.Tensor:
    max_vals = torch.max(x, dim=dimension, keepdim=True).values
    shifted_exp = torch.exp(x - max_vals)
    sum_exp = shifted_exp.sum(dim=dimension, keepdim=True)
    softmax_output = shifted_exp / sum_exp
    return softmax_output
    

In [233]:
t = torch.Tensor([1, 1])
t.size(-1)

2

In [234]:
softmax_output = softmax(t, dimension=0)
softmax_output

tensor([0.5000, 0.5000])

In [276]:
def scaled_dot_product_attention(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor=None
):
    """
        Q: batch_size, ..., seq_len, d_k
        K: batch_size, ..., seq_len, d_k
        V: batch_size, ..., seq_len, d_v
        mask: (seq_len, seq_len)
        returns: (batch_size, ..., d_v)
    """
    d_k = Q.size(-1)
    # Dot product
    # attention = Q @ K.transpose(-2,-1)
    attention = einsum(Q, K,
                       "batch_size ... seq_len_q d_q, batch_size ... seq_len_k d_k -> batch_size ... seq_len_q seq_len_k")
    attention_scaled = attention / torch.sqrt(torch.Tensor([d_k]))
    if mask is not None:
        attention_scaled = attention_scaled.masked_fill(mask == 0, -1e9)
        
    attention_weights = softmax(attention_scaled, dimension=-1)
    attention_output = einsum(attention_weights, V,
                              "batch_size ... seq_len_q seq_len_k, batch_size ... seq_len_v d_v -> batch_size ... seq_len_q d_v"
                             )
    return attention_output

In [262]:
# Example parameters
batch_size = 2
seq_len = 5
d_k = 64
d_v = 64

# Create random tensors
query = torch.randn(batch_size, seq_len, d_k)
key = torch.randn(batch_size, seq_len, d_k)
value = torch.randn(batch_size, seq_len, d_v)

# Create attention mask (optional)
mask = torch.ones(batch_size, seq_len, seq_len)
mask[:, :, -2:] = 0  # Mask last two positions

In [263]:
key.shape

torch.Size([2, 5, 64])

In [274]:
dot_product = einsum(query, key,
                     "batch_size ... seq_len_q d_q, batch_size ... seq_len_k d_k -> batch_size ... seq_len_q seq_len_k"
                    )

In [275]:
dot_product.shape

torch.Size([2, 5, 5])

In [269]:
mask.shape

torch.Size([2, 5, 5])

In [277]:
output = scaled_dot_product_attention(query, key, value, mask)

In [278]:
output.shape

torch.Size([2, 5, 64])

In [248]:
key.transpose(-2,-1).shape

torch.Size([2, 64, 7])

In [249]:
key.shape

torch.Size([2, 7, 64])

In [340]:
x = torch.randn(2,3,2)
x[0]

tensor([[-0.6832, -1.6233],
        [ 1.5669,  1.0561],
        [ 1.1985, -0.2246]])

In [341]:
x = torch.stack([x]*3, dim=-2)
x.shape

torch.Size([2, 3, 3, 2])

In [337]:
x = rearrange(x,
           "... a b c -> ... (a b) c"
          )
x.shape

torch.Size([2, 6, 3])

In [338]:
x[0]

tensor([[ 0.9364,  0.9364,  0.9364],
        [-1.8515, -1.8515, -1.8515],
        [-0.3825, -0.3825, -0.3825],
        [-0.8591, -0.8591, -0.8591],
        [-0.2684, -0.2684, -0.2684],
        [ 1.1453,  1.1453,  1.1453]])

In [410]:
class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super(MultiHeadSelfAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = int(d_model / num_heads)
        self.W_q = Linear(d_model, d_model)
        self.W_k = Linear(d_model, d_model)
        self.W_v = Linear(d_model, d_model)
        self.W_o = Linear(d_model, d_model)
        self.rope = RotaryPositionalEmbedding(theta=10000.0, d_k=self.d_k, max_seq_len=256)
    def forward(self, x: torch.Tensor, apply_mask: bool= True):
        seq_len = x.size(-2)
        if apply_mask:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            mask = ~mask
        query = self.W_q(x) # batch_size seq_len d_model
        key = self.W_k(x)
        value = self.W_v(x)
        ## rearrange
        query = rearrange(query,
                          "... seq_len (h d_k) -> ... h seq_len d_k",
                          h=self.num_heads, d_k=self.d_k
                         )
        key = rearrange(key,
                          "... seq_len (h d_k) -> ... h seq_len d_k",
                          h=self.num_heads, d_k=self.d_k
                         )
        value = rearrange(value,
                          "... seq_len (h d_k) -> ... h seq_len d_k",
                          h=self.num_heads, d_k=self.d_k
                         )
        query = rope(query, token_positions=torch.arange(seq_len))
        key = rope(query, token_positions=torch.arange(seq_len))
        attention_values = scaled_dot_product_attention(query, key, value, mask=mask)
        
        attention_values = rearrange(attention_values,
                                     "batch h seq_len d_k -> batch seq_len (h d_k)",
                                     h=int(self.num_heads), d_k=int(self.d_k)
                                    )
        return self.W_o(attention_values)

In [392]:
batch_size = 4
seq_len = 3
d_model = 12
h = 4
d_k = 3

Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

In [368]:
int(d_model / h)

3

In [411]:
mha = MultiHeadSelfAttention(d_model, h)

In [412]:
attention_weights = mha(Q)

In [413]:
attention_weights.shape

torch.Size([4, 3, 12])

In [390]:
attention_weights = rearrange(attention_weights,
                              "b h seq_len d_k -> b seq_len (h d_k)",
                              h=h, d_k=d_k
                             )

In [349]:
Q1 = Q.view(batch_size, seq_len, h, d_k).transpose(1, 2)
Q_ = rearrange(Q,
               "... seq_len (h d_k) -> ... h seq_len d_k",
               h=h, d_k=d_k
              )

In [387]:
attention_weights = rearrange(attention_weights,
                           "b h s d_k -> b s (h d_k)",
                           b=batch_size, h=4, d_k=3)

In [350]:
torch.equal(Q1, Q_)

True

In [414]:
record = {
            "answerKey": "B",
            "choices": {
                "label": ["A", "B", "C", "D"],
                "text": ["Shady areas increased.", "Food sources increased.", ...]
            },
            "question": "...Which best explains why there were more chipmunks the next year?"
        }

In [416]:
labels = record["choices"]["label"]
choices = record["choices"]["text"]

1

In [418]:
ord("A")

65