In [69]:
import random
from tqdm import tqdm

In [70]:
path = "../data/shakespeare/input.txt"

with open(path, "r") as file:
    text = file.read()

raw_paragraphs = text.split("\n\n")
paragraphs = []
for paragraph in raw_paragraphs:
    paragraph = paragraph.replace("\n", " ")
    paragraph = paragraph.replace("\r", " ")
    paragraph = paragraph.replace("\t", " ")
    paragraph = paragraph.replace("  ", " ")
    paragraph = paragraph.replace("/", " ")
    paragraph = "/" + paragraph + "/"
    paragraphs.append(paragraph.strip())

paragraphs_tokens = [list(paragraph) for paragraph in paragraphs]

In [71]:
n_grams = 30

unnormalized_token_counts = {}
for paragraph_tokens in tqdm(paragraphs_tokens):
    current_token_list = [paragraph_tokens[idx:] for idx in range(n_grams)]
    for current_tokens in zip(*current_token_list):
        current_token_counts = unnormalized_token_counts
        for token in current_tokens[:-1]:
            new_current_token_counts = current_token_counts.get(token, {})
            current_token_counts[token] = new_current_token_counts
            current_token_counts = new_current_token_counts
        current_token_counts[current_tokens[-1]] = current_token_counts.get(current_tokens[-1], 0) + 1

100%|██████████| 7222/7222 [00:16<00:00, 431.31it/s]


In [72]:
def normalize_token_counts(token_counts):
    if not isinstance(token_counts, dict):
        return token_counts
    # Separate numeric leaf nodes from sub-dictionaries
    total = 0
    normalized = {}
    for token, value in token_counts.items():
        if isinstance(value, dict):
            normalized[token] = normalize_token_counts(value)
        else:
            total += value
    # Normalize leaf nodes
    for token, value in token_counts.items():
        if not isinstance(value, dict):
            normalized[token] = value / total if total > 0 else 0
    return normalized

normalized_token_counts = normalize_token_counts(unnormalized_token_counts)

In [73]:
def sample(probabilities, n):
    result = []
    current_level = probabilities
    for _ in range(n):
        tokens, probs = zip(*[
            (token, prob) for token, prob in current_level.items()
            if not isinstance(prob, dict)
        ]) if all(not isinstance(v, dict) for v in current_level.values()) else zip(*[
            (token, value if isinstance(value, float) else sum_leaf(value))
            for token, value in current_level.items()
        ])

        next_token = random.choices(tokens, weights=probs, k=1)[0]
        result.append(next_token)
        # Move to next level
        current_level = current_level.get(next_token)
        if not isinstance(current_level, dict):
            break
    return tuple(result)


def sum_leaf(d):
    return sum(
        v if isinstance(v, float) else sum_leaf(v)
        for v in d.values()
    )

In [74]:
def generate_sequence(prob_dict, n, start_token='/', max_tokens=50):
    sequence = [start_token]
    for _ in range(max_tokens - 1):
        # Build context (use last n-1 tokens)
        context = sequence[-(n-1):]
        current_level = prob_dict
        # Traverse the tree down the context
        for token in context:
            if isinstance(current_level, dict) and token in current_level:
                current_level = current_level[token]
            else:
                return sequence  # Context not found, stop
        # Now sample next token based on this context
        if isinstance(current_level, dict):
            tokens, probs = zip(*[
                (token, value) for token, value in current_level.items()
                if isinstance(value, float)
            ]) if all(isinstance(v, float) for v in current_level.values()) else zip(*[
                (token, sum_leaf(value)) for token, value in current_level.items()
            ])
            
            next_token = random.choices(tokens, weights=probs, k=1)[0]
            sequence.append(next_token)
        else:
            break
    return sequence

In [75]:
generated = generate_sequence(normalized_token_counts, n=3, start_token='/', max_tokens=200)

print("Generated sequence:", ' '.join(generated))

Generated sequence: / R O S P E R L E O N T E R C I U S :   A s s   T h e   m y s   b i t h e   d r a l l s   b e c u r s e e   n o t h e y   y o u   s w e   t i s c o n   h u n d   s p a r t i f e !   I   d e d u r   m o s t r e   I   w r o p h e e p   n o o k e   t h o   c h m a k e   h a t   h e n   h e a r k !   M y   d i n   o r t l e d .   W e   t h ,   n e m a y   o f   m a l l   e a r e s s u e e   u n ;   w
