1919
2020from data_diff .abcs .compiler import AbstractCompiler , Compilable
2121from data_diff .queries .extras import ApplyFuncAndNormalizeAsString , Checksum , NormalizeAsString
22+ from data_diff .schema import RawColumnInfo
2223from data_diff .utils import ArithString , is_uuid , join_iter , safezip
2324from data_diff .queries .api import Expr , table , Select , SKIP , Explain , Code , this
2425from data_diff .queries .ast_classes import (
@@ -707,27 +708,18 @@ def type_repr(self, t) -> str:
707708 datetime : "TIMESTAMP" ,
708709 }[t ]
709710
710- def _parse_type_repr (self , type_repr : str ) -> Optional [Type [ColType ]]:
711- return self .TYPE_CLASSES .get (type_repr )
712-
713- def parse_type (
714- self ,
715- table_path : DbPath ,
716- col_name : str ,
717- type_repr : str ,
718- datetime_precision : int = None ,
719- numeric_precision : int = None ,
720- numeric_scale : int = None ,
721- ) -> ColType :
711+ def parse_type (self , table_path : DbPath , info : RawColumnInfo ) -> ColType :
722712 "Parse type info as returned by the database"
723713
724- cls = self ._parse_type_repr ( type_repr )
714+ cls = self .TYPE_CLASSES . get ( info . data_type )
725715 if cls is None :
726- return UnknownColType (type_repr )
716+ return UnknownColType (info . data_type )
727717
728718 if issubclass (cls , TemporalType ):
729719 return cls (
730- precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
720+ precision = info .datetime_precision
721+ if info .datetime_precision is not None
722+ else DEFAULT_DATETIME_PRECISION ,
731723 rounds = self .ROUNDS_ON_PREC_LOSS ,
732724 )
733725
@@ -738,22 +730,22 @@ def parse_type(
738730 return cls ()
739731
740732 elif issubclass (cls , Decimal ):
741- if numeric_scale is None :
742- numeric_scale = 0 # Needed for Oracle.
743- return cls (precision = numeric_scale )
733+ if info . numeric_scale is None :
734+ return cls ( precision = 0 ) # Needed for Oracle.
735+ return cls (precision = info . numeric_scale )
744736
745737 elif issubclass (cls , Float ):
746738 # assert numeric_scale is None
747739 return cls (
748740 precision = self ._convert_db_precision_to_digits (
749- numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
741+ info . numeric_precision if info . numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
750742 )
751743 )
752744
753745 elif issubclass (cls , (JSON , Array , Struct , Text , Native_UUID )):
754746 return cls ()
755747
756- raise TypeError (f"Parsing { type_repr } returned an unknown type ' { cls } ' ." )
748+ raise TypeError (f"Parsing { info . data_type } returned an unknown type { cls !r } ." )
757749
758750 def _convert_db_precision_to_digits (self , p : int ) -> int :
759751 """Convert from binary precision, used by floats, to decimal precision."""
@@ -1018,7 +1010,7 @@ def select_table_schema(self, path: DbPath) -> str:
10181010 f"WHERE table_name = '{ name } ' AND table_schema = '{ schema } '"
10191011 )
10201012
1021- def query_table_schema (self , path : DbPath ) -> Dict [str , tuple ]:
1013+ def query_table_schema (self , path : DbPath ) -> Dict [str , RawColumnInfo ]:
10221014 """Query the table for its schema for table in 'path', and return {column: tuple}
10231015 where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
10241016
@@ -1029,7 +1021,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10291021 if not rows :
10301022 raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
10311023
1032- d = {r [0 ]: r for r in rows }
1024+ d = {
1025+ r [0 ]: RawColumnInfo (
1026+ column_name = r [0 ],
1027+ data_type = r [1 ],
1028+ datetime_precision = r [2 ],
1029+ numeric_precision = r [3 ],
1030+ numeric_scale = r [4 ],
1031+ collation_name = r [5 ] if len (r ) > 5 else None ,
1032+ )
1033+ for r in rows
1034+ }
10331035 assert len (d ) == len (rows )
10341036 return d
10351037
@@ -1051,7 +1053,11 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
10511053 return list (res )
10521054
10531055 def _process_table_schema (
1054- self , path : DbPath , raw_schema : Dict [str , tuple ], filter_columns : Sequence [str ] = None , where : str = None
1056+ self ,
1057+ path : DbPath ,
1058+ raw_schema : Dict [str , RawColumnInfo ],
1059+ filter_columns : Sequence [str ] = None ,
1060+ where : str = None ,
10551061 ):
10561062 """Process the result of query_table_schema().
10571063
@@ -1067,7 +1073,7 @@ def _process_table_schema(
10671073 accept = {i .lower () for i in filter_columns }
10681074 filtered_schema = {name : row for name , row in raw_schema .items () if name .lower () in accept }
10691075
1070- col_dict = {row [ 0 ] : self .dialect .parse_type (path , * row ) for _name , row in filtered_schema .items ()}
1076+ col_dict = {info . column_name : self .dialect .parse_type (path , info ) for info in filtered_schema .values ()}
10711077
10721078 self ._refine_coltypes (path , col_dict , where )
10731079
@@ -1076,15 +1082,15 @@ def _process_table_schema(
10761082
10771083 def _refine_coltypes (
10781084 self , table_path : DbPath , col_dict : Dict [str , ColType ], where : Optional [str ] = None , sample_size = 64
1079- ):
1085+ ) -> Dict [ str , ColType ] :
10801086 """Refine the types in the column dict, by querying the database for a sample of their values
10811087
10821088 'where' restricts the rows to be sampled.
10831089 """
10841090
10851091 text_columns = [k for k , v in col_dict .items () if isinstance (v , Text )]
10861092 if not text_columns :
1087- return
1093+ return col_dict
10881094
10891095 fields = [Code (self .dialect .normalize_uuid (self .dialect .quote (c ), String_UUID ())) for c in text_columns ]
10901096
@@ -1116,7 +1122,9 @@ def _refine_coltypes(
11161122 )
11171123 else :
11181124 assert col_name in col_dict
1119- col_dict [col_name ] = String_VaryingAlphanum ()
1125+ col_dict [col_name ] = String_VaryingAlphanum (collation = col_dict [col_name ].collation )
1126+
1127+ return col_dict
11201128
11211129 def _normalize_table_path (self , path : DbPath ) -> DbPath :
11221130 if len (path ) == 1 :
0 commit comments