In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from tokenizers import ByteLevelBPETokenizer, trainers, pre_tokenizers, decoders
import json
import numpy as np
import pickle

In [32]:
### ------------------------------------------------------------------------------------------------------------------
#  Initialize Constants, Text Memory, and Tokenizer
### ------------------------------------------------------------------------------------------------------------------

device = 'mps' if torch.cuda.is_available() else 'cpu'

# How long before you stop and think?
focus_len = 21

# How many positional embeddings?
n_embd = 21

# Constants for multi-head attention
n_head = 3
n_layer = 3

# Define dropout
dropout = 0.2

# Initialize empty memory vector
with open("v01_txt/text_memory.txt", "w", encoding="utf-8") as text_memory:
    for _ in range(500):
        text_memory.write(" \n")

# Initialize tokenizer trained on nothing
init_tokenizer = ByteLevelBPETokenizer(vocab=None, merges=None)
init_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
init_tokenizer.decoder = decoders.ByteLevel()

init_tokenizer.train("v01_txt/text_memory.txt", vocab_size=1000,)

init_save_path = "/Users/bennyrose/Desktop/Papers n Projects/Little Guy/v0/v01_tokenizer"
init_tokenizer.save(init_save_path)

tokenizer = init_tokenizer

# Get vocab and vocab length
with open("v01_tokenizer", 'r') as f:
    json_content = f.read()

parsed_json = json.loads(json_content)
vocab = list(parsed_json["model"]["vocab"].keys())
vocab_size = len(vocab)

print(vocab)
print(vocab_size)

# Initialize "old_top_tokens" for very first response
old_top_tokens = [137, 137, 137]

# How much are other q-values discounted if feedback was positive?
q_discount_pos = 0.75

# How much are other q-values discounted if feedback was neutral?
q_discount_neut = 0.95

# How much is chosen q-value discounted if feedback was negative?
q_discount_neg = 0.5

# Initialize "old_top_tokens" and "old_action" for very first response
old_top_tokenids = [137, 137, 137]
old_action = None

# How many tokens to generate at a time?
max_new_tokens = 40
    




