In [None]:
# Standard library imports
import string
import re
from types import SimpleNamespace
from collections import defaultdict

# Scientific computing and ML imports 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.dummy import DummyClassifier
from sklearn.metrics import (
    f1_score, 
    classification_report,
    balanced_accuracy_score,
    roc_auc_score,
    recall_score, 
    fbeta_score
)
from sklearn.model_selection import train_test_split
from scipy.stats import entropy

# NLP related imports
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
from faster_whisper import WhisperModel
from nltk.tokenize import sent_tokenize
import nltk
from cefrpy import CEFRAnalyzer
import contractions
from fast_langdetect import detect
from simalign import SentenceAligner
from typing import Optional

# Data handling and visualization
import pandas as pd
import matplotlib.pyplot as plt
import wandb

# Install GPU acceleration package if needed
# %pip install flash-attn  # Uncomment to install GPU acceleration

# Download required NLTK data
nltk.download('punkt_tab')

# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Set global device (GPU if available, otherwise CPU)
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

## Define Name & Type

In [5]:
NAME = ""
AUDIO_TYPE = ".mp4"

# Paths to audio and reference text files
AUDIO_FILE = NAME + AUDIO_TYPE
REFERENCE_FILE = f"{NAME}.srt"
FILTERED_SRT = f"{NAME}[Filtered].srt"
EXCLUDED_WORDS = [
    "mm", "nah", "uh", "um", "whoa", "uhm", "uhmm", "ah", "mm-hmm", 
    "uh-huh", "uh-uh", "uh-uhm", "uh-uhmm", "uhm-h", "aw", "ugh", 
    "shh", "mmhmm", "huh", "hmm", "mmm", "oops", "oopsie", "uh-oh", 
    "whoops", "oof", "yup", "yep", "nope", "aha", "tsk", "ew", "phew", 
    "meh", "huh-uh", "huh-huh", "huh-uhm", "mhm", "oh", "hmm-m", 
    "er", "eh", "ahh", "yikes", "yawn", "ugh-ugh", "jeez", "duh", 
    "wow", "meh-meh", "uhhh", "ummm", "ugh-huh", "hmpf", "yawn-yawn", 
    "heh", "hmph", "eep", "gah", "uhp", "boo", "psst", "argh", "oi", 
    "ohh", "oh-ho", "whoa-whoa", "la", "laa", "ah-ha", "ha", "ha-ha", 
    "hahaha", "bah", "whew", "ehh", "huff", "uff", "sniff", "snort", 
    "gulp", "hic", "haah", "bleh", "blah", "bla", "mwaa", "uhuh", 
    "yah", "uhw", "eww", "ewww", "grr", "huh-huh", "haha", "shush", 
    "wha", "wham", "bam", "oooh", "aaah", "hrr", "uhhhhhh", "ummmmm", 
    "woah", "ughugh", "mm-mm", "uh-huh-huh", "erm", "grrr", "urr", 
    "yippie", "oops-a-daisy", "ouch", "eek", "zoinks", "woopsie", 
    "yeesh", "hm-mm", "uhhuh", "hrrmph", "bleugh", "rawr", "ick", 
    "whaa", "la-la", "meep", "pfft", "haaa", "ahhhhhh", "oii", "tsk-tsk", 
    "blub", "blurgh", "brr", "rrr", "oomph", "ohhhhhh", "hmmmmmm", 
    "ahhhhhhh", "guh", "ack", "zzzz", "hush", "hsh", "boo-hoo", "ho-hum", 
    "urrgh", "grumble", "murmur", "mutter", "uhhhhmm", "hah", "ah-ah", 
    "shoo", "la-la-la", "blah-blah", "tra-la", "lalala", "waah", "waaah", 
    "ooh-ooh", "uhh", "uhhhh", "erhm", "ermm", "urrggh", "aargh", 
    "hm-mm-mm", "uh-uh-uh", "uhm-uh", "hurmph", "grmph", "ha-umph", 
    "um-hum", "humph", "shhhhhh", "psssh", "whisper", "moan", "groan", 
    "ah-choo", "cough", "sneeze", "hick", "hiccup", "snore", "whaaat", 
    "doh", "hmh", "pfft-pfft", "chatter", "rumble", "buzz", "mumble", 
    "ooh-la-la", "ahem", "tut", "hrrmm", "grmph", "sigh", "gulp-gulp", 
    "oh-wow", "yeehaw", "oh-no", "ach", "achoo", "whoop", "zipp", "zzz"
]


## Create DF

In [6]:
def create_bracketless_lines(original_srt_file: str) -> list[str]:
    """
    Reads the original SRT file line by line and removes all content in brackets
    (...), [...], {...} only in text lines.
    Timestamps, line numbers, empty lines etc. remain unchanged.
    Returns the modified lines as a list.
    """
    bracketless_lines = []
    with open(original_srt_file, "r", encoding="utf-8") as f:
        for line in f:
            # Only modify text lines (not timestamps, numbers or empty lines)
            if '-->' not in line and not line.strip().isdigit() and line.strip():
                line = re.sub(r'\[.*?\]|\(.*?\)|\{.*?\}', '', line)
            bracketless_lines.append(line)
    return bracketless_lines


def remove_brackets_in_text(line: str) -> str:
    """
    Removes text in (), [] and {} from a text line.
    """
    return re.sub(r'\[.*?\]|\(.*?\)|\{.*?\}', '', line)


def clean_token(token: str) -> Optional[str]:
    """
    Cleans individual words:
    - Remove leading/trailing punctuation
    - Convert to lowercase
    - Empty strings -> None
    """
    token = token.strip(string.punctuation + '""\'').lower()
    return token if token else None


def read_srt_in_memory(srt_path: str) -> list[str]:
    """
    Reads the original SRT file into a list of lines.
    Removes bracket content ONLY in lines that are not timestamps
    and not pure number lines. Empty lines or time lines
    remain unchanged.
    """
    lines = []
    with open(srt_path, "r", encoding="utf-8") as f:
        for line in f:
            if '-->' not in line and not line.strip().isdigit() and line.strip():
                # This is a "text" line -> remove brackets
                line = remove_brackets_in_text(line)
            lines.append(line)
    return lines


