In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Define the input phrase
full_prompt = '''One must have a mind of winter
To regard the frost and the boughs
Of the pine-trees crusted with snow;
And have been cold a long time
To behold the junipers shagged with ice,
The spruces rough in the distant glitter
Of the January sun; and not to think
Of any misery in the sound of the wind,
In the sound of a few leaves,
Which is the sound of the land
Full of the same wind
That is blowing in the same bare place
For the listener, who listens in the snow,
And, nothing himself, beholds
Nothing that is not there and the nothing that is.'''

def apply_p_plus_n(text, n=7):

    lines = text.strip().split('\n') #split by line
    new_lines = []

    for line in lines:
        if not line.strip():  # Skip empty lines
            new_lines.append(line)
            continue

        # Split line into words
        words = line.split()

        if len(words) == 0:  # Skip if no words
            new_lines.append(line)
            continue

        # Remove the last word to create the prompt ready to apply prediction algorithm to
        prompt_words = words[:-1]
        prompt = ' '.join(prompt_words)

        # Tokenize the prompt
        input_ids = tokenizer.encode(prompt, return_tensors='pt')

        # Get model predictions
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits

        # Get logits for the last token
        last_token_logits = logits[0, -1, :]

        # Convert to probabilities
        probabilities = torch.softmax(last_token_logits, dim=-1)

        # Get top k tokens (we need at least n tokens)
        top_k_probabilities, top_k_indices = torch.topk(probabilities, n)

        # Get the nth most probable token (index n-1 since 0-indexed)
        nth_token_idx = top_k_indices[n-1]
        nth_token = tokenizer.decode(nth_token_idx).strip()

        # Construct new line with truncated line of original prompt + new nth most likely word
        new_line = prompt + ' ' + nth_token
        new_lines.append(new_line) #append to new poem

    return '\n'.join(new_lines)


# apply this logic for n=7 (p+7)
p7_text = apply_p_plus_n(full_prompt, n=7)

print("\n" + "=" * 60)
print("FINAL P+7 OUTPUT")
print("=" * 60)
print(p7_text)

# Save to file
with open('P+7.txt', 'w') as f:
    f.write(p7_text)


# Example n=699 (P+699)
p699_text = apply_p_plus_n(full_prompt, n=699)

print("\n" + "=" * 60)
print("FINAL P+699 OUTPUT")
print("=" * 60)
print(p699_text)

# Save to file
with open('P+699.txt', 'w') as f:
    f.write(p699_text)