['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'À', 'Á', 'Â', 'Ã', 'Ä', 'Å', 'Æ', 'Ç', 'È', 'É', 'Ê', 'Ë', 'Ì', 'Í', 'Î', 'Ï', 'Ð', 'Ñ', 'Ò', 'Ó', 'Ô', 'Õ', 'Ö', '×', 'Ø', 'Ù', 'Ú', 'Û', 'Ü', 'Ý', 'Þ', 'ß', 'à', 'á', 'â', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ð', 'ñ', 'ò', 'ó', 'ô', 'õ', 'ö', '÷', 'ø', 'ù', 'ú', 'û', 'ü', 'ý', 'þ', 'ÿ', 'Ā', 'ā', 'Ă', 'ă', 'Ą', 'ą', 'Ć', 'ć', 'Ĉ', 'ĉ', 'Ċ', 

In [33]:
### ------------------------------------------------------------------------------------------------------------------
#  Initialize learnable embedding tables
### ------------------------------------------------------------------------------------------------------------------

class EmbeddingTables(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(n_embd, n_embd)

        # Initialize and save token embedding table
        torch.nn.init.normal_(self.token_embedding_table.weight, mean=0.0, std=0.02)
        # token_embedding_weight = self.token_embedding_table.weight.detach().numpy()
        
        with open("v01_pkl/token_embed.pkl", "wb") as f:
            pickle.dump(self.token_embedding_table, f)
            # print("Saved token embedding table to token_embed.pkl")

        # Initialize and save position embedding table
        torch.nn.init.normal_(self.position_embedding_table.weight, mean=0.0, std=0.02)
        # position_embedding_weight = self.position_embedding_table.weight.detach().numpy()

        with open("v01_pkl/position_embed.pkl", "wb") as f:
            pickle.dump(self.position_embedding_table, f)
            # print("Saved position embedding table to position_embed.pkl")

    def update_embeds(self, vocab, new_vocab):
        # Update token embedding table
        with open('v01_pkl/token_embed.pkl', 'rb') as f:
            self.token_embedding_table = pickle.load(f)

        old_token_table = self.token_embedding_table.weight.detach().numpy()
        old_token_embedding_table = torch.tensor(old_token_table)

        new_token_embedding_table = old_token_embedding_table.clone()

        for token in new_vocab:
            if token not in vocab:
                random_embedding = torch.empty(1, old_token_embedding_table.size(1)).normal_(mean=0.0, std=0.02)
                new_token_embedding_table = torch.cat([new_token_embedding_table, random_embedding], dim=0)

        nums, dims = new_token_embedding_table.size()
        self.new_token_embedding_table = nn.Embedding(nums, dims)

        with open('v01_pkl/token_embed.pkl', 'wb') as f:
            pickle.dump(self.new_token_embedding_table, f)

        # Update position embedding table
        with open('v01_pkl/position_embed.pkl', 'rb') as f:
            self.position_embedding_table = pickle.load(f)

        old_position_table = self.position_embedding_table.weight.detach().numpy()
        old_position_embedding_table = torch.tensor(old_position_table)

        new_position_embedding_table = old_position_embedding_table.clone()

        for token in new_vocab:
            if token not in vocab:
                random_embedding = torch.empty(1, old_position_embedding_table.size(1)).normal_(mean=0.0, std=0.02)
                new_position_embedding_table = torch.cat([new_position_embedding_table, random_embedding], dim=0)
        
        nums, dims = new_position_embedding_table.size()
        self.new_position_embedding_table = nn.Embedding(nums, dims)

        with open('v01_pkl/position_embed.pkl', 'wb') as f:
            pickle.dump(self.new_position_embedding_table, f)


    # Function to get token and positional embeddings
    def forward(self, index):
        with open('v01_pkl/token_embed.pkl', 'rb') as f:
            self.token_embedding_table = pickle.load(f)

        with open('v01_pkl/position_embed.pkl', 'rb') as f:
            self.position_embedding_table = pickle.load(f)

        tok_emb = self.token_embedding_table(index)
        pos_emb = self.position_embedding_table(torch.arange(len(index), device = device))

        x = tok_emb + pos_emb

        return x

embeds = EmbeddingTables(vocab_size, n_embd)
        

In [34]:
### ------------------------------------------------------------------------------------------------------------------
#  Initialize Feed Forward architecture
### ------------------------------------------------------------------------------------------------------------------

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)
    
### ------------------------------------------------------------------------------------------------------------------
#  Initialize Single-Head Attention for Learning
### ------------------------------------------------------------------------------------------------------------------

class SingleHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(focus_len, focus_len)))

        self.dropout = nn.Dropout(dropout)

        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v

        x = self.ln1(x + out)
        y = self.ffwd(x)
        out = self.ln2(x + y)

        return out
    
single_head = SingleHead(n_embd)

In [35]:
### ------------------------------------------------------------------------------------------------------------------
#  Initialize Q-Learning Architecture
### ------------------------------------------------------------------------------------------------------------------

