# Load Libraries

In [1]:
import torch, sys
from datetime import datetime
from random import randint
from numpy import repeat
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Load Model

In [2]:
pretrained_model = 'openai-community/gpt2-xl'
# pretrained_model = 'EleutherAI/gpt-neo-1.3B'
# pretrained_model = 'perplexity-ai/r1-1776' ## Bloody large...
# pretrained_model = 'meta-llama/Llama-3.2-3B-Instruct' ## Doesn't seem to work well

pretrained_model_name = pretrained_model.split("/")[-1]

tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
# tokenizer.pad_token_id = tokenizer.eos_token_id

model     = AutoModelForCausalLM.from_pretrained(pretrained_model).to('mps')

# Main functional loop

In [3]:
prompts = ["Humans do not simply predict the next word.",
            "Longing for an unlikely outcome,",
            "Do you enjoy listening to a detuned piano?",
            "The mirrors always break in the same ways.",
            "This tongue is inevitably forked.",
            "You cannot tell my accent is from where?",
             "This is the cure for prediction envy:"
          ]

stopwords = [0, 1, 6, 13, 30, 526, 1701, 2474, 3548, 3228, 12248, 22857, 42720, 30823, 50256] 

In [16]:
# for sw in stopwords:
#     print(tokenizer.decode(sw))

In [5]:
min_new_tokens = 88

write_out = []

for prompt in prompts:
    
    print(prompt+"\n\n")

    ## Default generation
    encoded = tokenizer(prompt, return_tensors="pt").to("mps")

    with torch.inference_mode():
        generation = model.generate(**encoded, 
                                    max_new_tokens=110, 
                                    pad_token_id=50256, #128001
                                    do_sample=False
                                   )
    
    decoded = tokenizer.batch_decode(generation, skip_special_tokens=True)
    write_out.append(decoded[0]+"\n\n")
    
    # for idx, d in enumerate(pi_digits[:100]):
    for N in [0,1,2,3,4,5,6,7,8,9]:
        
        print(f"{N}\n")
        generated = prompt

        token_count = 1
        # token_id = 0
        keep_generating = True
        
        while keep_generating:
            encoded = tokenizer(generated, return_tensors="pt").to("mps")
            
            ## Get output from decoder.
            outputs = model.generate(**encoded,
                                     return_dict_in_generate=True, 
                                     output_scores=True,
                                     # output_hidden_states=True,
                                     max_new_tokens=1,
                                     do_sample=False,
                                     pad_token_id=50256 #128001
                                     )
            
            ## Convert to probabilities.
            probs = torch.nn.functional.softmax(outputs.scores[0], dim=1).cpu()
            
            ## Roll my own lookup.
            probs = pd.DataFrame(probs.tolist()[0])
            probs.columns = ['probability']
            probs.sort_values(ascending=False, by='probability', inplace=True)
            
            ## We can:
            ## 1. Always take the nth most likely token.
            topN = probs.head(N+1).copy()
            topN.loc[:, "token"] = topN.index.to_series().apply(lambda x: tokenizer.decode(x))

            selected_token = topN.token.iloc[N]
        
            if token_count > min_new_tokens:
                if any(item in topN.index for item in stopwords):
                    # take the most likely ending token.
                    stopwords_in_topN = [idx for idx in topN.index if idx in stopwords]
                    selected_token = topN.loc[stopwords_in_topN]\
                                          .sort_values("probability")\
                                          .token.iloc[-1]
                    # and stop generating more text.
                    keep_generating = False
                    # # take the least likely ending token.
                    # token_id = [idx for idx in topN.index if idx in stopwords][-1]
                    # # and stop generating more text.
                    # keep_generating = False
                
        
            ## 2. Pick a random token from the top 10.
            # token_id = probs.index[randint(4, 9)] # 5th to 9th place.
        
            ## 3. Use the digits of pi!
            # token_id = probs.index[int(d)]
        
            # generated = generated + tokenizer.decode(token_id, 
            #                                          skip_special_tokens=True, 
            #                                          clean_up_tokenization_spaces=True)

            generated = generated + selected_token
        
            # Update token count.
            token_count += 1
            sys.stdout.flush()
            sys.stdout.write(f"\r{token_count}")
        
        # scores = pd.concat(scores)
        # print(generated)
        write_out.append(f"N = {N+1} | {pretrained_model_name}\n" + generated + "\n\n")


# Writing multiple lines to the file
current_time = datetime.now().strftime("%d%m%Y %T%p").replace(":", ".")
filename = f'generations/{pretrained_model_name}_maxTokens-{min_new_tokens}_{current_time}.txt'

with open(filename, 'w') as out:
    out.writelines(write_out)

Humans do not simply predict the next word.


0

971

1062

1013

954

995

956

947

978

929

91Longing for an unlikely outcome,


0

911

922

973

904

915

936

977

918

939

90Do you enjoy listening to a detuned piano?


0

1031

932

903

934

975

906

987

908

999

97The mirrors always break in the same ways.


0

911

1012

1013

904

905

966

1237

918

929

91This tongue is inevitably forked.


0

901

992

1013

1014

915

956

1027

988

909

91You cannot tell my accent is from where?


0

931

942

903

1334

925

946

917

1038

949

90This is the cure for prediction envy:


0

911

952

1063

904

1015

1106

917

918

909

100

# Recycle Bin

In [99]:
# N = 25
# topN = torch.topk(probs, N, sorted=True)
# bottomN = torch.topk(probs, N, largest=False, sorted=True)

In [100]:
# high = 6
# low  = 9

# [prompt + newToken 
#      for newToken in 
#      gpt2_tokenizer.batch_decode(probs.index[high].tolist(), 
#                                  skip_special_tokens=True,
#                                  clean_up_tokenization_spaces=True)
#     ]