|
| 1 | +from typing import Any, Dict, Iterator, Optional |
| 2 | + |
| 3 | +from docling_core.transforms.chunker.base_code_chunker import CodeChunker |
| 4 | +from docling_core.transforms.chunker.code_chunk_utils.utils import Language |
| 5 | +from docling_core.transforms.chunker.hierarchical_chunker import ( |
| 6 | + ChunkType, |
| 7 | + CodeChunk, |
| 8 | + CodeDocMeta, |
| 9 | +) |
| 10 | +from docling_core.transforms.chunker.language_code_chunkers import ( |
| 11 | + CFunctionChunker, |
| 12 | + JavaFunctionChunker, |
| 13 | + JavaScriptFunctionChunker, |
| 14 | + PythonFunctionChunker, |
| 15 | + TypeScriptFunctionChunker, |
| 16 | +) |
| 17 | +from docling_core.types.doc.base import Size |
| 18 | +from docling_core.types.doc.document import ( |
| 19 | + CodeItem, |
| 20 | + DoclingDocument, |
| 21 | + DocumentOrigin, |
| 22 | + PageItem, |
| 23 | +) |
| 24 | +from docling_core.utils.legacy import _create_hash |
| 25 | + |
| 26 | + |
| 27 | +class LanguageDetector: |
| 28 | + """Utility class for detecting programming languages from code content and file extensions.""" |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def detect_from_extension(filename: Optional[str]) -> Optional[Language]: |
| 32 | + """Detect language from file extension.""" |
| 33 | + |
| 34 | + if not filename: |
| 35 | + return None |
| 36 | + |
| 37 | + filename_lower = filename.lower() |
| 38 | + |
| 39 | + for language in Language: |
| 40 | + for ext in language.file_extensions(): |
| 41 | + if filename_lower.endswith(ext): |
| 42 | + return language |
| 43 | + return None |
| 44 | + |
| 45 | + @staticmethod |
| 46 | + def detect_from_content(code_text: str) -> Optional[Language]: |
| 47 | + """Detect language from code content using heuristics.""" |
| 48 | + |
| 49 | + if not code_text: |
| 50 | + return None |
| 51 | + |
| 52 | + code_lower = code_text.lower().strip() |
| 53 | + |
| 54 | + if any( |
| 55 | + pattern in code_lower |
| 56 | + for pattern in [ |
| 57 | + "def ", |
| 58 | + "import ", |
| 59 | + "from ", |
| 60 | + 'if __name__ == "__main__"', |
| 61 | + "print(", |
| 62 | + "lambda ", |
| 63 | + "yield ", |
| 64 | + "async def", |
| 65 | + ] |
| 66 | + ) and not any( |
| 67 | + pattern in code_lower |
| 68 | + for pattern in ["public class", "private ", "protected ", "package "] |
| 69 | + ): |
| 70 | + return Language.PYTHON |
| 71 | + |
| 72 | + if any( |
| 73 | + pattern in code_lower |
| 74 | + for pattern in [ |
| 75 | + "package main", |
| 76 | + "func main()", |
| 77 | + 'import "fmt"', |
| 78 | + 'import "os"', |
| 79 | + "chan ", |
| 80 | + "interface{}", |
| 81 | + "go func", |
| 82 | + "defer ", |
| 83 | + ":= ", |
| 84 | + ] |
| 85 | + ) and not any( |
| 86 | + pattern in code_lower |
| 87 | + for pattern in [ |
| 88 | + "public class", |
| 89 | + "import java.", |
| 90 | + "System.out.println", |
| 91 | + "extends ", |
| 92 | + "implements ", |
| 93 | + ] |
| 94 | + ): |
| 95 | + return None |
| 96 | + |
| 97 | + if any( |
| 98 | + pattern in code_lower |
| 99 | + for pattern in [ |
| 100 | + "public class", |
| 101 | + "package ", |
| 102 | + "import java.", |
| 103 | + "public static void main", |
| 104 | + "extends ", |
| 105 | + "implements ", |
| 106 | + "String[]", |
| 107 | + "System.out.println", |
| 108 | + ] |
| 109 | + ) and not any( |
| 110 | + pattern in code_lower |
| 111 | + for pattern in ["package main", "func main()", "chan ", "interface{}"] |
| 112 | + ): |
| 113 | + return Language.JAVA |
| 114 | + |
| 115 | + if any( |
| 116 | + pattern in code_lower |
| 117 | + for pattern in [ |
| 118 | + ": string", |
| 119 | + ": number", |
| 120 | + ": boolean", |
| 121 | + "interface ", |
| 122 | + "type ", |
| 123 | + "enum ", |
| 124 | + "public ", |
| 125 | + "private ", |
| 126 | + "protected ", |
| 127 | + ] |
| 128 | + ): |
| 129 | + return Language.TYPESCRIPT |
| 130 | + |
| 131 | + if any( |
| 132 | + pattern in code_lower |
| 133 | + for pattern in [ |
| 134 | + "function ", |
| 135 | + "const ", |
| 136 | + "let ", |
| 137 | + "var ", |
| 138 | + "=>", |
| 139 | + "require(", |
| 140 | + "module.exports", |
| 141 | + "export ", |
| 142 | + "import ", |
| 143 | + "console.log", |
| 144 | + ] |
| 145 | + ): |
| 146 | + return Language.JAVASCRIPT |
| 147 | + |
| 148 | + if any( |
| 149 | + pattern in code_lower |
| 150 | + for pattern in [ |
| 151 | + "#include", |
| 152 | + "int main(", |
| 153 | + "void ", |
| 154 | + "char ", |
| 155 | + "float ", |
| 156 | + "double ", |
| 157 | + "struct ", |
| 158 | + "#define", |
| 159 | + "printf(", |
| 160 | + "scanf(", |
| 161 | + ] |
| 162 | + ): |
| 163 | + return Language.C |
| 164 | + |
| 165 | + return None |
| 166 | + |
| 167 | + @staticmethod |
| 168 | + def detect_language( |
| 169 | + code_text: str, filename: Optional[str] = None |
| 170 | + ) -> Optional[Language]: |
| 171 | + """Detect language from both filename and content.""" |
| 172 | + |
| 173 | + if filename: |
| 174 | + lang = LanguageDetector.detect_from_extension(filename) |
| 175 | + if lang: |
| 176 | + return lang |
| 177 | + return None |
| 178 | + |
| 179 | + return LanguageDetector.detect_from_content(code_text) |
| 180 | + |
| 181 | + |
| 182 | +class CodeChunkingStrategyFactory: |
| 183 | + """Factory for creating language-specific code chunking strategies.""" |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def create_chunker(language: Language, **kwargs: Any) -> CodeChunker: |
| 187 | + """Create a language-specific code chunker.""" |
| 188 | + |
| 189 | + chunker_map = { |
| 190 | + Language.PYTHON: PythonFunctionChunker, |
| 191 | + Language.TYPESCRIPT: TypeScriptFunctionChunker, |
| 192 | + Language.JAVASCRIPT: JavaScriptFunctionChunker, |
| 193 | + Language.C: CFunctionChunker, |
| 194 | + Language.JAVA: JavaFunctionChunker, |
| 195 | + } |
| 196 | + |
| 197 | + chunker_class = chunker_map.get(language) |
| 198 | + if not chunker_class: |
| 199 | + raise ValueError(f"No chunker available for language: {language}") |
| 200 | + |
| 201 | + return chunker_class(**kwargs) |
| 202 | + |
| 203 | + |
| 204 | +class DefaultCodeChunkingStrategy: |
| 205 | + """Default implementation of CodeChunkingStrategy that uses language detection and appropriate chunkers.""" |
| 206 | + |
| 207 | + def __init__(self, **chunker_kwargs: Any): |
| 208 | + """Initialize the strategy with optional chunker parameters.""" |
| 209 | + |
| 210 | + self.chunker_kwargs = chunker_kwargs |
| 211 | + self._chunker_cache: Dict[Language, CodeChunker] = {} |
| 212 | + |
| 213 | + def _get_chunker(self, language: Language) -> CodeChunker: |
| 214 | + """Get or create a chunker for the given language.""" |
| 215 | + |
| 216 | + if language not in self._chunker_cache: |
| 217 | + self._chunker_cache[language] = CodeChunkingStrategyFactory.create_chunker( |
| 218 | + language, **self.chunker_kwargs |
| 219 | + ) |
| 220 | + return self._chunker_cache[language] |
| 221 | + |
| 222 | + def chunk_code_item( |
| 223 | + self, |
| 224 | + code_text: str, |
| 225 | + language: Language, |
| 226 | + original_doc=None, |
| 227 | + original_item=None, |
| 228 | + **kwargs: Any, |
| 229 | + ) -> Iterator[CodeChunk]: |
| 230 | + """Chunk a single code item using the appropriate language chunker.""" |
| 231 | + |
| 232 | + if not code_text.strip(): |
| 233 | + return |
| 234 | + |
| 235 | + chunker = self._get_chunker(language) |
| 236 | + |
| 237 | + if original_doc and original_doc.origin: |
| 238 | + filename = original_doc.origin.filename or "code_chunk" |
| 239 | + mimetype = original_doc.origin.mimetype or "text/plain" |
| 240 | + binary_hash = _create_hash(code_text) |
| 241 | + else: |
| 242 | + filename = "code_chunk" |
| 243 | + mimetype = "text/plain" |
| 244 | + binary_hash = _create_hash(code_text) |
| 245 | + |
| 246 | + if original_item and hasattr(original_item, "self_ref"): |
| 247 | + self_ref = original_item.self_ref |
| 248 | + else: |
| 249 | + self_ref = "#/texts/0" |
| 250 | + |
| 251 | + code_item = CodeItem(text=code_text, self_ref=self_ref, orig=code_text) |
| 252 | + |
| 253 | + doc = DoclingDocument( |
| 254 | + name=filename, |
| 255 | + texts=[code_item], |
| 256 | + pages={0: PageItem(page_no=0, size=Size(width=612.0, height=792.0))}, |
| 257 | + origin=DocumentOrigin( |
| 258 | + filename=filename, mimetype=mimetype, binary_hash=binary_hash |
| 259 | + ), |
| 260 | + ) |
| 261 | + |
| 262 | + yield from chunker.chunk(doc, **kwargs) |
| 263 | + |
| 264 | + |
| 265 | +class NoOpCodeChunkingStrategy: |
| 266 | + """No-operation code chunking strategy that returns the original code as a single chunk.""" |
| 267 | + |
| 268 | + def chunk_code_item( |
| 269 | + self, |
| 270 | + code_text: str, |
| 271 | + language: Language, |
| 272 | + original_doc=None, |
| 273 | + original_item=None, |
| 274 | + **kwargs: Any, |
| 275 | + ) -> Iterator[CodeChunk]: |
| 276 | + """Return the code as a single chunk without further processing.""" |
| 277 | + |
| 278 | + if not code_text.strip(): |
| 279 | + return |
| 280 | + |
| 281 | + meta = CodeDocMeta( |
| 282 | + chunk_type=ChunkType.CODE_BLOCK, |
| 283 | + start_line=1, |
| 284 | + end_line=len(code_text.splitlines()), |
| 285 | + ) |
| 286 | + |
| 287 | + yield CodeChunk(text=code_text, meta=meta) |
0 commit comments