<a href="https://colab.research.google.com/github/jessiechd/RAG_Model/blob/main/0211_semantic_chunking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
!pip install -U sentence-transformers --q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m87.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m69.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Semantic chunker (SentenceTransformer)

In [36]:
import nltk
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt


class TextChunker:
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v1'):
        """Initialize the TextChunker with a specified sentence transformer model."""
        self.model = SentenceTransformer(model_name)

    def process_file(self, file_path, context_window=1, percentile_threshold=95, min_chunk_size=3):
        """
        Process a text file and split it into semantically meaningful chunks.

        Args:
            file_path: Path to the text file
            context_window: Number of sentences to consider on either side for context
            percentile_threshold: Percentile threshold for identifying breakpoints
            min_chunk_size: Minimum number of sentences in a chunk

        Returns:
            list: Semantically coherent text chunks
        """
        # Process the text file
        sentences = self._load_text(file_path)
        contextualized = self._add_context(sentences, context_window)
        embeddings = self.model.encode(contextualized)

        # Create and refine chunks
        distances = self._calculate_distances(embeddings)
        breakpoints = self._identify_breakpoints(distances, percentile_threshold)
        initial_chunks = self._create_chunks(sentences, breakpoints)

        # Merge small chunks for better coherence
        chunk_embeddings = self.model.encode(initial_chunks)
        final_chunks = self._merge_small_chunks(initial_chunks, chunk_embeddings, min_chunk_size)

        return final_chunks

    def _load_text(self, file_path):
        """Load and tokenize text from a file."""
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()
        return sent_tokenize(text)

    def _add_context(self, sentences, window_size):
        """Combine sentences with their neighbors for better context."""
        contextualized = []
        for i in range(len(sentences)):
            start = max(0, i - window_size)
            end = min(len(sentences), i + window_size + 1)
            context = ' '.join(sentences[start:end])
            contextualized.append(context)
        return contextualized

    def _calculate_distances(self, embeddings):
        """Calculate cosine distances between consecutive embeddings."""
        distances = []
        for i in range(len(embeddings) - 1):
            similarity = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
            distance = 1 - similarity
            distances.append(distance)
        return distances

    def _identify_breakpoints(self, distances, threshold_percentile):
        """Find natural breaking points in the text based on semantic distances."""
        threshold = np.percentile(distances, threshold_percentile)
        return [i for i, dist in enumerate(distances) if dist > threshold]

    def _create_chunks(self, sentences, breakpoints):
        """Create initial text chunks based on identified breakpoints."""
        chunks = []
        start_idx = 0

        for breakpoint in breakpoints:
            chunk = ' '.join(sentences[start_idx:breakpoint + 1])
            chunks.append(chunk)
            start_idx = breakpoint + 1

        # Add the final chunk
        final_chunk = ' '.join(sentences[start_idx:])
        chunks.append(final_chunk)

        return chunks

    def _merge_small_chunks(self, chunks, embeddings, min_size):
        """Merge small chunks with their most similar neighbor."""
        final_chunks = [chunks[0]]
        merged_embeddings = [embeddings[0]]

        for i in range(1, len(chunks) - 1):
            current_chunk_size = len(chunks[i].split('. '))

            if current_chunk_size < min_size:
                # Calculate similarities
                prev_similarity = cosine_similarity([embeddings[i]], [merged_embeddings[-1]])[0][0]
                next_similarity = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]

                if prev_similarity > next_similarity:
                    # Merge with previous chunk
                    final_chunks[-1] = f"{final_chunks[-1]} {chunks[i]}"
                    merged_embeddings[-1] = (merged_embeddings[-1] + embeddings[i]) / 2
                else:
                    # Merge with next chunk
                    chunks[i + 1] = f"{chunks[i]} {chunks[i + 1]}"
                    embeddings[i + 1] = (embeddings[i] + embeddings[i + 1]) / 2
            else:
                final_chunks.append(chunks[i])
                merged_embeddings.append(embeddings[i])

        final_chunks.append(chunks[-1])
        return final_chunks

    def evaluate_coherence(self, chunks):
        coherence_scores = []
        embeddings = self.model.encode(chunks)
        for i in range(len(embeddings) - 1):
            score = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
            coherence_scores.append(score)
        return np.mean(coherence_scores)


In [18]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [37]:
"""Example usage of the TextChunker class."""
# Initialize the chunker
chunker = TextChunker()

# Process a text file
file_path = "/content/17_qwen1.md"


In [38]:
chunks = chunker.process_file(
  file_path,
  context_window=1,
  percentile_threshold=95,
  min_chunk_size=3
)

