In [8]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader
import re
import string

class MovieScriptDataset(Dataset):
    def __init__(self, text, tokenizer):
        self.input_ids = tokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {"input_ids": self.input_ids[idx]}

def remove_punctuation(input_text):
    translation_table = str.maketrans("", "", string.punctuation)
    cleaned_text = input_text.translate(translation_table)
    return cleaned_text

def remove_whitespace(input_text):
    lines = input_text.split('\n')
    non_empty_lines = [line for line in lines if line.strip()]
    replaced_text = '\n'.join(non_empty_lines).replace('\t', ' ')
    cleaned_text = re.sub(r'\s+', ' ', replaced_text)
    return cleaned_text.strip() 

def read_movie_script(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        movie_script = file.read()
    return remove_punctuation(remove_whitespace(movie_script.lower()))

def fine_tune_model(model, tokenizer, train_dataset, num_train_epochs=3, batch_size=2):
    model.train()
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    for epoch in range(num_train_epochs):
        for batch in train_dataloader:
            input_ids = batch["input_ids"]
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    model.save_pretrained("./gpt2-finetuned")

movie_script = read_movie_script("samplerv1.txt")

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

training_dataset = MovieScriptDataset(movie_script, tokenizer)
fine_tune_model(model, tokenizer, training_dataset, 20, 64)

prompt_text = "Suddenly it all went black"
input_ids = tokenizer.encode(prompt_text, return_tensors='pt')
output = model.generate(input_ids, max_length=1000, num_return_sequences=1)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
