In [105]:
import os
from transformers import (
    GPT2Tokenizer, 
    GPT2LMHeadModel, 
)
from transformers import LogitsProcessorList, LogitsProcessor
import numpy as np
import torch
dir_path = os.getcwd()

In [54]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set PAD token

In [55]:
def contains(l1, l2):
    """Check if list l2 is contiguously contained in list l1."""
    len1, len2 = len(l1), len(l2)
    # Iterate through all possible positions in l1 where l2 might fit
    for pointer in range(len1 - len2 + 1):
        # Compare each element in l2 with the corresponding elements in l1
        if l1[pointer:pointer + len2] == l2:
            return True
    return False

In [134]:
# Custom Logits Processor for Enforcing constraint appearence
class OrderLogitsProcessor(LogitsProcessor):
    def __init__(self, constraints, eos_id):
        self.index = 0  
        self.eos_token_id = eos_id
        self.constraints = constraints
        self.generated_tokens = []
        self.generated_chunks = [0]*len(self.constraints)  # Track if each sequence has appeared

    def __call__(self, input_ids, scores):
        # Track generated tokens
        boost = 10
        self.generated_tokens.append(input_ids[0, -1].item())
        if len(self.generated_tokens)>10:
            self.generated_tokens.pop(0)
        # Check if any of the required sequences appear in the generated tokens
        for seq_idx, sequence in enumerate(self.constraints):
            if not self.generated_chunks[seq_idx]:
                # Check if the current sequence has appeared in the generated tokens
                if contains(self.generated_tokens, sequence):
                    self.generated_chunks[seq_idx] = 1  # Mark this sequence as generated
                    print(scores[:,sequence[-1]])
                else:
                    if input_ids[0, -1].item() in sequence:
                        next_token = sequence[sequence.index(input_ids[0, -1].item())+1]
                        scores[:,next_token] = torch.abs(scores[:,next_token])+boost #boost score of next token
                    elif scores[:,sequence[0]] == torch.abs(scores[:,sequence[0]]):#boost only if poisitive similarity
                        scores[:,sequence[0]]+= boost
                    elif torch.max(scores)<0:
                        scores[:,sequence[0]]+= boost
        # Once all sequences are generated, discourage <eos> generation
        if not(all(self.generated_chunks)):
            scores[:, self.eos_token_id] -= 10  # Strongly discourage <eos> token generation

        return scores

In [135]:

# Define Required Words
required_words = ["castle", "dragon", "knight"]

def load_model(model_path=dir_path+"/fine_tuned_gpt2"):
    model = GPT2LMHeadModel.from_pretrained(model_path)
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    return model, tokenizer

# Generate Text with Enforced Word Order
def generate_story(prompt, constraints= required_words):
    loaded_model, tokenizer = load_model()
    # Tokenize input
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    attention_mask = tokenizer(prompt, return_tensors="pt", padding=True).attention_mask  # Generate attention mask
    # Tokenize constraints
    constraint_tokens = [tokenizer(seq, add_special_tokens=False).input_ids for seq in constraints]
    logits_processor = LogitsProcessorList([OrderLogitsProcessor(constraint_tokens, tokenizer.eos_token_id)])
    max_attempts = 3
    max_length = 500  # Start with a reasonable max_length
    generated_texts = {}
    for _ in range(max_attempts):  
        output = loaded_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            logits_processor=logits_processor,
            do_sample=True,
            top_p=0.95,
            top_k=50,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id  # Explicitly set pad token
        )
        
        # Decode output
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_texts[sum([contains(generated_text, constraint) for constraint in constraints])] = generated_text
        # Check if all constraints are met
        if 3 in generated_texts.keys():
            return generated_text 
        
        # Increase max_length and retry if constraints are not met
        max_length += 100
    return generated_texts[max(generated_texts.keys())]


In [136]:
constraints = ["castle", "dragon", "knight"]
prompt  = "Once upon a time"
story = generate_story(prompt)
print(story)
save_dir = dir_path+"/generated_texts"
file_name = f"/Chekhov style story on {prompt} with constraints {constraints}.txt"
with open(save_dir+file_name, "w") as file:
    file.write(story)

tensor([-59.0673])
Once upon a time of calamity,

 sheknight came into our house, and it was evident that I was not the only one

 in the house. When she went into the kitchen I was at the head of

 the table, but I did not notice the pang of the stove, the spluttering of

 the door, the trembling of the stove, the smell of hay, the

 wind, the drowsy smell of tea, and the pang of the stove that was not

 being carried away by the cook.
I did not know the cook, but I remember she was always at

 supper, and we did not see each other, and I did not know her

 name.
I am not in the habit of staying long in the same room, said

 Vassilyev. My wife is not very good-natured, and

 we do not meet every day. When she was still a little girl she

 used to go out and go to the yard. I used to go to her in the evening,

 and she used to say, The lady who is going to take a boy to the

 school, you must give her a piece of bread and water.
She used to say to me, We must have supper and water, too