# Print results
print(f"Successfully split text into {len(chunks)} chunks")

for i in range(len(chunks)):
  print(f"Chunk {i+1}: {len(chunks[i].split('. '))} sentences")

coherence_score = chunker.evaluate_coherence(chunks)
print(f"Coherence Score: {coherence_score:.4f}")

Successfully split text into 10 chunks
Chunk 1: 1 sentences
Chunk 2: 22 sentences
Chunk 3: 134 sentences
Chunk 4: 14 sentences
Chunk 5: 3 sentences
Chunk 6: 5 sentences
Chunk 7: 26 sentences
Chunk 8: 11 sentences
Chunk 9: 3 sentences
Chunk 10: 10 sentences
Coherence Score: 0.4878


In [39]:
chunks = chunker.process_file(
  file_path,
  context_window=1,
  percentile_threshold=85,
  min_chunk_size=3
)

# Print results
print(f"Successfully split text into {len(chunks)} chunks")

for i in range(len(chunks)):
  print(f"Chunk {i+1}: {len(chunks[i].split('. '))} sentences")

coherence_score = chunker.evaluate_coherence(chunks)
print(f"Coherence Score: {coherence_score:.4f}")

Successfully split text into 22 chunks
Chunk 1: 1 sentences
Chunk 2: 11 sentences
Chunk 3: 5 sentences
Chunk 4: 6 sentences
Chunk 5: 111 sentences
Chunk 6: 13 sentences
Chunk 7: 8 sentences
Chunk 8: 8 sentences
Chunk 9: 4 sentences
Chunk 10: 4 sentences
Chunk 11: 3 sentences
Chunk 12: 7 sentences
Chunk 13: 8 sentences
Chunk 14: 5 sentences
Chunk 15: 5 sentences
Chunk 16: 6 sentences
Chunk 17: 3 sentences
Chunk 18: 5 sentences
Chunk 19: 3 sentences
Chunk 20: 3 sentences
Chunk 21: 8 sentences
Chunk 22: 2 sentences
Coherence Score: 0.4071


In [40]:
chunks = chunker.process_file(
  file_path,
  context_window=2,
  percentile_threshold=95,
  min_chunk_size=3
)


# Print results
print(f"Successfully split text into {len(chunks)} chunks")

for i in range(len(chunks)):
  print(f"Chunk {i+1}: {len(chunks[i].split('. '))} sentences")

coherence_score = chunker.evaluate_coherence(chunks)
print(f"Coherence Score: {coherence_score:.4f}")

Successfully split text into 9 chunks
Chunk 1: 14 sentences
Chunk 2: 8 sentences
Chunk 3: 38 sentences
Chunk 4: 96 sentences
Chunk 5: 12 sentences
Chunk 6: 41 sentences
Chunk 7: 5 sentences
Chunk 8: 4 sentences
Chunk 9: 11 sentences
Coherence Score: 0.4426


# Semantic Chunking with added ROUGE, QA score
- added ROUGE and QA-based retrieval success score
- added dynamic vs. fixed context windows (adaptive based on similarity)
- added K-means clustering for breakpoints instead of thresholds
- improved merging based on topic similarity
- added chunks overlaps

In [29]:
!pip install rouge --q

In [43]:
import nltk
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import rouge
from collections import defaultdict


