diff --git a/IHP/IDE/CodeGen/MigrationGenerator.hs b/IHP/IDE/CodeGen/MigrationGenerator.hs index 7cd5e53e9..87ca63348 100644 --- a/IHP/IDE/CodeGen/MigrationGenerator.hs +++ b/IHP/IDE/CodeGen/MigrationGenerator.hs @@ -368,7 +368,7 @@ normalizeStatement StatementCreateTable { unsafeGetCreateTable = table } = State (normalizedTable, normalizeTableRest) = normalizeTable table normalizeStatement AddConstraint { tableName, constraint, deferrable, deferrableType } = [ AddConstraint { tableName, constraint = normalizeConstraint tableName constraint, deferrable, deferrableType } ] normalizeStatement CreateEnumType { name, values } = [ CreateEnumType { name = Text.toLower name, values = map Text.toLower values } ] -normalizeStatement CreatePolicy { name, action, tableName, using, check } = [ CreatePolicy { name, tableName, using = normalizeExpression <$> using, check = normalizeExpression <$> check, action = normalizePolicyAction action } ] +normalizeStatement CreatePolicy { name, action, tableName, using, check } = [ CreatePolicy { name, tableName, using = (unqualifyExpression tableName . normalizeExpression) <$> using, check = (unqualifyExpression tableName . normalizeExpression) <$> check, action = normalizePolicyAction action } ] normalizeStatement CreateIndex { columns, indexType, .. } = [ CreateIndex { columns = map normalizeIndexColumn columns, indexType = normalizeIndexType indexType, .. } ] normalizeStatement CreateFunction { .. } = [ CreateFunction { orReplace = False, language = Text.toUpper language, functionBody = removeIndentation $ normalizeNewLines functionBody, .. } ] normalizeStatement otherwise = [otherwise] @@ -506,6 +506,36 @@ normalizeExpression (SelectExpression Select { columns, from, whereClause, alias normalizeExpression (DotExpression a b) = DotExpression (normalizeExpression a) b normalizeExpression (ExistsExpression a) = ExistsExpression (normalizeExpression a) +-- | Replaces @table.field@ with just @field@ +-- +-- >>> unqualifyExpression "servers" (sql "SELECT * FROM servers WHERE servers.is_public") +-- sql "SELECT * FROM servers WHERE is_public" +-- +unqualifyExpression :: Text -> Expression -> Expression +unqualifyExpression scope expression = unqualifyExpression expression + where + unqualifyExpression e@(TextExpression {}) = e + unqualifyExpression e@(VarExpression {}) = e + unqualifyExpression (CallExpression function args) = CallExpression function (map unqualifyExpression args) + unqualifyExpression (NotEqExpression a b) = NotEqExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (EqExpression a b) = EqExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (AndExpression a b) = AndExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (IsExpression a b) = IsExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (NotExpression a) = NotExpression (unqualifyExpression a) + unqualifyExpression (OrExpression a b) = OrExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (LessThanExpression a b) = LessThanExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (LessThanOrEqualToExpression a b) = LessThanOrEqualToExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (GreaterThanExpression a b) = GreaterThanExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (GreaterThanOrEqualToExpression a b) = GreaterThanOrEqualToExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression e@(DoubleExpression {}) = e + unqualifyExpression e@(IntExpression {}) = e + unqualifyExpression (ConcatenationExpression a b) = ConcatenationExpression (unqualifyExpression a) (unqualifyExpression b) + unqualifyExpression (TypeCastExpression a b) = TypeCastExpression (unqualifyExpression a) b + unqualifyExpression (SelectExpression Select { columns, from, whereClause, alias }) = SelectExpression Select { columns = (unqualifyExpression <$> columns), from = from, whereClause = unqualifyExpression whereClause, alias } + unqualifyExpression (ExistsExpression a) = ExistsExpression (unqualifyExpression a) + unqualifyExpression (DotExpression (VarExpression scope') b) | scope == scope' = VarExpression b + unqualifyExpression (DotExpression a b) = DotExpression (unqualifyExpression a) b + resolveAlias :: Maybe Text -> Expression -> Expression -> Expression resolveAlias (Just alias) fromExpression expression = diff --git a/Test/IDE/CodeGeneration/MigrationGenerator.hs b/Test/IDE/CodeGeneration/MigrationGenerator.hs index 2f85daa77..89748d40a 100644 --- a/Test/IDE/CodeGeneration/MigrationGenerator.hs +++ b/Test/IDE/CodeGeneration/MigrationGenerator.hs @@ -1076,10 +1076,10 @@ END;$$ language PLPGSQL;|] diffSchemas targetSchema actualSchema `shouldBe` migration - it "should not detect changes in complex policies" do + it "should normalize qualified identifiers in policy expressions" do -- https://github.com/digitallyinduced/ihp/issues/1480 let targetSchema = sql $ cs [plain| - CREATE POLICY "Users can manage servers they have access to" ON servers USING (user_id = ihp_user_id() OR (EXISTS (SELECT 1 FROM public.user_server_access WHERE user_server_access.user_id = user_id AND user_server_access.server_id = servers.id))) WITH CHECK (user_id = ihp_user_id()); + CREATE POLICY "Users can manage servers they have access to" ON servers USING (servers.user_id = ihp_user_id() OR (EXISTS (SELECT 1 FROM public.user_server_access WHERE user_server_access.user_id = ihp_user_id() AND user_server_access.server_id = servers.id))); |] let actualSchema = sql $ cs [plain| -- @@ -1088,8 +1088,7 @@ END;$$ language PLPGSQL;|] CREATE POLICY "Users can manage servers they have access to" ON public.servers USING (((user_id = public.ihp_user_id()) OR (EXISTS ( SELECT 1 FROM public.user_server_access - WHERE ((user_server_access.user_id = user_server_access.user_id) AND (user_server_access.server_id = servers.id)))))) WITH CHECK ((user_id = public.ihp_user_id())); - + WHERE ((user_server_access.user_id = public.ihp_user_id()) AND (user_server_access.server_id = servers.id)))))); |] let migration = sql [i| |]