11use  std:: collections:: HashMap ; 
2- use  std:: hash:: { Hash ,  Hasher } ; 
32use  std:: sync:: Arc ; 
43
54use  crate :: auth:: { AuthManager ,  Permission ,  ResourceType } ; 
@@ -19,20 +18,12 @@ use pgwire::api::stmt::QueryParser;
1918use  pgwire:: api:: stmt:: StoredStatement ; 
2019use  pgwire:: api:: { ClientInfo ,  PgWireServerHandlers ,  Type } ; 
2120use  pgwire:: error:: { PgWireError ,  PgWireResult } ; 
22- use  std:: sync:: atomic:: { AtomicU64 ,  Ordering } ; 
23- use  std:: time:: { Duration ,  Instant } ; 
24- use  tokio:: sync:: { Mutex ,  RwLock } ; 
21+ use  pgwire:: messages:: response:: TransactionStatus ; 
22+ use  tokio:: sync:: Mutex ; 
2523
2624use  arrow_pg:: datatypes:: df; 
2725use  arrow_pg:: datatypes:: { arrow_schema_to_pg_fields,  into_pg_type} ; 
2826
29- #[ derive( Debug ,  Clone ,  Copy ,  PartialEq ) ]  
30- pub  enum  TransactionState  { 
31-     None , 
32-     Active , 
33-     Failed , 
34- } 
35- 
3627/// Simple startup handler that does no authentication 
3728/// For production, use DfAuthSource with proper pgwire authentication handlers 
3829pub  struct  SimpleStartupHandler ; 
@@ -66,26 +57,12 @@ impl PgWireServerHandlers for HandlerFactory {
6657    } 
6758} 
6859
69- /// Per-connection transaction state storage 
70- /// We use a hash of both PID and secret key as the connection identifier for better uniqueness 
71- pub  type  ConnectionId  = u64 ; 
72- 
73- #[ derive( Debug ,  Clone ) ]  
74- struct  ConnectionState  { 
75-     transaction_state :  TransactionState , 
76-     last_activity :  Instant , 
77- } 
78- 
79- type  ConnectionStates  = Arc < RwLock < HashMap < ConnectionId ,  ConnectionState > > > ; 
80- 
8160/// The pgwire handler backed by a datafusion `SessionContext` 
8261pub  struct  DfSessionService  { 
8362    session_context :  Arc < SessionContext > , 
8463    parser :  Arc < Parser > , 
8564    timezone :  Arc < Mutex < String > > , 
86-     connection_states :  ConnectionStates , 
8765    auth_manager :  Arc < AuthManager > , 
88-     cleanup_counter :  AtomicU64 , 
8966} 
9067
9168impl  DfSessionService  { 
@@ -100,57 +77,10 @@ impl DfSessionService {
10077            session_context, 
10178            parser, 
10279            timezone :  Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) , 
103-             connection_states :  Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) , 
10480            auth_manager, 
105-             cleanup_counter :  AtomicU64 :: new ( 0 ) , 
106-         } 
107-     } 
108- 
109-     async  fn  get_transaction_state ( & self ,  client_id :  ConnectionId )  -> TransactionState  { 
110-         self . connection_states 
111-             . read ( ) 
112-             . await 
113-             . get ( & client_id) 
114-             . map ( |s| s. transaction_state ) 
115-             . unwrap_or ( TransactionState :: None ) 
116-     } 
117- 
118-     async  fn  update_transaction_state ( & self ,  client_id :  ConnectionId ,  new_state :  TransactionState )  { 
119-         let  mut  states = self . connection_states . write ( ) . await ; 
120- 
121-         // Update or insert state using entry API 
122-         states
123-             . entry ( client_id) 
124-             . and_modify ( |s| { 
125-                 s. transaction_state  = new_state; 
126-                 s. last_activity  = Instant :: now ( ) ; 
127-             } ) 
128-             . or_insert ( ConnectionState  { 
129-                 transaction_state :  new_state, 
130-                 last_activity :  Instant :: now ( ) , 
131-             } ) ; 
132- 
133-         // Inline cleanup every 100 operations 
134-         if  self . cleanup_counter . fetch_add ( 1 ,  Ordering :: Relaxed )  % 100  == 0  { 
135-             let  cutoff = Instant :: now ( )  - Duration :: from_secs ( 3600 ) ; 
136-             states. retain ( |_,  state| state. last_activity  > cutoff) ; 
13781        } 
13882    } 
13983
140-     fn  get_client_id < C :  ClientInfo > ( client :  & C )  -> ConnectionId  { 
141-         // Use a hash of PID, secret key, and socket address for better uniqueness 
142-         let  ( pid,  secret)  = client. pid_and_secret_key ( ) ; 
143-         let  socket_addr = client. socket_addr ( ) ; 
144- 
145-         // Create a hash of all identifying values 
146-         let  mut  hasher = std:: collections:: hash_map:: DefaultHasher :: new ( ) ; 
147-         pid. hash ( & mut  hasher) ; 
148-         secret. hash ( & mut  hasher) ; 
149-         socket_addr. hash ( & mut  hasher) ; 
150- 
151-         hasher. finish ( ) 
152-     } 
153- 
15484    /// Check if the current user has permission to execute a query 
15585     async  fn  check_query_permission < C > ( & self ,  client :  & C ,  query :  & str )  -> PgWireResult < ( ) > 
15686    where 
@@ -290,24 +220,15 @@ impl DfSessionService {
290220    where 
291221        C :  ClientInfo , 
292222    { 
293-         let  client_id = Self :: get_client_id ( client) ; 
294- 
295223        // Transaction handling based on pgwire example: 
296224        // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57 
297225        match  query_lower. trim ( )  { 
298226            "begin"  | "begin transaction"  | "begin work"  | "start transaction"  => { 
299-                 match  self . get_transaction_state ( client_id) . await  { 
300-                     TransactionState :: None  => { 
301-                         self . update_transaction_state ( client_id,  TransactionState :: Active ) 
302-                             . await ; 
303-                         Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) ) 
304-                     } 
305-                     TransactionState :: Active  => { 
306-                         // Already in transaction, PostgreSQL allows this but issues a warning 
307-                         // For simplicity, we'll just return BEGIN again 
227+                 match  client. transaction_status ( )  { 
228+                     TransactionStatus :: Idle  | TransactionStatus :: Transaction  => { 
308229                        Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) ) 
309230                    } 
310-                     TransactionState :: Failed  => { 
231+                     TransactionStatus :: Error  => { 
311232                        // Can't start new transaction from failed state 
312233                        Err ( PgWireError :: UserError ( Box :: new ( 
313234                            pgwire:: error:: ErrorInfo :: new ( 
@@ -320,27 +241,16 @@ impl DfSessionService {
320241                } 
321242            } 
322243            "commit"  | "commit transaction"  | "commit work"  | "end"  | "end transaction"  => { 
323-                 match  self . get_transaction_state ( client_id) . await  { 
324-                     TransactionState :: Active  => { 
325-                         self . update_transaction_state ( client_id,  TransactionState :: None ) 
326-                             . await ; 
327-                         Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) ) 
328-                     } 
329-                     TransactionState :: None  => { 
330-                         // PostgreSQL allows COMMIT outside transaction with warning 
244+                 match  client. transaction_status ( )  { 
245+                     TransactionStatus :: Idle  | TransactionStatus :: Transaction  => { 
331246                        Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) ) 
332247                    } 
333-                     TransactionState :: Failed  => { 
334-                         // COMMIT in failed transaction is treated as ROLLBACK 
335-                         self . update_transaction_state ( client_id,  TransactionState :: None ) 
336-                             . await ; 
248+                     TransactionStatus :: Error  => { 
337249                        Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) ) 
338250                    } 
339251                } 
340252            } 
341253            "rollback"  | "rollback transaction"  | "rollback work"  | "abort"  => { 
342-                 self . update_transaction_state ( client_id,  TransactionState :: None ) 
343-                     . await ; 
344254                Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) ) 
345255            } 
346256            _ => Ok ( None ) , 
@@ -399,7 +309,7 @@ impl SimpleQueryHandler for DfSessionService {
399309        C :  ClientInfo  + Unpin  + Send  + Sync , 
400310    { 
401311        let  query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ; 
402-         log:: debug!( "Received query: {}"  ,  query ) ;  // Log the query for debugging 
312+         log:: debug!( "Received query: {query}"  ) ;  // Log the query for debugging 
403313
404314        // Check permissions for the query (skip for SET, transaction, and SHOW statements) 
405315        if  !query_lower. starts_with ( "set" ) 
@@ -429,9 +339,9 @@ impl SimpleQueryHandler for DfSessionService {
429339            return  Ok ( vec ! [ resp] ) ; 
430340        } 
431341
432-         // Check if we're in a failed transaction and block non-transaction commands  
433-         let  client_id =  Self :: get_client_id ( client ) ; 
434-         if  self . get_transaction_state ( client_id ) . await  == TransactionState :: Failed  { 
342+         // Check if we're in a failed transaction and block non-transaction 
343+         // commands 
344+         if  client . transaction_status ( )  == TransactionStatus :: Error  { 
435345            return  Err ( PgWireError :: UserError ( Box :: new ( 
436346                pgwire:: error:: ErrorInfo :: new ( 
437347                    "ERROR" . to_string ( ) , 
@@ -447,12 +357,6 @@ impl SimpleQueryHandler for DfSessionService {
447357        let  df = match  df_result { 
448358            Ok ( df)  => df, 
449359            Err ( e)  => { 
450-                 // If we're in a transaction and a query fails, mark transaction as failed 
451-                 let  client_id = Self :: get_client_id ( client) ; 
452-                 if  self . get_transaction_state ( client_id) . await  == TransactionState :: Active  { 
453-                     self . update_transaction_state ( client_id,  TransactionState :: Failed ) 
454-                         . await ; 
455-                 } 
456360                return  Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ; 
457361            } 
458362        } ; 
@@ -557,7 +461,7 @@ impl ExtendedQueryHandler for DfSessionService {
557461            . to_lowercase ( ) 
558462            . trim ( ) 
559463            . to_string ( ) ; 
560-         log:: debug!( "Received execute extended query: {}"  ,  query ) ;  // Log for debugging 
464+         log:: debug!( "Received execute extended query: {query}"  ) ;  // Log for debugging 
561465
562466        // Check permissions for the query (skip for SET and SHOW statements) 
563467        if  !query. starts_with ( "set" )  && !query. starts_with ( "show" )  { 
@@ -580,9 +484,9 @@ impl ExtendedQueryHandler for DfSessionService {
580484            return  Ok ( resp) ; 
581485        } 
582486
583-         // Check if we're in a failed transaction and block non-transaction commands  
584-         let  client_id =  Self :: get_client_id ( client ) ; 
585-         if  self . get_transaction_state ( client_id ) . await  == TransactionState :: Failed  { 
487+         // Check if we're in a failed transaction and block non-transaction 
488+         // commands 
489+         if  client . transaction_status ( )  == TransactionStatus :: Error  { 
586490            return  Err ( PgWireError :: UserError ( Box :: new ( 
587491                pgwire:: error:: ErrorInfo :: new ( 
588492                    "ERROR" . to_string ( ) , 
@@ -605,12 +509,6 @@ impl ExtendedQueryHandler for DfSessionService {
605509        let  dataframe = match  self . session_context . execute_logical_plan ( plan) . await  { 
606510            Ok ( df)  => df, 
607511            Err ( e)  => { 
608-                 // If we're in a transaction and a query fails, mark transaction as failed 
609-                 let  client_id = Self :: get_client_id ( client) ; 
610-                 if  self . get_transaction_state ( client_id) . await  == TransactionState :: Active  { 
611-                     self . update_transaction_state ( client_id,  TransactionState :: Failed ) 
612-                         . await ; 
613-                 } 
614512                return  Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ; 
615513            } 
616514        } ; 
@@ -633,7 +531,7 @@ impl QueryParser for Parser {
633531        sql :  & str , 
634532        _types :  & [ Type ] , 
635533    )  -> PgWireResult < Self :: Statement >  { 
636-         log:: debug!( "Received parse extended query: {}"  ,  sql ) ;  // Log for debugging 
534+         log:: debug!( "Received parse extended query: {sql}"  ) ;  // Log for debugging 
637535        let  context = & self . session_context ; 
638536        let  state = context. state ( ) ; 
639537        let  logical_plan = state
@@ -654,134 +552,3 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
654552    types. sort_by ( |a,  b| a. 0 . cmp ( b. 0 ) ) ; 
655553    types. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( ) 
656554} 
657- 
658- #[ cfg( test) ]  
659- mod  tests { 
660-     use  super :: * ; 
661-     use  datafusion:: prelude:: SessionContext ; 
662- 
663-     #[ tokio:: test]  
664-     async  fn  test_transaction_isolation ( )  { 
665-         let  session_context = Arc :: new ( SessionContext :: new ( ) ) ; 
666-         let  auth_manager = Arc :: new ( AuthManager :: new ( ) ) ; 
667-         let  service = DfSessionService :: new ( session_context,  auth_manager) ; 
668- 
669-         // Simulate two different connection IDs 
670-         let  client_id_1 = 1001 ; 
671-         let  client_id_2 = 1002 ; 
672- 
673-         // Client 1 starts a transaction 
674-         service
675-             . update_transaction_state ( client_id_1,  TransactionState :: Active ) 
676-             . await ; 
677- 
678-         // Client 2 starts a transaction 
679-         service
680-             . update_transaction_state ( client_id_2,  TransactionState :: Active ) 
681-             . await ; 
682- 
683-         // Verify both have active transactions independently 
684-         { 
685-             let  states = service. connection_states . read ( ) . await ; 
686-             assert_eq ! ( 
687-                 states. get( & client_id_1) . map( |s| s. transaction_state) , 
688-                 Some ( TransactionState :: Active ) 
689-             ) ; 
690-             assert_eq ! ( 
691-                 states. get( & client_id_2) . map( |s| s. transaction_state) , 
692-                 Some ( TransactionState :: Active ) 
693-             ) ; 
694-         } 
695- 
696-         // Client 1 fails a transaction 
697-         service
698-             . update_transaction_state ( client_id_1,  TransactionState :: Failed ) 
699-             . await ; 
700- 
701-         // Verify client 1 is failed but client 2 is still active 
702-         { 
703-             let  states = service. connection_states . read ( ) . await ; 
704-             assert_eq ! ( 
705-                 states. get( & client_id_1) . map( |s| s. transaction_state) , 
706-                 Some ( TransactionState :: Failed ) 
707-             ) ; 
708-             assert_eq ! ( 
709-                 states. get( & client_id_2) . map( |s| s. transaction_state) , 
710-                 Some ( TransactionState :: Active ) 
711-             ) ; 
712-         } 
713- 
714-         // Client 1 rollback 
715-         service
716-             . update_transaction_state ( client_id_1,  TransactionState :: None ) 
717-             . await ; 
718- 
719-         // Client 2 commit 
720-         service
721-             . update_transaction_state ( client_id_2,  TransactionState :: None ) 
722-             . await ; 
723- 
724-         // Verify both are back to None state 
725-         { 
726-             let  states = service. connection_states . read ( ) . await ; 
727-             assert_eq ! ( 
728-                 states. get( & client_id_1) . map( |s| s. transaction_state) , 
729-                 Some ( TransactionState :: None ) 
730-             ) ; 
731-             assert_eq ! ( 
732-                 states. get( & client_id_2) . map( |s| s. transaction_state) , 
733-                 Some ( TransactionState :: None ) 
734-             ) ; 
735-         } 
736-     } 
737- 
738-     #[ tokio:: test]  
739-     async  fn  test_opportunistic_cleanup ( )  { 
740-         let  session_context = Arc :: new ( SessionContext :: new ( ) ) ; 
741-         let  auth_manager = Arc :: new ( AuthManager :: new ( ) ) ; 
742-         let  service = DfSessionService :: new ( session_context,  auth_manager) ; 
743- 
744-         // Add some connection states 
745-         service
746-             . update_transaction_state ( 2001 ,  TransactionState :: Active ) 
747-             . await ; 
748-         service
749-             . update_transaction_state ( 2002 ,  TransactionState :: Failed ) 
750-             . await ; 
751- 
752-         // Manually create an old connection 
753-         { 
754-             let  mut  states = service. connection_states . write ( ) . await ; 
755-             states. insert ( 
756-                 2003 , 
757-                 ConnectionState  { 
758-                     transaction_state :  TransactionState :: Active , 
759-                     last_activity :  Instant :: now ( )  - Duration :: from_secs ( 7200 ) ,  // 2 hours old 
760-                 } , 
761-             ) ; 
762-         } 
763- 
764-         // Set cleanup counter to trigger cleanup on next update (fetch_add returns old value) 
765-         service. cleanup_counter . store ( 99 ,  Ordering :: Relaxed ) ; 
766- 
767-         // First update sets counter to 100 (99 + 1) 
768-         service
769-             . update_transaction_state ( 2004 ,  TransactionState :: Active ) 
770-             . await ; 
771- 
772-         // This should trigger cleanup (counter becomes 100, 100 % 100 == 0) 
773-         service
774-             . update_transaction_state ( 2005 ,  TransactionState :: Active ) 
775-             . await ; 
776- 
777-         // Verify only the old connection was removed (cleanup is now inline, no wait needed) 
778-         { 
779-             let  states = service. connection_states . read ( ) . await ; 
780-             assert ! ( states. contains_key( & 2001 ) ) ; 
781-             assert ! ( states. contains_key( & 2002 ) ) ; 
782-             assert ! ( !states. contains_key( & 2003 ) ) ;  // Old connection should be removed 
783-             assert ! ( states. contains_key( & 2004 ) ) ; 
784-             assert ! ( states. contains_key( & 2005 ) ) ; 
785-         } 
786-     } 
787- } 
0 commit comments