class QLearn():
    def __init__(self, vocab_size):
        super().__init__()

        self.actions = [[1, 1, 1], [1, 1, 2],
                        [1, 2, 1], [1, 2, 2],
                        [2, 1, 1], [2, 1, 2],
                        [2, 2, 1], [2, 2, 2], None]
        
        self.q_vals = torch.ones(vocab_size, vocab_size, vocab_size, len(self.actions), dtype=torch.float)
        self.q_vals[:, :, :, :8] = 0.1
        self.q_vals[:, :, :, 8] = 0.2

        # with open("v0_pkl/q_vals.pkl", "wb") as f:
        #     pickle.dump(self.q_vals, f)

    def update_q_matrix(self, vocab, new_vocab):
        # Open saved q_val table
        # with open('v0_pkl/q_vals.pkl', 'rb') as f:
        #     self.q_vals = pickle.load(f)

        # Create new table from updated vocab
        new_q_vals = torch.ones(len(new_vocab), len(new_vocab), len(new_vocab), len(self.actions), dtype=torch.float)
        new_q_vals[:, :, :, :8] = 0.1
        new_q_vals[:, :, :, 8] = 0.2

        # Copy values from the existing lookup table to the new one
        for i, token_i in enumerate(vocab):
            if token_i in new_vocab:
                for j, token_j in enumerate(vocab):
                    if token_j in new_vocab:
                        for k, token_k in enumerate(vocab):
                            if token_k in new_vocab:
                                new_i = new_vocab.index(token_i)
                                new_j = new_vocab.index(token_j)
                                new_k = new_vocab.index(token_k)
                                new_q_vals[new_i, new_j, new_k] = self.q_vals[i, j, k]

        self.q_vals = new_q_vals

        # Save updated table as q_vals
        # with open('v0_pkl/q_vals.pkl', 'wb') as f:
        #     pickle.dump(self.new_q_vals, f)

    def update_q_values(self, old_top_tokenids, old_action, vocab):

        # Update q-values based on feedback
        with open('v01_txt/current_feedback.txt', 'r') as f:
            feedback = f.read()

        # with open('v0_pkl/q_vals.pkl', 'rb') as f:
        #     self.q_vals = pickle.load(f)

        action_index = self.actions.index(old_action)
        
        if feedback == "1":    # Positive feedback
            for i, token_i in enumerate(vocab):
                for j, token_j in enumerate(vocab):
                    for k, token_k in enumerate(vocab):
                        if ((i in old_top_tokenids) or (j in old_top_tokenids) or (k in old_top_tokenids)):
                            sum = 0
                            for _ in range(9):
                                if _ != action_index:
                                    self.q_vals[i, j, k, _] *= q_discount_pos
                                    sum += self.q_vals[i, j, k, _]

                            # sum = torch.round(sum, decimals=6)
                            self.q_vals[i, j, k, action_index] = 1 - sum


        elif feedback == "2":  # Neutral feedback
            for i, token_i in enumerate(vocab):
                for j, token_j in enumerate(vocab):
                    for k, token_k in enumerate(vocab):
                        if ((i in old_top_tokenids) or (j in old_top_tokenids) or (k in old_top_tokenids)):
                            sum = 0
                            for _ in range(9):
                                if _ != action_index:
                                    self.q_vals[i, j, k, _] *= q_discount_neut
                                    sum += self.q_vals[i, j, k, _]

                            # sum = torch.round(sum, decimals=6)
                            self.q_vals[i, j, k, action_index] = 1 - sum
            
        else:                  # Negative feedback
            for i, token_i in enumerate(vocab):
                for j, token_j in enumerate(vocab):
                    for k, token_k in enumerate(vocab):
                        if ((i in old_top_tokenids) or (j in old_top_tokenids) or (k in old_top_tokenids)):
                            new_val = self.q_vals[i, j, k, action_index] * q_discount_neg
                            diff = self.q_vals[i, j, k, action_index] - new_val
                            to_add = diff / 8          # Distribute difference among remaining 8 paths

                            self.q_vals[i, j, k, action_index] = new_val
                            for _ in range(9):
                                if _ != action_index:
                                    self.q_vals[i, j, k, _] += to_add


            
    def get_path(self, new_top_tokens):
        # Load the appropriate q_vals
        # with open('v0_pkl/q_vals.pkl', 'rb') as f:
        #     self.q_vals = pickle.load(f)

        # Get the probability distribution located at top_token indices
        action_probs = self.q_vals[new_top_tokens[0], new_top_tokens[1], new_top_tokens[2]]
        # print(action_probs)

        # Select an action index from this distribution
        # action_index = np.random.choice(len(action_probs), p=action_probs)
        action_index = torch.multinomial(action_probs, 1).item()

        # Get action from this index
        action = self.actions[action_index]

        return action
    
q_learn = QLearn(vocab_size)
        

