In [1]:
import requests
import PyPDF2
import json
import re
import nltk
import os
import numpy as np
import torch
import gradio as gr
from io import BytesIO
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import T5ForConditionalGeneration, TrainingArguments, Trainer

In [2]:
# Download NLTK tokenizer if not available
nltk.download('punkt_tab')
from nltk.tokenize import sent_tokenize

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/jbodrenko/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [3]:
# Model for Labeling the dataset with questions and answers
model_name = "tiiuae/Falcon3-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
# Define file path
dataset_file = "directive_dataset.json"
directive_pdf_url = "https://eur-lex.europa.eu/legal-content/EN/TXT/PDF/?uri=CELEX:32018L1972"

# Fetch and parse the directive text
def fetch_directive_pdf(url):
    response = requests.get(url)
    if response.status_code != 200:
        raise Exception("Failed to fetch the directive PDF")
    
    pdf_file = BytesIO(response.content)
    reader = PyPDF2.PdfReader(pdf_file)
    text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text())

    # Fix hyphenation and normalize spaces
    text = re.sub(r"(\w+)-\s+(\w+)", r"\1\2", text)  # Remove hyphenation
    text = re.sub(r"\s+", " ", text).strip()  # Normalize spaces
    text = re.sub(r'(\d{2}\.\d{2}\.\d{4})', r'[\1]', text)  # Wrap dates in square brackets
    text = re.sub(r'\bL\s+\d{3}/\d{2}\s+EN\b', r'[L 321/98 EN]', text)  # Wrap references

    return text

# Prepare dataset for Hugging Face tokenizers
def prepare_huggingface_dataset(text, tokenizer, max_length=500, min_length=10):
    sentences = sent_tokenize(text)
    sentence_list = []
    sent_lengths = []

    for sentence in sentences:
        tokenized_sentence = tokenizer(sentence, truncation=False, padding=False)
        sentence_length = len(tokenized_sentence['input_ids'])  # Token length

        if min_length < sentence_length <= max_length:
            sentence_list.append({"text": sentence})
            sent_lengths.append(sentence_length)
            
    sentence_lengths = {
        "avg_len": np.mean(sent_lengths),
        "max_len": max(sent_lengths),
        "min_len": min(sent_lengths), 
        "median_len": np.median(sent_lengths),
        "std_len": np.std(sent_lengths)
    }
    
    print(f"Median sentence length: {sentence_lengths['median_len']}\nAvg sentence length: {sentence_lengths['avg_len']}\nSentence length std: {sentence_lengths['std_len']}\nMax sentence length: {sentence_lengths['max_len']}\nMin sentence length: {sentence_lengths['min_len']}")

    return Dataset.from_list(sentence_list), sentence_lengths

# Load or create dataset
def load_or_create_dataset(tokenizer):
    if os.path.exists(dataset_file):
        print("Loading dataset from file...")
        with open(dataset_file, "r", encoding="utf-8") as f:
            data = json.load(f)
            dataset = Dataset.from_list(data[0])
            sentence_lengths = data[1]
            print(f"Median sentence length: {sentence_lengths['median_len']}\nAvg sentence length: {sentence_lengths['avg_len']}\nSentence length std: {sentence_lengths['std_len']}\nMax sentence length: {sentence_lengths['max_len']}\nMin sentence length: {sentence_lengths['min_len']}")
            
    else:
        print("Fetching and processing directive...")
        directive_text = fetch_directive_pdf(directive_pdf_url)
        dataset, sent_length = prepare_huggingface_dataset(directive_text, tokenizer)
        with open(dataset_file, "w", encoding="utf-8") as f:
            json.dump([dataset.to_list(),sent_length], f, indent=4, ensure_ascii=False)
    
    return dataset

In [5]:
# Load or create dataset and show basic statistics of snippet length (in tokens)
dataset = load_or_create_dataset(tokenizer)

Loading dataset from file...
Median sentence length: 56.0
Avg sentence length: 66.8248807975726
Sentence length std: 47.018184659819724
Max sentence length: 487
Min sentence length: 11


In [6]:
# Checking dataset structure
dataset

Dataset({
    features: ['text'],
    num_rows: 2307
})

In [7]:
def generate_qa(example, tokenizer, model):
    """
    Generates a structured question-answer pair from input text, ensuring proper extraction.
    """
    text = example['text']

    # Few-shot prompt for structured output
    prompt = (
        "Generate a meaningful question-answer pair from the following directive text.\n"
        f"Text: {text}\n"
        "Question:"
    )

    # Ensure padding token is correctly set
    tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token  

    # Tokenize input with proper padding and truncation
    inputs = tokenizer(
        prompt, 
        return_tensors="pt", 
        truncation=True, 
        padding="max_length", 
        max_length=256
    )

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'], # which tokens to ignore in input
            max_length=512,  # truncation length
            num_return_sequences=1, # produce single answer per input
            pad_token_id=tokenizer.eos_token_id # what padding token was used
        )

    # Decode and clean output text
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    # Improved regex to extract the first valid Q&A pair
    match = re.search(r'(?:Question|Q):\s*(.*?)\s*(?:Answer|A):\s*(.*)', output_text, re.DOTALL)

    # If result contains q-a pair in requested format
    if match:
        question = match.group(1).strip()
        answer = match.group(2).strip()

        # Clean up potential artifacts
        question = re.sub(r'^(question_\d+:|Solution:|\s*<\|assistant\|>\s*)', '', question, flags=re.IGNORECASE).strip()
        answer = re.sub(r'^(answer_\d+:)', '', answer, flags=re.IGNORECASE).strip()

        return {'question': question, 'answers': {'text': [answer]}}

    # Otherwise assume annotation result is invalid
    return {'question': None, 'answers': {'text': [None]}}


