From cd8f0a064a95ba065747425269732d8fa52d54bf Mon Sep 17 00:00:00 2001 From: orenzhang <41963680+OrenZhang@users.noreply.github.com> Date: Sun, 17 May 2026 14:16:41 +0800 Subject: [PATCH] feat(risk): add risk block dialog and related components --- config.example.yaml | 10 + frontend/app/layout.tsx | 2 + .../common/risk/risk-block-dialog.tsx | 59 +++++ .../components/common/risk/risk-info-box.tsx | 38 +++ .../common/risk/risk-warning-toast.tsx | 20 ++ frontend/lib/services/core/api-client.ts | 77 +++++- frontend/lib/services/core/errors.ts | 5 +- internal/apps/oauth/middlewares.go | 6 + internal/apps/oauth/risk.go | 225 ++++++++++++++++++ internal/config/model.go | 34 ++- internal/router/router.go | 2 +- 11 files changed, 462 insertions(+), 16 deletions(-) create mode 100644 frontend/components/common/risk/risk-block-dialog.tsx create mode 100644 frontend/components/common/risk/risk-info-box.tsx create mode 100644 frontend/components/common/risk/risk-warning-toast.tsx create mode 100644 internal/apps/oauth/risk.go diff --git a/config.example.yaml b/config.example.yaml index 9d5aba04..91176135 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -133,6 +133,16 @@ worker: linuxDo: api_key: "" +# OpenAPI Risk +openapi_risk: + enabled: false + base_url: "https://audit.example.com" + username: "" + password: "" + cache_ttl_seconds: 3600 + prompt_risk_levels: [] + block_risk_levels: [] + # OpenTelemetry otel: sampling_rate: 0.1 # 采样率 0.0-1.0 diff --git a/frontend/app/layout.tsx b/frontend/app/layout.tsx index f2f2f369..fb32b7be 100644 --- a/frontend/app/layout.tsx +++ b/frontend/app/layout.tsx @@ -5,6 +5,7 @@ import { ThemeProvider } from "@/components/layout/theme-provider"; import { CustomThemeProvider } from "@/lib/theme"; import { BellRingProvider } from "@/contexts/bell-ring-context"; import { NotificationSettingsProvider } from "@/contexts/notification-settings-context"; +import { RiskBlockDialog } from "@/components/common/risk/risk-block-dialog"; import "./globals.css"; const inter = Inter({ @@ -54,6 +55,7 @@ export default function RootLayout({ {children} + diff --git a/frontend/components/common/risk/risk-block-dialog.tsx b/frontend/components/common/risk/risk-block-dialog.tsx new file mode 100644 index 00000000..6306b0fb --- /dev/null +++ b/frontend/components/common/risk/risk-block-dialog.tsx @@ -0,0 +1,59 @@ +"use client" + +import { useEffect, useState } from "react" +import { AlertTriangle } from "lucide-react" +import { RiskInfo, RiskInfoBox } from "@/components/common/risk/risk-info-box" +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog" + +const RISK_BLOCKED_EVENT = "credit-risk-blocked" + +function isRiskInfo(value: unknown): value is RiskInfo { + if (!value || typeof value !== "object") return false + const riskInfo = value as Partial + return typeof riskInfo.risk_level === "string" && Array.isArray(riskInfo.risk_labels) +} + +export function RiskBlockDialog() { + const [riskInfo, setRiskInfo] = useState(null) + + useEffect(() => { + const handleRiskBlocked = (event: Event) => { + const detail = (event as CustomEvent).detail + if (!isRiskInfo(detail)) return + setRiskInfo(detail) + } + + window.addEventListener(RISK_BLOCKED_EVENT, handleRiskBlocked) + return () => window.removeEventListener(RISK_BLOCKED_EVENT, handleRiskBlocked) + }, []) + + return ( + + event.preventDefault()} + onPointerDownOutside={(event) => event.preventDefault()} + onInteractOutside={(event) => event.preventDefault()} + className="sm:max-w-md" + > + +
+ +
+ 账号存在风险 + + 因触发风控,账号暂时无法使用需登录的功能 + +
+ + +
+
+ ) +} diff --git a/frontend/components/common/risk/risk-info-box.tsx b/frontend/components/common/risk/risk-info-box.tsx new file mode 100644 index 00000000..b2e05df2 --- /dev/null +++ b/frontend/components/common/risk/risk-info-box.tsx @@ -0,0 +1,38 @@ +import { cn } from "@/lib/utils" +import { Badge } from "@/components/ui/badge" + +export interface RiskInfo { + risk_level: string + risk_labels: string[] +} + +export function RiskInfoBox({ + riskInfo, + label = "风险等级", + labelClassName, +}: { + riskInfo: RiskInfo | null + label?: string + labelClassName?: string +}) { + return ( +
+
+ {label} + {riskInfo?.risk_level || "未知"} +
+ +
+ {riskInfo?.risk_labels.length ? ( + riskInfo.risk_labels.map(label => ( + + {label} + + )) + ) : ( + 暂无风险标签详情 + )} +
+
+ ) +} diff --git a/frontend/components/common/risk/risk-warning-toast.tsx b/frontend/components/common/risk/risk-warning-toast.tsx new file mode 100644 index 00000000..ade26329 --- /dev/null +++ b/frontend/components/common/risk/risk-warning-toast.tsx @@ -0,0 +1,20 @@ +import { toast } from "sonner" +import { RiskInfo, RiskInfoBox } from "@/components/common/risk/risk-info-box" + +const RISK_TOAST_ID = "credit-risk-warning" + +export function showRiskWarningToast(riskInfo: RiskInfo) { + toast.custom( + () => ( +
+ +
+ ), + { + id: RISK_TOAST_ID, + duration: Infinity, + closeButton: false, + dismissible: false, + }, + ) +} diff --git a/frontend/lib/services/core/api-client.ts b/frontend/lib/services/core/api-client.ts index cfe4ca13..30f1262d 100644 --- a/frontend/lib/services/core/api-client.ts +++ b/frontend/lib/services/core/api-client.ts @@ -1,5 +1,6 @@ import axios, { AxiosError, AxiosResponse, CancelTokenSource, InternalAxiosRequestConfig } from 'axios'; import { toast } from 'sonner'; +import { showRiskWarningToast } from '@/components/common/risk/risk-warning-toast'; import { apiConfig } from './config'; import { ApiErrorBase, @@ -36,6 +37,63 @@ const cancelTokens = new Map(); */ const pendingRequests = new Map>>(); +const RISK_LEVEL_HEADER = 'x-credit-risk-level'; +const RISK_LABELS_HEADER = 'x-credit-risk-labels'; +const RISK_BLOCKED_CODE = 'RISK_BLOCKED'; +const RISK_BLOCKED_EVENT = 'credit-risk-blocked'; + +interface RiskInfo { + risk_level: string; + risk_labels: string[]; +} + +function decodeRiskLabels(value?: string): string[] { + if (!value || typeof window === 'undefined') return []; + + try { + const binary = window.atob(value); + const bytes = Uint8Array.from(binary, char => char.charCodeAt(0)); + const json = new TextDecoder().decode(bytes); + const labels = JSON.parse(json); + return Array.isArray(labels) ? labels.filter((label): label is string => typeof label === 'string') : []; + } catch { + return []; + } +} + +function riskInfoFromHeaders(headers: AxiosResponse['headers']): RiskInfo | null { + const riskLevel = headers[RISK_LEVEL_HEADER]; + if (typeof riskLevel !== 'string' || !riskLevel) return null; + + const riskLabelsHeader = headers[RISK_LABELS_HEADER]; + return { + risk_level: riskLevel, + risk_labels: typeof riskLabelsHeader === 'string' ? decodeRiskLabels(riskLabelsHeader) : [], + }; +} + +function riskInfoFromDetails(details: unknown): RiskInfo | null { + if (!details || typeof details !== 'object') return null; + + const riskLevel = 'risk_level' in details ? (details as { risk_level?: unknown }).risk_level : undefined; + const riskLabels = 'risk_labels' in details ? (details as { risk_labels?: unknown }).risk_labels : undefined; + if (typeof riskLevel !== 'string' || !riskLevel) return null; + + return { + risk_level: riskLevel, + risk_labels: Array.isArray(riskLabels) ? riskLabels.filter((label): label is string => typeof label === 'string') : [], + }; +} + +function showRiskWarning(riskInfo: RiskInfo): void { + showRiskWarningToast(riskInfo); +} + +function showRiskBlockedDialog(riskInfo: RiskInfo): void { + if (typeof window === 'undefined') return; + window.dispatchEvent(new CustomEvent(RISK_BLOCKED_EVENT, { detail: riskInfo })); +} + /** * 生成请求的唯一键 * 包含方法、URL 和请求数据的哈希,确保不同参数的请求不会被误取消 @@ -99,6 +157,12 @@ apiClient.interceptors.response.use( const requestKey = getRequestKey(response.config); cancelTokens.delete(requestKey); pendingRequests.delete(requestKey); + + const riskInfo = riskInfoFromHeaders(response.headers); + if (riskInfo) { + showRiskWarning(riskInfo); + } + return response; }, (error: AxiosError) => { @@ -122,8 +186,19 @@ apiClient.interceptors.response.use( /* 403 权限不足错误 */ if (error.response?.status === 403) { + if (error.response.data?.error_code === RISK_BLOCKED_CODE) { + const riskInfo = riskInfoFromDetails(error.response.data.details) || riskInfoFromHeaders(error.response.headers); + if (riskInfo) { + showRiskBlockedDialog(riskInfo); + } + + return Promise.reject( + new ForbiddenError(error.response.data?.error_msg || '账号存在风险', RISK_BLOCKED_CODE, error.response.data?.details), + ); + } + return Promise.reject( - new ForbiddenError(error.response.data?.error_msg || '权限不足,请过盾后重试'), + new ForbiddenError(error.response.data?.error_msg || '权限不足,请过盾后重试', error.response.data?.error_code, error.response.data?.details), ); } diff --git a/frontend/lib/services/core/errors.ts b/frontend/lib/services/core/errors.ts index e98c52f4..83c2b3c0 100644 --- a/frontend/lib/services/core/errors.ts +++ b/frontend/lib/services/core/errors.ts @@ -55,8 +55,8 @@ export class UnauthorizedError extends ApiErrorBase { * 权限不足错误 (403) */ export class ForbiddenError extends ApiErrorBase { - constructor(message = '权限不足') { - super(message, 'FORBIDDEN', 403); + constructor(message = '权限不足', code = 'FORBIDDEN', details?: unknown) { + super(message, code, 403, details); this.name = 'ForbiddenError'; Object.setPrototypeOf(this, ForbiddenError.prototype); } @@ -103,4 +103,3 @@ export class ValidationError extends ApiErrorBase { export function isCancelError(error: unknown): boolean { return error !== null && typeof error === 'object' && ('__CANCEL__' in error && (error as { __CANCEL__?: boolean }).__CANCEL__ === true || ('message' in error && (error as { message?: string }).message === '请求已被取消')); } - diff --git a/internal/apps/oauth/middlewares.go b/internal/apps/oauth/middlewares.go index e510e6a6..c0fc013c 100644 --- a/internal/apps/oauth/middlewares.go +++ b/internal/apps/oauth/middlewares.go @@ -82,6 +82,12 @@ func LoginRequired() gin.HandlerFunc { // set user info util.SetToContext(c, UserObjKey, &user) + if risk, ok := checkOpenAPIUserRisk(ctx, user.ID); ok { + if blocked := applyOpenAPIUserRisk(c, risk); blocked { + return + } + } + // next c.Next() } diff --git a/internal/apps/oauth/risk.go b/internal/apps/oauth/risk.go new file mode 100644 index 00000000..42f87a91 --- /dev/null +++ b/internal/apps/oauth/risk.go @@ -0,0 +1,225 @@ +/* +Copyright 2025 linux.do + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/linux-do/credit/internal/config" + "github.com/linux-do/credit/internal/db" + "github.com/linux-do/credit/internal/logger" + "github.com/linux-do/credit/internal/util" + "github.com/redis/go-redis/v9" +) + +const ( + openAPIRiskCacheKeyFormat = "openapi_risk:user:%d" + minOpenAPIRiskCacheTTL = time.Hour + + riskLevelHeader = "X-Credit-Risk-Level" + riskLabelsHeader = "X-Credit-Risk-Labels" + exposeHeader = "Access-Control-Expose-Headers" + + riskBlockedCode = "RISK_BLOCKED" + riskBlockedMsg = "账号存在风险" +) + +type openAPIUserRiskItem struct { + Label string `json:"label"` + Value string `json:"value"` +} + +type openAPIUserRiskResponse struct { + Risky bool `json:"risky"` + RiskLevel string `json:"risk_level"` + Risks []openAPIUserRiskItem `json:"risks"` +} + +type riskBlockDetails struct { + RiskLevel string `json:"risk_level"` + RiskLabels []string `json:"risk_labels"` +} + +func checkOpenAPIUserRisk(ctx context.Context, userID uint64) (*openAPIUserRiskResponse, bool) { + cfg := config.Config.OpenAPIRisk + if !cfg.Enabled || strings.TrimSpace(cfg.BaseURL) == "" { + return nil, false + } + if db.Redis == nil { + logger.ErrorF(ctx, "[OpenAPIRisk] redis is not initialized, skip risk check") + return nil, false + } + + cacheKey := fmt.Sprintf(openAPIRiskCacheKeyFormat, userID) + var cached openAPIUserRiskResponse + if err := db.GetJSON(ctx, cacheKey, &cached); err == nil { + return &cached, true + } else if err != nil && !errors.Is(err, redis.Nil) { + logger.ErrorF(ctx, "[OpenAPIRisk] read cache failed, skip risk check: %v", err) + return nil, false + } + + risk, err := fetchOpenAPIUserRisk(ctx, userID) + if err != nil { + logger.ErrorF(ctx, "[OpenAPIRisk] fetch user risk failed, skip risk check: %v", err) + return nil, false + } + + if err := db.SetJSON(ctx, cacheKey, risk, openAPIRiskCacheTTL()); err != nil { + logger.ErrorF(ctx, "[OpenAPIRisk] write cache failed, skip risk check: %v", err) + return nil, false + } + + return risk, true +} + +func fetchOpenAPIUserRisk(ctx context.Context, userID uint64) (*openAPIUserRiskResponse, error) { + cfg := config.Config.OpenAPIRisk + endpoint := fmt.Sprintf( + "%s/api/open/v1/risk/users/%d", + strings.TrimRight(cfg.BaseURL, "/"), + userID, + ) + + headers := map[string]string{ + "Accept": "application/json", + } + if cfg.Username != "" || cfg.Password != "" { + token := base64.StdEncoding.EncodeToString([]byte(cfg.Username + ":" + cfg.Password)) + headers["Authorization"] = "Basic " + token + } + + resp, err := util.Request(ctx, http.MethodGet, endpoint, nil, headers, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var risk openAPIUserRiskResponse + if err := json.NewDecoder(resp.Body).Decode(&risk); err != nil { + return nil, fmt.Errorf("decode response failed: %w", err) + } + + return &risk, nil +} + +func openAPIRiskCacheTTL() time.Duration { + ttl := time.Duration(config.Config.OpenAPIRisk.CacheTTLSeconds) * time.Second + if ttl < minOpenAPIRiskCacheTTL { + return minOpenAPIRiskCacheTTL + } + return ttl +} + +func applyOpenAPIUserRisk(c *gin.Context, risk *openAPIUserRiskResponse) bool { + if risk == nil || !risk.Risky { + return false + } + + labels := riskLabels(risk) + cfg := config.Config.OpenAPIRisk + if containsString(cfg.BlockRiskLevels, risk.RiskLevel) { + setRiskHeaders(c, risk.RiskLevel, labels) + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ + "error_code": riskBlockedCode, + "error_msg": riskBlockedMsg, + "details": riskBlockDetails{ + RiskLevel: risk.RiskLevel, + RiskLabels: labels, + }, + }) + return true + } + + if containsString(cfg.PromptRiskLevels, risk.RiskLevel) { + setRiskHeaders(c, risk.RiskLevel, labels) + } + + return false +} + +func setRiskHeaders(c *gin.Context, riskLevel string, labels []string) { + labelsJSON, err := json.Marshal(labels) + if err != nil { + logger.ErrorF(c.Request.Context(), "[OpenAPIRisk] marshal risk labels failed: %v", err) + return + } + + c.Header(riskLevelHeader, riskLevel) + c.Header(riskLabelsHeader, base64.StdEncoding.EncodeToString(labelsJSON)) + appendExposeHeaders(c, riskLevelHeader, riskLabelsHeader) +} + +func appendExposeHeaders(c *gin.Context, names ...string) { + existing := c.Writer.Header().Get(exposeHeader) + exposed := make([]string, 0, len(names)+1) + if existing != "" { + exposed = append(exposed, strings.Split(existing, ",")...) + } + exposed = append(exposed, names...) + + seen := make(map[string]struct{}, len(exposed)) + normalized := make([]string, 0, len(exposed)) + for _, header := range exposed { + header = strings.TrimSpace(header) + if header == "" { + continue + } + key := strings.ToLower(header) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + normalized = append(normalized, header) + } + + c.Header(exposeHeader, strings.Join(normalized, ", ")) +} + +func riskLabels(risk *openAPIUserRiskResponse) []string { + labels := make([]string, 0, len(risk.Risks)) + for _, item := range risk.Risks { + label := strings.TrimSpace(item.Label) + if label == "" { + continue + } + labels = append(labels, label) + } + return labels +} + +func containsString(values []string, target string) bool { + target = strings.TrimSpace(target) + for _, value := range values { + if strings.TrimSpace(value) == target { + return true + } + } + return false +} diff --git a/internal/config/model.go b/internal/config/model.go index ce478f9b..f8fac649 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -19,17 +19,18 @@ package config import "time" type configModel struct { - App appConfig `mapstructure:"app"` - OAuth2 OAuth2Config `mapstructure:"oauth2"` - Database databaseConfig `mapstructure:"database"` - Redis redisConfig `mapstructure:"redis"` - Log logConfig `mapstructure:"log"` - Scheduler schedulerConfig `mapstructure:"scheduler"` - Worker workerConfig `mapstructure:"worker"` - ClickHouse clickHouseConfig `mapstructure:"clickhouse"` - LinuxDo linuxDoConfig `mapstructure:"linuxdo"` - Otel otelConfig `mapstructure:"otel"` - S3 s3Config `mapstructure:"s3"` + App appConfig `mapstructure:"app"` + OAuth2 OAuth2Config `mapstructure:"oauth2"` + Database databaseConfig `mapstructure:"database"` + Redis redisConfig `mapstructure:"redis"` + Log logConfig `mapstructure:"log"` + Scheduler schedulerConfig `mapstructure:"scheduler"` + Worker workerConfig `mapstructure:"worker"` + ClickHouse clickHouseConfig `mapstructure:"clickhouse"` + LinuxDo linuxDoConfig `mapstructure:"linuxdo"` + OpenAPIRisk openAPIRiskConfig `mapstructure:"openapi_risk"` + Otel otelConfig `mapstructure:"otel"` + S3 s3Config `mapstructure:"s3"` } // appConfig 应用基本配置 @@ -180,6 +181,17 @@ type linuxDoConfig struct { ApiKey string `mapstructure:"api_key"` } +// openAPIRiskConfig OpenAPI 用户风险配置 +type openAPIRiskConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password" json:"-"` + CacheTTLSeconds int `mapstructure:"cache_ttl_seconds"` + PromptRiskLevels []string `mapstructure:"prompt_risk_levels"` + BlockRiskLevels []string `mapstructure:"block_risk_levels"` +} + // otelConfig OpenTelemetry 配置 type otelConfig struct { SamplingRate float64 `mapstructure:"sampling_rate"` diff --git a/internal/router/router.go b/internal/router/router.go index 3091b6c9..4d60a9df 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -131,7 +131,7 @@ func Serve() { // OAuth apiV1Router.GET("/oauth/login", oauth.GetLoginURL) - apiV1Router.GET("/oauth/logout", oauth.LoginRequired(), oauth.Logout) + apiV1Router.GET("/oauth/logout", oauth.Logout) apiV1Router.POST("/oauth/callback", oauth.Callback) apiV1Router.GET("/oauth/user-info", oauth.LoginRequired(), oauth.UserInfo)