def custom_split_into_sentences(text: str) -> list[str]:
    """
    First searches for pairs of musical notes (♪), so that everything between
    two ♪ (inclusive) becomes ONE sentence segment, e.g. "♪ Yeah, baby ♪".
    Everything that doesn't fall into such a pair is processed normally.
    Then each segment is further split via sent_tokenize if additional
    sentence boundaries exist.
    """
    # First split by ♪, but keep them in the list (capture group)
    raw_parts = re.split(r'(♪)', text)

    merged_segments = []
    i = 0
    while i < len(raw_parts):
        current = raw_parts[i].strip()

        # Check if this element is a musical note and if
        # raw_parts[i+2] contains another ♪
        if current == '♪' and (i + 2) < len(raw_parts) and raw_parts[i+2].strip() == '♪':
            # Example: ["♪", " Yeah, baby ", "♪"]
            inner_text = raw_parts[i+1].strip() if (i+1 < len(raw_parts)) else ""
            merged_segments.append(f'♪ {inner_text} ♪')
            i += 3  # Skip the three used elements
        else:
            # If no ♪-pair exists, just take the current element (ignore empty strings)
            if current:
                merged_segments.append(current)
            i += 1

    # Now merged_segments contains either "♪ Yeah, baby ♪"
    # or "some text without musical notes" or both mixed.
    # Then we split each segment using sent_tokenize just in case
    # (e.g. if there's a period or question mark in "♪ ...").
    final_sentences = []
    for segment in merged_segments:
        # segment could be "♪ Yeah, baby ♪" OR "This is a sentence. And another one."
        for s in sent_tokenize(segment):
            s = s.strip()
            if s:
                final_sentences.append(s)

    return final_sentences


def extract_tokens_with_sentences(srt_lines: list[str]) -> list[dict]:
    """
    Extracts tokens and their corresponding original sentences from the SRT lines,
    correctly handling sentences that span multiple lines. Additionally, text
    between musical note pairs is wrapped in its own sentence "♪ ... ♪".
    """
    tokens_with_sentences = []
    full_text = ""

    # Combine all text lines into a single string
    for line in srt_lines:
        if '-->' not in line and not line.strip().isdigit() and line.strip():
            cleaned_line = line.strip().lstrip('-').strip()
            full_text += cleaned_line + " "

    # Segment at musical note pairs and sentence boundaries
    sentences = custom_split_into_sentences(full_text)

    # Extract tokens from each sentence and store the original sentence
    for sentence in sentences:
        words = sentence.split()
        position = 0
        for w in words:
            position += 1
            t = clean_token(w)
            if t:
                tokens_with_sentences.append({
                    'token': t,
                    'original_sentence': sentence,
                    'position': position
                })

    return tokens_with_sentences


# Read and process the SRT file
srt_lines_in_memory = read_srt_in_memory(REFERENCE_FILE)
original_tokens = extract_tokens_with_sentences(srt_lines_in_memory)

## Audio Complexity

In [None]:
# Initialize the Whisper model
AudioModel = WhisperModel(
    "tiny",
    device="cpu",
    compute_type="int8",
    cpu_threads=1,
    num_workers=1
)

# Transcribe audio with word timestamps and probabilities
segments, info = AudioModel.transcribe(
    AUDIO_FILE,
    word_timestamps=True,
    beam_size=1
)

