<a href="https://colab.research.google.com/github/blooming-ai/generativeai/blob/main/generativeai/examples/generative/ipynb/gpt_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT code
https://github.com/karpathy/minGPT

In [80]:
#Get data file
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-07-31 08:36:16--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-07-31 08:36:16 (19.6 MB/s) - ‘input.txt’ saved [1115394/1115394]



## Word embedding

In [81]:
file_path = "sample_data/input.txt"

In [78]:
import re
import pdb
import string
#Inspired from https://www.geeksforgeeks.org/byte-pair-encoding-bpe-in-nlp/
# But rewritten for better understanding and working on files

#Byte pair encoding
def construct_word_vocab(word_vocab:dict, file_path:str)->dict:
    '''
    Read file and update word_vocab dict. word_vocab has format word_vocab[('w','o','r','d')]->freq
    '''
    with open(file_path) as fp:
        while True:
            line = fp.readline()
            if not line: break
            else:
                lst = line.split()
                for item in lst:
                    item.strip()
                    item = "".join(ch for ch in item if ch.isalnum()) # keep only alpha-numeric characters
                    lst_item = tuple(item.lower()) # tuple to make a hashable dict key
                    if len(lst_item) == 0: continue #ignore empty key
                    if lst_item not in word_vocab: word_vocab[lst_item] = 1
                    else:                     word_vocab[lst_item] += 1

    return word_vocab

def get_byte_pair_hist(word_vocab:dict)->dict:
    '''
    Read word_vocab[('w','o','r','d')]->freq and construct byte pair histogram
    returns pair[('w','o')]->freq
    '''
    pair = {}
    for word, freq in word_vocab.items():
        for i in range(len(word)-1):
            key = (word[i],word[i+1])
            if key not in pair: pair[key] = freq
            else:               pair[key] += freq

    return pair

def merge_pair(replace_pair:tuple, word_vocab_in:dict)->dict:
    '''
    replace pair in the key of the word_vocab. E.g. pair=('w','o') then update
    word_vocab_in[('w','o','r','d')]->freq to
    word_vocab_in[('wo','r','d')]->freq
    '''
    word_vocab_out = {}
    for word, freq in word_vocab_in.items():
        word_new = list()
        is_last_char_used = False
        i=0
        while i < len(word)-1:
            if (word[i],word[i+1]) == replace_pair:
                word_new.append(word[i]+word[i+1])
                if i == len(word)-2: is_last_char_used = True
                i += 1 # skip the next merged character
            else:
                word_new.append(word[i])

            i += 1

        if not is_last_char_used : word_new.append(word[len(word)-1])
        word_vocab_out[tuple(word_new)] = freq

    return word_vocab_out

def byte_pair_encoding(word_vocab:dict, n:int=0)->dict:
    '''
    Given word_vocab_in[('w','o','r','d')]->freq merges pairs with highest frequency 'n'.
    If n==0, runs till all pairs are merged.
    E.g. if ('w','o') has highest freq. then word_vocab_in[('wo','r','d')]->freq.
    Also returns the final vocab.
    '''
    i = 0
    vocab = list(string.ascii_lowercase)
    vocab.extend(list(string.digits))
    # print(vocab)
    while True:
        pairs = get_byte_pair_hist(word_vocab)
        best = max(pairs, key=pairs.get)
        vocab.append(best[0]+best[1])
        if pairs[best] == 1: break # do not replace if pair occurs only once.
        word_vocab = merge_pair(best, word_vocab)
        if(i>= n and n>0): break # break if i exceeds n
        i += 1

    return word_vocab, vocab


In [74]:
#Testing:
# word_vocab = {}
# word_vocab = construct_word_vocab(word_vocab, file_path)
# print("is empyt string a key in word_vocab:",() in word_vocab)
# pairs = get_byte_pair_hist(word_vocab)
# best = max(pairs, key=pairs.get)
# word_vocab_out = merge_pair(('f','i'), word_vocab)

# print(word_vocab)
# print(pairs)
# print(best)
# print(word_vocab_out)





In [79]:
word_vocab = construct_word_vocab(dict(), file_path)
bpe_word_vocab, vocab = byte_pair_encoding(word_vocab, 200)
print(vocab)

['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'th', 'ou', 'an', 'er', 'in', 'the', 'or', 'en', 'ar', 'is', 'es', 'on', 'and', 'at', 'll', 'to', 'st', 'you', 'me', 'no', 'ha', 'ing', 'se', 'of', 'wh', 'le', 'wi', 'be', 'he', 're', 'it', 've', 'ch', 'ro', 'my', 'for', 'as', 'ce', 'ay', 'that', 'ed', 'li', 'ir', 'ld', 'we', 'ut', 'ere', 'ke', 'not', 'us', 'ri', 'de', 'lo', 'with', 'so', 'gh', 'ent', 'co', 'thou', 'your', 'go', 'hi', 'ow', 'our', 'et', 'al', 'ad', 'ther', 'his', 'but', 'un', 'this', 'io', 'all', 'est', 'have', 'ard', 'ly', 'ur', 'do', 'ght', 'ra', 'him', 'ma', 'king', 'od', 'ord', 'ess', 'what', 'now', 'am', 'pe', 'thy', 'ver', 'ill', 'sha', 'are', 'fe', 'id', 'will', 'her', 'ould', 'fa', 'ck', 'ge', 'one', 'man', 'ne', 'by', 'ru', 'su', 'la', 'ta', 'pr', 'ti', 'if', 'po', 'con', 'ter', 'sh', 'shall', 'lord', 'end', 'qu', 'mor', 'thee', 'wh

## Positional Encoding

In [None]:
#Sinusoidal position encoding

## Transformer

In [18]:
from collections import defaultdict
def def_value():
  '''
  Function to return a default values for a key that is not present in
  defaultdict
  '''
  return 0
a = 'aa'
di = defaultdict()
di[a] += 1
print(di)

KeyError: ignored

## Network Setup

## Load Data

# Train