@@ -17,8 +17,6 @@ use crate::{extract_spans, Diagnostic};
1717use once_cell:: sync:: Lazy ;
1818use proc_macro2:: { Span , TokenStream } ;
1919use quote:: quote;
20- use rustpython_codegen as codegen;
21- use rustpython_compiler:: compile;
2220use rustpython_compiler_core:: { CodeObject , FrozenModule , Mode } ;
2321use std:: {
2422 collections:: HashMap ,
@@ -51,15 +49,25 @@ struct CompilationSource {
5149 span : ( Span , Span ) ,
5250}
5351
52+ pub trait Compiler {
53+ fn compile (
54+ & self ,
55+ source : & str ,
56+ mode : Mode ,
57+ module_name : String ,
58+ ) -> Result < CodeObject , Box < dyn std:: error:: Error > > ;
59+ }
60+
5461impl CompilationSource {
5562 fn compile_string < D : std:: fmt:: Display , F : FnOnce ( ) -> D > (
5663 & self ,
5764 source : & str ,
5865 mode : Mode ,
5966 module_name : String ,
67+ compiler : & dyn Compiler ,
6068 origin : F ,
6169 ) -> Result < CodeObject , Diagnostic > {
62- compile ( source, mode, module_name, codegen :: CompileOpts :: default ( ) ) . map_err ( |err| {
70+ compiler . compile ( source, mode, module_name) . map_err ( |err| {
6371 Diagnostic :: spans_error (
6472 self . span ,
6573 format ! ( "Python compile error from {}: {}" , origin( ) , err) ,
@@ -71,21 +79,30 @@ impl CompilationSource {
7179 & self ,
7280 mode : Mode ,
7381 module_name : String ,
82+ compiler : & dyn Compiler ,
7483 ) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
7584 match & self . kind {
76- CompilationSourceKind :: Dir ( rel_path) => {
77- self . compile_dir ( & CARGO_MANIFEST_DIR . join ( rel_path) , String :: new ( ) , mode)
78- }
85+ CompilationSourceKind :: Dir ( rel_path) => self . compile_dir (
86+ & CARGO_MANIFEST_DIR . join ( rel_path) ,
87+ String :: new ( ) ,
88+ mode,
89+ compiler,
90+ ) ,
7991 _ => Ok ( hashmap ! {
8092 module_name. clone( ) => FrozenModule {
81- code: self . compile_single( mode, module_name) ?,
93+ code: self . compile_single( mode, module_name, compiler ) ?,
8294 package: false ,
8395 } ,
8496 } ) ,
8597 }
8698 }
8799
88- fn compile_single ( & self , mode : Mode , module_name : String ) -> Result < CodeObject , Diagnostic > {
100+ fn compile_single (
101+ & self ,
102+ mode : Mode ,
103+ module_name : String ,
104+ compiler : & dyn Compiler ,
105+ ) -> Result < CodeObject , Diagnostic > {
89106 match & self . kind {
90107 CompilationSourceKind :: File ( rel_path) => {
91108 let path = CARGO_MANIFEST_DIR . join ( rel_path) ;
@@ -95,10 +112,10 @@ impl CompilationSource {
95112 format ! ( "Error reading file {:?}: {}" , path, err) ,
96113 )
97114 } ) ?;
98- self . compile_string ( & source, mode, module_name, || rel_path. display ( ) )
115+ self . compile_string ( & source, mode, module_name, compiler , || rel_path. display ( ) )
99116 }
100117 CompilationSourceKind :: SourceCode ( code) => {
101- self . compile_string ( & textwrap:: dedent ( code) , mode, module_name, || {
118+ self . compile_string ( & textwrap:: dedent ( code) , mode, module_name, compiler , || {
102119 "string literal"
103120 } )
104121 }
@@ -113,6 +130,7 @@ impl CompilationSource {
113130 path : & Path ,
114131 parent : String ,
115132 mode : Mode ,
133+ compiler : & dyn Compiler ,
116134 ) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
117135 let mut code_map = HashMap :: new ( ) ;
118136 let paths = fs:: read_dir ( path)
@@ -144,6 +162,7 @@ impl CompilationSource {
144162 format ! ( "{}.{}" , parent, file_name)
145163 } ,
146164 mode,
165+ compiler,
147166 ) ?) ;
148167 } else if file_name. ends_with ( ".py" ) {
149168 let stem = path. file_stem ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
@@ -163,7 +182,7 @@ impl CompilationSource {
163182 format ! ( "Error reading file {:?}: {}" , path, err) ,
164183 )
165184 } ) ?;
166- self . compile_string ( & source, mode, module_name. clone ( ) , || {
185+ self . compile_string ( & source, mode, module_name. clone ( ) , compiler , || {
167186 path. strip_prefix ( & * CARGO_MANIFEST_DIR )
168187 . ok ( )
169188 . unwrap_or ( & path)
@@ -239,35 +258,28 @@ impl PyCompileInput {
239258 Some ( ident) => ident,
240259 None => continue ,
241260 } ;
261+ let check_str = || match & name_value. lit {
262+ Lit :: Str ( s) => Ok ( s) ,
263+ _ => Err ( err_span ! ( name_value. lit, "{ident} must be a string" ) ) ,
264+ } ;
242265 if ident == "mode" {
243- match & name_value. lit {
244- Lit :: Str ( s) => match s. value ( ) . parse ( ) {
245- Ok ( mode_val) => mode = Some ( mode_val) ,
246- Err ( e) => bail_span ! ( s, "{}" , e) ,
247- } ,
248- _ => bail_span ! ( name_value. lit, "mode must be a string" ) ,
266+ let s = check_str ( ) ?;
267+ match s. value ( ) . parse ( ) {
268+ Ok ( mode_val) => mode = Some ( mode_val) ,
269+ Err ( e) => bail_span ! ( s, "{}" , e) ,
249270 }
250271 } else if ident == "module_name" {
251- module_name = Some ( match & name_value. lit {
252- Lit :: Str ( s) => s. value ( ) ,
253- _ => bail_span ! ( name_value. lit, "module_name must be string" ) ,
254- } )
272+ module_name = Some ( check_str ( ) ?. value ( ) )
255273 } else if ident == "source" {
256274 assert_source_empty ( & source) ?;
257- let code = match & name_value. lit {
258- Lit :: Str ( s) => s. value ( ) ,
259- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
260- } ;
275+ let code = check_str ( ) ?. value ( ) ;
261276 source = Some ( CompilationSource {
262277 kind : CompilationSourceKind :: SourceCode ( code) ,
263278 span : extract_spans ( & name_value) . unwrap ( ) ,
264279 } ) ;
265280 } else if ident == "file" {
266281 assert_source_empty ( & source) ?;
267- let path = match & name_value. lit {
268- Lit :: Str ( s) => PathBuf :: from ( s. value ( ) ) ,
269- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
270- } ;
282+ let path = check_str ( ) ?. value ( ) . into ( ) ;
271283 source = Some ( CompilationSource {
272284 kind : CompilationSourceKind :: File ( path) ,
273285 span : extract_spans ( & name_value) . unwrap ( ) ,
@@ -278,19 +290,13 @@ impl PyCompileInput {
278290 }
279291
280292 assert_source_empty ( & source) ?;
281- let path = match & name_value. lit {
282- Lit :: Str ( s) => PathBuf :: from ( s. value ( ) ) ,
283- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
284- } ;
293+ let path = check_str ( ) ?. value ( ) . into ( ) ;
285294 source = Some ( CompilationSource {
286295 kind : CompilationSourceKind :: Dir ( path) ,
287296 span : extract_spans ( & name_value) . unwrap ( ) ,
288297 } ) ;
289298 } else if ident == "crate_name" {
290- let name = match & name_value. lit {
291- Lit :: Str ( s) => s. parse ( ) ?,
292- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
293- } ;
299+ let name = check_str ( ) ?. parse ( ) ?;
294300 crate_name = Some ( name) ;
295301 }
296302 }
@@ -351,12 +357,17 @@ struct PyCompileArgs {
351357 crate_name : syn:: Path ,
352358}
353359
354- pub fn impl_py_compile ( input : TokenStream ) -> Result < TokenStream , Diagnostic > {
360+ pub fn impl_py_compile (
361+ input : TokenStream ,
362+ compiler : & dyn Compiler ,
363+ ) -> Result < TokenStream , Diagnostic > {
355364 let input: PyCompileInput = parse2 ( input) ?;
356365 let args = input. parse ( false ) ?;
357366
358367 let crate_name = args. crate_name ;
359- let code = args. source . compile_single ( args. mode , args. module_name ) ?;
368+ let code = args
369+ . source
370+ . compile_single ( args. mode , args. module_name , compiler) ?;
360371
361372 let bytes = code. to_bytes ( ) ;
362373 let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
@@ -369,12 +380,15 @@ pub fn impl_py_compile(input: TokenStream) -> Result<TokenStream, Diagnostic> {
369380 Ok ( output)
370381}
371382
372- pub fn impl_py_freeze ( input : TokenStream ) -> Result < TokenStream , Diagnostic > {
383+ pub fn impl_py_freeze (
384+ input : TokenStream ,
385+ compiler : & dyn Compiler ,
386+ ) -> Result < TokenStream , Diagnostic > {
373387 let input: PyCompileInput = parse2 ( input) ?;
374388 let args = input. parse ( true ) ?;
375389
376390 let crate_name = args. crate_name ;
377- let code_map = args. source . compile ( args. mode , args. module_name ) ?;
391+ let code_map = args. source . compile ( args. mode , args. module_name , compiler ) ?;
378392
379393 let data =
380394 rustpython_compiler_core:: frozen_lib:: encode_lib ( code_map. iter ( ) . map ( |( k, v) | ( & * * k, v) ) ) ;
0 commit comments