def clean_word(token):
    """Clean word by removing punctuation and standardizing apostrophes"""
    return token.strip(string.punctuation + '""'').replace(" ", "").replace("'", "'").lower()

def remove_apostrophes(word):
    """Remove apostrophes for comparison"""
    return re.sub(r"['']", "", word).lower()

# Extract predicted tokens with probabilities and timestamps
predicted_tokens = []
for segment in segments:
    for word_info in segment.words:
        cleaned = clean_word(word_info.word)
        probability = word_info.probability
        timestep = word_info.start
        predicted_tokens.append({
            'token': cleaned,
            'probability': probability,
            'timestep': timestep
        })

original_sequence = [t['token'] for t in original_tokens]
predicted_sequence = [t['token'] for t in predicted_tokens]

# Perform sequence alignment using difflib.SequenceMatcher
matcher = difflib.SequenceMatcher(None, original_sequence, predicted_sequence)
aligned_results = []

for opcode in matcher.get_opcodes():
    tag, i1, i2, j1, j2 = opcode
    
    if tag == 'equal':
        # Direct 1:1 mapping between original and prediction
        for idx_orig, idx_pred in zip(range(i1, i2), range(j1, j2)):
            aligned_results.append({
                'word': original_tokens[idx_orig]['token'],
                'audio_complexity': 1 - predicted_tokens[idx_pred]['probability'],
                'timestep': predicted_tokens[idx_pred]['timestep']
            })

    elif tag == 'replace':
        # Handle different cases of replacements
        orig_joined = " ".join(original_sequence[i1:i2]).lower()
        pred_joined = " ".join(predicted_sequence[j1:j2]).lower()

        # Case 1: Handle contractions (e.g., "he's" ↔ "he is")
        if len(predicted_sequence[j1:j2]) == 1 and contractions.fix(predicted_sequence[j1]) == orig_joined:
            pred = predicted_tokens[j1]
            for idx_orig in range(i1, i2):
                aligned_results.append({
                    'word': original_tokens[idx_orig]['token'],
                    'audio_complexity': 1 - pred['probability'],
                    'timestep': pred['timestep']
                })

        # Case 2: Compare without apostrophes (e.g., "name's" ↔ "names")
        elif remove_apostrophes(orig_joined) == remove_apostrophes(pred_joined):
            pred = predicted_tokens[j1]
            for idx_orig in range(i1, i2):
                aligned_results.append({
                    'word': original_tokens[idx_orig]['token'],
                    'audio_complexity': 1 - pred['probability'],
                    'timestep': pred['timestep']
                })

        # Fallback: Handle non-matching words
        else:
            for idx_orig in range(i1, i2):
                aligned_results.append({
                    'word': original_tokens[idx_orig]['token'],
                    'audio_complexity': 1.0,
                    'timestep': None
                })

    elif tag == 'delete':
        # Handle words present in original but missing in prediction
        for idx_orig in range(i1, i2):
            aligned_results.append({
                'word': original_tokens[idx_orig]['token'],
                'audio_complexity': 1.0,
                'timestep': None
            })

    elif tag == 'insert':
        # Skip additional words in prediction
        pass

# Create DataFrame from results
df = pd.DataFrame(aligned_results)
print(df)

## Filter in/out

In [None]:
def detect_tokens_non_english(original_tokens, exception_words):
    """
    1) Detects language of each sentence once and caches it in sentence_language_cache
    2) If sentence != 'en' and != 'la', each token is checked individually:
       - If token detected as non 'en'/'la' -> process=True, display=True, set_manually=True
       - Otherwise -> process=False, display=False, set_manually=False  
    3) If sentence is English/Latin, all flags are set to False
    
    Returns DataFrame with same number of rows as original_tokens (index-synchronized)
    """
    # Step 1: Detect language once per sentence (cache)
    unique_sentences = {item['original_sentence'] for item in original_tokens}
    sentence_language_cache = {}
    for sentence in unique_sentences:
        cleaned_sentence = sentence.replace("\n", "").strip()
        if cleaned_sentence:
            result = detect(cleaned_sentence, low_memory=False)
            sentence_language_cache[sentence] = result["lang"]
        else:
            sentence_language_cache[sentence] = None  # No language detectable

    # Step 2: Check if sentence != 'en'/'la'; if yes, check token individually
    token_info = []
    for item in original_tokens:
        sent_lang = sentence_language_cache.get(item["original_sentence"], None)
        
        # If sentence detected as English or Latin -> all False
        if sent_lang in ("en", "la"):
            token_info.append({
                "word": item["token"],
                "display": False,
                "set_manually": False
            })
        else:
            # Sentence not en/la -> check token individually
            token_text = item["token"].strip()
            if token_text:
                if token_text in exception_words:
                    # Word in exception list -> all False
                    token_info.append({
                        "word": item["token"],
                        "display": False,
                        "set_manually": False
                    })
                else:
                    word_detection = detect(token_text, low_memory=False)
                    word_lang = word_detection["lang"]
                    if word_lang in ("en", "la"):
                        # Word is English/Latin -> all False
                        token_info.append({
                            "word": item["token"],
                            "display": False,
                            "set_manually": False
                        })
                    else:
                        # Word not English or Latin
                        token_info.append({
                            "word": item["token"],
                            "display": True,
                            "set_manually": True
                        })
                        print(f"Token '{item['token']}' in sentence '{item['original_sentence']}' is not English/Latin.")
            else:
                # Empty token -> all False optional
                token_info.append({
                    "word": item["token"],
                    "display": False,
                    "set_manually": False
                })

    return pd.DataFrame(token_info)

def mark_non_english_in_df(df: pd.DataFrame, original_tokens: list[dict], exception_words = ["i", "no"]) -> pd.DataFrame:
    """
    Creates helper DataFrame (df_language) using detect_tokens_non_english()
    and transfers process, display, set_manually columns to df,
    synchronized with original_tokens index.
    """
    # Create DataFrame with columns based on word-by-word check
    df_language = detect_tokens_non_english(original_tokens, exception_words)

    # Transfer columns to df, assuming df and original_tokens are index-synchronized 
    df["display"] = df_language["display"]
    df["set_manually"] = df_language["set_manually"]

    return df

def mark_notes_in_df(df: pd.DataFrame) -> pd.DataFrame:
    """Mark musical notes and text between note symbols in DataFrame"""
    skip = False
    marked_words = []  
    
    # Process all rows
    for idx, row in df.iterrows():
        if row["word"] == "♪":
            df.at[idx, 'display'] = True
            df.at[idx, 'set_manually'] = True
            df.at[idx, 'process'] = False
            marked_words.append(row["word"])
            skip = not skip
            continue
        
        if skip:
            df.at[idx, 'display'] = True
            df.at[idx, 'set_manually'] = True
            df.at[idx, 'process'] = False
            marked_words.append(row["word"]) 
    
    # Print summary of marked words
    if marked_words:
        print("\nIncluded words:")
        print(", ".join(marked_words))
    
    return df

def mark_excluded_words(df: pd.DataFrame) -> pd.DataFrame:
    """Mark words from EXCLUDED_WORDS list in DataFrame"""
    mask = df['word'].isin(EXCLUDED_WORDS)
    df.loc[mask, ['display', 'set_manually', 'process']] = [False, True, True]
    
    total_excluded = mask.sum()
    print(f"\nExcluded {total_excluded} words based on EXCLUDED_WORDS list")
    
    return df

def mark_numbers_in_df(df: pd.DataFrame) -> pd.DataFrame:
    """Mark numbers greater than 13 in DataFrame"""
    marked_words = []
    for idx, row in df.iterrows():
        try:
            num = float(row['word'])
            if num > 13:
                df.at[idx, 'display'] = True
                df.at[idx, 'set_manually'] = True
                df.at[idx, 'process'] = True
                marked_words.append(row['word'])
        except ValueError:
            continue
            
    if marked_words:
        print("\nIncluded numbers:")
        print(", ".join(marked_words))
    return df

# Initialize columns
df['display'] = None
df['set_manually'] = False
df['process'] = True

# Process DataFrame
exception_words = ["i", "no", "so"]
df = mark_non_english_in_df(df, original_tokens, exception_words)
df = mark_notes_in_df(df)
df = mark_excluded_words(df)
df = mark_numbers_in_df(df)

print(df)

In [None]:
print(df.to_string())

## Translation

In [None]:
translation_model_name = "Helsinki-NLP/opus-mt-en-de"
device = 0 if torch.cuda.is_available() else -1
translator_pipeline = pipeline(
    "translation",
    model=translation_model_name,
    device=device,
    max_length=512,
    early_stopping=True
)

aligner = SentenceAligner(
    model="bert",
    token_type="bpe",
    matching_methods="mai",
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

def batch_translate_and_align(all_tokens, batch_size=32):
    SPECIAL_TOKENS = {
        "♪": "♪"
    }

    # Remember the original index for restoring order later
    for i, tok in enumerate(all_tokens):
        tok["original_index"] = i

    # Group tokens by their original sentence
    sentence_to_tokens = defaultdict(list)
    for token_data in all_tokens:
        sentence_en = token_data["original_sentence"]
        sentence_to_tokens[sentence_en].append(token_data)

    unique_sentences = list(sentence_to_tokens.keys())

    # Translate unique sentences in batches
    translations = []
    for i in range(0, len(unique_sentences), batch_size):
        batch_sentences = unique_sentences[i : i + batch_size]
        batch_translations = translator_pipeline(batch_sentences)
        translations.extend(batch_translations)

    # Map each English sentence to its German translation
    sentence_to_de = {}
    for i, sentence_en in enumerate(unique_sentences):
        sentence_to_de[sentence_en] = translations[i]["translation_text"]

    # Align each sentence once
    alignment_info = {}
    for sentence_en in unique_sentences:
        # Sort tokens by position for alignment (but they still remember their original_index)
        token_list = sorted(sentence_to_tokens[sentence_en], key=lambda x: x["position"])
        src_tokens = [td["token"] for td in token_list]

        german_sent = sentence_to_de[sentence_en]
        tgt_tokens = german_sent.split()

        aligns = aligner.get_word_aligns(src_tokens, tgt_tokens)
        alignment_pairs = aligns["inter"]

        alignment_info[sentence_en] = {
            "german_sentence": german_sent,
            "src_tokens": src_tokens,
            "tgt_tokens": tgt_tokens,
            "alignment_pairs": alignment_pairs
        }

    # Build final results for each token
    results = []
    for sentence_en, token_data_list in sentence_to_tokens.items():
        info = alignment_info[sentence_en]
        german_sent = info["german_sentence"]
        tgt_tokens = info["tgt_tokens"]
        alignment_pairs = info["alignment_pairs"]

        # Sort by position for alignment
        token_data_list_sorted = sorted(token_data_list, key=lambda x: x["position"])
        for idx, token_data in enumerate(token_data_list_sorted):
            token = token_data["token"]
            if token in SPECIAL_TOKENS:
                aligned_words_cleaned = SPECIAL_TOKENS[token]
            else:
                # Find all target indices in the alignment for this source index (idx)
                aligned_indices = [
                    tgt_idx for (src_idx, tgt_idx) in alignment_pairs
                    if src_idx == idx
                ]
                if not aligned_indices:
                    # If no aligned word, do single-word translation
                    single_word_translation = translator_pipeline([token])[0]["translation_text"]
                    aligned_words_cleaned = single_word_translation
                else:
                    # Join all aligned words in the target
                    aligned_words = [tgt_tokens[t_i] for t_i in aligned_indices if 0 <= t_i < len(tgt_tokens)]
                    aligned_words_cleaned = " ".join(
                        w.translate(str.maketrans('', '', string.punctuation))
                        for w in aligned_words
                    )

            results.append({
                "original_index": token_data["original_index"],
                "english_token": token,
                "english_sentence": sentence_en,
                "german_translation": aligned_words_cleaned,
                "german_full_sentence": german_sent
            })

    # Important: restore original order
    results_sorted = sorted(results, key=lambda x: x["original_index"])
    return results_sorted

# Suppose you already have original_tokens as in your code
result = batch_translate_and_align(original_tokens, batch_size=1)

# Now create your DataFrame in the same order as original_tokens
df_translations = pd.DataFrame(result)
df['translation'] = df_translations['german_translation']


In [None]:
print(df.to_string())

## Word Occurrence

In [None]:
# Create a dictionary to keep track of word counts and calculate the normalized values directly
word_counts = {}

# Function to get count and update dictionary with normalized value
def get_normalized_occurrence(word, max=15): # More than 15 per word is unnecessary and more detailed values for smaller values
    word_counts[word] = word_counts.get(word, 0) + 1
    # Normalize between 0 and 1 using min=1 and max=200
    if word_counts[word] >= max:
        return 1.0
    return (word_counts[word] - 1) / (max - 1)

# Add column showing normalized occurrence value only for rows where process is True
df.loc[df['process'], 'word_occurrence'] = df.loc[df['process'], 'word'].apply(get_normalized_occurrence)

print(df)

## Word Complexity

In [None]:
# Define Analyzer class
analyzer = CEFRAnalyzer()

# CEFR levels mapping
cefr_levels = {
    'A1': 0.0,
    'A2': 0.2,
    'B1': 0.4,
    'B2': 0.6,
    'C1': 0.8,
    'C2': 1.0
}

# Function to get complexity score for a word
def get_word_complexity(word):
    try:
        level = analyzer.get_average_word_level_CEFR(word)
        if level is not None:
            score = cefr_levels.get(level.name, 0.0)
            #print(f"Word: {word}, Level: {level.name}, Score: {score}")
            return score
    except Exception as e:
        print(f"Error retrieving level for word '{word}': {e}")
        return 0.0
    return 0.0

# Compute word complexities
df.loc[df['process'], 'word_complexity'] = df.loc[df['process'], 'word'].apply(get_word_complexity)

# Output updated DataFrame
print(df)

## Sentence Complexity

In [None]:
def calculate_sentence_complexities(original_tokens):
    """
    Calculate sentence complexities and map them to individual words.
    Returns DataFrame with complexity scores normalized by maximum entropy.
    """
    def calculate_entropy(sentence):
        """Calculate entropy score for a given sentence using MLM predictions"""
        inputs = tokenizer_mlm(sentence, return_tensors='pt')
        with torch.no_grad():
            outputs = model_mlm(**inputs)
        predictions = torch.softmax(outputs.logits, dim=-1)
        
        token_probs = []
        for i, token_id in enumerate(inputs.input_ids[0]):
            token_prob = predictions[0, i, token_id].item()
            token_probs.append(token_prob)
        
        return entropy(token_probs)

    # Calculate max_entropy using a complex reference sentence
    reference_sentence = "I used to believe that technology could save us from the climate crisis, that all the big brains in the world would come up with a silver bullet to stop carbon pollution, that a clever policy would help that technology spread, and our concern about the greenhouse gases heating the planet would be a thing of the past, and we wouldn't have to worry about the polar bears anymore."
    max_entropy = calculate_entropy(reference_sentence)

    # Create cache for computed sentence complexities
    sentence_complexity_cache = {}
    
    # Calculate complexity only for unique sentences
    unique_sentences = set(item['original_sentence'] for item in original_tokens)
    for sentence in unique_sentences:
        entropy_value = calculate_entropy(sentence)
        normalized_entropy = entropy_value / max_entropy
        sentence_complexity_cache[sentence] = normalized_entropy
    
    # Create result list using cached values
    word_entries = [
        {
            'word': item['token'],
            'sentence_complexity': sentence_complexity_cache[item['original_sentence']]
        }
        for item in original_tokens
    ]
    
    return pd.DataFrame(word_entries)

# Initialize tokenizer and model for Masked Language Model
model_id = "answerdotai/ModernBERT-base"
tokenizer_mlm = AutoTokenizer.from_pretrained(model_id)
model_mlm = AutoModelForMaskedLM.from_pretrained(model_id)

# Calculate sentence complexities and add to DataFrame
df_sentence_complexity = calculate_sentence_complexities(original_tokens)
df["sentence_complexity"] = df_sentence_complexity["sentence_complexity"]
print(df)

## Word Importance

In [None]:
def compute_word_importance_single(sentence, original_token, position, tokenizer_mlm, model_mlm):
    """
    Compute importance of a word in a sentence.
    
    Args:
        sentence: Original sentence (string)
        original_token: Original token (e.g. "he's")
        position: 1-based position of token in sentence (e.g. 1 = first word)
        tokenizer_mlm: Hugging Face tokenizer for Masked LM
        model_mlm: Hugging Face model for Masked LM
    
    Returns:
        float: Importance score between 0 and 1
    """
    # Tokenize entire sentence with word IDs
    encoding = tokenizer_mlm(
        sentence,
        return_tensors='pt',
        add_special_tokens=True
    )
        
    # word_ids() returns index of original word for each sub-token
    # or None (e.g. for [CLS], [SEP])
    word_ids = encoding.word_ids(batch_index=0)  # List of length seq_len

    # Target is (position - 1) since word_ids() starts at 0
    # while 'position' in original_tokens starts at 1
    target_word_id = position - 1
    
    # Find all sub-token indices belonging to this word
    subtoken_indices = [i for i, w_id in enumerate(word_ids) if w_id == target_word_id]
    
    # If no sub-tokens found, word may not exist or position mismatch -> importance = 0
    if not subtoken_indices:
        return 0.0
    
    # Create copy of token IDs for masking
    masked_input = {k: v.clone() for k, v in encoding.items() if isinstance(v, torch.Tensor)}
    
    # Replace all relevant subtokens with [MASK]
    for idx in subtoken_indices:
        masked_input['input_ids'][0, idx] = tokenizer_mlm.mask_token_id
    
    # Run inference
    with torch.no_grad():
        outputs = model_mlm(**masked_input)
    predictions = outputs.logits  # [batch_size, seq_len, vocab_size]
    
    # Calculate probability of original token
    # Use product of probabilities for all subtokens (common choice)
    prob_product = 1.0
    
    for idx in subtoken_indices:
        # Softmax for current subtoken
        softmax_probs = torch.softmax(predictions[0, idx], dim=-1)
        original_token_id = encoding.input_ids[0, idx]  # ID of "real" subtoken
        subtoken_prob = softmax_probs[original_token_id].item()
        prob_product *= subtoken_prob
    
    # Define importance as (1 - product of subtoken probabilities)
    importance = 1.0 - prob_product
    return importance

# Calculate word importance for each token
word_importances = []
for i, item in enumerate(original_tokens):
    # Check if processing is requested for this row according to DataFrame
    if df.loc[i, "process"]:
        importance = compute_word_importance_single(
            sentence=item["original_sentence"],
            original_token=item["token"],
            position=item["position"],
            tokenizer_mlm=tokenizer_mlm,
            model_mlm=model_mlm
        )
    else:
        # Set default value if process=False
        importance = None
    
    word_importances.append(importance)

# Add importances to DataFrame and format
df["word_importance"] = word_importances
df["word_importance"] = df["word_importance"].apply(lambda x: f"{x:.10f}")

print(df)

In [None]:
print(df.to_string())

## Import Dataset

In [None]:
file_names = ['dataset1.csv', 'datatset2.csv', 'dataset3.csv']

dfs = []
for file_name in file_names:
    df_temp = pd.read_csv(file_name)
    dfs.append(df_temp)

df = pd.concat(dfs, ignore_index=True)
df

## Random NN as Comparison

In [None]:
# Prepare data
X = df[['audio_complexity', 'word_complexity', 'sentence_complexity', 
        'word_importance', 'word_occurrence']].values
y = df['display'].astype(int).values  # Convert to integer

# Calculate baseline random accuracy
random_accuracy = 1 - df['display'].mean()
print(f"Random Accuracy: {random_accuracy * 100:.2f}%")

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Evaluate baseline models
# 1. Stratified random baseline
dummy_clf = DummyClassifier(strategy='stratified', random_state=42)
dummy_clf.fit(X, y)
y_pred = dummy_clf.predict(X)
y_pred_proba = dummy_clf.predict_proba(X)

# 2. Most frequent class baseline
dummy_clf_freq = DummyClassifier(strategy='most_frequent', random_state=42)
dummy_clf_freq.fit(X, y)
y_pred_freq = dummy_clf_freq.predict(X)

# Calculate metrics
print("\nBaseline Performance Metrics:")
print("-" * 30)

# F1 Scores
baseline_f1 = f1_score(y, y_pred)
baseline_f1_freq = f1_score(y, y_pred_freq)
print(f"F1 Score (stratified): {baseline_f1:.4f}")
print(f"F1 Score (most frequent): {baseline_f1_freq:.4f}")

# F2 Scores
f2_score_stratified = fbeta_score(y, y_pred, beta=2)
f2_score_most_frequent = fbeta_score(y, y_pred_freq, beta=2)
print(f"F2 Score (stratified): {f2_score_stratified:.4f}")
print(f"F2 Score (most frequent): {f2_score_most_frequent:.4f}")

# Additional metrics
bal_acc = balanced_accuracy_score(y, y_pred)
roc_auc = roc_auc_score(y, y_pred_proba[:, 1])
print(f"Balanced Accuracy: {bal_acc:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")

# Class distribution
print(f"\nClass distribution:")
print(np.bincount(y))

# Detailed classification reports
print("\nClassification Report (stratified):")
print(classification_report(y, y_pred))

print("\nClassification Report (most frequent):")
print(classification_report(y, y_pred_freq))

## Model

In [None]:
class BinaryClassifier(nn.Module):
    def __init__(self, input_features):
        super(BinaryClassifier, self).__init__()
        
        # Create dynamic layer list
        layers = []
        current_size = input_features
        
        # Build hidden layers dynamically
        for hidden_size in wandb.config.hidden_layers:
            layers.extend([
                nn.Linear(current_size, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(),
                nn.Dropout(wandb.config.dropout_rate)
            ])
            current_size = hidden_size
        
        # Add final output layer
        layers.append(nn.Linear(current_size, 1))
        
        # Create sequential model
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Custom Loss Functions
class FocalLoss(nn.Module):
    """Focal loss for binary classification"""
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        # Ensure tensors have same shape
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        bce_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return torch.mean(focal_loss)

class FocalLossWithSigmoid(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=gamma)

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)   # Sigmoid here
        return self.focal(inputs, targets)

class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05):
        super(AsymmetricLoss, self).__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip

    def forward(self, inputs, targets):
        targets = targets.view(-1, 1)
        inputs = torch.clamp(inputs, self.clip, 1 - self.clip)
        
        # Positive samples
        pt_pos = torch.where(targets == 1, inputs, torch.ones_like(inputs))
        loss_pos = -torch.log(pt_pos) * torch.pow(1 - pt_pos, self.gamma_pos) * targets
        
        # Negative samples
        pt_neg = torch.where(targets == 0, 1 - inputs, torch.ones_like(inputs))
        loss_neg = -torch.log(pt_neg) * torch.pow(1 - pt_neg, self.gamma_neg) * (1 - targets)
        
        return torch.mean(loss_pos + loss_neg)

class BCEWithSigmoid(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCELoss()

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)   # Sigmoid here
        return self.bce(inputs, targets)

class AsymmetricLossWithSigmoid(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05):
        super().__init__()
        self.asl = AsymmetricLoss(gamma_neg=gamma_neg, gamma_pos=gamma_pos, clip=clip)

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)   # Sigmoid here
        return self.asl(inputs, targets)