In [36]:
### ------------------------------------------------------------------------------------------------------------------
#  Step 1 (L)
### ------------------------------------------------------------------------------------------------------------------
def L1():
    # Start new "current input" record
    with open("v01_txt/current_input.txt", "w", encoding="utf-8"):
        pass

    while True:
        
        # Retrieve Feedback and Input
        emoji_choice = input("1. 🤩  2. 🙂  3. 😓\n Enter Feedback, or Press 'return' to Exit: ")
        
        if emoji_choice == "":
            break
        if emoji_choice not in ("1", "2", "3"):
            emoji_choice = "2"

        # Map the user's choice to an emoji and emotion
        emoji_mapping = {
            '1': '🤩',
            '2': '🙂',
            '3': '😓'
        }

        emoji = emoji_mapping.get(emoji_choice)

        emotion_mapping = {
            '1': 'Positive',
            '2': 'Neutral',
            '3': 'Negative'
        }

        emotion = emotion_mapping.get(emoji_choice)
        
        # Open text window
        user_input = input(f"Say Something {emotion} {emoji}: ")

        if user_input.endswith(".txt"):
            print("Received a .txt file!")
            break

        print("Received: ", emoji, user_input)

        with open("v01_txt/current_feedback.txt", "w"):
            pass

        with open("v01_txt/current_feedback.txt", "a") as current_feedback:
            current_feedback.write(emoji_choice)

        with open("v01_txt/current_input.txt", "a", encoding="utf-8") as current_input:
            current_input.write(user_input + "\n")
        line_nums = sum(1 for _ in open("v01_txt/current_input.txt", "r", encoding="utf-8"))
        if line_nums > 5:
            print("Let Little Guy Think!")
            break


In [37]:
### ------------------------------------------------------------------------------------------------------------------
#  Step 4 (L)
### ------------------------------------------------------------------------------------------------------------------

def L4():

    # Given current input, add to buffer
    with open("v01_txt/current_input.txt", "r", encoding="utf-8") as current_input:

        with open("v01_txt/buffer.txt", "a", encoding="utf-8") as buffer_file:
            buffer_file.writelines(current_input) 

    # Given current input, add to daily buffer
    with open("v01_txt/current_input.txt", "r", encoding="utf-8") as current_input:

        with open("v01_txt/daily_buffer.txt", "a", encoding="utf-8") as daily_buffer:
            daily_buffer.writelines(current_input)

    
    with open("v01_txt/buffer.txt", "r", encoding="utf-8") as buffer_file:
        line_count = sum(1 for line in buffer_file)
        # print(line_count) 

    # If buffer is full, add to text memory, empty buffer, and retokenize based on recent memory
    if line_count >= 10:

        push_text_replace("text_memory.txt", "buffer.txt")

        with open("v01_txt/buffer.txt", "w", encoding="utf-8") as buffer_file:
            pass
    
        # Initialize tokenizer and train on recent text memory
        new_tokenizer = ByteLevelBPETokenizer(vocab=None, merges=None)
        new_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
        new_tokenizer.decoder = decoders.ByteLevel()

        new_tokenizer.train("v01_txt/text_memory.txt", vocab_size=1000,)

        initial_save_path = "/Users/bennyrose/Desktop/Papers n Projects/Little Guy/v0/new_tokenizer"
        new_tokenizer.save(initial_save_path)
        
        return new_tokenizer
    

##
### Create function for altering .txt files
##
    
def push_text_replace(memory, input):
    # Read the lines from input1.txt
    with open(f"v01_txt/{memory}", "r", encoding="utf-8") as memory_file_read:
        memory_lines = memory_file_read.readlines()
        memory_file_read.close()


    # Read the lines from text_memory.txt
    with open(f"v01_txt/{input}", "r", encoding="utf-8") as input_file_read:
        input_lines = input_file_read.readlines()
        input_file_read.close()

    # Combine the lines from input1.txt and text_memory.txt
    new_lines = memory_lines[len(input_lines):] + input_lines

    # Write the combined lines back to text_memory.txt
    with open(f"v01_txt/{memory}", "w", encoding="utf-8") as memory_file_write:
        memory_file_write.writelines(new_lines)

In [38]:
### ------------------------------------------------------------------------------------------------------------------
#  Step 5 (L)
### ------------------------------------------------------------------------------------------------------------------

def L5(tokenizer):
    with open("v01_tokenizer", "r") as file:
        json_content = file.read()

    parsed_json = json.loads(json_content)
    new_vocab = list(parsed_json["model"]["vocab"].keys())
    new_vocab_len = len(parsed_json["model"]["vocab"])

    return new_vocab, new_vocab_len