class TextChunker2:
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v1'):
        self.model = SentenceTransformer(model_name)

    def process_file(self, file_path, dynamic_window=True, min_chunk_size=3, overlap=1, num_clusters=5):
        sentences = self._load_text(file_path)
        contextualized = self._add_dynamic_context(sentences) if dynamic_window else self._add_fixed_context(sentences)
        embeddings = self.model.encode(contextualized)

        # Identify breakpoints using KMeans clustering
        breakpoints = self._identify_breakpoints_clustering(embeddings, num_clusters)
        initial_chunks = self._create_chunks(sentences, breakpoints, overlap)

        # Merge small chunks using topic modeling and similarity
        chunk_embeddings = self.model.encode(initial_chunks)
        final_chunks = self._merge_small_chunks(initial_chunks, chunk_embeddings, min_chunk_size)

        return final_chunks

    def _load_text(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()
        return sent_tokenize(text)

    def _add_fixed_context(self, sentences, window_size=1):
        return [' '.join(sentences[max(0, i-window_size): min(len(sentences), i+window_size+1)]) for i in range(len(sentences))]

    def _add_dynamic_context(self, sentences):
        contextualized = []
        embeddings = self.model.encode(sentences)
        for i in range(len(sentences)):
            similarities = cosine_similarity([embeddings[i]], embeddings)[0]
            closest_indices = np.argsort(-similarities)[:3]  # Select 2 most relevant neighbors
            context = ' '.join(sentences[j] for j in sorted(closest_indices))
            contextualized.append(context)
        return contextualized

    def _identify_breakpoints_clustering(self, embeddings, num_clusters):
        kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10).fit(embeddings)
        labels = kmeans.labels_
        return [i for i in range(1, len(labels)) if labels[i] != labels[i-1]]

    def _create_chunks(self, sentences, breakpoints, overlap):
        chunks, start_idx = [], 0
        for breakpoint in breakpoints:
            end_idx = breakpoint + 1
            chunk = ' '.join(sentences[max(0, start_idx - overlap): end_idx])
            chunks.append(chunk)
            start_idx = end_idx
        chunks.append(' '.join(sentences[max(0, start_idx - overlap):]))
        return chunks

    def _merge_small_chunks(self, chunks, embeddings, min_size):
        final_chunks, merged_embeddings = [chunks[0]], [embeddings[0]]
        for i in range(1, len(chunks) - 1):
            if len(chunks[i].split('. ')) < min_size:
                prev_sim = cosine_similarity([embeddings[i]], [merged_embeddings[-1]])[0][0]
                next_sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
                if prev_sim > next_sim:
                    final_chunks[-1] += ' ' + chunks[i]
                    merged_embeddings[-1] = (merged_embeddings[-1] + embeddings[i]) / 2
                else:
                    chunks[i + 1] = chunks[i] + ' ' + chunks[i + 1]
                    embeddings[i + 1] = (embeddings[i] + embeddings[i + 1]) / 2
            else:
                final_chunks.append(chunks[i])
                merged_embeddings.append(embeddings[i])
        final_chunks.append(chunks[-1])
        return final_chunks

    def evaluate_coherence(self, chunks):
        coherence_scores = []
        embeddings = self.model.encode(chunks)
        for i in range(len(embeddings) - 1):
            score = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
            coherence_scores.append(score)
        return np.mean(coherence_scores)

    def evaluate_rouge(self, original_text, chunks):
        rouge_evaluator = rouge.Rouge()
        scores = [rouge_evaluator.get_scores(chunk, original_text)[0]['rouge-1']['f'] for chunk in chunks]
        return np.mean(scores)

    def evaluate_qa_performance(self, retrieval_system, test_questions):
        correct, total = 0, len(test_questions)
        for question, expected_answer in test_questions:
            retrieved_chunk = retrieval_system.retrieve(question)
            if expected_answer in retrieved_chunk:
                correct += 1
        return correct / total


In [42]:
with open(file_path, "r", encoding="utf-8") as f:
    text = f.read()

In [44]:
class SimpleRetrieval:
    def __init__(self, chunks):
        self.chunks = chunks
        self.embeddings = chunker.model.encode(chunks)

    def retrieve(self, query):
        query_embedding = chunker.model.encode([query])
        similarities = cosine_similarity(query_embedding, self.embeddings)[0]
        best_chunk = self.chunks[np.argmax(similarities)]
        return best_chunk

retrieval_system = SimpleRetrieval(chunks)

In [48]:
test_questions = [
    ("What is this text about?", "Family medicine training in Africa"),
    ("Who is involved in the training?", "Department of Family and Emergency Medicine"),
]

In [47]:
chunker2 = TextChunker2()
chunks2 = chunker2.process_file(
    file_path,
    dynamic_window=True,
    min_chunk_size=3,
    overlap=1,
    num_clusters=3)


# Print results
print(f"Successfully split text into {len(chunks2)} chunks")

for i in range(len(chunks2)):
  print(f"Chunk {i+1}: {len(chunks2[i].split('. '))} sentences")


