In [105]:
import os
from transformers import (
    GPT2Tokenizer, 
    GPT2LMHeadModel, 
)
from transformers import LogitsProcessorList, LogitsProcessor
import numpy as np
import torch
dir_path = os.getcwd()

In [178]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set PAD token

In [177]:
def contains(l1, l2):
    """Check if list l2 is contiguously contained in list l1."""
    len1, len2 = len(l1), len(l2)
    # Iterate through all possible positions in l1 where l2 might fit
    for pointer in range(len1 - len2 + 1):
        # Compare each element in l2 with the corresponding elements in l1
        if l1[pointer:pointer + len2] == l2:
            return True
    return False

In [179]:
# Custom Logits Processor for Enforcing constraint appearence
class OrderLogitsProcessor(LogitsProcessor):
    def __init__(self, constraints, eos_id):
        self.index = 0  
        self.eos_token_id = eos_id
        self.constraints = constraints
        self.generated_tokens = []
        self.generated_chunks = [0]*len(self.constraints)  # Track if each sequence has appeared

    def __call__(self, input_ids, scores):
        # Track generated tokens
        boost = 10
        self.generated_tokens.append(input_ids[0, -1].item())
        if len(self.generated_tokens)>50:
            self.generated_tokens.pop(0)
        # Check if any of the required sequences appear in the generated tokens
        for seq_idx, sequence in enumerate(self.constraints):
            if not self.generated_chunks[seq_idx]:
                # Check if the current sequence has appeared in the generated tokens
                if self.generated_tokens[-len(sequence):] == sequence:
                    self.generated_chunks[seq_idx] = 1  # Mark this sequence as generated
                    print(f"unboosted score:{scores[:,sequence[-1]]}, constraint enforced: {tokenizer.decode(sequence, skip_special_tokens=True)}, max score: {torch.max(scores)}")
                else:
                    if input_ids[0, -1].item() in sequence:
                        next_index = sequence.index(input_ids[0, -1].item()) +1
                        if next_index < len(sequence) and sequence[:next_index] == self.generated_tokens[-next_index:]:  # Ensure it's within bounds and boosting is appropriate
                            next_token = sequence[next_index]
                            scores[:, next_token] += boost  # Boost next token
                    elif scores[:,sequence[0]] == torch.abs(scores[:,sequence[0]]):#boost only if poisitive similarity
                        scores[:,sequence[0]]+= boost/2
                    elif torch.max(scores)<0:
                        scores[:,sequence[0]]+= boost/2
        # Once all sequences are generated, discourage <eos> generation
        if not(all(self.generated_chunks)):
            scores[:, self.eos_token_id] -= 10  # Strongly discourage <eos> token generation

        return scores

In [171]:

# Define Required Words
required_words = ["castle", "dragon", "knight"]

def load_model(model_path=dir_path+"/fine_tuned_gpt2"):
    model = GPT2LMHeadModel.from_pretrained(model_path)
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    return model, tokenizer

# Generate Text with Enforced Word Order
def generate_story(prompt, constraints= required_words):
    loaded_model, tokenizer = load_model()
    # Tokenize input
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    attention_mask = tokenizer(prompt, return_tensors="pt", padding=True).attention_mask  # Generate attention mask
    # Tokenize constraints
    constraint_tokens = [tokenizer(seq, add_special_tokens=False).input_ids for seq in constraints]
    logits_processor = LogitsProcessorList([OrderLogitsProcessor(constraint_tokens, tokenizer.eos_token_id)])
    max_attempts = 3
    max_length = 500  # Start with a reasonable max_length
    generated_texts = {}
    for _ in range(max_attempts):  
        output = loaded_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            logits_processor=logits_processor,
            do_sample=True,
            top_p=0.95,
            top_k=50,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id  # Explicitly set pad token
        )
        
        # Decode output
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_texts[sum([contains(generated_text, constraint) for constraint in constraints])] = generated_text
        # Check if all constraints are met
        if 3 in generated_texts.keys():
            return generated_text 
        
        # Increase max_length and retry if constraints are not met
        max_length += 100
    return generated_texts[max(generated_texts.keys())]


In [None]:
f= open(dir_path+"/datasets/scraped_tolstoy.txt","r") #prompts derived from War and Peace
prompts = f.read().replace('\n', '').split(". ")
f= open(dir_path+"/datasets/scraped_dickens.txt","r")#constraints derived from Tale of Two Cities
constraints_list = f.read().replace('\n', '').split(". ")
import random

In [181]:
constraints = []
prompt  = random.choice(prompts)
while len(prompt)<10 or len(prompt)>100:#make prompt easy and informative enough to work with
    prompt = random.choice(prompts)

while len(constraints)<3:
    constraint = random.choice(constraints_list)
    if len(constraint)>30:
        chunk = " ".join(constraint.split()[:4]) #first 4 words make the constraint
        constraints.append(chunk)
    else:
        constraint = random.choice(prompts)
print(f"Prompt:{prompt}")
print(f"Constraints:{constraints}")
story = generate_story(prompt, constraints= constraints)
print(f"Story:\n {story}")
save_dir = dir_path+"/generated_texts"
file_name = f"{prompt}__{constraints}.txt"
with open(os.path.join(save_dir, file_name), "w") as file:
        file.write(story)


Prompt:He had not seen the hussars all that day, but had heard about them from an infantry officer
Constraints:['The peril of an', 'I could not see', 'Crunchers attention was here']
unboosted score:tensor([-77.7164]), constraint enforced: I could not see, max score: -67.96165466308594
unboosted score:tensor([-12.5516]), constraint enforced: The peril of an, max score: -6.466152667999268
Story:
 He had not seen the hussars all that day, but had heard about them from an infantry officer. He

 had been to see them all the evening, and had never seen them

 again.
The old man had not gone off to the river, but had been going up to the

 highroad where the highroad was to be reached. He had no means of getting to the

 highroad, and had only made his way along the highroad in the direction of the

 highroad. He had no right to go on walking, and had to

 make up his mind to go on walking.
The old man was lying on the ground, exhausted, exhausted, with his eyes

 filled with tears. It was ev