In [2]:
with open("input.txt", 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
len(text)  # ~1M characters

1115394

In [4]:
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [5]:
# get unique characters
chars = sorted(list(set(text)))
vocab_size = len(chars)

print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


# Tokenization

In [6]:
# mappings, can use unordered data structure for these
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for ch, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]  # encode a str into numbers, must use an ordererd data structuer
decode = lambda l: ''.join([itos[i] for i in l])  # decode a list of ints

print(encode("Hello, world!"))
print(decode(encode("Hello, world!")))

[20, 43, 50, 50, 53, 6, 1, 61, 53, 56, 50, 42, 2]
Hello, world!


## Other ways
- SentencePiece
- tiktoken

In [7]:
import tiktoken

enc = tiktoken.get_encoding("gpt2")

In [8]:
enc.n_vocab  # ~50k, ours is 65

50257

In [9]:
enc.encode("Hello, world!")

[15496, 11, 995, 0]

In [10]:
enc.decode([15496, 11, 995, 0])

'Hello, world!'

## Prepare Dataset

Tokenize/Encode our training set

In [11]:
import torch

data = torch.tensor(encode(text), dtype=torch.long)

data.shape, data.dtype

  cpu = _conversion_method_template(device=torch.device("cpu"))


(torch.Size([1115394]), torch.int64)

In [12]:
data[:1000]  # is 1 big 1D tensor

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

Split into Train/Test

In [13]:
# 90/10 split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

## Batches
Training on whole dataset is expensive <br>
So we divide it into chunks with some max sequence length (AKA context size, block size, batch dimension)

In [14]:
block_size = 8
train_data[:block_size+1]  # +1 cuz we want to predict that one, given the first `block_size` chars

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

We predict on each position (from 0 to block_size) <br>
This is not just for performance: make Transformer used to seeing contexts of length from 1 to block_size

In [15]:
x = train_data[:block_size]  # we use each position, from 0 to block_size
y = train_data[1:block_size+1]  # corresponding predictions for each position
for t in range(block_size):
    ctx = x[:t+1]
    target = y[t]
    print(f"Given {ctx}, predict {target}")

Given tensor([18]), predict 47
Given tensor([18, 47]), predict 56
Given tensor([18, 47, 56]), predict 57
Given tensor([18, 47, 56, 57]), predict 58
Given tensor([18, 47, 56, 57, 58]), predict 1
Given tensor([18, 47, 56, 57, 58,  1]), predict 15
Given tensor([18, 47, 56, 57, 58,  1, 15]), predict 47
Given tensor([18, 47, 56, 57, 58,  1, 15, 47]), predict 58