Successfully split text into 61 chunks
Chunk 1: 4 sentences
Chunk 2: 3 sentences
Chunk 3: 6 sentences
Chunk 4: 5 sentences
Chunk 5: 4 sentences
Chunk 6: 13 sentences
Chunk 7: 3 sentences
Chunk 8: 6 sentences
Chunk 9: 5 sentences
Chunk 10: 6 sentences
Chunk 11: 8 sentences
Chunk 12: 4 sentences
Chunk 13: 3 sentences
Chunk 14: 3 sentences
Chunk 15: 3 sentences
Chunk 16: 3 sentences
Chunk 17: 5 sentences
Chunk 18: 5 sentences
Chunk 19: 5 sentences
Chunk 20: 4 sentences
Chunk 21: 6 sentences
Chunk 22: 4 sentences
Chunk 23: 3 sentences
Chunk 24: 4 sentences
Chunk 25: 4 sentences
Chunk 26: 4 sentences
Chunk 27: 3 sentences
Chunk 28: 8 sentences
Chunk 29: 3 sentences
Chunk 30: 11 sentences
Chunk 31: 43 sentences
Chunk 32: 6 sentences
Chunk 33: 5 sentences
Chunk 34: 7 sentences
Chunk 35: 3 sentences
Chunk 36: 4 sentences
Chunk 37: 5 sentences
Chunk 38: 4 sentences
Chunk 39: 6 sentences
Chunk 40: 4 sentences
Chunk 41: 6 sentences
Chunk 42: 3 sentences
Chunk 43: 4 sentences
Chunk 44: 4 sentences

ValueError: too many values to unpack (expected 2)

In [49]:
coherence_score = chunker2.evaluate_coherence(chunks2)
rouge_score = chunker2.evaluate_rouge(text, chunks2)
qa_accuracy = chunker2.evaluate_qa_performance(retrieval_system, test_questions)

print(f"Coherence Score: {coherence_score:.4f}")
print(f"ROUGE Score: {rouge_score:.4f}")
print(f"QA Accuracy: {qa_accuracy * 100:.2f}%")

Coherence Score: 0.6663
ROUGE Score: 0.0771
QA Accuracy: 0.00%


In [50]:
chunker2 = TextChunker2()
chunks2 = chunker2.process_file(
    file_path,
    dynamic_window=True,
    min_chunk_size=20,
    overlap=5,
    num_clusters=3)

# Print results
print(f"Successfully split text into {len(chunks2)} chunks")

for i in range(len(chunks2)):
  print(f"Chunk {i+1}: {len(chunks2[i].split('. '))} sentences")

Successfully split text into 21 chunks
Chunk 1: 9 sentences
Chunk 2: 36 sentences
Chunk 3: 20 sentences
Chunk 4: 62 sentences
Chunk 5: 21 sentences
Chunk 6: 43 sentences
Chunk 7: 27 sentences
Chunk 8: 62 sentences
Chunk 9: 34 sentences
Chunk 10: 51 sentences
Chunk 11: 23 sentences
Chunk 12: 22 sentences
Chunk 13: 26 sentences
Chunk 14: 35 sentences
Chunk 15: 26 sentences
Chunk 16: 39 sentences
Chunk 17: 23 sentences
Chunk 18: 35 sentences
Chunk 19: 36 sentences
Chunk 20: 34 sentences
Chunk 21: 8 sentences


In [52]:
coherence_score = chunker2.evaluate_coherence(chunks2)
rouge_score = chunker2.evaluate_rouge(text, chunks2)
qa_accuracy = chunker2.evaluate_qa_performance(retrieval_system, test_questions)

print(f"Coherence Score: {coherence_score:.4f}")
print(f"ROUGE Score: {rouge_score:.4f}")
print(f"QA Accuracy: {qa_accuracy * 100:.2f}%")

Coherence Score: 0.7323
ROUGE Score: 0.1941
QA Accuracy: 0.00%


In [53]:
chunker2 = TextChunker2()
chunks2 = chunker2.process_file(
    file_path,
    dynamic_window=True,
    min_chunk_size=40,
    overlap=5,
    num_clusters=3)

# Print results
print(f"Successfully split text into {len(chunks2)} chunks")

for i in range(len(chunks2)):
  print(f"Chunk {i+1}: {len(chunks2[i].split('. '))} sentences")

Successfully split text into 11 chunks
Chunk 1: 9 sentences
Chunk 2: 56 sentences
Chunk 3: 62 sentences
Chunk 4: 64 sentences
Chunk 5: 43 sentences
Chunk 6: 80 sentences
Chunk 7: 51 sentences
Chunk 8: 45 sentences
Chunk 9: 61 sentences
Chunk 10: 159 sentences
Chunk 11: 41 sentences


In [54]:
coherence_score = chunker2.evaluate_coherence(chunks2)
rouge_score = chunker2.evaluate_rouge(text, chunks2)
qa_accuracy = chunker2.evaluate_qa_performance(retrieval_system, test_questions)

print(f"Coherence Score: {coherence_score:.4f}")
print(f"ROUGE Score: {rouge_score:.4f}")
print(f"QA Accuracy: {qa_accuracy * 100:.2f}%")

Coherence Score: 0.6473
ROUGE Score: 0.2869
QA Accuracy: 0.00%