class LLoss(nn.Module):
    def __init__(self, beta=1):
        super(LLoss, self).__init__()
        self.beta = beta
    
    def forward(self, inputs, targets):
        targets = targets.view(-1, 1)
        L = 1 + torch.pow(inputs - targets, 2)
        L = torch.log(L)
        return torch.mean(L)

class MLoss(nn.Module):
    def __init__(self, beta=1):
        super(MLoss, self).__init__()
        self.beta = beta
    
    def forward(self, inputs, targets):
        targets = targets.view(-1, 1)
        M = torch.abs(inputs - targets)
        M = 1 - torch.exp(-self.beta * M)
        return torch.mean(M)

# Loss Function Factory
def get_loss_function(loss_name):
    if loss_name == "BCE":
        return BCEWithSigmoid().to(device)
    elif loss_name == "BCEWithLogits":
        pos_weight = torch.tensor([(y_train == 0).sum() / (y_train == 1).sum()]).to(device)
        return nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
    elif loss_name == "Focal":
        return FocalLossWithSigmoid(
            alpha=wandb.config.loss_params["focal_alpha"],
            gamma=wandb.config.loss_params["focal_gamma"]
        ).to(device)
    elif loss_name == "Asymmetric":
        return AsymmetricLossWithSigmoid(
            gamma_neg=wandb.config.loss_params["asymmetric_gamma_neg"],
            gamma_pos=wandb.config.loss_params["asymmetric_gamma_pos"],
            clip=wandb.config.loss_params["asymmetric_clip"]
        ).to(device)
    elif loss_name == "L":
        return LLoss(beta=wandb.config.loss_params["l_beta"]).to(device)
    elif loss_name == "M":
        return MLoss(beta=wandb.config.loss_params["m_beta"]).to(device)
    else:
        raise ValueError(f"Unbekannte Loss-Funktion: {loss_name}")