Batch = many sequences stacked on top of each other (as GPUs are good at parallel processing, we'll process these sequences in a batch in parallel (all at once)) <br>
Make a batch:

In [16]:
torch.manual_seed(1337)

batch_size = 4  # number of independent sequences to process in parallel
block_size = 8  # max len of each sequence

def get_batch(split):
    data = train_data if split == 'train' else val_data

    # pick batch_size number of indices, each ranging from 0 to "len(data) - block_size"
    # why "- block_size"? to avoid out of bounds, cuz we use i+block_size below
    ix = torch.randint(len(data) - block_size, (batch_size, ))

    # stack each sequence for parallel processing (GPUs are good at that)
    x = torch.stack([data[i:i+block_size] for i in ix])  # random input sequences
    # why "+1"? cuz we want to predict next token (if input starts from index i, then corresponding prediction starts from i+1)
    # would "+1" in "i+block_size+1" give out of bounds? no, cuz this index is excluded (it'll go from i+1 to i+block_size)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])  # corresponding predictions for each position in input sequence
    
    return x, y

xb, yb = get_batch("train")
print('Inputs: ', xb.shape)
print('Targets: ', yb.shape)
print('-------')

for b in range(batch_size): # for each sequence in this batch
    for t in range(block_size): # for each position in current sequence
        # "b" gives current sequence
        # "t" gives position in current sequence
        ctx = xb[b, :t+1]  # t+1 is exclusive, will give 0 to t
        target = yb[b, t]
        print(f"Given {ctx}, predict {target}")


Inputs:  torch.Size([4, 8])
Targets:  torch.Size([4, 8])
-------
Given tensor([24]), predict 43
Given tensor([24, 43]), predict 58
Given tensor([24, 43, 58]), predict 5
Given tensor([24, 43, 58,  5]), predict 57
Given tensor([24, 43, 58,  5, 57]), predict 1
Given tensor([24, 43, 58,  5, 57,  1]), predict 46
Given tensor([24, 43, 58,  5, 57,  1, 46]), predict 43
Given tensor([24, 43, 58,  5, 57,  1, 46, 43]), predict 39
Given tensor([44]), predict 53
Given tensor([44, 53]), predict 56
Given tensor([44, 53, 56]), predict 1
Given tensor([44, 53, 56,  1]), predict 58
Given tensor([44, 53, 56,  1, 58]), predict 46
Given tensor([44, 53, 56,  1, 58, 46]), predict 39
Given tensor([44, 53, 56,  1, 58, 46, 39]), predict 58
Given tensor([44, 53, 56,  1, 58, 46, 39, 58]), predict 1
Given tensor([52]), predict 58
Given tensor([52, 58]), predict 1
Given tensor([52, 58,  1]), predict 58
Given tensor([52, 58,  1, 58]), predict 46
Given tensor([52, 58,  1, 58, 46]), predict 39
Given tensor([52, 58,  1,

In [17]:
xb

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

# Modeling - Bigram Model

The simplest model

In [18]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        
        # below is our lookup table
        # the rows will be interpreted as logits (scores) for the next token
        # each token will read these logits for next token from this table
        # when we train it, it'll adjust these scores
        # so it'll be able to predict the next token given the current one
        self.token_emb_table = nn.Embedding(vocab_size, vocab_size)  # table size CxC, where C is vocab_size
    
    def forward(self, idx, targets=None):
        # dimension of idx and targets is (B, T) i.e. (batch_size, block_size)
        # below gives tensor of shape (B, T, C)
        # cuz each token "t" picks the "t-th" row from the table
        # and as each row is vocab_size large, it's like each token becomes vocab_size in length
        
        logits = self.token_emb_table(idx)

        # targets is optional cuz we don't provide it during inference/generation
        # it's only provided/needed during training
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # F.cross_entropy() expects C to be 2nd dimension
            # stretch out all the sequences, preserving the C dimension
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            
            # use -ve log-likelihood as loss to know how well we're predicting
            # when training, it'll maximize the logit associated with correct target (so that one will be picked on later predictions)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # note: the generation process for this Bigram model is inefficient
        # cuz we're feeding the whole sequence for each forward pass
        # but the model only looks at the last token to predict the next token
        # why keep it inefficient?
        # so it works without any changes in future for models that look into the past

        
        # idx is (B, T)
        # max_new_tokens is number of tokens we want to generate
        for _ in range(max_new_tokens):
            logits, loss = self(idx)  # (B, T, C)
            
            # select the logits from last time step for each batch
            # (B, T, C) -> (B, C)
            logits = logits[:, -1, :]
            # softmax turns logits into probabilities (makes them b/w 0 & 1)
            probs = F.softmax(logits, dim=-1)

            # use the probs as multinomial distribution to sample 1 value (the predicted next token)
            # samples 1 value for each batch -> (B, 1)
            idx_next = torch.multinomial(probs, num_samples=1)
            # stack prev context and current predicted token for next iteration
            # this makes it (B, T), (B, T+1), (B, T+2), ..., , (B, T+max_new_tokens)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [19]:
m = BigramLM(vocab_size)
logits, loss = m(xb, yb)
logits.shape, loss

(torch.Size([32, 65]), tensor(4.8786, grad_fn=<NllLossBackward0>))

In [20]:
# (B, T) = (1, 1) -> 1 batch & 1 initial token
initial = torch.zeros((1, 1), dtype=torch.long)

generated_text = m.generate(initial, max_new_tokens=100)

# get only first batch: (B, T+100) -> (T+100)
# then turn it into python list instead of pytorch tensor
generated_text = generated_text[0].tolist()

decode(generated_text)  # gives random gibberish as we've not trained it

"\nSr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3"

# Training - Bigram Model

In [21]:
# simplest: SGD optimizer
# AdamW is much more advanced and popular, works well
# typical good learning rate: 3e-4
# but for very small models, can get away w/ using higher learning rates
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [22]:
batch_size = 32

for steps in range(10_000):
    xb, yb = get_batch("train")  # get a batch w/ randomly chosen sequences
    
    logits, loss = m(xb, yb)

    optimizer.zero_grad(set_to_none=True)  # zero out the gradient
    loss.backward()  # calculate the loss w.r.t each parameter
    optimizer.step()  # update the parameters

print(loss.item())

2.5727508068084717


In [23]:
initial = torch.zeros((1, 1), dtype=torch.long)
generated_text = m.generate(initial, max_new_tokens=500)[0].tolist()
print(decode(generated_text))


Iyoteng h hasbe pave pirance
Rie hicomyonthar's
Plinseard ith henoure wounonthioneir thondy, y heltieiengerofo'dsssit ey
KIN d pe wither vouprrouthercc.
hathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so it t jod weancotha:
h hay.JUCle n prids, r loncave w hollular s O:
HIs; ht anjx?

DUThinqunt.

LaZAnde.
athave l.
KEONH:
ARThanco be y,-hedarwnoddy scace, tridesar, wnl'shenous s ls, theresseys
PlorseelapinghiybHen yof GLUCEN t l-t E:
I hisgothers je are!-e!
QLYotouciullle'z


# Self-Attention

Want to make tokens communicate with each other <br>
but in a special way (tokens should not communicate with future tokens) <br>
<br>
Simplest "Communication": Average the previous elements/time steps/tokens <br>
e.g. For 5th token, average the channels of tokens from 1 to 5 and use it as 5th token's channel

In [90]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [91]:
# x[b, t] will be mean of x[b, i] where i=1 to t
# bow = bag of words = term used when simply averaging things
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        # for current sequence in batch, get tokens from index 0 to t (first t tokens)
        xprev = x[b, :t+1]  # (t, C)
        # average the 0th dim (first dim)
        # (t, C) -> (1, C)
        xbow[b, t] = torch.mean(xprev, 0)
xbow[0, 0]

tensor([ 0.1808, -0.0700])

## Mathematical trick in Self-Attention
Above implementation is inefficient (uses for-loop) <br>
Efficient way would use matrix multiplication (cuz GPUs are good at that), AKA "vectorization"

In [92]:
# matrix multiplication will help in efficient implementation
# matrix c = a @ b here will be sum along columns of b
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(a)
print(b)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[8., 6.],
        [5., 2.],
        [4., 4.]])
tensor([[17., 12.],
        [17., 12.],
        [17., 12.]])


In [93]:
# if we use a lower triangular matrix of all 1s,
# then each row of c will be sum along columns of b till a certain element
# e.g. 1st row of c = sum from 1st element to 1st element (so it's just the first element)
# e.g. 2nd row of c = sum from 1st element to 2nd element
# e.g. 3rd row of c = sum from 1st element to 3rd element
# => can see "ith" row of c uses first "i" elements of b (along column)
# it's similar to what we want (average along first "t" elements)

a = torch.tril(torch.ones(3, 3))  # tril -> make lower triangular matrix from given matrix
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(a)
print(b)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[7., 4.],
        [5., 0.],
        [5., 3.]])
