Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.tinyengine.it.common.utils;

import java.util.List;
import java.util.regex.Pattern;

public class SqlIdentifierValidator {

private static final Pattern IDENTIFIER_PATTERN =
Pattern.compile("^[a-zA-Z_][a-zA-Z0-9_]*$");

private static final Pattern ORDER_TYPE_PATTERN =
Pattern.compile("^(ASC|DESC)$", Pattern.CASE_INSENSITIVE);

private SqlIdentifierValidator() {
}

public static void validate(String identifier) {
if (identifier == null || !IDENTIFIER_PATTERN.matcher(identifier).matches()) {
throw new IllegalArgumentException("Invalid SQL identifier: " + identifier);
}
}

public static void validateAll(List<String> identifiers) {
if (identifiers == null) {
return;
}
identifiers.forEach(SqlIdentifierValidator::validate);
}

public static void validateOrderType(String orderType) {
if (orderType == null || !ORDER_TYPE_PATTERN.matcher(orderType).matches()) {
throw new IllegalArgumentException("Invalid order type: " + orderType);
}
}
}
24 changes: 16 additions & 8 deletions base/src/main/java/com/tinyengine/it/dynamic/dto/DynamicQuery.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.tinyengine.it.dynamic.dto;

import jakarta.validation.constraints.Pattern;
import lombok.Data;