def get_optimizer(model, optimizer_name):
    if optimizer_name == "AdamW":
        return torch.optim.AdamW(
            model.parameters(),
            lr=wandb.config.learning_rate,
            weight_decay=wandb.config.weight_decay
        )
    elif optimizer_name == "Adam":
        return torch.optim.Adam(
            model.parameters(),
            lr=wandb.config.learning_rate,
            weight_decay=wandb.config.weight_decay
        )
    elif optimizer_name == "RMSprop":
        return torch.optim.RMSprop(
            model.parameters(),
            lr=wandb.config.learning_rate,
            weight_decay=wandb.config.weight_decay
        )
    else:
        raise ValueError(f"Optimizer {optimizer_name} nicht unterstützt")

def train_epoch(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []
    
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs.squeeze(), y_batch)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        predictions.extend((torch.sigmoid(outputs) > 0.5).squeeze().cpu().detach().numpy())
        true_labels.extend(y_batch.cpu().numpy())
    
    # Berechne beide Metriken
    f1 = f1_score(true_labels, predictions)
    f2 = fbeta_score(true_labels, predictions, beta=2.0)
    
    return total_loss / len(train_loader), {"f1": f1, "f2": f2}

def calculate_recall_pos(all_targets, all_preds):
    return recall_score(all_targets, all_preds, pos_label=1)