tensor([[ 7.,  4.],
        [12.,  4.],
        [17.,  7.]])


In [94]:
# finally, for average, we need to normalize "a" (divide each row by sum of elements in that row)
# why/what's "normalize"?
#     e.g. (1+2)/2 is average, we can write it like 1/2 * (1+2) = 0.5 * (1+2)
#     in prev cell, each row of "c" is a sum (doesn't multiply or divide by anything)
#     to get the multiplication number for each row, we "normalize" the rows of "a" (i.e. divide by sum of elements along that row)
#         e.g. row of a = [1 1 0], sum is 1+1+0 = 2, so divide by 2 -> [0.5 0.5 0]
#         now, for this row, a @ b will give: 0.5 * (the sum), i.e. the average

a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[8., 9.],
        [2., 7.],
        [3., 9.]])
tensor([[8.0000, 9.0000],
        [5.0000, 8.0000],
        [4.3333, 8.3333]])


Efficient (or "vectorized") Self-Attention

In [95]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
print(weights)

# for 1 sequence: (T, T) @ (T, C) -> (T, C)
# for whole batch: (T, T) @ (B, T, C) -> (B, T, T) @ (B, T, C) = (B, T, C) 
xbow_vec = weights @ x

print(torch.allclose(xbow, xbow_vec, 1e-4))

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
True