In [8]:
# Load or create subset of snippets for demo/testing purposes
subset_path = 'directive_subset.json'
model = AutoModelForCausalLM.from_pretrained(model_name) #model name is defined with the tokenizer before
def load_or_create_subset(dataset, subset_size, qa_generator, subset_path=subset_path, seed=None):
    '''Wrapper to reduce repetitive annotation work.'''
    
    subset_path = subset_path.replace('.json', f'_{subset_size}.json')
    
    if os.path.exists(subset_path):
        print(f"Loading subset from {subset_path}")
        with open(subset_path, "r", encoding="utf-8") as f:
            subset = Dataset.from_list(json.load(f))            
    else:
        print(f"Generating subset of {subset_size} snippets...")
        
        # Randomly selecting subset of text snippets from the dataset (uniform prob.)
        subset = dataset.shuffle(seed=seed).select(range(subset_size))
        
        # Generate a question-answer pair for each text snippet
        subset = subset.map(lambda example: generate_qa(example, tokenizer, model))
        
        # Save to file
        with open(subset_path, "w", encoding="utf-8") as f:
            json.dump(subset.to_list(), f, indent=4, ensure_ascii=False)
    
    return subset

In [9]:
qa_subset = load_or_create_subset(dataset=dataset, subset_size=50, qa_generator=generate_qa, seed=1)

Loading subset from directive_subset_50.json


In [10]:
# Dataset state after annotation
qa_subset

Dataset({
    features: ['text', 'question', 'answers'],
    num_rows: 50
})

In [11]:
# Keeping only snippets where annotation results are valid
valid_indices = [i for i,entry in enumerate(qa_subset) if entry['question'] is not None]
qa_subset = qa_subset.select(valid_indices)

In [12]:
# Dataset state after removing snippets with invalid annotations
qa_subset

Dataset({
    features: ['text', 'question', 'answers'],
    num_rows: 21
})

In [13]:
# Dummy Train-test split without shuffling (need to be improved)
valid_len = len(valid_indices)
train_indices = round(0.8*valid_len)
test_indices = train_indices
train_set = qa_subset.select(range(train_indices))
test_set = qa_subset.select(range(train_indices, valid_len))

In [14]:
# Train set summary
train_set

Dataset({
    features: ['text', 'question', 'answers'],
    num_rows: 17
})

In [15]:
# Test set summary
test_set

Dataset({
    features: ['text', 'question', 'answers'],
    num_rows: 4
})

In [16]:
# Viewing the dataset contents to identify potential issues (
print(test_set[0])
print(train_set[0])

{'text': 'That Annex should theref ore be deleted.', 'question': 'question for the text "That Annex should theref ore be deleted."?', 'answers': {'text': ['does that annex should theref ore be deleted?\nA:\n<|assistant|>\nDoes that Annex need to be deleted?']}}
{'text': 'National regulato ry author ities should theref ore fully reflect any opinion submitted by BEREC in their measures imposing any oblig ation on an under taking or other wise resolving the dispute in such cases.', 'question': 'question regarding the directive?', 'answers': {'text': ['is it a directive?\n<|assistant|>\nIs the directive mentioned in the text?']}}


In [17]:
# Tokenizer to convert text to number and add some attributes required by the qa model to be fine-tuned
model_checkpoint = "google/flan-t5-base"  # or "t5-small", "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def preprocess_function_no_context(examples):
    """
    Tokenizes question-answer pairs for training a generative model.
    The model is trained to generate answers from the given questions.
    """
    # Tokenize the questions as model input
    model_inputs = tokenizer(
        examples["question"],  
        max_length=512,
        truncation=True,
        padding="max_length"
    )

    # Extract answers - empty strings need to be dealt with
    answers_text = [ans["text"][0] if ans["text"] else "" for ans in examples["answers"]]
    
    # Tokenize answers as labels
    labels = tokenizer(
        answers_text, 
        max_length=128, 
        truncation=True,
        padding="max_length"
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [18]:
# Tokenize the annotated datasets
tokenized_train = train_set.map(preprocess_function_no_context, batched=True)
tokenized_test = test_set.map(preprocess_function_no_context, batched=True)

Map:   0%|          | 0/17 [00:00<?, ? examples/s]

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

In [19]:
# The model to fine-tune
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [20]:
# Defining training parameters
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)



In [21]:
# Setting-up the training wrapper
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,  # If you split it earlier
)

In [22]:
# Fine-tuning
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,No log,29.783703
2,No log,28.909153
3,No log,28.595795


TrainOutput(global_step=9, training_loss=25.898856268988716, metrics={'train_runtime': 116.1908, 'train_samples_per_second': 0.439, 'train_steps_per_second': 0.077, 'total_flos': 34922624974848.0, 'train_loss': 25.898856268988716, 'epoch': 3.0})

In [23]:
# Retrieving the trained model from the trainer
model = trainer.model
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [24]:
# Checking if the results make sense
def generate_answer(question):
    '''Function called by the Gradio interface in the demo.'''
    # Format input for T5 (T5 expects a 'question:' prefix)
    input_text = f"question: {question}"
    
    # Tokenize the user input
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    
    # Generate output
    outputs = model.generate(inputs.input_ids, max_length=100)
    
    # Decode and print answer
    generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return generated_answer

# Running gradio demo interface
textbox = gr.Textbox(label="Type your question here:", placeholder="What is the directive about?", lines=10)

gr.Interface(fn=generate_answer, inputs=textbox, outputs="text").launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


