In [1]:
import os
import json
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from tqdm import tqdm
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def load_book(file_path):
    """Load the text content of a book from the specified file path."""
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()
    return text

def split_into_chunks(text, chunk_size=1000, overlap=200):
    """Split text into overlapping chunks to maintain context continuity."""
    words = text.split()
    chunks = []
    start = 0
    while start < len(words):
        end = start + chunk_size
        chunk = ' '.join(words[start:end])
        chunks.append(chunk)
        start += chunk_size - overlap

    print("Nomber of chunks",len(chunks))
    return chunks

In [39]:

def extract_mentions(chunks, characters):
    """Extract contexts around character mentions across text chunks."""
    character_contexts = defaultdict(list)
    characters_lower = [char.lower() for char in characters]  # Calculate lowercase version here

    for chunk_idx, chunk in enumerate(chunks):
        for char, char_lower in zip(characters, characters_lower):
            # Use regex word boundaries for exact matches, case-insensitive
            pattern = r'\b' + re.escape(char_lower) + r'\b'
            matches = re.finditer(pattern, chunk.lower())
            for match in matches:
                # Extract the sentence containing the character
                # Assuming sentences end with '.', '!', or '?'
                sentence_end = max(chunk.rfind('.', 0, match.start()), 
                                   chunk.rfind('!', 0, match.start()), 
                                   chunk.rfind('?', 0, match.start()))
                sentence_start = chunk.rfind('.', 0, match.start()) + 1
                if sentence_start == -1:
                    sentence_start = 0
                sentence = chunk[sentence_start:sentence_end].strip()
                if sentence:
                    character_contexts[char].append(sentence)
    return character_contexts


In [41]:
def aggregate_contexts(character_contexts):
    """Aggregate all contexts for each character into a single text."""
    aggregated = {}
    for char, contexts in character_contexts.items():
        aggregated[char] = ' '.join(contexts)
    return aggregated


In [43]:
# Initialize zero-shot classifier
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

# Define roles for classification
roles = [
    "Protagonist",
    "Antagonist",
    "Supporting Character",
    "Love Interest",
    "Foil",
    "Comic Relief",
    "Symbolic Character",
    "Neutral Character"
]


In [44]:

def classify_roles(aggregated_contexts, roles, classifier, max_length=1024):
    """Classify roles of characters based on their aggregated contexts."""
    character_roles = {}
    for char, context in aggregated_contexts.items():
        # Truncate context if it exceeds the model's maximum length
        if len(context) > max_length:
            context = context[:max_length]
        
        try:
            result = classifier(context, roles, multi_label=False)
            top_role = result['labels'][0]  # Get top predicted role
            character_roles[char] = top_role
        except Exception as e:
            print(f"Error classifying {char}: {e}")
            character_roles[char] = "Unknown"
    return character_roles


In [59]:

def process_book(file_path, output_path, characters):
    """Complete pipeline from loading book to classifying character roles."""
    # Load the book
    text = load_book(file_path)
    
    # Split into chunks
    chunks = split_into_chunks(text)
    
    # Extract mentions
    extracted_contexts = extract_mentions(chunks, characters)
    
    
    # Aggregate contexts
    aggregated_contexts = aggregate_contexts(extracted_contexts)
    print(aggregated_contexts)
    
    # Classify roles
    character_roles = classify_roles(aggregated_contexts, roles, classifier)
    
    # Save to JSON
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(character_roles, f, indent=4)
    
    print(f"Character roles saved to {output_path}")


In [46]:
def aggregate_contexts(character_contexts):
    aggregated = {}
    for char, contexts in character_contexts.items():
        aggregated[char] = ' '.join(contexts)
    return aggregated

In [51]:
def load_character_list(file_path):
    """Load character names from a YAML file containing book title and person entities."""
    with open(file_path, 'r', encoding='utf-8') as file:
        data = yaml.safe_load(file)
    characters = data.get("person_entities", [])
    return characters


In [53]:
# Define file paths based on your project structure
book_name = "Agatha Christie___The Secret Adversary"  # or read dynamically as needed
input_file = os.path.join("..\\data", "selected_100_books", f"{book_name}.txt")
character_file = os.path.join("..\\data", "processed\\ner_results\\Agatha_Christie___The_Secret_Adversary_person_entities.json")  # Adjust to correct file location
output_file = os.path.join("output", f"{book_name.replace(' ', '_')}_context_roles.json")

In [61]:
import yaml
import re
# Ensure output directory exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Load characters from the file
characters = load_character_list(character_file)

# Process the book
process_book(input_file, output_file, characters)

# Optional: Read and print the results
with open(output_file, 'r', encoding='utf-8') as f:
    data = json.load(f)
    print(json.dumps(data, indent=4))

Nomber of chunks 95


KeyboardInterrupt: 