In [39]:
### ------------------------------------------------------------------------------------------------------------------
#  Step 6 (L)
### ------------------------------------------------------------------------------------------------------------------

def L6(tokenizer, focus_len):

    with open("v01_txt/current_input.txt", "r", encoding="utf-8") as file:
        text_data = file.read()

    # Encode text
    tokens = tokenizer.encode(text_data).ids
    data = torch.tensor(tokens)

    if len(data) > focus_len:     # Truncate long entries
        index = data[-focus_len:]
    elif len(data) < focus_len:   # Pad short entries
        empty_add = torch.tensor([256] * (focus_len - len(data)))
        index = torch.cat([empty_add, data])
    else: index = data

    return index, data

In [40]:
### ------------------------------------------------------------------------------------------------------------------
#  Step 7 (L)
### ------------------------------------------------------------------------------------------------------------------

def GetTopTokens(attended, index, vocab):
    avgs = torch.mean(torch.abs(attended), dim=1)
    top_indices = torch.topk(torch.abs(avgs), k=3, dim=0)[1]

    # Create a new tensor with 1s at the top 3 indices and 0s elsewhere
    binary_tensor = torch.zeros_like(avgs)
    binary_tensor.scatter_(0, top_indices, 1)
    
    top_tokens = binary_tensor.clone().detach().bool()

    decoded_tokens = [tokenizer.decode([token_id]) for token_id in index.tolist()]

    selected_tokens = [token for token, is_top in zip(decoded_tokens, top_tokens) if is_top]
    selected_token_ids = [token_id for token_id, is_top in zip(index.tolist(), top_tokens) if is_top]

    return selected_tokens, selected_token_ids
    

In [41]:
### ------------------------------------------------------------------------------------------------------------------
#  Define transformer blocks
### ------------------------------------------------------------------------------------------------------------------
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        y = self.sa(x)
        x = self.ln1(x + y)
        y = self.ffwd(x)
        x = self.ln2(x + y)
        return x


### ------------------------------------------------------------------------------------------------------------------
#  Define multi-head attention
### ------------------------------------------------------------------------------------------------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


### ------------------------------------------------------------------------------------------------------------------
#  Define feed forward sequence
### ------------------------------------------------------------------------------------------------------------------
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

### ------------------------------------------------------------------------------------------------------------------
#  Define individual attention head
### ------------------------------------------------------------------------------------------------------------------
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(focus_len, focus_len)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v

        return out
    
    
