In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F

vocab_size = 65

# hyperparameters
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

class CustomSequential(nn.Sequential):
    def forward(self, x, attention_mask=None):
        for module in self:
            x = module(x, attention_mask)
        return x

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x,  attention_mask=None):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        if attention_mask == None:
          # print('attention mask is:', attention_mask)
          wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        else:
          # print(attention_mask)
          wei = wei.masked_fill(attention_mask[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x,  attention_mask=None):
        out = torch.cat([h(x,  attention_mask=attention_mask) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x,  attention_mask=None):
        x = x + self.sa(self.ln1(x),  attention_mask=attention_mask)
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = CustomSequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx,  attention_mask, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x, attention_mask) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

        
    def construct_mask(self, idx_thought, attention_mask, n_ahead):

    
      B, T, R = idx_thought.shape  # batch_size, seq_length, n_thoughts_ahead
      
      og_mask = torch.tril(torch.ones(T, T))
      og_mask = og_mask.repeat(n_ahead, 1)
      
      # permute
      idx_thought_perm = idx_thought.permute(0, 2, 1)  # (B, T, R) -> (B, R, T)
      idx_thought_perm = idx_thought_perm.contiguous().view(B, R * T)  # (B, T, R) -> (B, R * T)

      # print("idx_thought_perm", idx_thought_perm.shape)
      # print("idx_thought", idx_thought)

      # create masks
      main_diagonal_mask = torch.eye(R * T)  # new diagonal masks
      main_diagonal_mask[: attention_mask.shape[0], :attention_mask.shape[1]] += attention_mask
     
      # ugly eye addition
      for i in range(T, T * R, T):
        main_diagonal_mask[-i:, :i] += torch.eye(i)
    
      # augmenting based on og mask
      main_diagonal_mask[:, :T] += og_mask
      main_diagonal_mask[main_diagonal_mask > 1] = 1
      # print(main_diagonal_mask)
      return main_diagonal_mask, idx_thought_perm

    def think(self, idx, attention_mask, max_new_tokens=1):

        B, T = idx.shape

        idx_thought = idx.unsqueeze(-1) # add thought dimension R: (B, T) -> (B, T, R)

        for thought_ahead in range(max_new_tokens):

            if thought_ahead == 0:
              
              # compute logits
              idx_cond = idx[:, -block_size:]
              logits, loss = self(idx_cond, attention_mask)

              # get and flatten probs 
              probs = F.softmax(logits, dim=-1) 
              probs = probs.view(B * T, -1) # flatten (B, T, C) -> (B*T, C)

              # sample next thought token for each token in sequence
              idx_next_thought = torch.multinomial(probs, num_samples=1) # (B*T, 1)
              idx_next_thought = idx_next_thought.view(B, T) # (B*T, 1) -> (B, T)

              # concatenate 
              idx_next_thought = idx_next_thought.unsqueeze(-1)
              idx_thought = torch.cat((idx_thought, idx_next_thought), dim=-1) # (B, T+1)

              # update attention_mask
              attention_mask, idx_thought_permute = self.construct_mask(idx_thought, attention_mask, n_ahead=thought_ahead+2)
            
            else:

              # compute logits
              idx_cond = idx_thought_permute[:, -block_size:]
              logits, loss = self(idx_cond, attention_mask)
              logits = logits[:, -T:]

              # get and flatten probs 
              probs = F.softmax(logits, dim=-1) 
              probs = probs.view(B * T, -1) # flatten (B, T, C) -> (B*T, C)

              # sample next thought token for each token in sequence
              idx_next_thought = torch.multinomial(probs, num_samples=1) # (B*T, 1)
              idx_next_thought = idx_next_thought.view(B, T) # (B*T, 1) -> (B, T)

              # concatenate 
              idx_next_thought = idx_next_thought.unsqueeze(-1)
              idx_thought = torch.cat((idx_thought, idx_next_thought), dim=-1) # (B, T+1)

              # update attention_mask
              attention_mask, idx_thought_permute = self.construct_mask(idx_thought, attention_mask, n_ahead=thought_ahead+2)

        return idx_thought

model = BigramLanguageModel()
m = model.to(device)
input = torch.randint(5, (1, 2))
print("INPUT", input)
out = m.think(input, attention_mask=torch.tril(torch.ones(2, 2)), max_new_tokens=3)
out

INPUT tensor([[0, 4]])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [4]:
model knows answer -> just no thought and respond
model does not know answer? think [star-style but model can pick]
and then talk?

ie xxxx -> yyyy
if probs of yyy sharp enough/top_k -> go ahead
if not, think

ie measure if the topk traj is the corr answer 

if not, think and take thoughts that lead to correct answer becoming the topk

then finetune the thought head on that separately? 

so just finetune thought Head


but also need to finetune the classifier/decision threshold thingy

can use logits for that ? ie distributions of logits to predict -> 0, 1; try then to recognise this during training? or is there smth else that might be helpful here? smth before logits maybe? 

SyntaxError: invalid syntax. Perhaps you forgot a comma? (812658534.py, line 1)

In [5]:
import torch

# Assuming T and R are defined
T = 3
R = 2

# Create the main diagonal mask
main_diagonal_mask = torch.zeros((T * R, T * R))

# Create the base eye matrix
eye_matrices = [torch.eye(T * i) for i in range(1, R + 1)]

# Concatenate to create the block diagonal structure
block_diag_mask = torch.block_diag(*eye_matrices)

# Fill the main_diagonal_mask
main_diagonal_mask[-block_diag_mask.shape[0]:, :block_diag_mask.shape[1]] += block_diag_mask

print(main_diagonal_mask)
eye_matrices


RuntimeError: The size of tensor a (6) must match the size of tensor b (9) at non-singleton dimension 1