def evaluate(model, val_loader, criterion=None, is_final=False, is_test=False):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            
            outputs = model(X_batch)
            if not is_final:
                loss = criterion(outputs.squeeze(), y_batch)
                total_loss += loss.item()
            
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            
            all_preds.extend(preds.view(-1).cpu().numpy())
            all_targets.extend(y_batch.cpu().numpy())
            all_probs.extend(probs.view(-1).cpu().numpy())
    
    f1 = f1_score(all_targets, all_preds)
    f2 = fbeta_score(all_targets, all_preds, beta=2.0) 
    
    if is_final:
        bacc = balanced_accuracy_score(all_targets, all_preds)
        auc = roc_auc_score(all_targets, all_probs)
        
        print('\nFinale Evaluierung:')
        print(f'F1-Score: {f1:.4f}')
        print(f'F2-Score: {f2:.4f}')  
        print(f'Balanced Accuracy: {bacc:.4f}')
        print(f'ROC-AUC: {auc:.4f}')
        
        if not is_test:
            wandb.log({
                "final_f1": f1,
                "final_f2": f2, 
                "final_balanced_accuracy": bacc,
                "final_roc_auc": auc
            })
        
        return f1, f2, bacc, auc  
    else:
        avg_loss = total_loss / len(val_loader)
        return avg_loss, {"f1": f1, "f2": f2}  


def train_model(model, train_loader, val_loader, epochs=None, patience=None):    
    epochs = epochs if epochs is not None else wandb.config.epochs
    patience = patience if patience is not None else wandb.config.early_stopping_patience
    
    criterion = get_loss_function(wandb.config.loss_function)
    optimizer = get_optimizer(model, wandb.config.optimizer)
    
    best_val_f2 = 0  
    patience_counter = 0
    
    train_losses = []
    val_losses = []
    train_f1s = []
    train_f2s = []
    val_f1s = []
    val_f2s = []  
    
    for epoch in range(epochs):
        train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_metrics = evaluate(model, val_loader, criterion, is_final=False)
        
        val_f1 = val_metrics["f1"]
        val_f2 = val_metrics["f2"]  
        
        print(f'Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f} | Val F2: {val_f2:.4f}')
        
        wandb.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_f1': val_f1,
            'val_f2': val_f2 
        })
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_f1s.append(train_metrics["f1"])
        train_f2s.append(train_metrics["f2"])
        val_f1s.append(val_f1)
        val_f2s.append(val_f2)  
        
        # Modell-Speicherung basierend auf F2-Score
        if val_f2 > best_val_f2:  
            best_val_f2 = val_f2
            torch.save(model.state_dict(), 'best_model.pth')
            wandb.save('best_model.pth')
            print(f'Epoch {epoch:03d}: New best model saved with F2 = {val_f2:.4f}')
            patience_counter = 0
        else:
            patience_counter += 1

        
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            break
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pth'))
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_f2s, label='Train FBeta')
    plt.plot(val_f2s, label='Val FBeta')
    plt.title('Recall Scores')
    plt.xlabel('Epoch')
    plt.ylabel('Recall Score')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Final evaluation
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device) # Daten auf GPU laden
            outputs = model(X_batch)
            predictions.extend((torch.sigmoid(outputs) > 0.5).squeeze().cpu().numpy())
            true_labels.extend(y_batch.cpu().numpy())
    
    print("\nFinal Classification Report:")
    print(classification_report(true_labels, predictions))
    
    return model

def prepare_data():
    # Daten vorbereiten
    X = df[['audio_complexity', 'word_complexity', 'sentence_complexity', 'word_importance', 'word_occurrence']].values
    y = df['display'].astype(int).values

    global y_train, X_train
    # Aufteilen in Training und Validation
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) # stratify für Klassenbalance wie im Original

    # Umwandeln in Tensoren, für Pytorch notwendig
    X_train = torch.FloatTensor(X_train).to(device) # Tensoren direkt auf GPU laden
    X_val = torch.FloatTensor(X_val).to(device) # Tensoren direkt auf GPU laden
    y_train = torch.FloatTensor(y_train).to(device) # Tensoren direkt auf GPU laden
    y_val = torch.FloatTensor(y_val).to(device) # Tensoren direkt auf GPU laden

    train_dataset = TensorDataset(X_train, y_train) # TensorDataset, simpler als Dataset
    val_dataset = TensorDataset(X_val, y_val)

    global train_loader
    # Gewichtetes Sampling für unbalancierte Klassen
    if wandb.config.weighted_sampler:
        class_counts = np.bincount(y_train.int().cpu().numpy()) # .int() für Datentype int32, .numpy() für Numpy-Array, bincount für Klassenverteilung (bei [0, 1, 1, 2, 2, 2] = [1, 2, 3])
        # [3771  317]
        weights = 1.0 / torch.tensor(class_counts, dtype=torch.float) # Klassen werden je nach Anzahl in Trainingsdaten gewichtet
        # 1 / tensor([3771.,  317.]) > tensor([0.0003, 0.0032])
        samples_weights = weights[y_train.int().cpu()] # Gewichtung für jedes Sample, also false * 0.0003 und true * 0.0032 > damit hat true höhere P gezogen zu werden

        sampler = WeightedRandomSampler(
            weights=samples_weights,
            num_samples=len(samples_weights),
            replacement=True # Mit replacement=True können die wenigen Samples der Minderheitsklasse mehrfach verwendet werden, bei false wäre gewichtetes Sampling nicht möglich
        )

        # DataLoader mit Weighted Sampler
        train_loader = DataLoader(
            train_dataset,
            batch_size=wandb.config.batch_size,
            sampler=sampler,
            drop_last=True 
        )
    else:
        # DataLoader ohne Weighted Sampling
        train_loader = DataLoader(
            train_dataset,
            batch_size=wandb.config.batch_size,
            shuffle=True,
            drop_last=True 
        )

    global val_loader
    # Validation DataLoader (ohne Weighted Sampling, da nur für Evaluierung)
    val_loader = DataLoader(val_dataset, batch_size=wandb.config.batch_size, shuffle=False, drop_last=True) # Shuffle beim Validieren nicht notwendig