### ------------------------------------------------------------------------------------------------------------------
#  Initialize Dual-GPT Architecture
### ------------------------------------------------------------------------------------------------------------------
class DGPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        ##
        ### Create block sequence 1 with distinct blocks
        ##
        self.block_11 = Block(n_embd, n_head=n_head)
        self.block_12 = Block(n_embd, n_head=n_head)
        self.block_13 = Block(n_embd, n_head=n_head)

        self.blocks_1 = nn.Sequential(self.block_11, self.block_12, self.block_13)

        self.seq1_blocks = [self.block_11, self.block_12, self.block_13]

        ##
        ### Create block sequence 2 with distinct blocks
        ##
        self.block_21 = Block(n_embd, n_head=n_head)
        self.block_22 = Block(n_embd, n_head=n_head)
        self.block_23 = Block(n_embd, n_head=n_head)

        self.blocks_2 = nn.Sequential(self.block_21, self.block_22, self.block_23)

        self.seq2_blocks = [self.block_21, self.block_22, self.block_23]

        ##
        ### Create layer norm and linear layer
        ##
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        self.apply(self.__init__weights)

    def __init__weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    ##
    ### When vocab changes, add a new linear layer of correct size
    ##
    def update_linlayer(self, old_vocab_size, new_vocab_size):
        new_lm_head = nn.Linear(old_vocab_size, new_vocab_size)
        torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02)
        if new_lm_head.bias is not None:
            torch.nn.init.zeros_(new_lm_head.bias)
        
        if isinstance(self.lm_head, nn.Linear):
            self.lm_head = nn.Sequential(self.lm_head, new_lm_head)
        else: self.lm_head = nn.Sequential(*self.lm_heah, new_lm_head)

    ##
    ### Train both major pathways
    ##
    def train_forward(self, embeds, index, targets=None):

        x = embeds.forward(index)
        
        x1 = self.blocks_1(x)
        x2 = self.blocks_2(x)
        x1 = self.ln_f(x1)
        x2 = self.ln_f(x2)
        logits1 = self.lm_head(x1)
        logits2 = self.lm_head(x2)

        if targets is None:
            loss = None
        else:
            B1, T1, C1 = logits1.shape
            logits1 = logits1.view(B1*T1, C1)
            targets = targets.view(B1*T1)
            loss1 = F.cross_entropy(logits1, targets)

            B2, T2, C2 = logits2.shape
            logits2 = logits2.view(B2*T2, C2)
            targets = targets.view(B2*T2)
            loss2 = F.cross_entropy(logits2, targets)

        return logits1, loss1, logits2, loss2
    
    ##
    ### Forward step for generation
    ##
    def generate_forward(self, embeds, index, targets=None):

        x = embeds.forward(index)

        x = self.curr_blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    ##
    ### Select transformer sequence and generate
    ##
    def generate(self, embeds, index, path, max_new_tokens):
        if path == None:
            return None
        else:
            block_path = [self.seq2_blocks[i] if p == 2 else self.seq1_blocks[i] for i, p in enumerate(path)]
            self.curr_blocks = nn.Sequential(*block_path)

        for _ in range(max_new_tokens):
            logits, loss = self.generate_forward(embeds, index, None)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            index_next = torch.multinomial(probs, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
            index = index[:, 1:]
        return index
    
dual_gpt = DGPT(vocab_size)


In [42]:
##
### Get new input text
## 

L1()

##
### Update Q-Values
##
q_learn.update_q_values(old_top_tokenids, old_action, vocab)

##
### Save input text and retokenize if needed
##

new_tokenizer = L4()

##
### Save new tokenizer that may have been generated in previous step
### Additionally get new vocab and new vocab length
##

if new_tokenizer is not None:
    tokenizer = new_tokenizer
    tokenizer.save("v01_tokenizer")

    # Get new vocab
    new_vocab, new_vocab_len = L5(tokenizer)

    # Update embedding tables
    embeds.update_embeds(vocab, new_vocab)

    # Updating Q-Table
    q_learn.update_q_matrix(vocab, new_vocab)

    vocab = new_vocab

##
### Encode and get training text (index)
##

index, data = L6(tokenizer, focus_len)

##
### Get embedding of index
##

x = embeds.forward(index)

##
### Perform single-head attention on the embedding
##

attended = single_head.forward(x)

##
### Extract the top 3 priority tokens
##

new_top_tokens, new_top_tokenids = GetTopTokens(attended, index, vocab)

##
### Get transformer sequence from Q-values
##

path = q_learn.get_path(new_top_tokenids)

##
### Reset "old" variables from this step
##

old_action = path
old_top_tokenids = new_top_tokenids

if path is not None:
    generated_chars = tokenizer.decode(dual_gpt.generate(embeds, index.unsqueeze(0), path, max_new_tokens)[0].tolist())
    print(f"Little Guy says: {generated_chars}")
else: print("Little Guy's got nothin' to say!")


Received:  🤩 I hope I fixed it!
Little Guy's got nothin' to say!


In [15]:
print(path)
print(new_top_tokens)
# print(old_action)
# print(len(vocab))

[1, 1, 2]
[' \n', ' \n', ' \n']


In [79]:
### ------------------------------------------------------------------------------------------------------------------
#  Get training batches from daily buffer
### ------------------------------------------------------------------------------------------------------------------

### Place in initialization block
block_size = 32
batch_size = 64

### Place in training loop function
with open("v01_txt/daily_buffer.txt", "r", encoding="utf-8") as f:
    text = f.read()

    ### Encode daily buffer text
    tokens = tokenizer.encode(text).ids
    data = torch.tensor(tokens, dtype=torch.long)

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))

    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y


In [None]:
##
### For optimizer, dual_gpt.parameters() won't capture transformer sequences seperately
### figure out how to optimize different sequences independently
##