<a href="https://colab.research.google.com/github/blooming-ai/generativeai/blob/main/generativeai/examples/generative/ipynb/gpt_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT code
https://github.com/karpathy/minGPT

In [4]:
#Get data file
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!mv input.txt sample_data/

--2023-08-11 06:15:43--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2023-08-11 06:15:43 (29.5 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



## Word embedding
### The following BPE code is adapted from
### [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/pdf/1508.07909.pdf)


In [5]:
file_path = "sample_data/input.txt"

In [25]:
import re
import pdb
import string
from collections import defaultdict

#Byte pair encoding
def word_to_charachter_tuple(word:str)->tuple:
    '''
    Converts a word into a tuple of characters along with an end character.
    "word" -> ('w','o','r','d','</w>')
    '''
    word.strip()
    word = "".join(ch for ch in word if ch.isalnum()) # keep only alpha-numeric characters
    _lst = list(word.lower())
    _lst.append("<\w>") # add end of word
    return tuple(_lst)

def get_bigram_list(word_as_tuple:tuple)->list:
    '''
    returns ('w','o','r','d','</w>') -> [('w','o'),('o,'r'),('r','d'),('d','</w>')]
    '''
    output = []
    for i in range(len(word_as_tuple)-1):
        output.append((word_as_tuple[i],word_as_tuple[i+1]))
    return output

def construct_word_vocab(word_vocab:defaultdict, file_path:str)->dict:
    '''
    Read file and update word_vocab dict. word_vocab has format word_vocab[('w','o','r','d')]->freq
    '''
    with open(file_path) as fp:
        for line in fp:
            for item in line.split(): #split ignore multiple spaces
                item_as_tuple = word_to_charachter_tuple(item) # tuple to make a hashable dict key
                if len(item_as_tuple) == 0: continue #ignore empty key
                word_vocab[item_as_tuple] += 1

    return word_vocab

def get_byte_pair_hist(word_vocab:defaultdict)->dict:
    '''
    Read word_vocab[('w','o','r','d\w')]->freq and construct byte pair histogram
    returns pair[('w','o')]->freq
    '''
    pair = defaultdict(int)
    for word, freq in word_vocab.items():
        for i in range(len(word)-1):
            key = (word[i],word[i+1])
            pair[key] += freq

    return pair

def merge_pair(replace_pair:tuple, word_vocab_in:dict)->dict:
    '''
    replace pair in the key of the word_vocab. E.g. pair=('w','o') then update
    word_vocab_in[('w','o','r','d')]->freq to
    word_vocab_in[('wo','r','d')]->freq
    '''
    word_vocab_out = {}
    for word, freq in word_vocab_in.items():
        word_new = list()
        is_last_char_used = False
        i=0
        while i < len(word)-1:
            if (word[i],word[i+1]) == replace_pair:
                word_new.append(word[i]+word[i+1])
                if i == len(word)-2: is_last_char_used = True
                i += 1 # skip the next merged character
            else:
                word_new.append(word[i])

            i += 1

        if not is_last_char_used : word_new.append(word[len(word)-1])
        word_vocab_out[tuple(word_new)] = freq

    return word_vocab_out

def byte_pair_encoding(word_vocab:dict, n:int=0)->dict:
    '''
    Given word_vocab_in[('w','o','r','d')]->freq merges pairs with highest frequency 'n' times.
    E.g. if ('w','o') has highest freq. then word_vocab_in[('wo','r','d')]->freq.
    returns:
    Encoder encoder[('w','o')] -> id
    Decoder decoder[ id ] -> ('w','o')
    merges [ ('w','o') ... ]
    '''
    i = 0
    merges = []
   # merges = list(string.ascii_lowercase)
   # merges.extend(list(string.digits))
    while True:
        pairs = get_byte_pair_hist(word_vocab)
        best = max(pairs, key=pairs.get)
        merges.append((best[0],best[1]))
        if pairs[best] == 1: break # do not replace if pair occurs only once.
        word_vocab = merge_pair(best, word_vocab)
        if(i >= n and n > 0): break # break if i exceeds n
        i += 1

    bpe_encoder = {token:i for i, token in enumerate(merges) }
    bpe_decoder = {i:token for i, token in enumerate(merges) }
    return bpe_encoder, bpe_decoder, merges


In [26]:
#Testing:
word_vocab = defaultdict(int)
word_vocab = construct_word_vocab(word_vocab, file_path)
print("is empty string a key in word_vocab:",() in word_vocab)
pairs = get_byte_pair_hist(word_vocab)
best = max(pairs, key=pairs.get)
word_vocab_out = merge_pair(('f','i'), word_vocab)

print("word vocab -> ",word_vocab)
print("byte pair hist -> ",pairs)
print("best pair -> ",best)
print("merge f,i-> ",word_vocab_out)





is empty string a key in word_vocab: False
word vocab ->  defaultdict(<class 'int'>, {('f', 'i', 'r', 's', 't', '<\\w>'): 361, ('c', 'i', 't', 'i', 'z', 'e', 'n', '<\\w>'): 100, ('b', 'e', 'f', 'o', 'r', 'e', '<\\w>'): 193, ('w', 'e', '<\\w>'): 865, ('p', 'r', 'o', 'c', 'e', 'e', 'd', '<\\w>'): 21, ('a', 'n', 'y', '<\\w>'): 189, ('f', 'u', 'r', 't', 'h', 'e', 'r', '<\\w>'): 36, ('h', 'e', 'a', 'r', '<\\w>'): 230, ('m', 'e', '<\\w>'): 1764, ('s', 'p', 'e', 'a', 'k', '<\\w>'): 292, ('a', 'l', 'l', '<\\w>'): 890, ('y', 'o', 'u', '<\\w>'): 3142, ('a', 'r', 'e', '<\\w>'): 784, ('r', 'e', 's', 'o', 'l', 'v', 'e', 'd', '<\\w>'): 15, ('r', 'a', 't', 'h', 'e', 'r', '<\\w>'): 77, ('t', 'o', '<\\w>'): 4766, ('d', 'i', 'e', '<\\w>'): 138, ('t', 'h', 'a', 'n', '<\\w>'): 480, ('f', 'a', 'm', 'i', 's', 'h', '<\\w>'): 4, ('k', 'n', 'o', 'w', '<\\w>'): 350, ('c', 'a', 'i', 'u', 's', '<\\w>'): 17, ('m', 'a', 'r', 'c', 'i', 'u', 's', '<\\w>'): 122, ('i', 's', '<\\w>'): 2078, ('c', 'h', 'i', 'e', 'f', '<\

In [27]:
word_vocab = construct_word_vocab(defaultdict(int), file_path)
bpe_encoder, bpe_decoder, merges = byte_pair_encoding(word_vocab, 200)

print("bpe encoder map-> ", bpe_encoder)
print("merges -> ", merges)


bpe encoder map->  {('e', '<\\w>'): 0, ('t', 'h'): 1, ('s', '<\\w>'): 2, ('t', '<\\w>'): 3, ('d', '<\\w>'): 4, ('r', '<\\w>'): 5, ('y', '<\\w>'): 6, ('o', 'u'): 7, ('a', 'n'): 8, ('i', 'n'): 9, ('o', '<\\w>'): 10, ('e', 'n'): 11, ('l', '<\\w>'): 12, ('o', 'n'): 13, ('a', 'r'): 14, ('e', 'r'): 15, ('an', 'd<\\w>'): 16, ('th', 'e<\\w>'): 17, ('o', 'r'): 18, ('h', 'a'): 19, ('e', 'r<\\w>'): 20, ('i', 's<\\w>'): 21, ('f', '<\\w>'): 22, ('y', 'ou'): 23, ('l', 'l<\\w>'): 24, ('i', '<\\w>'): 25, ('t', 'o<\\w>'): 26, ('e', 'a'): 27, ('in', 'g'): 28, ('n', 'o'): 29, ('w', 'i'): 30, ('e', 's'): 31, ('th', '<\\w>'): 32, ('o', 'm'): 33, ('a', '<\\w>'): 34, ('o', 'f<\\w>'): 35, ('c', 'h'): 36, ('e', 's<\\w>'): 37, ('ing', '<\\w>'): 38, ('in', '<\\w>'): 39, ('v', 'e<\\w>'): 40, ('s', 't'): 41, ('a', 't<\\w>'): 42, ('m', 'y<\\w>'): 43, ('you', '<\\w>'): 44, ('e', 'd<\\w>'): 45, ('en', '<\\w>'): 46, ('o', 'r<\\w>'): 47, ('l', 'i'): 48, ('m', '<\\w>'): 49, ('w', 'h'): 50, ('o', 'w'): 51, ('s', 't<\\w>'

## Positional Encoding

In [None]:
#Sinusoidal position encoding

## Transformer

In [None]:
from collections import defaultdict
def def_value():
  '''
  Function to return a default values for a key that is not present in
  defaultdict
  '''
  return 0
a = 'aa'
di = defaultdict()
di[a] += 1
print(di)

KeyError: ignored

## Network Setup

## Load Data

# Train