def start_training():
    # Überprüfen, ob eine GPU verfügbar ist
    global device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # Wandb Initialisierung
    wandb.init(
        project="",
        config={
            "learning_rate": 0.004252027091067649,
            "architecture": "MLP",
            "dataset": "3boys",
            "epochs": 150,
            "batch_size": 16,
            "optimizer": "AdamW",  # Choices: 'AdamW', 'Adam', 'RMSprop'
            "weight_decay": 0.0344653149484944,
            "loss_function": "BCE",  # Choices: ['BCE', 'BCEWithLogits', 'Focal', 'Asymmetric', 'L', 'M']
            "loss_params": {
                "focal_alpha": 1,
                "focal_gamma": 2,
                "asymmetric_gamma_neg": 4,
                "asymmetric_gamma_pos": 1,
                "asymmetric_clip": 0.05,
                "l_beta": 1,
                "m_beta": 1
            },
            "weighted_sampler": True,
            "hidden_layers": [64, 128, 64],
            "dropout_rate": 0.14279874595232844,
            "early_stopping_patience": 20,
            "beta_score": 2.0
        }
    )

    # Daten vorbereiten
    prepare_data()

    # Train model
    num_features = X_train.shape[1]  
    model = BinaryClassifier(input_features=num_features)
    train_model(model, train_loader, val_loader)

    # Final evaluation
    model.load_state_dict(torch.load('best_model.pth'))
    criterion = get_loss_function(wandb.config.loss_function)
    model.eval()
    evaluate(model, val_loader, criterion, is_final=True, is_test=False)
    wandb.finish()


start_training()

## Sweep GPU-Basismodel

In [None]:
sweep_config = {
    'method': 'bayes', 
    'metric': {
        'name': 'val_f2',
        'goal': 'maximize'
    },
    'parameters': {
        'learning_rate': {
            'min': 0.0001,
            'max': 0.01
        },
        'batch_size': {
            'values': [8, 16, 32, 64]
        },
        'optimizer': {
            'values': ['AdamW', 'Adam', 'RMSprop']
        },
        'weight_decay': {
            'min': 0.001,
            'max': 0.1
        },
        'loss_function': {
            'values': ['BCE', 'BCEWithLogits'] # 'Focal', 'Asymmetric', 'L', 'M'
        },
        'weighted_sampler': {
            'values': [True, False]  
        },
        'hidden_layers': {
            'values': [[16, 8], [8, 16], [32, 16], [64, 32], [8, 4, 2], [16, 8, 4], [32, 16, 8], [64, 32, 16], [128, 64, 32], [256, 128, 64], [64, 128, 64], [32, 64, 32], [16, 32, 16], [8, 16, 8], [4, 8, 8, 4], [4, 8, 16, 8], [8, 16, 16, 8], [8, 16, 32, 16, 8]]
        },
        'dropout_rate': {
            'min': 0.1,
            'max': 0.8
        },
        'epochs': {
            'value': 150 
        },
        'early_stopping_patience': {
            'value': 10
        }
    }
}

project_name = ""
sweep_id = wandb.sweep(sweep_config, project=project_name)
wandb.agent(sweep_id, function=start_training, count=150)

## Evaluieren

In [None]:
model = BinaryClassifier(input_features=5).to(device)
model.load_state_dict(torch.load('model.pth', map_location=device))

model.eval()

evaluate(model, val_loader, is_final=True, is_test=True)

predictions = []
true_labels = []
with torch.no_grad():
    for X_batch, y_batch in val_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)
        predictions.extend((torch.sigmoid(outputs) > 0.5).squeeze().cpu().numpy())
        true_labels.extend(y_batch.cpu().numpy())

print("\nFinal Classification Report:")
print(classification_report(true_labels, predictions))



## Display?

In [None]:
model = BinaryClassifier(input_features=5).to(device)
model.load_state_dict(torch.load('best_model.pth', map_location=device))

model.eval()

def predict_with_bias(df, bias=0.0, batch_size=64):
    predictions = []
    
    try:
        features = df[['audio_complexity', 'word_complexity', 'sentence_complexity',
               'word_importance', 'word_occurrence']].astype('float32').values
        features = torch.tensor(features, dtype=torch.float32).to(device)

        
        # Batched predictions
        with torch.no_grad():
            for i in range(0, len(features), batch_size):
                batch = features[i:i+batch_size]
                
                # Predictions
                outputs = model(batch)
                
                # Add Bias 
                outputs += bias
            
                predictions.extend((torch.sigmoid(outputs) > 0.5).cpu().numpy().flatten())
        
        return predictions
        
    except Exception as e:
        print(f"Fehler bei der Vorhersage: {e}")
        return None

df.loc[~df['set_manually'], 'display'] = predict_with_bias(df.loc[~df['set_manually']], bias=0.0)

# More positive classifications:
# df.loc[~df['set_manually'], 'display'] = predict_with_bias(dfa, bias=0.1)  # Apply positive bias for more display=True

# More negative classifications:
# df.loc[~df['set_manually'], 'display']= predict_with_bias(dfb, bias=-0.1)  # Apply negative bias for more display=False

print(f"Percentage displayed in df (no bias): {(df['display'].sum() / len(df)) * 100:.2f}%")
#print(f"Percentage displayed in dfa (positive bias): {(df['display'].sum() / len(dfa)) * 100:.2f}%")
#print(f"Percentage displayed in dfb (negative bias): {(df['display'].sum() / len(dfb)) * 100:.2f}%")


In [None]:
print(df.to_string())

## Generating Subtitles

In [None]:
def remove_punctuation_not_between_letters(text: str) -> str:
    """
    Removes punctuation marks (? , ; . -) only when they are not
    between two alphanumeric characters.
    """
    return re.sub(r'(?<![A-Za-z0-9])[?;.,=-]+|[?;.,=-]+(?![A-Za-z0-9])', '', text)

