In [1]:
!pip install torch transformers datasets



In [2]:
!pip install 'accelerate>=0.26.0'

Note: you may need to restart the kernel to use updated packages.


In [5]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

# Load the pre-trained GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token  # Explicitly set pad token
model.eval()

# Function to generate creative text
def generate_text(starting_sentence, max_length=200, temperature=0.9, top_k=40):
    """
    Generate creative text based on a starting sentence.
    
    Args:
        starting_sentence (str): The input sentence to start the generation.
        max_length (int): Maximum length of the generated sequence (in tokens).
        temperature (float): Controls creativity (higher = more random).
        top_k (int): Limits sampling to top k tokens for coherence.
    
    Returns:
        str: The generated text.
    """
    # Tokenize with attention mask
    encodings = tokenizer(starting_sentence, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
    input_ids = encodings['input_ids']
    attention_mask = encodings['attention_mask']
    
    # Generate text with the model
    with torch.no_grad():
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,  # Pass attention mask
            pad_token_id=tokenizer.eos_token_id,  # Explicitly set pad token
            max_length=max_length,
            temperature=temperature,
            top_k=top_k,
            do_sample=True,
            num_return_sequences=1
        )
    
    # Decode the generated tokens
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Trim to 50-100 words
    words = generated_text.split()
    if len(words) > 100:
        generated_text = ' '.join(words[:100])
    
    return generated_text

# Main loop
print("Welcome to the Creative Writing Text Generator!")
print("Enter a starting sentence (e.g., 'The forest was silent until…'). Type 'quit' to exit.")
while True:
    starting_sentence = input("\nEnter starting sentence: ")
    if starting_sentence.lower() == 'quit':
        print("Exiting the text generator. Goodbye!")
        break
    try:
        generated_text = generate_text(starting_sentence)
        print("\nGenerated Text:\n", generated_text)
        word_count = len(generated_text.split())
        print(f"(Word count: {word_count})")
    except Exception as e:
        print(f"An error occurred: {e}. Please try again.")

Welcome to the Creative Writing Text Generator!
Enter a starting sentence (e.g., 'The forest was silent until…'). Type 'quit' to exit.



Enter starting sentence:  quit


Exiting the text generator. Goodbye!