Another version <br>
This one will be used: <br>
 - Weights are initialized from zero (like other weights) and we can think of their initial value as "affinity"/interaction strength <br>
 - These affinities will be trained <br>

In [96]:
tril = torch.tril(torch.ones(T, T))
print(tril)
weights = torch.zeros((T, T))
# where tril contains 0, put -inf there
print(weights)
weights = weights.masked_fill(tril == 0, float('-inf'))  # tokens from past won't communicate with future
print(weights)
# softmax(z_ij) = (e^z_ij)/sum_i(e^z_ij)
# z_ij represents an element in the matrix, it can either be 0 or -inf for this matrix
# so e^z_ij will either be e^0 = 1 or e^-inf = 0
# when z_ij is -inf, softmax = 0
# when z_ij is 0, softmax will put normalized value
#     sum_i(e^z_ij) sums along the "ith" row
#     e.g. for 2nd row which is [0 0 -inf -inf ... -inf]: (e^0 + e^0 + e^-inf + e^-inf + ... + -inf) = (1 + 1 + 0 + 0 + ... + 0) = 2
#     so each element e^z_ij will be divided by sum of elements e^z_ij (it will be normalized)
print(torch.exp(weights))  # to check e^z_ij
weights = F.softmax(weights, dim=-1)
print(weights)

xbow3 = weights @ x

torch.allclose(xbow, xbow3, 1e-4)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0.,

True

## Self-Attention
Right now, "weights" contains uniform values in each row <br>
But we want them to be different
 - idea: tokens will find some past tokens more "interesting" or helpful for the prediction task

Self-Attention allows info to flow from past in data-dependent way <br>
 - Every token will emit 2 vectors: Key and Query
   - Query: "What am I looking for?"
   - Key: "What do I contain?"
 - To get affinities b/w tokens, we perform dot-product b/w Keys and Queries
   - e.g. token 5's Query will be dot-product-ed with the other tokens' Keys
   - If a Key and Query interact to a very high amount, then token 5 will learn more about the token associated with that Key
 - For aggregation, we don't use average:
   - we have another vector called Value for each token
     - Value: "here's what I have", "here's my identity", "here's what I'll communicate"
     - we don't aggregate the raw input tokens, but we aggregate their Value vectors
   - we send the dot-product of Query and Key to Softmax to normalize it
   - the output from Softmax is then multiplied with Value
 - so, Attention(Q, K, V) = softmax(Q @ K) @ V

In [105]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)


# addition to prev code
head_size = 16  # new hyperparam

# these take in a vector of size C and give a vector of size head_size
# bias=False means instead of "W @ Input + b" it's just "W @ Input" (no value is added)
#    W will be of size (C, head_size) cuz (1, C) @ (C, head_size) = (1, head_size) which is the size we specified
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# get the Key and Query vector associated with all tokens of all sequences in batch "x"
# these are (B, T, head_size) cuz we feed (B, T, C) rather than (1, C)
#    (B, T, C) @ (C, head_size) -> (B, T, head_size), not (1, head_size)
k = key(x)
q = query(x)

# now the Queries and Keys interact
# (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
# the (T, T) dim means "affinity of all tokens with all other tokens" in that sequence
weights = q @ k.transpose(-2, -1)  # transpose the last 2 dims so they have compatible dims for matrix multiplication
print("Q @ K:\n", weights[0])

tril = torch.tril(torch.ones(T, T))
# weights = torch.zeros((T, T))  # we initialized it with affinities above
weights = weights.masked_fill(tril == 0, float('-inf'))  # tokens from past won't communicate with future
print("\nMasked:\n", weights[0])
weights = F.softmax(weights, dim=-1)  # normalize
print("\nSoftmax(Q @ K):\n", weights[0])

# out = weights @ x  # not the aggregation we use
v = value(x)
out = weights @ v  # now output is (B, T, head_size)
print("\nFinal attention output:\n", out[0])

Q @ K:
 tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

Masked:
 tensor([[-1.7629,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-3.3334, -1.6556,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.0226, -1.2606,  0.0762,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7836, -0.8014, -0.3368, -0.8496, 

In [106]:
weights.shape, out.shape

(torch.Size([4, 8, 8]), torch.Size([4, 8, 16]))