def seconds_to_srt_timestamp(seconds: float) -> str:
    """
    Converts time in seconds to SRT timestamp format "HH:MM:SS,mmm".
    """
    if pd.isna(seconds):
        return None
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millis = int(round((seconds - int(seconds)) * 1000))
    return f"{hours:02}:{minutes:02}:{secs:02},{millis:03}"

def count_tokens(block_lines: list[str]) -> int:
    """
    Counts number of alphanumeric tokens in block lines.
    """
    count = 0
    for line in block_lines[2:]:
        words = line.split()
        count += len([w for w in words if clean_token(w)])
    return count

def merge_and_clean_placeholders(words: list[str], placeholder: str) -> list[str]:
    """
    Merges consecutive placeholders and removes placeholders
    at start or end of sequence.
    """
    merged_words = []
    prev_w = None
    for w in words:
        # Avoid consecutive placeholders
        if w == placeholder and prev_w == placeholder:
            continue
        merged_words.append(w)
        prev_w = w
    # Remove leading/trailing placeholders
    while merged_words and merged_words[0] == placeholder:
        merged_words.pop(0)
    while merged_words and merged_words[-1] == placeholder:
        merged_words.pop()
    return merged_words

def create_srt_file(
    srt_lines: list[str],
    df: pd.DataFrame,
    new_srt_file: str,
    original_timesteps: bool = False,
    languages: list[str] = ["en", "de"]
) -> None:
    """
    Creates new SRT file with following parameters:
    - original_timesteps: If True, keeps original timestamps;
                         otherwise shifts start time to first displayed word
    - languages: List with "en" (original) and/or "de" (translation).
                 If both included, writes two blocks.
    """
    placeholder = '•'
    display_flags = df["display"].tolist()
    df_index = 0

    with open(new_srt_file, 'w', encoding='utf-8') as outfile:
        subtitle_block = []
        for line in srt_lines:
            # Block ends with empty line
            if line.strip() == '':
                if subtitle_block:
                    _write_block(subtitle_block, display_flags, df_index, df,
                               outfile, placeholder, languages, original_timesteps)
                    df_index += count_tokens(subtitle_block)
                    subtitle_block = []
                outfile.write('\n')
            else:
                subtitle_block.append(line)

        # Process final block
        if subtitle_block:
            _write_block(subtitle_block, display_flags, df_index, df,
                        outfile, placeholder, languages, original_timesteps)

def _write_block(block_lines: list[str],
                display_flags: list[bool],
                start_idx: int,
                df: pd.DataFrame,
                outfile,
                placeholder: str,
                languages: list[str],
                original_timesteps: bool):
    """
    Writes according to languages:
      - original only (en)
      - translation only (de)
      - both (en + de)

    Timestamp either kept (original_timesteps=True)
    or set to first displayed word (original_timesteps=False)
    """
    if len(block_lines) < 2:
        return

    # 1) Nummer und Timestamp extrahieren
    subtitle_number = block_lines[0].strip()
    timestamp_line = block_lines[1].strip()
    text_lines = block_lines[2:]
    token_count = count_tokens(block_lines)

    # 2) Erstes angezeigtes Wort finden (falls original_timesteps=False)
    first_displayed_idx = None
    for i in range(token_count):
        current_idx = start_idx + i
        if current_idx < len(display_flags) and display_flags[current_idx]:
            first_displayed_idx = current_idx
            break

    # Timestamp anpassen, falls gewünscht
    try:
        original_start, original_end = timestamp_line.split(' --> ')
    except ValueError:
        print(f"Warnung: Ungültiges Timestamp-Format in Block {subtitle_number}.")
        original_start, original_end = "00:00:00,000", "00:00:00,000"

    if not original_timesteps and first_displayed_idx is not None:
        timestep = df.loc[first_displayed_idx, 'timestep']
        if not pd.isna(timestep):
            new_start_time = seconds_to_srt_timestamp(timestep)
            if new_start_time:
                timestamp_line = f"{new_start_time} --> {original_end}"

    # 3) Tokens sammeln und entscheiden, ob "•" eingefügt wird
    #    (wenn zwischen zwei angezeigten Wörtern min. eines übersprungen wurde)
    filtered_tokens_per_lang = {lang: [] for lang in languages}
    df_index = start_idx
    # Hier merken wir uns für jede Sprache,
    # ob wir zuletzt "Skip" hatten (mindestens ein Wort ausgelassen).
    skip_occurred = {lang: False for lang in languages}

    for line in text_lines:
        words = line.split()
        for w in words:
            t_clean = clean_token(w)
            # Prüfung: ist das ein echtes Wort oder nur z.B. Satzzeichen?
            if t_clean is None:
                # Wenn kein alphanumerisches Token -> wir ignorieren es hier,
                # oder fügst du ggf. stattdessen placeholder hinzu
                for lang in languages:
                    # Wenn du Satzzeichen unbedingt behalten willst, kann man
                    # line_parts[lang].append(w) tun. Sonst:
                    pass
                continue

            # Jetzt prüfen, ob das Wort angezeigt wird
            if df_index < len(display_flags) and display_flags[df_index]:
                # Wenn vorher ein Skip war -> Platzhalter einfügen
                for lang in languages:
                    if skip_occurred[lang]:
                        filtered_tokens_per_lang[lang].append(placeholder)
                        skip_occurred[lang] = False
                # Wort übernehmen
                if "en" in languages:
                    filtered_tokens_per_lang["en"].append(w)
                if "de" in languages:
                    filtered_tokens_per_lang["de"].append(df.loc[df_index, "translation"])
            else:
                # Wort wird ausgelassen -> Skip notieren
                for lang in languages:
                    skip_occurred[lang] = True
            df_index += 1

    # 4) Jetzt entfernen wir z.B. doppelte Platzhalter,
    #    und bereinigen Satzzeichen, falls nötig
    for lang in languages:
        merged = merge_and_clean_placeholders(filtered_tokens_per_lang[lang], placeholder)
        # Satzzeichen bereinigen
        merged_line = remove_punctuation_not_between_letters(" ".join(merged))

        # Am Ende in eine Liste packen, damit wir pro Sprache final ausgeben können
        filtered_tokens_per_lang[lang] = [merged_line] if merged_line else []

    # 5) Finale Ausgabe: Jede Sprache wird zu genau einer Zeile zusammengefasst
    for lang in languages:
        text_lines_lang = filtered_tokens_per_lang[lang]
        if text_lines_lang:
            single_line = " ".join(text_lines_lang)
            outfile.write(f"{subtitle_number}\n")
            outfile.write(f"{timestamp_line}\n")
            outfile.write(f"{single_line}\n\n")


# Extract display flags and create filtered SRT
display_flags = df['display'].tolist()

create_srt_file(
    srt_lines=srt_lines_in_memory,
    df=df,
    new_srt_file=FILTERED_SRT,
    original_timesteps=False,
    languages=["en", "de"]
)

print(f"Filtered SRT written to: {FILTERED_SRT}")
