1919
2020from data_diff .abcs .compiler import AbstractCompiler , Compilable
2121from data_diff .queries .extras import ApplyFuncAndNormalizeAsString , Checksum , NormalizeAsString
22- from data_diff .utils import ArithString , is_uuid , join_iter , safezip
22+ from data_diff .schema import RawColumnInfo
23+ from data_diff .utils import ArithString , ArithUUID , is_uuid , join_iter , safezip
2324from data_diff .queries .api import Expr , table , Select , SKIP , Explain , Code , this
2425from data_diff .queries .ast_classes import (
2526 Alias ,
@@ -248,6 +249,9 @@ def _compile(self, compiler: Compiler, elem) -> str:
248249 return self .timestamp_value (elem )
249250 elif isinstance (elem , bytes ):
250251 return f"b'{ elem .decode ()} '"
252+ elif isinstance (elem , ArithUUID ):
253+ s = f"'{ elem .uuid } '"
254+ return s .upper () if elem .uppercase else s .lower () if elem .lowercase else s
251255 elif isinstance (elem , ArithString ):
252256 return f"'{ elem } '"
253257 assert False , elem
@@ -681,8 +685,10 @@ def _constant_value(self, v):
681685 return f"'{ v } '"
682686 elif isinstance (v , datetime ):
683687 return self .timestamp_value (v )
684- elif isinstance (v , UUID ):
688+ elif isinstance (v , UUID ): # probably unused anymore in favour of ArithUUID
685689 return f"'{ v } '"
690+ elif isinstance (v , ArithUUID ):
691+ return f"'{ v .uuid } '"
686692 elif isinstance (v , decimal .Decimal ):
687693 return str (v )
688694 elif isinstance (v , bytearray ):
@@ -708,27 +714,18 @@ def type_repr(self, t) -> str:
708714 datetime : "TIMESTAMP" ,
709715 }[t ]
710716
711- def _parse_type_repr (self , type_repr : str ) -> Optional [Type [ColType ]]:
712- return self .TYPE_CLASSES .get (type_repr )
713-
714- def parse_type (
715- self ,
716- table_path : DbPath ,
717- col_name : str ,
718- type_repr : str ,
719- datetime_precision : int = None ,
720- numeric_precision : int = None ,
721- numeric_scale : int = None ,
722- ) -> ColType :
717+ def parse_type (self , table_path : DbPath , info : RawColumnInfo ) -> ColType :
723718 "Parse type info as returned by the database"
724719
725- cls = self ._parse_type_repr ( type_repr )
720+ cls = self .TYPE_CLASSES . get ( info . data_type )
726721 if cls is None :
727- return UnknownColType (type_repr )
722+ return UnknownColType (info . data_type )
728723
729724 if issubclass (cls , TemporalType ):
730725 return cls (
731- precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
726+ precision = info .datetime_precision
727+ if info .datetime_precision is not None
728+ else DEFAULT_DATETIME_PRECISION ,
732729 rounds = self .ROUNDS_ON_PREC_LOSS ,
733730 )
734731
@@ -739,22 +736,22 @@ def parse_type(
739736 return cls ()
740737
741738 elif issubclass (cls , Decimal ):
742- if numeric_scale is None :
743- numeric_scale = 0 # Needed for Oracle.
744- return cls (precision = numeric_scale )
739+ if info . numeric_scale is None :
740+ return cls ( precision = 0 ) # Needed for Oracle.
741+ return cls (precision = info . numeric_scale )
745742
746743 elif issubclass (cls , Float ):
747744 # assert numeric_scale is None
748745 return cls (
749746 precision = self ._convert_db_precision_to_digits (
750- numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
747+ info . numeric_precision if info . numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
751748 )
752749 )
753750
754751 elif issubclass (cls , (JSON , Array , Struct , Text , Native_UUID )):
755752 return cls ()
756753
757- raise TypeError (f"Parsing { type_repr } returned an unknown type ' { cls } ' ." )
754+ raise TypeError (f"Parsing { info . data_type } returned an unknown type { cls !r } ." )
758755
759756 def _convert_db_precision_to_digits (self , p : int ) -> int :
760757 """Convert from binary precision, used by floats, to decimal precision."""
@@ -1019,7 +1016,7 @@ def select_table_schema(self, path: DbPath) -> str:
10191016 f"WHERE table_name = '{ name } ' AND table_schema = '{ schema } '"
10201017 )
10211018
1022- def query_table_schema (self , path : DbPath ) -> Dict [str , tuple ]:
1019+ def query_table_schema (self , path : DbPath ) -> Dict [str , RawColumnInfo ]:
10231020 """Query the table for its schema for table in 'path', and return {column: tuple}
10241021 where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
10251022
@@ -1030,7 +1027,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10301027 if not rows :
10311028 raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
10321029
1033- d = {r [0 ]: r for r in rows }
1030+ d = {
1031+ r [0 ]: RawColumnInfo (
1032+ column_name = r [0 ],
1033+ data_type = r [1 ],
1034+ datetime_precision = r [2 ],
1035+ numeric_precision = r [3 ],
1036+ numeric_scale = r [4 ],
1037+ collation_name = r [5 ] if len (r ) > 5 else None ,
1038+ )
1039+ for r in rows
1040+ }
10341041 assert len (d ) == len (rows )
10351042 return d
10361043
@@ -1052,7 +1059,11 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
10521059 return list (res )
10531060
10541061 def _process_table_schema (
1055- self , path : DbPath , raw_schema : Dict [str , tuple ], filter_columns : Sequence [str ] = None , where : str = None
1062+ self ,
1063+ path : DbPath ,
1064+ raw_schema : Dict [str , RawColumnInfo ],
1065+ filter_columns : Sequence [str ] = None ,
1066+ where : str = None ,
10561067 ):
10571068 """Process the result of query_table_schema().
10581069
@@ -1068,7 +1079,7 @@ def _process_table_schema(
10681079 accept = {i .lower () for i in filter_columns }
10691080 filtered_schema = {name : row for name , row in raw_schema .items () if name .lower () in accept }
10701081
1071- col_dict = {row [ 0 ] : self .dialect .parse_type (path , * row ) for _name , row in filtered_schema .items ()}
1082+ col_dict = {info . column_name : self .dialect .parse_type (path , info ) for info in filtered_schema .values ()}
10721083
10731084 self ._refine_coltypes (path , col_dict , where )
10741085
@@ -1077,15 +1088,15 @@ def _process_table_schema(
10771088
10781089 def _refine_coltypes (
10791090 self , table_path : DbPath , col_dict : Dict [str , ColType ], where : Optional [str ] = None , sample_size = 64
1080- ):
1091+ ) -> Dict [ str , ColType ] :
10811092 """Refine the types in the column dict, by querying the database for a sample of their values
10821093
10831094 'where' restricts the rows to be sampled.
10841095 """
10851096
10861097 text_columns = [k for k , v in col_dict .items () if isinstance (v , Text )]
10871098 if not text_columns :
1088- return
1099+ return col_dict
10891100
10901101 fields = [Code (self .dialect .normalize_uuid (self .dialect .quote (c ), String_UUID ())) for c in text_columns ]
10911102
@@ -1105,7 +1116,10 @@ def _refine_coltypes(
11051116 )
11061117 else :
11071118 assert col_name in col_dict
1108- col_dict [col_name ] = String_UUID ()
1119+ col_dict [col_name ] = String_UUID (
1120+ lowercase = all (s == s .lower () for s in uuid_samples ),
1121+ uppercase = all (s == s .upper () for s in uuid_samples ),
1122+ )
11091123 continue
11101124
11111125 if self .SUPPORTS_ALPHANUMS : # Anything but MySQL (so far)
@@ -1117,7 +1131,9 @@ def _refine_coltypes(
11171131 )
11181132 else :
11191133 assert col_name in col_dict
1120- col_dict [col_name ] = String_VaryingAlphanum ()
1134+ col_dict [col_name ] = String_VaryingAlphanum (collation = col_dict [col_name ].collation )
1135+
1136+ return col_dict
11211137
11221138 def _normalize_table_path (self , path : DbPath ) -> DbPath :
11231139 if len (path ) == 1 :
0 commit comments