import java.util.List;
Expand All @@ -8,12 +9,19 @@
@Data
public class DynamicQuery {

private String nameEn; // 表名
private String nameCh; // 表中文名
private List<String> fields; // 查询字段
private Map<String, Object> params; // 查询条件
private Integer currentPage = 1; // 页码
private Integer pageSize = 10; // 每页大小
private String orderBy; // 排序字段
private String orderType = "ASC"; // 排序方式
@Pattern(regexp = "^[a-zA-Z_][a-zA-Z0-9_]*$", message = "模型名称格式不正确")
private String nameEn;
private String nameCh;
private List<
@Pattern(regexp = "^[a-zA-Z_][a-zA-Z0-9_]*$", message = "字段名格式不正确")
String> fields;
private Map<String, Object> params;
private Integer currentPage = 1;
private Integer pageSize = 10;

@Pattern(regexp = "^[a-zA-Z_][a-zA-Z0-9_]*$", message = "排序字段格式不正确")
private String orderBy;

@Pattern(regexp = "^(?i)(ASC|DESC)$", message = "排序方式只能是 ASC 或 DESC")
private String orderType = "ASC";
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import com.tinyengine.it.dynamic.dto.DynamicUpdate;
import com.tinyengine.it.model.dto.ParametersDto;
import com.tinyengine.it.model.entity.Model;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import com.tinyengine.it.common.utils.SqlIdentifierValidator;
import com.tinyengine.it.service.material.ModelService;
import org.springframework.context.annotation.Lazy;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
Expand All @@ -31,12 +33,26 @@

@Service
@Slf4j
@RequiredArgsConstructor
public class DynamicModelService {

private static final Set<String> SYSTEM_FIELDS = Set.of(
"id", "created_at", "updated_at", "deleted_at", "created_by", "updated_by"
);

private final JdbcTemplate jdbcTemplate;
private final NamedParameterJdbcTemplate namedParameterJdbcTemplate;
private final LoginUserContext loginUserContext;
private final ModelService modelService;

public DynamicModelService(JdbcTemplate jdbcTemplate,
NamedParameterJdbcTemplate namedParameterJdbcTemplate,
LoginUserContext loginUserContext,
@Lazy ModelService modelService) {
this.jdbcTemplate = jdbcTemplate;
this.namedParameterJdbcTemplate = namedParameterJdbcTemplate;
this.loginUserContext = loginUserContext;
this.modelService = modelService;
}


/**
Expand Down Expand Up @@ -182,6 +198,17 @@ public List<Map<String, Object>> dynamicQuery(String tableName,
String orderBy,
Integer limit) {

SqlIdentifierValidator.validate(tableName);
SqlIdentifierValidator.validateAll(fields);
if (conditions != null && !conditions.isEmpty()) {
for (String key : conditions.keySet()) {
SqlIdentifierValidator.validate(key);
}
}
if (orderBy != null && !orderBy.isEmpty()) {
SqlIdentifierValidator.validate(orderBy.replaceAll("(?i)\\s+(ASC|DESC)$", ""));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// 1. 构建SQL
StringBuilder sql = new StringBuilder("SELECT ");

Expand Down Expand Up @@ -267,18 +294,19 @@ public Long count(String tableName, Map<String, Object> conditions) {
* 分页查询
*/
public Map<String, Object> queryWithPage(DynamicQuery dto) {
String tableName = getTableName( dto.getNameEn());
String tableName = getTableName(dto.getNameEn());
List<String> fields = dto.getFields();
Map<String, Object> conditions = dto.getParams();
String orderBy = dto.getOrderBy();
Integer pageNum = dto.getCurrentPage();
Integer pageSize = dto.getPageSize();

validateQueryFields(dto);

// 计算分页
Integer limit = null;
if (pageNum != null && pageSize != null) {
limit = pageSize;
// 如果需要偏移量,可以在这里处理
}

// 执行查询
Expand All @@ -292,6 +320,60 @@ public Map<String, Object> queryWithPage(DynamicQuery dto) {

return result;
}

private Set<String> getAllowedFields(String nameEn) {
List<Model> modelList = modelService.getModelByEnName(nameEn);
if (modelList == null || modelList.isEmpty()) {
return Collections.emptySet();
}
Model model = modelList.get(0);
Set<String> allowed = new HashSet<>(SYSTEM_FIELDS);
if (model.getParameters() != null) {
for (Object param : model.getParameters()) {
String prop = extractProp(param);
if (prop != null) {
allowed.add(prop);
}
}
}
return allowed;
}

@SuppressWarnings("unchecked")
private String extractProp(Object param) {
if (param instanceof ParametersDto) {
return ((ParametersDto) param).getProp();
}
if (param instanceof Map) {
Object value = ((Map<String, Object>) param).get("prop");
return value != null ? value.toString() : null;
}
return null;
}

private void validateQueryFields(DynamicQuery dto) {
Set<String> allowedFields = getAllowedFields(dto.getNameEn());

if (dto.getFields() != null && !dto.getFields().isEmpty()) {
for (String field : dto.getFields()) {
SqlIdentifierValidator.validate(field);
if (!allowedFields.contains(field)) {
throw new IllegalArgumentException("不允许的字段: " + field);
}
}
}

if (dto.getOrderBy() != null && !dto.getOrderBy().isEmpty()) {
SqlIdentifierValidator.validate(dto.getOrderBy());
if (!allowedFields.contains(dto.getOrderBy())) {
throw new IllegalArgumentException("不允许的排序字段: " + dto.getOrderBy());
}
}

if (dto.getOrderType() != null) {
SqlIdentifierValidator.validateOrderType(dto.getOrderType());
}
}
private Object convertValueByType(Object value, String fieldType, String columnName) {
try {
switch (fieldType) {
Expand Down Expand Up @@ -525,6 +607,9 @@ public Map<String, Object> createData(DynamicInsert dataDto) {

String tableName = getTableName(dataDto.getNameEn());
Map<String, Object> record = new HashMap<>(dataDto.getParams());
for (String col : record.keySet()) {
SqlIdentifierValidator.validate(col);
}
String userId = loginUserContext.getLoginUserId();
// 添加系统字段
record.put("created_by",userId);
Expand Down Expand Up @@ -606,6 +691,9 @@ public Map<String,Object> updateDateById(DynamicUpdate dto) {
}
Long id = Long.parseLong(params1.get("id").toString());
Map<String, Object> updateFields = dto.getData();
for (String col : updateFields.keySet()) {
SqlIdentifierValidator.validate(col);
}
String tableName = getTableName(modelId);
StringBuilder sql = new StringBuilder("UPDATE " + tableName + " SET ");
List<Object> params = new ArrayList<>();
Expand Down
Loading
Loading