diff --git a/README.md b/README.md index f3c05dc..225cdaf 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,39 @@ # gpuaudit -Scan your AWS account for GPU waste and get actionable recommendations to cut your cloud spend. +Scan your cloud for GPU waste and get actionable recommendations to cut your spend. ``` -$ gpuaudit scan --profile ml-prod +$ gpuaudit scan --skip-eks - GPU Fleet Summary - Total GPU instances: 14 - Total monthly GPU spend: $47,832 - Estimated monthly waste: $18,240 (38%) + Found 38 GPU nodes across 47 nodes in gpu-cluster - CRITICAL (3 instances, $8,940/mo potential savings) + gpuaudit — GPU Cost Audit for AWS + Account: 123456789012 | Regions: us-east-1 | Duration: 4.2s - i-0a1b2c3d4e g5.12xlarge (4x A10G) $4,380/mo Idle — no activity for 18 days → terminate - i-9f8e7d6c5b p4d.24xlarge (8x A100) $23,652/mo Idle — <1% CPU for 6 days → terminate - sagemaker:asr ml.g6.48xlarge (8x L40S) $9,490/mo GPU util avg 8% → downsize to ml.g5.xlarge + ┌──────────────────────────────────────────────────────────┐ + │ GPU Fleet Summary │ + ├──────────────────────────────────────────────────────────┤ + │ Total GPU instances: 38 │ + │ Total monthly GPU spend: $127450 │ + │ Estimated monthly waste: $18200 ( 14%) │ + └──────────────────────────────────────────────────────────┘ + + CRITICAL — 3 instance(s), $15400/mo potential savings + + Instance Type Monthly Signal Recommendation + ──────────────────────────────────── ────────────────────────── ──────── ──────────────── ────────────────────────────────────────────── + gpu-cluster/ip-10-15-255-248 g6e.16xlarge (1× L40S) $ 6752 idle Node up 13 days with 0 GPU pods scheduled. + gpu-cluster/ip-10-22-250-15 g6e.16xlarge (1× L40S) $ 6752 idle Node up 1 days with 0 GPU pods scheduled. + ... ``` +## What it scans + +- **EC2** — GPU instances (g4dn, g5, g6, g6e, p4d, p4de, p5, inf2, trn1) with CloudWatch metrics +- **SageMaker** — Endpoints with GPU utilization and invocation metrics +- **EKS** — Managed GPU node groups via the AWS EKS API +- **Kubernetes** — GPU nodes and pod allocation via the Kubernetes API (Karpenter, self-managed, any CNI) + ## What it detects - **Idle GPU instances** — running but doing nothing (low CPU + near-zero network for 24+ hours) @@ -25,6 +42,7 @@ $ gpuaudit scan --profile ml-prod - **Stale instances** — non-production instances running 90+ days - **SageMaker low utilization** — endpoints with <10% GPU utilization - **SageMaker oversized** — endpoints using <30% GPU memory on multi-GPU instances +- **K8s unallocated GPUs** — nodes with GPU capacity but no pods requesting GPUs ## Install @@ -36,7 +54,7 @@ Or build from source: ```bash git clone https://github.com/gpuaudit/cli.git -cd gpuaudit +cd cli go build -o gpuaudit ./cmd/gpuaudit ``` @@ -49,20 +67,155 @@ gpuaudit scan # Specific profile and region gpuaudit scan --profile production --region us-east-1 +# Kubernetes cluster scan (uses KUBECONFIG or ~/.kube/config) +gpuaudit scan --skip-eks + +# Specific kubeconfig and context +gpuaudit scan --kubeconfig ~/.kube/config --kube-context gpu-cluster + # JSON output for automation -gpuaudit scan --format json --output report.json +gpuaudit scan --format json -o report.json -# Markdown for docs/PRs -gpuaudit scan --format markdown +# Compare two scans to see what changed +gpuaudit diff old-report.json new-report.json # Slack Block Kit payload (pipe to webhook) -gpuaudit scan --format slack --output - | curl -X POST -H 'Content-Type: application/json' -d @- $SLACK_WEBHOOK +gpuaudit scan --format slack -o - | \ + curl -X POST -H 'Content-Type: application/json' -d @- $SLACK_WEBHOOK -# Skip CloudWatch metrics (faster, less accurate) -gpuaudit scan --skip-metrics - -# Skip SageMaker scanning +# Skip specific scanners +gpuaudit scan --skip-metrics # faster, less accurate gpuaudit scan --skip-sagemaker +gpuaudit scan --skip-eks # skip AWS EKS API (use --skip-k8s for Kubernetes API) +gpuaudit scan --skip-k8s +``` + +## Comparing scans + +Save scan results as JSON, then diff them later: + +```bash +gpuaudit scan --format json -o scan-apr-08.json +# ... time passes, changes happen ... +gpuaudit scan --format json -o scan-apr-15.json +gpuaudit diff scan-apr-08.json scan-apr-15.json +``` + +``` + gpuaudit diff — 2026-04-08 12:00 UTC → 2026-04-15 12:00 UTC + + ┌──────────────────────────────────────────────────────────┐ + │ Cost Delta │ + ├──────────────────────────────────────────────────────────┤ + │ Monthly spend: $142000 → $127450 (-$14550) │ + │ Estimated waste: $31000 → $18200 (-$12800) │ + │ Instances: 45 → 38 (-9 removed, +2 added) │ + └──────────────────────────────────────────────────────────┘ + + REMOVED — 9 instance(s), -$16200/mo + ... +``` + +Matches instances by ID. Reports added, removed, and changed instances with per-field diffs (instance type, pricing model, cost, state, GPU allocation, waste severity). + +## Multi-Account Scanning + +Scan multiple AWS accounts in a single invocation using STS AssumeRole. + +### Prerequisites + +Deploy a read-only IAM role (`gpuaudit-reader`) to each target account. See [Cross-Account Role Setup](#cross-account-role-setup) below. + +### Usage + +```bash +# Scan specific accounts +gpuaudit scan --targets 111111111111,222222222222 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Cross-Account Role Setup + +#### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + policy = file("gpuaudit-policy.json") # from: gpuaudit iam-policy > gpuaudit-policy.json +} +``` + +Deploy to all accounts using Terraform workspaces or CloudFormation StackSets. + +#### CloudFormation StackSet + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Parameters: + ManagementAccountId: + Type: String +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" ``` ## IAM permissions @@ -73,7 +226,7 @@ gpuaudit is read-only. It never modifies your infrastructure. Generate the minim gpuaudit iam-policy ``` -This outputs a JSON policy requiring only `Describe*`, `List*`, `Get*` permissions for EC2, SageMaker, CloudWatch, Cost Explorer, and Pricing APIs. +For Kubernetes scanning, gpuaudit needs `get`/`list` on `nodes` and `pods` cluster-wide. ## GPU pricing reference @@ -83,8 +236,7 @@ gpuaudit pricing # Filter by GPU model gpuaudit pricing --gpu H100 -gpuaudit pricing --gpu A10G -gpuaudit pricing --gpu T4 +gpuaudit pricing --gpu L4 ``` ## Output formats @@ -92,18 +244,17 @@ gpuaudit pricing --gpu T4 | Format | Flag | Use case | |---|---|---| | Table | `--format table` (default) | Terminal viewing | -| JSON | `--format json` | Automation, CI/CD pipelines | +| JSON | `--format json` | Automation, CI/CD, `gpuaudit diff` | | Markdown | `--format markdown` | PRs, wikis, docs | | Slack | `--format slack` | Slack webhook integration | ## How it works -1. **Discovery** — Scans EC2 and SageMaker across multiple regions for GPU instance families (g4dn, g5, g6, g6e, p4d, p4de, p5, inf2, trn1) +1. **Discovery** — Scans EC2, SageMaker, EKS node groups, and Kubernetes API across multiple regions for GPU resources 2. **Metrics** — Collects 7-day CloudWatch metrics: CPU, network I/O for EC2; GPU utilization, GPU memory, invocations for SageMaker -3. **Analysis** — Applies 6 waste detection rules with severity levels (critical/warning) -4. **Recommendations** — Generates specific actions (terminate, downsize, switch pricing) with estimated monthly savings - -Regions scanned by default: us-east-1, us-east-2, us-west-2, eu-west-1, eu-west-2, eu-central-1, ap-southeast-1, ap-northeast-1, ap-south-1. +3. **K8s allocation** — Lists pods requesting `nvidia.com/gpu` resources and maps them to nodes +4. **Analysis** — Applies 7 waste detection rules with severity levels (critical/warning/info) +5. **Recommendations** — Generates specific actions (terminate, downsize, switch pricing) with estimated monthly savings ## Project structure @@ -113,21 +264,22 @@ gpuaudit/ ├── internal/ │ ├── models/ Core data types (GPUInstance, WasteSignal, Recommendation) │ ├── pricing/ Bundled GPU pricing database (40+ instance types) -│ ├── analysis/ Waste detection rules engine -│ ├── output/ Formatters (table, JSON, markdown, Slack) -│ └── providers/aws/ EC2, SageMaker, CloudWatch, scanner orchestrator +│ ├── analysis/ Waste detection rules engine (7 rules) +│ ├── diff/ Scan comparison logic +│ ├── output/ Formatters (table, JSON, markdown, Slack, diff) +│ └── providers/ +│ ├── aws/ EC2, SageMaker, EKS, CloudWatch, Cost Explorer +│ └── k8s/ Kubernetes API GPU node/pod discovery └── LICENSE Apache 2.0 ``` ## Roadmap -- [ ] AWS Cost Explorer integration (actual vs projected spend) -- [ ] EKS GPU pod discovery +- [ ] DCGM GPU metrics via Kubernetes (actual GPU utilization, not just allocation) - [ ] SageMaker training job analysis - [ ] Multi-account (AWS Organizations) scanning - [ ] GCP + Azure support - [ ] GitHub Action for scheduled scans -- [ ] Historical scan comparison (`gpuaudit diff`) ## License diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index ce8d61e..9232ad9 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -13,12 +13,16 @@ import ( "github.com/spf13/cobra" - "github.com/gpuaudit/cli/internal/models" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + "github.com/gpuaudit/cli/internal/analysis" - awsprovider "github.com/gpuaudit/cli/internal/providers/aws" - k8sprovider "github.com/gpuaudit/cli/internal/providers/k8s" + "github.com/gpuaudit/cli/internal/diff" + "github.com/gpuaudit/cli/internal/models" "github.com/gpuaudit/cli/internal/output" "github.com/gpuaudit/cli/internal/pricing" + awsprovider "github.com/gpuaudit/cli/internal/providers/aws" + k8sprovider "github.com/gpuaudit/cli/internal/providers/k8s" ) var version = "dev" @@ -49,10 +53,28 @@ var ( scanSkipCosts bool scanKubeconfig string scanKubeContext string + scanPromURL string + scanPromEndpoint string scanExcludeTags []string scanMinUptimeDays int + scanTargets []string + scanRole string + scanExternalID string + scanOrg bool + scanSkipSelf bool ) +// --- diff command --- + +var diffFormat string + +var diffCmd = &cobra.Command{ + Use: "diff ", + Short: "Compare two scan results and show what changed", + Args: cobra.ExactArgs(2), + RunE: runDiff, +} + var scanCmd = &cobra.Command{ Use: "scan", Short: "Scan AWS account for GPU waste", @@ -71,18 +93,37 @@ func init() { scanCmd.Flags().BoolVar(&scanSkipCosts, "skip-costs", false, "Skip Cost Explorer data enrichment") scanCmd.Flags().StringVar(&scanKubeconfig, "kubeconfig", "", "Path to kubeconfig file (default: ~/.kube/config)") scanCmd.Flags().StringVar(&scanKubeContext, "kube-context", "", "Kubernetes context to use (default: current context)") + scanCmd.Flags().StringVar(&scanPromURL, "prom-url", "", "Prometheus URL for GPU metrics on EC2 and K8s (e.g., https://prometheus.corp.example.com)") + scanCmd.Flags().StringVar(&scanPromEndpoint, "prom-endpoint", "", "In-cluster Prometheus service as namespace/service:port (e.g., monitoring/prometheus:9090)") scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") + scanCmd.Flags().StringSliceVar(&scanTargets, "targets", nil, "Account IDs to scan (comma-separated)") + scanCmd.Flags().StringVar(&scanRole, "role", "", "IAM role name to assume in each target") + scanCmd.Flags().StringVar(&scanExternalID, "external-id", "", "STS external ID for cross-account role assumption") + scanCmd.Flags().BoolVar(&scanOrg, "org", false, "Auto-discover all accounts from AWS Organizations") + scanCmd.Flags().BoolVar(&scanSkipSelf, "skip-self", false, "Exclude the caller's own account from the scan") + scanCmd.MarkFlagsMutuallyExclusive("targets", "org") + + diffCmd.Flags().StringVar(&diffFormat, "format", "table", "Output format: table, json") rootCmd.AddCommand(scanCmd) + rootCmd.AddCommand(diffCmd) rootCmd.AddCommand(pricingCmd) rootCmd.AddCommand(iamPolicyCmd) rootCmd.AddCommand(versionCmd) } func runScan(cmd *cobra.Command, args []string) error { + if scanPromURL != "" && scanPromEndpoint != "" { + return fmt.Errorf("--prom-url and --prom-endpoint are mutually exclusive") + } + ctx := context.Background() + if (len(scanTargets) > 0 || scanOrg) && scanRole == "" { + return fmt.Errorf("--role is required when using --targets or --org") + } + opts := awsprovider.DefaultScanOptions() opts.Profile = scanProfile opts.Regions = scanRegions @@ -92,9 +133,17 @@ func runScan(cmd *cobra.Command, args []string) error { opts.SkipCosts = scanSkipCosts opts.ExcludeTags = parseExcludeTags(scanExcludeTags) opts.MinUptimeDays = scanMinUptimeDays - + opts.Targets = scanTargets + opts.Role = scanRole + opts.ExternalID = scanExternalID + opts.PromURL = scanPromURL + opts.OrgScan = scanOrg + opts.SkipSelf = scanSkipSelf + + awsAvailable := true result, err := awsprovider.Scan(ctx, opts) if err != nil { + awsAvailable = false if scanSkipK8s { return fmt.Errorf("scan failed: %w", err) } @@ -106,13 +155,18 @@ func runScan(cmd *cobra.Command, args []string) error { // Kubernetes API scan if !scanSkipK8s { k8sOpts := k8sprovider.ScanOptions{ - Kubeconfig: scanKubeconfig, - Context: scanKubeContext, + Kubeconfig: scanKubeconfig, + Context: scanKubeContext, + PromURL: scanPromURL, + PromEndpoint: scanPromEndpoint, } k8sInstances, err := k8sprovider.Scan(ctx, k8sOpts) if err != nil { fmt.Fprintf(os.Stderr, " warning: Kubernetes scan failed: %v\n", err) } else if len(k8sInstances) > 0 { + if !scanSkipMetrics { + enrichK8sGPUMetrics(ctx, k8sInstances, k8sOpts, opts, awsAvailable) + } analysis.AnalyzeAll(k8sInstances) result.Instances = append(result.Instances, k8sInstances...) result.Summary = awsprovider.BuildSummary(result.Instances) @@ -144,6 +198,40 @@ func runScan(cmd *cobra.Command, args []string) error { return nil } +func runDiff(cmd *cobra.Command, args []string) error { + old, err := loadScanResult(args[0]) + if err != nil { + return fmt.Errorf("loading old scan: %w", err) + } + new, err := loadScanResult(args[1]) + if err != nil { + return fmt.Errorf("loading new scan: %w", err) + } + + result := diff.Compare(old, new) + + switch strings.ToLower(diffFormat) { + case "json": + return output.FormatDiffJSON(os.Stdout, result) + default: + output.FormatDiffTable(os.Stdout, result) + } + + return nil +} + +func loadScanResult(path string) (*models.ScanResult, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var result models.ScanResult + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + return &result, nil +} + // --- pricing command --- var pricingGPU string @@ -222,6 +310,7 @@ var iamPolicyCmd = &cobra.Command{ "ec2:DescribeInstances", "ec2:DescribeInstanceTypes", "ec2:DescribeRegions", + "ec2:DescribeSpotPriceHistory", }, "Resource": "*", }, @@ -273,8 +362,21 @@ var iamPolicyCmd = &cobra.Command{ }, "Resource": "*", }, + { + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader", + }, + { + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*", + }, }, } + fmt.Fprintln(os.Stdout, "// The last two statements (CrossAccount, Organizations) are only needed for --targets or --org scanning.") enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") enc.Encode(policy) @@ -300,3 +402,60 @@ func parseExcludeTags(raw []string) map[string]string { } return tags } + +func enrichK8sGPUMetrics(ctx context.Context, instances []models.GPUInstance, k8sOpts k8sprovider.ScanOptions, awsOpts awsprovider.ScanOptions, awsAvailable bool) { + // Source 1: CloudWatch Container Insights (skip if AWS creds unavailable) + if awsAvailable && len(instances) > 0 && instances[0].ClusterName != "" { + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if awsOpts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(awsOpts.Profile)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err == nil { + region := instances[0].Region + if region == "" { + region = "us-east-1" + } + cfg.Region = region + cwClient := cloudwatch.NewFromConfig(cfg) + fmt.Fprintf(os.Stderr, " Enriching K8s GPU metrics via CloudWatch Container Insights...\n") + awsprovider.EnrichK8sGPUMetrics(ctx, cwClient, instances, instances[0].ClusterName, awsprovider.DefaultMetricWindow) + } + } + + // Source 2: DCGM exporter scrape + remaining := 0 + for _, inst := range instances { + if inst.Source == models.SourceK8sNode && inst.AvgGPUUtilization == nil { + remaining++ + } + } + if remaining > 0 { + client, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + k8sprovider.EnrichDCGMMetrics(ctx, client, instances) + } + } + + // Source 3: Prometheus query + remaining = 0 + for _, inst := range instances { + if inst.Source == models.SourceK8sNode && inst.AvgGPUUtilization == nil { + remaining++ + } + } + if remaining > 0 && (k8sOpts.PromURL != "" || k8sOpts.PromEndpoint != "") { + var client k8sprovider.K8sClient + if k8sOpts.PromEndpoint != "" { + c, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + client = c + } + } + promOpts := k8sprovider.PrometheusOptions{ + URL: k8sOpts.PromURL, + Endpoint: k8sOpts.PromEndpoint, + } + k8sprovider.EnrichPrometheusMetrics(ctx, client, instances, promOpts) + } +} diff --git a/docs/specs/2026-04-18-multi-target-scanning-design.md b/docs/specs/2026-04-18-multi-target-scanning-design.md new file mode 100644 index 0000000..9c2fd34 --- /dev/null +++ b/docs/specs/2026-04-18-multi-target-scanning-design.md @@ -0,0 +1,374 @@ +# Multi-Target Scanning + +**Date:** April 18, 2026 +**Status:** Draft + +--- + +## Summary + +Add the ability to scan multiple AWS accounts (and eventually GCP projects / Azure subscriptions) in a single `gpuaudit scan` invocation. Uses STS AssumeRole to obtain credentials for each target, scans them all in parallel, and merges results into a single flat output with per-target sub-summaries. + +Zero breaking changes — existing single-account behavior is the default. + +--- + +## CLI Interface + +### New flags on `gpuaudit scan` + +| Flag | Type | Description | +|------|------|-------------| +| `--targets` | `[]string` | Comma-separated list of account IDs to scan | +| `--role` | `string` | IAM role name to assume in each target (required with `--targets` or `--org`) | +| `--org` | `bool` | Auto-discover all accounts from AWS Organizations | +| `--external-id` | `string` | STS external ID for cross-account role assumption (optional) | +| `--skip-self` | `bool` | Exclude the caller's own account from the scan | + +### Constraints + +- `--targets` and `--org` are mutually exclusive. +- `--role` is required when `--targets` or `--org` is set. +- No `--targets` or `--org` means scan the caller's account only (current behavior, no changes). +- The caller's own account is included by default unless `--skip-self` is set. + +### Examples + +```bash +# Current behavior (unchanged) +gpuaudit scan + +# Scan 3 specific accounts +gpuaudit scan --targets 111111111111,222222222222,333333333333 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Org scan, exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID for extra security +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Flag naming rationale + +Flags use provider-neutral names (`--targets` not `--accounts`, `--role` not `--assume-role`) so that when GCP and Azure support lands, the same flags work: targets are project IDs or subscription IDs, role is a service account or principal name. No renaming, no backward-compatibility concerns. + +--- + +## Architecture + +### New file: `internal/providers/aws/multiaccount.go` + +Contains: + +- `Target` struct: `{AccountID string, Config aws.Config}` +- `ResolveTargets(ctx, cfg, opts) ([]Target, []TargetError)`: + - No `--targets`/`--org`: returns caller's account with existing config. + - `--targets`: calls `sts:AssumeRole` for each account ID, returns credentials. Failed assumptions are collected as `TargetError`, not fatal. + - `--org`: calls `organizations:ListAccounts`, filters to active accounts, then assumes role in each. + - Caller's own account is included (with original config, no AssumeRole needed) unless `--skip-self`. +- `TargetError` struct: `{AccountID string, Err error}` + +### Changes to `ScanOptions` + +```go +type ScanOptions struct { + // ... existing fields ... + Targets []string // account IDs to scan + Role string // role name to assume + ExternalID string // STS external ID + OrgScan bool // auto-discover from Organizations + SkipSelf bool // exclude caller's account +} +``` + +### Changes to `Scan()` + +Current flow: +``` +load config → get account ID → scan regions in parallel → merge → analyze → output +``` + +New flow: +``` +load config → ResolveTargets() → for each target (parallel): + for each region (parallel): + scanRegion(ctx, target.Config, target.AccountID, region, opts) +→ merge all instances into flat list +→ filter, analyze, enrich (unchanged) +→ BuildSummary (global + per-target sub-summaries) +→ output +``` + +All targets are scanned in parallel. Within each target, all regions are scanned in parallel (same as today). + +### Error handling: best-effort + +- `ResolveTargets` returns both successful targets and a list of `TargetError`s. +- Scan continues for all resolvable targets. +- Per-region errors within a target are handled as today (warn and continue). +- Target-level errors are surfaced in the output (see Output section). +- Exit code: 0 = success, non-zero if all targets failed. + +### Unchanged components + +- Analysis rules — operate per-instance, already provider-agnostic. +- Diff command — matches by `instance_id`, globally unique across accounts. +- `GPUInstance` model — already has `AccountID` field. +- Pricing database — account-independent. + +--- + +## Model Changes + +### `ScanResult` + +```go +type ScanResult struct { + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` // caller's account (kept for backward compat) + Targets []string `json:"targets,omitempty"` // NEW: all scanned target IDs + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` // NEW: per-target breakdown + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` // NEW: failed targets +} + +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} +``` + +New fields use `omitempty` — single-account scans produce identical JSON to today. + +--- + +## Output Changes + +### Table + +When multiple targets are present, two additions: + +1. **"By Target" summary table** after the global summary: + +``` + By Target + ┌──────────────┬───────────┬───────────┬───────────┬───────┐ + │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │ + ├──────────────┼───────────┼───────────┼───────────┼───────┤ + │ 111111111111 │ 31 │ $142,000 │ $38,000 │ 27% │ + │ 222222222222 │ 12 │ $35,400 │ $4,200 │ 12% │ + └──────────────┴───────────┴───────────┴───────────┴───────┘ +``` + +2. **"Target" column** in instance detail tables. + +Single-target scans look identical to today. + +### JSON + +New `targets`, `target_summaries`, and `target_errors` fields as shown in the model above. Omitted when empty. + +### Markdown + +Per-target summary section added when multiple targets present. + +### Slack + +Per-target summary block added when multiple targets present. + +### Errors + +When targets fail, a warnings section appears in all formats: + +``` + Warnings + ✗ 444444444444 — AssumeRole failed: AccessDenied + ✗ 555555555555 — role "gpuaudit-reader" not found in account +``` + +--- + +## IAM Policy Updates + +### `gpuaudit iam-policy` additions + +Add two new statements to the generated policy: + +```json +{ + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader" +}, +{ + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*" +} +``` + +These are printed as a separate "Multi-Account Permissions" section in the `iam-policy` output, with a comment explaining they're only needed for `--targets` or `--org` scanning. Always included in the output — users can ignore them if they only scan a single account. + +--- + +## Cross-Account Role Setup + +### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +variable "external_id" { + description = "External ID for AssumeRole (optional but recommended)" + type = string + default = "" +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + Condition = var.external_id != "" ? { + StringEquals = { "sts:ExternalId" = var.external_id } + } : {} + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Sid = "EC2ReadOnly" + Effect = "Allow" + Action = ["ec2:DescribeInstances", "ec2:DescribeInstanceTypes", "ec2:DescribeRegions"] + Resource = "*" + }, + { + Sid = "SageMakerReadOnly" + Effect = "Allow" + Action = ["sagemaker:ListEndpoints", "sagemaker:DescribeEndpoint", "sagemaker:DescribeEndpointConfig"] + Resource = "*" + }, + { + Sid = "EKSReadOnly" + Effect = "Allow" + Action = ["eks:ListClusters", "eks:ListNodegroups", "eks:DescribeNodegroup"] + Resource = "*" + }, + { + Sid = "CloudWatchReadOnly" + Effect = "Allow" + Action = ["cloudwatch:GetMetricData", "cloudwatch:GetMetricStatistics", "cloudwatch:ListMetrics"] + Resource = "*" + }, + { + Sid = "CostExplorerReadOnly" + Effect = "Allow" + Action = ["ce:GetCostAndUsage", "ce:GetReservationUtilization", "ce:GetSavingsPlansUtilization"] + Resource = "*" + }, + { + Sid = "PricingReadOnly" + Effect = "Allow" + Action = ["pricing:GetProducts"] + Resource = "*" + } + ] + }) +} +``` + +### CloudFormation (for StackSet deployment across all accounts) + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Description: gpuaudit cross-account reader role + +Parameters: + ManagementAccountId: + Type: String + Description: Account ID where gpuaudit runs + ExternalId: + Type: String + Description: External ID for AssumeRole + Default: "" + +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` + +Recommended deployment: use CloudFormation StackSets to deploy the role to all member accounts from the management account. + +--- + +## Testing + +- **Unit tests for `ResolveTargets`**: mock STS and Organizations clients, verify correct target list for each mode (explicit, org, skip-self, mixed failures). +- **Unit tests for `BuildSummary`**: verify per-target summaries compute correctly with instances from multiple accounts. +- **Unit tests for output formatters**: verify "By Target" table and Target column appear only when multiple targets present. +- **Integration test pattern**: test the full `Scan` flow with mocked AWS clients for 2-3 accounts, verify merged output. diff --git a/docs/specs/2026-04-19-k8s-gpu-metrics-design.md b/docs/specs/2026-04-19-k8s-gpu-metrics-design.md new file mode 100644 index 0000000..ed8c3d7 --- /dev/null +++ b/docs/specs/2026-04-19-k8s-gpu-metrics-design.md @@ -0,0 +1,149 @@ +# K8s GPU Metrics Collection + +## Goal + +Collect GPU utilization metrics for Kubernetes GPU nodes discovered by gpuaudit, using a per-node fallback chain of three sources: CloudWatch Container Insights, DCGM exporter scrape, and Prometheus query. Enable utilization-based waste detection for K8s GPU nodes (currently limited to allocation-based detection only). + +## Architecture + +Three metrics sources, tried in priority order **per node** (stop at the first source that returns data for a given node): + +1. **CloudWatch Container Insights** — AWS API call, no in-cluster access needed beyond what we already have. +2. **DCGM exporter scrape** — probe port 9400 on dcgm-exporter pods via K8s API proxy. +3. **Prometheus query** — query a user-configured Prometheus endpoint for historical GPU metrics. + +All three populate the same existing fields: `GPUInstance.AvgGPUUtilization` and `GPUInstance.AvgGPUMemUtilization`. + +## Data Flow + +``` +1. AWS scan → ScanResult (EC2, SageMaker, EKS) +2. K8s scan → []GPUInstance (nodes + allocation) +3. Enrich K8s GPU metrics (fallback chain): + a. CloudWatch Container Insights (if AWS creds available, !skipMetrics) + b. DCGM scrape via K8s API proxy (for nodes still missing metrics) + c. Prometheus query (for remaining nodes, if --prom-url or --prom-endpoint set) +4. AnalyzeAll on K8s instances +5. Merge into result +``` + +Steps 3a through 3c each skip nodes that already have `AvgGPUUtilization` populated by a prior step. + +## Source 1: CloudWatch Container Insights + +Requires the CloudWatch Observability EKS add-on to be installed in the cluster. If not installed, the query returns empty (not an error) and we fall through. + +**Metrics queried:** +- `node_gpu_utilization` (Average) — maps to `AvgGPUUtilization` +- `node_gpu_memory_utilization` (Average) — maps to `AvgGPUMemUtilization` + +**Namespace:** `ContainerInsights` + +**Dimensions:** `ClusterName` + `InstanceId` + +**Implementation:** New function `EnrichK8sGPUMetrics(ctx, client CloudWatchClient, instances []GPUInstance, clusterName string, window MetricWindow)` in `internal/providers/aws/cloudwatch.go`, following the same pattern as `EnrichEC2Metrics` and `EnrichSageMakerMetrics`. + +**Prerequisites per node:** The node must have an EC2 instance ID (extracted from `providerID`). Non-AWS nodes are skipped for this source. + +**Wiring:** Called from `main.go` after the K8s scan returns instances, passing the CloudWatch client from the AWS config. Only called when AWS credentials are available and `!skipMetrics`. + +## Source 2: DCGM Exporter Scrape + +Auto-detected, no user configuration needed. + +**Discovery:** List pods across all namespaces matching labels `app=nvidia-dcgm-exporter` or `app.kubernetes.io/name=dcgm-exporter`. If no pods found, log `"DCGM exporter not detected, skipping"` and fall through to Prometheus. + +**Scraping:** For each GPU node still missing metrics, find the dcgm-exporter pod on that node (match by `pod.Spec.NodeName`), then scrape `/metrics` on port 9400 via the K8s API proxy (`ProxyGet`). + +**Metrics parsed:** +- `DCGM_FI_DEV_GPU_UTIL` — maps to `AvgGPUUtilization` +- `DCGM_FI_DEV_MEM_COPY_UTIL` — maps to `AvgGPUMemUtilization` + +These are point-in-time values, not historical averages. The analysis rule's confidence (0.85 vs 0.9) accounts for this lower fidelity. + +**Prometheus text format parsing:** Use `prometheus/common/expfmt` to parse the scrape response. + +**K8s client extension:** Add `ProxyGet(ctx, namespace, podName, port, path string) ([]byte, error)` to the `K8sClient` interface. Wraps `clientset.CoreV1().Pods(ns).ProxyGet()`. + +**Stderr output:** +``` + Probing DCGM exporter on GPU nodes... + DCGM: got GPU metrics for 3 of 5 remaining nodes +``` + +## Source 3: Prometheus Query + +Only attempted when `--prom-url` or `--prom-endpoint` is provided. No auto-discovery. + +**CLI flags:** +- `--prom-url` — full URL to a Prometheus-compatible API (e.g., `https://prometheus.corp.example.com`, AMP endpoint, Grafana Cloud). Hit directly via HTTP. +- `--prom-endpoint` — in-cluster service as `namespace/service:port` (e.g., `monitoring/prometheus:9090`). Proxied through the K8s API server. + +These flags are mutually exclusive. Error if both are set. + +**Query:** Batch all remaining nodes into one PromQL query: +``` +avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"node1|node2|..."}[7d]) +``` +And similarly for `DCGM_FI_DEV_MEM_COPY_UTIL`. + +**API:** HTTP GET to `/api/v1/query`, parse the standard Prometheus JSON response. No client library — plain `net/http` for direct URLs, K8s API proxy for in-cluster endpoints. + +**Stderr output:** +``` + Querying Prometheus at monitoring/prometheus:9090... + Prometheus: got GPU metrics for 2 of 3 remaining nodes +``` + +## Analysis Rule + +New rule `ruleK8sLowGPUUtil` in `internal/analysis/rules.go`: + +- **Source filter:** `SourceK8sNode` only +- **Guard:** `AvgGPUUtilization != nil` (skip nodes where no metrics were collected) +- **Threshold:** average GPU utilization < 10% +- **Signal type:** `low_utilization` +- **Severity:** Critical +- **Confidence:** 0.85 +- **Recommendation:** "GPU utilization averaging X%. Consider bin-packing more workloads, downsizing, or removing from the node pool." +- **Savings estimate:** `MonthlyCost * 0.8` (same rough estimate as SageMaker equivalent) + +**Interplay with `ruleK8sUnallocatedGPU`:** Both rules can fire on the same node. Unallocated detects zero pod scheduling (allocation-based). Low-util detects pods that are scheduled but barely using the GPU (utilization-based). Different problems, different fixes. + +## File Changes + +- **Modify:** `internal/providers/aws/cloudwatch.go` — add `EnrichK8sGPUMetrics()` +- **Create:** `internal/providers/k8s/metrics.go` — DCGM scraping, Prometheus querying, fallback orchestration +- **Create:** `internal/providers/k8s/metrics_test.go` — tests for DCGM and Prometheus paths +- **Modify:** `internal/providers/k8s/discover.go` — extend `K8sClient` interface with `ProxyGet` (DCGM pod discovery uses existing `ListPods` with label selector) +- **Modify:** `internal/providers/k8s/scanner.go` — wire metrics enrichment into the K8s scan, accept new options +- **Modify:** `internal/analysis/rules.go` — add `ruleK8sLowGPUUtil` +- **Modify:** `internal/analysis/rules_test.go` — tests for the new rule +- **Modify:** `cmd/gpuaudit/main.go` — add `--prom-url` and `--prom-endpoint` flags, wire CloudWatch enrichment for K8s instances + +## Error Handling + +- **CloudWatch returns empty:** Not an error. Container Insights add-on probably not installed. Fall through to DCGM. +- **No EC2 instance ID on a node:** Skip CW enrichment for that node (non-AWS or providerID not set). +- **No dcgm-exporter pods found:** Log on stderr, fall through to Prometheus. +- **DCGM scrape fails for a node:** Warn on stderr, continue with other nodes. Don't fail the scan. +- **Prometheus endpoint unreachable:** Warn on stderr, continue without metrics for remaining nodes. +- **Both `--prom-url` and `--prom-endpoint` set:** Return an error at flag validation time. + +## New Dependencies + +- `prometheus/common/expfmt` — for parsing Prometheus text format from DCGM exporter scrapes. Small, well-established library. + +## IAM Policy + +No new IAM permissions required. `EnrichK8sGPUMetrics` uses the existing `cloudwatch:GetMetricData` permission already in the IAM policy output. + +## RBAC + +The K8s API proxy calls (`ProxyGet` to pods) require the `pods/proxy` resource permission. For DCGM scraping: +``` +- apiGroups: [""] + resources: ["pods/proxy"] + verbs: ["get"] +``` +This should be documented and added to any RBAC guide. diff --git a/docs/superpowers/plans/2026-04-18-multi-target-scanning.md b/docs/superpowers/plans/2026-04-18-multi-target-scanning.md new file mode 100644 index 0000000..ebde5e3 --- /dev/null +++ b/docs/superpowers/plans/2026-04-18-multi-target-scanning.md @@ -0,0 +1,1537 @@ +# Multi-Target Scanning Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Enable gpuaudit to scan multiple AWS accounts in a single invocation via STS AssumeRole, with optional Organizations auto-discovery. + +**Architecture:** New `multiaccount.go` handles target resolution (explicit list or Organizations API) and credential assumption. The existing `Scan()` function is refactored to accept multiple targets and scan them all in parallel. Output formatters gain per-target summary sections when multiple targets are present. All new fields use `omitempty` so single-account scans produce identical output to today. + +**Tech Stack:** Go 1.24, AWS SDK v2 (STS, Organizations), cobra CLI, standard library testing + +--- + +## File Map + +| File | Action | Responsibility | +|------|--------|---------------| +| `internal/providers/aws/multiaccount.go` | Create | `Target` struct, `ResolveTargets()`, `TargetError` type, STS AssumeRole + Organizations list | +| `internal/providers/aws/multiaccount_test.go` | Create | Tests for `ResolveTargets()` with mock STS/Org clients | +| `internal/models/models.go` | Modify | Add `TargetSummary`, `TargetErrorInfo` types; add new fields to `ScanResult` | +| `internal/providers/aws/scanner.go` | Modify | Refactor `Scan()` to use `ResolveTargets()` and scan all targets in parallel | +| `cmd/gpuaudit/main.go` | Modify | Add `--targets`, `--role`, `--org`, `--external-id`, `--skip-self` flags; wire into `ScanOptions` | +| `internal/providers/aws/summary.go` | Create | Extract `BuildSummary` from scanner.go, add `BuildTargetSummaries()` | +| `internal/providers/aws/summary_test.go` | Create | Tests for per-target summary computation | +| `internal/output/table.go` | Modify | Add "By Target" summary table and "Target" column when multiple targets | +| `internal/output/markdown.go` | Modify | Add per-target summary section when multiple targets | +| `internal/output/slack.go` | Modify | Add per-target summary block when multiple targets | +| `go.mod` | Modify | Add `organizations` SDK dependency | + +--- + +### Task 1: Add model types for multi-target results + +**Files:** +- Modify: `internal/models/models.go` + +- [ ] **Step 1: Add `TargetSummary` and `TargetErrorInfo` types and new `ScanResult` fields** + +Add to `internal/models/models.go` after the `ScanSummary` struct: + +```go +// TargetSummary provides per-target aggregate statistics. +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +// TargetErrorInfo describes a target that failed to scan. +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} +``` + +Add three new fields to `ScanResult`: + +```go +type ScanResult struct { + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` + Targets []string `json:"targets,omitempty"` + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` +} +``` + +- [ ] **Step 2: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success (new types are additive, omitempty means no output change) + +- [ ] **Step 3: Run existing tests to confirm nothing broke** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 4: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/models/models.go +git commit -m "Add TargetSummary and TargetErrorInfo model types for multi-target scanning" +``` + +--- + +### Task 2: Extract `BuildSummary` and add `BuildTargetSummaries` + +**Files:** +- Create: `internal/providers/aws/summary.go` +- Create: `internal/providers/aws/summary_test.go` +- Modify: `internal/providers/aws/scanner.go` (remove `BuildSummary` — it moves to summary.go) + +- [ ] **Step 1: Write the failing test for `BuildTargetSummaries`** + +Create `internal/providers/aws/summary_test.go`: + +```go +package aws + +import ( + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestBuildTargetSummaries_MultipleAccounts(t *testing.T) { + instances := []models.GPUInstance{ + { + AccountID: "111111111111", + MonthlyCost: 1000, + EstimatedSavings: 500, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityCritical}}, + }, + { + AccountID: "111111111111", + MonthlyCost: 2000, + EstimatedSavings: 0, + }, + { + AccountID: "222222222222", + MonthlyCost: 3000, + EstimatedSavings: 1000, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityWarning}}, + }, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 2 { + t.Fatalf("expected 2 target summaries, got %d", len(summaries)) + } + + // Find each target + var s1, s2 *models.TargetSummary + for i := range summaries { + switch summaries[i].Target { + case "111111111111": + s1 = &summaries[i] + case "222222222222": + s2 = &summaries[i] + } + } + + if s1 == nil || s2 == nil { + t.Fatal("missing target summaries") + } + + if s1.TotalInstances != 2 { + t.Errorf("acct1: expected 2 instances, got %d", s1.TotalInstances) + } + if s1.TotalMonthlyCost != 3000 { + t.Errorf("acct1: expected $3000 cost, got $%.0f", s1.TotalMonthlyCost) + } + if s1.TotalEstimatedWaste != 500 { + t.Errorf("acct1: expected $500 waste, got $%.0f", s1.TotalEstimatedWaste) + } + if s1.CriticalCount != 1 { + t.Errorf("acct1: expected 1 critical, got %d", s1.CriticalCount) + } + + if s2.TotalInstances != 1 { + t.Errorf("acct2: expected 1 instance, got %d", s2.TotalInstances) + } + if s2.WarningCount != 1 { + t.Errorf("acct2: expected 1 warning, got %d", s2.WarningCount) + } +} + +func TestBuildTargetSummaries_SingleAccount(t *testing.T) { + instances := []models.GPUInstance{ + {AccountID: "111111111111", MonthlyCost: 1000}, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } +} + +func TestBuildTargetSummaries_Empty(t *testing.T) { + summaries := BuildTargetSummaries(nil) + + if len(summaries) != 0 { + t.Fatalf("expected 0 summaries for nil input, got %d", len(summaries)) + } +} +``` + +- [ ] **Step 2: Run the test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestBuildTargetSummaries -v` +Expected: FAIL (function not defined) + +- [ ] **Step 3: Create `summary.go` with `BuildSummary` (moved from scanner.go) and `BuildTargetSummaries`** + +Create `internal/providers/aws/summary.go`: + +```go +package aws + +import ( + "sort" + + "github.com/gpuaudit/cli/internal/models" +) + +// BuildSummary computes aggregate statistics for a set of GPU instances. +func BuildSummary(instances []models.GPUInstance) models.ScanSummary { + s := models.ScanSummary{ + TotalInstances: len(instances), + } + + for _, inst := range instances { + s.TotalMonthlyCost += inst.MonthlyCost + s.TotalEstimatedWaste += inst.EstimatedSavings + + maxSeverity := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSeverity = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { + maxSeverity = models.SeverityWarning + } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { + maxSeverity = models.SeverityInfo + } + } + + switch maxSeverity { + case models.SeverityCritical: + s.CriticalCount++ + case models.SeverityWarning: + s.WarningCount++ + case models.SeverityInfo: + s.InfoCount++ + default: + s.HealthyCount++ + } + } + + if s.TotalMonthlyCost > 0 { + s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 + } + + return s +} + +// BuildTargetSummaries computes per-target breakdowns from a flat instance list. +func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary { + if len(instances) == 0 { + return nil + } + + byTarget := make(map[string][]models.GPUInstance) + for _, inst := range instances { + byTarget[inst.AccountID] = append(byTarget[inst.AccountID], inst) + } + + summaries := make([]models.TargetSummary, 0, len(byTarget)) + for target, insts := range byTarget { + ts := models.TargetSummary{ + Target: target, + TotalInstances: len(insts), + } + for _, inst := range insts { + ts.TotalMonthlyCost += inst.MonthlyCost + ts.TotalEstimatedWaste += inst.EstimatedSavings + + maxSev := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSev = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { + maxSev = models.SeverityWarning + } + } + switch maxSev { + case models.SeverityCritical: + ts.CriticalCount++ + case models.SeverityWarning: + ts.WarningCount++ + } + } + if ts.TotalMonthlyCost > 0 { + ts.WastePercent = (ts.TotalEstimatedWaste / ts.TotalMonthlyCost) * 100 + } + summaries = append(summaries, ts) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].TotalMonthlyCost > summaries[j].TotalMonthlyCost + }) + + return summaries +} +``` + +- [ ] **Step 4: Remove `BuildSummary` and `matchesExcludeTags` from `scanner.go`** + +In `internal/providers/aws/scanner.go`, delete the `BuildSummary` function (lines 235-272) and keep `matchesExcludeTags`. The `BuildSummary` is now in `summary.go`. No import changes needed since both files are in the same package. + +- [ ] **Step 5: Run tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: all pass, including the new `TestBuildTargetSummaries_*` tests + +- [ ] **Step 6: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/summary.go internal/providers/aws/summary_test.go internal/providers/aws/scanner.go +git commit -m "Extract BuildSummary to summary.go and add BuildTargetSummaries" +``` + +--- + +### Task 3: Implement `ResolveTargets` with STS AssumeRole + +**Files:** +- Create: `internal/providers/aws/multiaccount.go` +- Create: `internal/providers/aws/multiaccount_test.go` +- Modify: `go.mod` (add organizations dependency) + +- [ ] **Step 1: Add the Organizations SDK dependency** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go get github.com/aws/aws-sdk-go-v2/service/organizations` + +- [ ] **Step 2: Write failing tests for `ResolveTargets`** + +Create `internal/providers/aws/multiaccount_test.go`: + +```go +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +type mockSTSClient struct { + identity *sts.GetCallerIdentityOutput + roles map[string]*sts.AssumeRoleOutput // keyed by account ID + failAccts map[string]error +} + +func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return m.identity, nil +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Extract account ID from ARN: arn:aws:iam:::role/ + arn := aws.ToString(params.RoleArn) + // Simple extraction: find the account ID between the 4th and 5th colons + acct := "" + colons := 0 + for i, c := range arn { + if c == ':' { + colons++ + if colons == 4 { + rest := arn[i+1:] + for j, r := range rest { + if r == ':' { + acct = rest[:j] + break + } + } + break + } + } + } + if err, ok := m.failAccts[acct]; ok { + return nil, err + } + if out, ok := m.roles[acct]; ok { + return out, nil + } + return nil, fmt.Errorf("no role for account %s", acct) +} + +type mockOrgClient struct { + accounts []orgtypes.Account +} + +func (m *mockOrgClient) ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) { + return &organizations.ListAccountsOutput{Accounts: m.accounts}, nil +} + +func TestResolveTargets_NoTargets_ReturnsSelf(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, ScanOptions{}) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target, got %d", len(targets)) + } + if targets[0].AccountID != "999999999999" { + t.Errorf("expected account 999999999999, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_ExplicitTargets(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + "222222222222": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK2"), SecretAccessKey: aws.String("SK2"), SessionToken: aws.String("ST2"), + }}, + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "gpuaudit-reader", + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + // 2 explicit + self = 3 + if len(targets) != 3 { + t.Fatalf("expected 3 targets (2 explicit + self), got %d", len(targets)) + } +} + +func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111"}, + Role: "gpuaudit-reader", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (skip self), got %d", len(targets)) + } + if targets[0].AccountID != "111111111111" { + t.Errorf("expected 111111111111, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_PartialFailure(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + failAccts: map[string]error{ + "222222222222": fmt.Errorf("AccessDenied"), + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "gpuaudit-reader", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if errs[0].AccountID != "222222222222" { + t.Errorf("expected error for 222222222222, got %s", errs[0].AccountID) + } + if len(targets) != 1 { + t.Fatalf("expected 1 successful target, got %d", len(targets)) + } +} + +func TestResolveTargets_OrgDiscovery(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + } + + orgClient := &mockOrgClient{ + accounts: []orgtypes.Account{ + {Id: aws.String("999999999999"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("111111111111"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("333333333333"), Status: orgtypes.AccountStatusSuspended}, + }, + } + + opts := ScanOptions{ + OrgScan: true, + Role: "gpuaudit-reader", + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, orgClient, opts) + + // 999 (self, no assume) + 111 (assumed) = 2 targets; 333 is suspended so skipped + // Note: 999 is self so not assumed; 111 is assumed successfully + if len(targets) != 2 { + t.Fatalf("expected 2 targets (self + 1 active non-self), got %d", len(targets)) + } + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } +} +``` + +- [ ] **Step 3: Run test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestResolveTargets -v` +Expected: FAIL (function and types not defined) + +- [ ] **Step 4: Implement `multiaccount.go`** + +Create `internal/providers/aws/multiaccount.go`: + +```go +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) + +// Target represents a resolved scan target with its credentials. +type Target struct { + AccountID string + Config aws.Config +} + +// TargetError records a target that failed credential resolution. +type TargetError struct { + AccountID string + Err error +} + +// STSClient is the subset of the STS API we need. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// OrgClient is the subset of the Organizations API we need. +type OrgClient interface { + ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) +} + +// ResolveTargets determines which accounts to scan and obtains credentials for each. +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + } + selfAccount := aws.ToString(identity.Account) + + // No multi-target flags: return self only + if len(opts.Targets) == 0 && !opts.OrgScan { + return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + } + + // Determine account IDs to scan + var accountIDs []string + if opts.OrgScan { + discovered, err := discoverOrgAccounts(ctx, orgClient) + if err != nil { + return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", err)}} + } + accountIDs = discovered + } else { + accountIDs = opts.Targets + } + + var targets []Target + var targetErrors []TargetError + + // Include self unless skipped + if !opts.SkipSelf { + targets = append(targets, Target{AccountID: selfAccount, Config: baseCfg}) + } + + // Assume role in each non-self account + for _, acctID := range accountIDs { + if acctID == selfAccount { + continue // already included as self (or skipped) + } + + cfg, err := assumeRole(ctx, baseCfg, stsClient, acctID, opts.Role, opts.ExternalID) + if err != nil { + targetErrors = append(targetErrors, TargetError{AccountID: acctID, Err: err}) + continue + } + targets = append(targets, Target{AccountID: acctID, Config: cfg}) + } + + return targets, targetErrors +} + +func discoverOrgAccounts(ctx context.Context, client OrgClient) ([]string, error) { + var accounts []string + var nextToken *string + + for { + out, err := client.ListAccounts(ctx, &organizations.ListAccountsInput{ + NextToken: nextToken, + }) + if err != nil { + return nil, err + } + for _, acct := range out.Accounts { + if acct.Status == orgtypes.AccountStatusActive { + accounts = append(accounts, aws.ToString(acct.Id)) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + return accounts, nil +} + +func assumeRole(ctx context.Context, baseCfg aws.Config, stsClient STSClient, accountID, roleName, externalID string) (aws.Config, error) { + roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleName) + + input := &sts.AssumeRoleInput{ + RoleArn: &roleArn, + RoleSessionName: aws.String("gpuaudit"), + } + if externalID != "" { + input.ExternalId = &externalID + } + + out, err := stsClient.AssumeRole(ctx, input) + if err != nil { + return aws.Config{}, fmt.Errorf("AssumeRole %s: %w", roleArn, err) + } + + creds := out.Credentials + cfg := baseCfg.Copy() + cfg.Credentials = credentials.NewStaticCredentialsProvider( + aws.ToString(creds.AccessKeyId), + aws.ToString(creds.SecretAccessKey), + aws.ToString(creds.SessionToken), + ) + + return cfg, nil +} +``` + +- [ ] **Step 5: Fix the test import — add `ststypes` import** + +The tests reference `ststypes.Credentials`. Add this import to `multiaccount_test.go`: + +```go +import ( + // ... existing imports ... + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) +``` + +- [ ] **Step 6: Run tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestResolveTargets -v` +Expected: all pass + +- [ ] **Step 7: Run full test suite** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 8: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/multiaccount.go internal/providers/aws/multiaccount_test.go go.mod go.sum +git commit -m "Add ResolveTargets with STS AssumeRole and Organizations discovery" +``` + +--- + +### Task 4: Refactor `Scan()` for multi-target parallel scanning + +**Files:** +- Modify: `internal/providers/aws/scanner.go` + +- [ ] **Step 1: Add multi-target fields to `ScanOptions`** + +In `internal/providers/aws/scanner.go`, add to the `ScanOptions` struct: + +```go +type ScanOptions struct { + Profile string + Regions []string + MetricWindow MetricWindow + SkipMetrics bool + SkipSageMaker bool + SkipEKS bool + SkipCosts bool + ExcludeTags map[string]string + MinUptimeDays int + + // Multi-target options + Targets []string + Role string + ExternalID string + OrgScan bool + SkipSelf bool +} +``` + +- [ ] **Step 2: Refactor `Scan()` to use `ResolveTargets` and scan all targets in parallel** + +Replace the `Scan` function in `scanner.go` with: + +```go +func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { + start := time.Now() + + // Load AWS config + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if opts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(opts.Profile)) + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err != nil { + return nil, fmt.Errorf("loading AWS config: %w", err) + } + + // Resolve targets + stsClient := sts.NewFromConfig(cfg) + var orgClient OrgClient + if opts.OrgScan { + orgClient = organizations.NewFromConfig(cfg) + } + + targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + if len(targets) == 0 { + return nil, fmt.Errorf("no scannable targets resolved") + } + + // Report target errors + for _, te := range targetErrors { + fmt.Fprintf(os.Stderr, " warning: target %s: %v\n", te.AccountID, te.Err) + } + + fmt.Fprintf(os.Stderr, " Scanning %d target(s)...\n", len(targets)) + + // Determine regions to scan + regions := opts.Regions + if len(regions) == 0 { + regions, err = getGPURegions(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("listing regions: %w", err) + } + } + + fmt.Fprintf(os.Stderr, " Scanning %d regions per target for GPU instances...\n", len(regions)) + + // Scan all targets in parallel + type targetResult struct { + accountID string + instances []models.GPUInstance + regions []string + } + + resultsCh := make(chan targetResult, len(targets)) + var wg sync.WaitGroup + + for _, target := range targets { + wg.Add(1) + go func(t Target) { + defer wg.Done() + instances, scannedRegions := scanTarget(ctx, t, regions, opts) + resultsCh <- targetResult{ + accountID: t.AccountID, + instances: instances, + regions: scannedRegions, + } + }(target) + } + + go func() { + wg.Wait() + close(resultsCh) + }() + + var allInstances []models.GPUInstance + regionSet := make(map[string]bool) + callerAccount := "" + if len(targets) > 0 { + callerAccount = targets[0].AccountID + } + + for res := range resultsCh { + allInstances = append(allInstances, res.instances...) + for _, r := range res.regions { + regionSet[r] = true + } + } + + var scannedRegions []string + for r := range regionSet { + scannedRegions = append(scannedRegions, r) + } + + // Filter by excluded tags + if len(opts.ExcludeTags) > 0 { + filtered := allInstances[:0] + excluded := 0 + for _, inst := range allInstances { + if matchesExcludeTags(inst.Tags, opts.ExcludeTags) { + excluded++ + continue + } + filtered = append(filtered, inst) + } + allInstances = filtered + if excluded > 0 { + fmt.Fprintf(os.Stderr, " Excluded %d instance(s) by tag filter.\n", excluded) + } + } + + // Run analysis + analysis.AnalyzeAll(allInstances) + + // Suppress signals below minimum uptime threshold + if opts.MinUptimeDays > 0 { + minHours := float64(opts.MinUptimeDays) * 24 + for i := range allInstances { + inst := &allInstances[i] + if inst.UptimeHours >= minHours { + continue + } + inst.WasteSignals = nil + inst.Recommendations = nil + inst.EstimatedSavings = 0 + } + } + + // Build summaries + summary := BuildSummary(allInstances) + + result := &models.ScanResult{ + Timestamp: start, + AccountID: callerAccount, + Regions: scannedRegions, + ScanDuration: time.Since(start).Round(time.Millisecond).String(), + Instances: allInstances, + Summary: summary, + } + + // Add multi-target metadata + if len(targets) > 1 || len(targetErrors) > 0 { + for _, t := range targets { + result.Targets = append(result.Targets, t.AccountID) + } + result.TargetSummaries = BuildTargetSummaries(allInstances) + for _, te := range targetErrors { + result.TargetErrors = append(result.TargetErrors, models.TargetErrorInfo{ + Target: te.AccountID, + Error: te.Err.Error(), + }) + } + } + + return result, nil +} + +// scanTarget scans all regions for a single target account. +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string) { + type regionResult struct { + region string + instances []models.GPUInstance + err error + } + + results := make(chan regionResult, len(regions)) + var wg sync.WaitGroup + + for _, region := range regions { + wg.Add(1) + go func(r string) { + defer wg.Done() + instances, err := scanRegion(ctx, target.Config, target.AccountID, r, opts) + results <- regionResult{region: r, instances: instances, err: err} + }(region) + } + + go func() { + wg.Wait() + close(results) + }() + + var allInstances []models.GPUInstance + var scannedRegions []string + + for res := range results { + if res.err != nil { + fmt.Fprintf(os.Stderr, " warning: %s/%s: %v\n", target.AccountID, res.region, res.err) + continue + } + if len(res.instances) > 0 { + allInstances = append(allInstances, res.instances...) + scannedRegions = append(scannedRegions, res.region) + } + } + + // Enrich with Cost Explorer data (per-target, since CE is account-scoped) + if !opts.SkipCosts && len(allInstances) > 0 { + ceClient := costexplorer.NewFromConfig(target.Config) + if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { + fmt.Fprintf(os.Stderr, " warning: %s cost enrichment: %v\n", target.AccountID, err) + } + } + + return allInstances, scannedRegions +} +``` + +- [ ] **Step 3: Add the organizations import to scanner.go** + +Add to the import block: + +```go +"github.com/aws/aws-sdk-go-v2/service/organizations" +``` + +- [ ] **Step 4: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 5: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 6: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/scanner.go +git commit -m "Refactor Scan() for parallel multi-target scanning" +``` + +--- + +### Task 5: Wire CLI flags into scan command + +**Files:** +- Modify: `cmd/gpuaudit/main.go` + +- [ ] **Step 1: Add flag variables and register flags** + +Add the new flag variables alongside the existing scan flags: + +```go +var ( + // ... existing flags ... + scanTargets []string + scanRole string + scanExternalID string + scanOrg bool + scanSkipSelf bool +) +``` + +In the `init()` function, add after the existing `scanCmd.Flags` calls: + +```go +scanCmd.Flags().StringSliceVar(&scanTargets, "targets", nil, "Account IDs to scan (comma-separated)") +scanCmd.Flags().StringVar(&scanRole, "role", "", "IAM role name to assume in each target") +scanCmd.Flags().StringVar(&scanExternalID, "external-id", "", "STS external ID for cross-account role assumption") +scanCmd.Flags().BoolVar(&scanOrg, "org", false, "Auto-discover all accounts from AWS Organizations") +scanCmd.Flags().BoolVar(&scanSkipSelf, "skip-self", false, "Exclude the caller's own account from the scan") +scanCmd.MarkFlagsMutuallyExclusive("targets", "org") +``` + +- [ ] **Step 2: Wire flags into `ScanOptions` in `runScan`** + +In the `runScan` function, add the new fields to the opts construction: + +```go +opts := awsprovider.DefaultScanOptions() +opts.Profile = scanProfile +opts.Regions = scanRegions +opts.SkipMetrics = scanSkipMetrics +opts.SkipSageMaker = scanSkipSageMaker +opts.SkipEKS = scanSkipEKS +opts.SkipCosts = scanSkipCosts +opts.ExcludeTags = parseExcludeTags(scanExcludeTags) +opts.MinUptimeDays = scanMinUptimeDays +opts.Targets = scanTargets +opts.Role = scanRole +opts.ExternalID = scanExternalID +opts.OrgScan = scanOrg +opts.SkipSelf = scanSkipSelf +``` + +- [ ] **Step 3: Add validation — `--role` required with `--targets` or `--org`** + +Add at the top of `runScan`, before creating opts: + +```go +if (len(scanTargets) > 0 || scanOrg) && scanRole == "" { + return fmt.Errorf("--role is required when using --targets or --org") +} +``` + +- [ ] **Step 4: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 5: Verify CLI help shows new flags** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --help` +Expected: new flags visible in help text + +- [ ] **Step 6: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 7: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add --targets, --role, --org, --external-id, --skip-self flags to scan command" +``` + +--- + +### Task 6: Update table formatter for multi-target output + +**Files:** +- Modify: `internal/output/table.go` + +- [ ] **Step 1: Add "By Target" summary table to `FormatTable`** + +In `internal/output/table.go`, add a new function and call it from `FormatTable` after the summary box: + +```go +func printTargetSummary(w io.Writer, result *models.ScanResult) { + if len(result.TargetSummaries) < 2 { + return + } + + fmt.Fprintf(w, " By Target\n") + fmt.Fprintf(w, " ┌──────────────┬───────────┬───────────┬───────────┬───────┐\n") + fmt.Fprintf(w, " │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │\n") + fmt.Fprintf(w, " ├──────────────┼───────────┼───────────┼───────────┼───────┤\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, " │ %-12s │ %9d │ $%8.0f │ $%8.0f │ %4.0f%% │\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintf(w, " └──────────────┴───────────┴───────────┴───────────┴───────┘\n\n") + + // Target errors + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, " Warnings\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, " ✗ %s — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } +} +``` + +In `FormatTable`, add the call after the summary box and before the "No GPU instances" check: + +```go +// ... after the summary box closing line ... + +printTargetSummary(w, result) + +if s.TotalInstances == 0 { +``` + +- [ ] **Step 2: Add "Target" column to `printInstanceTable` when multi-target** + +Modify `printInstanceTable` to accept and use target info. Since the formatter doesn't know if it's multi-target from just the instance slice, pass the result: + +Change the call sites in `FormatTable` from: +```go +printInstanceTable(w, critical) +``` +to: +```go +multiTarget := len(result.TargetSummaries) > 1 +printInstanceTable(w, critical, multiTarget) +``` + +Update `printInstanceTable`: + +```go +func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget bool) { + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s %10s %-16s %s\n", + "Instance", "Target", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 14), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } else { + fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", + "Instance", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } + + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + + signal := "" + if len(inst.WasteSignals) > 0 { + signal = inst.WasteSignals[0].Type + } + + rec := "" + if len(inst.Recommendations) > 0 { + rec = inst.Recommendations[0].Description + } + + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s $%9.0f %-16s %s\n", + name, inst.AccountID, typeDesc, inst.MonthlyCost, signal, rec) + } else { + fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", + name, typeDesc, inst.MonthlyCost, signal, rec) + } + } + fmt.Fprintln(w) +} +``` + +- [ ] **Step 3: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 4: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/table.go +git commit -m "Add per-target summary table and target column to table formatter" +``` + +--- + +### Task 7: Update markdown and Slack formatters for multi-target output + +**Files:** +- Modify: `internal/output/markdown.go` +- Modify: `internal/output/slack.go` + +- [ ] **Step 1: Add per-target section to markdown formatter** + +In `internal/output/markdown.go`, add after the Summary table (after the `s.HealthyCount` line and before the "No GPU instances" check): + +```go +// Per-target breakdown +if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "## By Target\n\n") + fmt.Fprintf(w, "| Target | Instances | Spend/mo | Waste/mo | Waste |\n") + fmt.Fprintf(w, "|---|---|---|---|---|\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, "| %s | %d | $%.0f | $%.0f | %.0f%% |\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintln(w) +} + +if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, "## Warnings\n\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, "- **%s** — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) +} +``` + +Also add a "Target" column to the Findings table when multi-target. Change the table header and row formatting: + +```go +if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "| Instance | Target | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|---|\n") +} else { + fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|\n") +} + +for _, inst := range result.Instances { + // ... existing name/signal/rec/savings formatting ... + + if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "| %s | %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.AccountID, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } else { + fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } +} +``` + +- [ ] **Step 2: Add per-target block to Slack formatter** + +In `internal/output/slack.go`, in `FormatSlack`, add after the summary block and divider: + +```go +// Per-target breakdown +if len(result.TargetSummaries) > 1 { + lines := []string{"*By Target*"} + for _, ts := range result.TargetSummaries { + lines = append(lines, fmt.Sprintf("• `%s` — %d instances, $%.0f/mo spend, $%.0f/mo waste (%.0f%%)", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + blocks = append(blocks, map[string]any{"type": "divider"}) +} + +// Target errors +if len(result.TargetErrors) > 0 { + lines := []string{":warning: *Target Warnings*"} + for _, te := range result.TargetErrors { + lines = append(lines, fmt.Sprintf("• `%s` — %s", te.Target, te.Error)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) +} +``` + +- [ ] **Step 3: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 4: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/markdown.go internal/output/slack.go +git commit -m "Add per-target summaries to markdown and Slack formatters" +``` + +--- + +### Task 8: Update `iam-policy` command + +**Files:** +- Modify: `cmd/gpuaudit/main.go` + +- [ ] **Step 1: Add cross-account and Organizations statements to `iam-policy` output** + +In `cmd/gpuaudit/main.go`, in the `iamPolicyCmd` Run function, add two new statements to the policy `Statement` slice: + +```go +{ + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader", +}, +{ + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*", +}, +``` + +Add a comment before encoding: + +```go +fmt.Fprintln(os.Stdout, "// The last two statements (CrossAccount, Organizations) are only needed for --targets or --org scanning.") +``` + +- [ ] **Step 2: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 3: Verify output looks correct** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit iam-policy` +Expected: JSON policy with the two new statements appended + +- [ ] **Step 4: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add cross-account and Organizations permissions to iam-policy output" +``` + +--- + +### Task 9: Update README with multi-target documentation + +**Files:** +- Modify: `README.md` + +- [ ] **Step 1: Add multi-account scanning section to README** + +Add a new section after the existing usage documentation: + +```markdown +## Multi-Account Scanning + +Scan multiple AWS accounts in a single invocation using STS AssumeRole. + +### Prerequisites + +Deploy a read-only IAM role (`gpuaudit-reader`) to each target account. See [Cross-Account Role Setup](#cross-account-role-setup) below. + +### Usage + +```bash +# Scan specific accounts +gpuaudit scan --targets 111111111111,222222222222 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Cross-Account Role Setup + +#### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + policy = file("gpuaudit-policy.json") # from: gpuaudit iam-policy > gpuaudit-policy.json +} +``` + +Deploy to all accounts using Terraform workspaces or CloudFormation StackSets. + +#### CloudFormation StackSet + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Parameters: + ManagementAccountId: + Type: String +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` +``` + +- [ ] **Step 2: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add README.md +git commit -m "Add multi-account scanning docs to README" +``` + +--- + +### Task 10: End-to-end verification + +**Files:** None (verification only) + +- [ ] **Step 1: Run full test suite** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: all pass + +- [ ] **Step 2: Run go vet** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go vet ./...` +Expected: no issues + +- [ ] **Step 3: Verify CLI help** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --help` +Expected: all new flags visible (--targets, --role, --org, --external-id, --skip-self) + +- [ ] **Step 4: Verify mutual exclusivity** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --targets 111 --org --role test 2>&1` +Expected: error about mutually exclusive flags + +- [ ] **Step 5: Verify --role validation** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --targets 111 2>&1` +Expected: error "role is required when using --targets or --org" + +- [ ] **Step 6: Verify single-account scan still works (no regression)** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --skip-metrics --skip-sagemaker --skip-eks --skip-k8s --skip-costs 2>&1` +Expected: runs normally, output unchanged from before this feature diff --git a/docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md b/docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md new file mode 100644 index 0000000..14c7f2c --- /dev/null +++ b/docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md @@ -0,0 +1,1394 @@ +# K8s GPU Metrics Collection Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Collect GPU utilization metrics for Kubernetes GPU nodes via a per-node fallback chain (CloudWatch Container Insights → DCGM exporter → Prometheus), and add a utilization-based waste detection rule. + +**Architecture:** Three metrics sources tried in priority order per node, all populating the existing `AvgGPUUtilization` and `AvgGPUMemUtilization` fields on `GPUInstance`. A new analysis rule `ruleK8sLowGPUUtil` flags nodes with GPU utilization < 10%. The fallback chain is wired in `main.go` between K8s discovery and analysis. + +**Tech Stack:** Go, AWS SDK v2 (CloudWatch), client-go (K8s API proxy), prometheus/common/expfmt (Prometheus text parsing), net/http (Prometheus API) + +--- + +## File Structure + +| File | Responsibility | +|------|---------------| +| `internal/providers/aws/cloudwatch.go` | Add `EnrichK8sGPUMetrics()` — CloudWatch Container Insights queries | +| `internal/providers/aws/cloudwatch_test.go` | Tests for `EnrichK8sGPUMetrics()` (new file) | +| `internal/providers/k8s/discover.go` | Extend `K8sClient` interface with `ProxyGet` | +| `internal/providers/k8s/scanner.go` | Extend `ScanOptions` with Prometheus config, export `BuildClientPublic` | +| `internal/providers/k8s/metrics.go` | DCGM scraping, Prometheus querying, fallback orchestration (new file) | +| `internal/providers/k8s/metrics_test.go` | Tests for DCGM and Prometheus paths (new file) | +| `internal/analysis/rules.go` | Add `ruleK8sLowGPUUtil` | +| `internal/analysis/rules_test.go` | Tests for new rule | +| `cmd/gpuaudit/main.go` | Add `--prom-url`, `--prom-endpoint` flags; wire CW enrichment for K8s instances | + +--- + +### Task 1: CloudWatch Container Insights Enrichment + +**Files:** +- Create: `internal/providers/aws/cloudwatch_test.go` +- Modify: `internal/providers/aws/cloudwatch.go:60-80` + +This task adds `EnrichK8sGPUMetrics()` following the exact same pattern as the existing `EnrichEC2Metrics()` and `EnrichSageMakerMetrics()` functions. It queries the `ContainerInsights` namespace for `node_gpu_utilization` and `node_gpu_memory_utilization`. + +- [ ] **Step 1: Write the failing tests** + +Create `internal/providers/aws/cloudwatch_test.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + cwtypes "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" + + "github.com/gpuaudit/cli/internal/models" +) + +type mockCloudWatchClient struct { + output *cloudwatch.GetMetricDataOutput + err error +} + +func (m *mockCloudWatchClient) GetMetricData(ctx context.Context, params *cloudwatch.GetMetricDataInput, optFns ...func(*cloudwatch.Options)) (*cloudwatch.GetMetricDataOutput, error) { + if m.err != nil { + return nil, m.err + } + return m.output, nil +} + +func TestEnrichK8sGPUMetrics_PopulatesUtilization(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{ + MetricDataResults: []cwtypes.MetricDataResult{ + {Id: aws.String("gpu_util_i_abc123"), Values: []float64{45.0, 50.0, 55.0}}, + {Id: aws.String("gpu_mem_i_abc123"), Values: []float64{30.0, 35.0, 40.0}}, + }, + }, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "ml-cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected avg GPU util 50.0, got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 35.0 { + t.Errorf("expected avg GPU mem util 35.0, got %f", *instances[0].AvgGPUMemUtilization) + } +} + +func TestEnrichK8sGPUMetrics_SkipsNonK8sNodes(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-ec2", Source: models.SourceEC2}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for non-K8s instance") + } +} + +func TestEnrichK8sGPUMetrics_SkipsNodesWithoutInstanceID(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "node-hostname", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for node without EC2 instance ID") + } +} + +func TestEnrichK8sGPUMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + AvgGPUUtilization: &gpuUtil, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Errorf("expected existing value 75.0 to be preserved, got %f", *instances[0].AvgGPUUtilization) + } +} + +func TestEnrichK8sGPUMetrics_HandlesAPIError(t *testing.T) { + client := &mockCloudWatchClient{ + err: fmt.Errorf("access denied"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-abc123", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util after API error") + } +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `go test ./internal/providers/aws/ -run TestEnrichK8sGPUMetrics -v` +Expected: FAIL — `EnrichK8sGPUMetrics` not defined + +- [ ] **Step 3: Implement EnrichK8sGPUMetrics** + +Add to `internal/providers/aws/cloudwatch.go`, after the `EnrichSageMakerMetrics` function (after line 80): + +```go +// EnrichK8sGPUMetrics populates GPU utilization metrics on K8s nodes using CloudWatch Container Insights. +func EnrichK8sGPUMetrics(ctx context.Context, client CloudWatchClient, instances []models.GPUInstance, clusterName string, window MetricWindow) { + type nodeRef struct { + index int + instanceID string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode { + continue + } + if inst.AvgGPUUtilization != nil { + continue + } + if !strings.HasPrefix(inst.InstanceID, "i-") { + continue + } + nodes = append(nodes, nodeRef{index: i, instanceID: inst.InstanceID}) + } + if len(nodes) == 0 { + return + } + + now := time.Now() + start := now.Add(-window.Duration) + + clusterDim := cwtypes.Dimension{ + Name: aws.String("ClusterName"), + Value: aws.String(clusterName), + } + + for _, node := range nodes { + instanceDim := cwtypes.Dimension{ + Name: aws.String("InstanceId"), + Value: aws.String(node.instanceID), + } + + safeID := strings.ReplaceAll(node.instanceID, "-", "_") + + queries := []cwtypes.MetricDataQuery{ + metricQuery2("gpu_util_"+safeID, "ContainerInsights", "node_gpu_utilization", "Average", window.Period, clusterDim, instanceDim), + metricQuery2("gpu_mem_"+safeID, "ContainerInsights", "node_gpu_memory_utilization", "Average", window.Period, clusterDim, instanceDim), + } + + results, err := fetchMetrics(ctx, client, queries, start, now) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Container Insights metrics unavailable for %s: %v\n", node.instanceID, err) + continue + } + + instances[node.index].AvgGPUUtilization = results["gpu_util_"+safeID] + instances[node.index].AvgGPUMemUtilization = results["gpu_mem_"+safeID] + } +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `go test ./internal/providers/aws/ -run TestEnrichK8sGPUMetrics -v` +Expected: PASS (all 5 tests) + +- [ ] **Step 5: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 6: Commit** + +```bash +git add internal/providers/aws/cloudwatch.go internal/providers/aws/cloudwatch_test.go +git commit -m "Add EnrichK8sGPUMetrics for CloudWatch Container Insights GPU metrics" +``` + +--- + +### Task 2: Extend K8sClient Interface with ProxyGet + +**Files:** +- Modify: `internal/providers/k8s/discover.go:24-27` +- Modify: `internal/providers/k8s/scanner.go:91-101` +- Modify: `internal/providers/k8s/discover_test.go:19-30` + +This task adds `ProxyGet` to the `K8sClient` interface and updates the mock and wrapper. This is needed for both DCGM scraping (Task 3) and Prometheus in-cluster queries (Task 4). + +- [ ] **Step 1: Add ProxyGet to the K8sClient interface** + +In `internal/providers/k8s/discover.go`, change the `K8sClient` interface (lines 24-27) from: + +```go +type K8sClient interface { + ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) + ListPods(ctx context.Context, namespace string, opts metav1.ListOptions) (*corev1.PodList, error) +} +``` + +to: + +```go +type K8sClient interface { + ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) + ListPods(ctx context.Context, namespace string, opts metav1.ListOptions) (*corev1.PodList, error) + ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) +} +``` + +- [ ] **Step 2: Implement ProxyGet on k8sClientWrapper** + +In `internal/providers/k8s/scanner.go`, add this method after the `ListPods` method (after line 101): + +```go +func (w *k8sClientWrapper) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + return w.clientset.CoreV1().Pods(namespace).ProxyGet("http", podName, port, path, nil).DoRaw(ctx) +} +``` + +- [ ] **Step 3: Add ProxyGet to the mock in tests** + +In `internal/providers/k8s/discover_test.go`, change the `mockK8sClient` struct (lines 19-22) from: + +```go +type mockK8sClient struct { + nodes *corev1.NodeList + pods *corev1.PodList +} +``` + +to: + +```go +type mockK8sClient struct { + nodes *corev1.NodeList + pods *corev1.PodList + proxyData map[string][]byte + proxyErr error +} +``` + +And add the method after `ListPods` (after line 30): + +```go +func (m *mockK8sClient) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + if m.proxyErr != nil { + return nil, m.proxyErr + } + key := fmt.Sprintf("%s/%s:%s%s", namespace, podName, port, path) + if data, ok := m.proxyData[key]; ok { + return data, nil + } + return nil, fmt.Errorf("no mock data for %s", key) +} +``` + +- [ ] **Step 4: Run tests to verify nothing is broken** + +Run: `go test ./internal/providers/k8s/ -v` +Expected: All existing tests pass + +- [ ] **Step 5: Commit** + +```bash +git add internal/providers/k8s/discover.go internal/providers/k8s/scanner.go internal/providers/k8s/discover_test.go +git commit -m "Add ProxyGet to K8sClient interface for pod API proxy" +``` + +--- + +### Task 3: DCGM Exporter Scraping + +**Files:** +- Create: `internal/providers/k8s/metrics.go` +- Create: `internal/providers/k8s/metrics_test.go` + +This task implements DCGM exporter auto-discovery and metric scraping. It discovers dcgm-exporter pods by label, matches them to GPU nodes, scrapes `/metrics` on port 9400, and parses `DCGM_FI_DEV_GPU_UTIL` and `DCGM_FI_DEV_MEM_COPY_UTIL`. + +- [ ] **Step 1: Add the `prometheus/common` dependency** + +Run: `go get github.com/prometheus/common@latest` + +This will also pull in `github.com/prometheus/client_model` (needed for `dto.MetricFamily`). + +- [ ] **Step 2: Write the failing tests** + +Create `internal/providers/k8s/metrics_test.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "context" + "fmt" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +func dcgmPod(name, namespace, nodeName string) corev1.Pod { + return corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/name": "dcgm-exporter", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + } +} + +const sampleDCGMMetrics = `# HELP DCGM_FI_DEV_GPU_UTIL GPU utilization. +# TYPE DCGM_FI_DEV_GPU_UTIL gauge +DCGM_FI_DEV_GPU_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 42.0 +DCGM_FI_DEV_GPU_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 38.0 +# HELP DCGM_FI_DEV_MEM_COPY_UTIL GPU memory utilization. +# TYPE DCGM_FI_DEV_MEM_COPY_UTIL gauge +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 55.0 +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 60.0 +` + +func TestEnrichDCGMMetrics_PopulatesUtilization(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 40.0 { + t.Errorf("expected avg GPU util 40.0 (average of 42 and 38), got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 57.5 { + t.Errorf("expected avg GPU mem util 57.5 (average of 55 and 60), got %f", *instances[0].AvgGPUMemUtilization) + } + if enriched != 1 { + t.Errorf("expected 1 enriched node, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Error("should not overwrite existing utilization") + } + if enriched != 0 { + t.Errorf("expected 0 enriched nodes, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_NoDCGMPods(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{Items: []corev1.Pod{}}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil when no DCGM pods") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_HandlesScrapeError(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyErr: fmt.Errorf("connection refused"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil after scrape error") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestParseDCGMMetrics(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte(sampleDCGMMetrics)) + + if gpuUtil == nil { + t.Fatal("expected gpu util") + } + if *gpuUtil != 40.0 { + t.Errorf("expected 40.0, got %f", *gpuUtil) + } + if memUtil == nil { + t.Fatal("expected mem util") + } + if *memUtil != 57.5 { + t.Errorf("expected 57.5, got %f", *memUtil) + } +} + +func TestParseDCGMMetrics_EmptyInput(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte("")) + if gpuUtil != nil || memUtil != nil { + t.Error("expected nil for empty input") + } +} +``` + +- [ ] **Step 3: Run tests to verify they fail** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichDCGM|TestParseDCGM" -v` +Expected: FAIL — functions not defined + +- [ ] **Step 4: Implement DCGM metrics enrichment** + +Create `internal/providers/k8s/metrics.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "bytes" + "context" + "fmt" + "os" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +// EnrichDCGMMetrics discovers dcgm-exporter pods and scrapes GPU metrics for K8s nodes +// that don't already have AvgGPUUtilization populated. Returns the number of nodes enriched. +func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance) int { + needsMetrics := make(map[string]int) + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + needsMetrics[inst.InstanceID] = i + } + if len(needsMetrics) == 0 { + return 0 + } + + dcgmPods, err := findDCGMPods(ctx, client) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: could not list DCGM exporter pods: %v\n", err) + return 0 + } + if len(dcgmPods) == 0 { + fmt.Fprintf(os.Stderr, " DCGM exporter not detected, skipping\n") + return 0 + } + + fmt.Fprintf(os.Stderr, " Probing DCGM exporter on GPU nodes...\n") + + enriched := 0 + for _, pod := range dcgmPods { + idx, ok := needsMetrics[pod.Spec.NodeName] + if !ok { + continue + } + + data, err := client.ProxyGet(ctx, pod.Namespace, pod.Name, "9400", "/metrics") + if err != nil { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failed for %s: %v\n", pod.Spec.NodeName, err) + continue + } + + gpuUtil, memUtil := parseDCGMMetrics(data) + if gpuUtil != nil { + instances[idx].AvgGPUUtilization = gpuUtil + instances[idx].AvgGPUMemUtilization = memUtil + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " DCGM: got GPU metrics for %d of %d remaining nodes\n", enriched, len(needsMetrics)) + return enriched +} + +func findDCGMPods(ctx context.Context, client K8sClient) ([]corev1.Pod, error) { + podList, err := client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app.kubernetes.io/name=dcgm-exporter", + }) + if err != nil { + return nil, err + } + if len(podList.Items) > 0 { + return runningPods(podList.Items), nil + } + + podList, err = client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app=nvidia-dcgm-exporter", + }) + if err != nil { + return nil, err + } + return runningPods(podList.Items), nil +} + +func runningPods(pods []corev1.Pod) []corev1.Pod { + var result []corev1.Pod + for _, p := range pods { + if p.Status.Phase == corev1.PodRunning { + result = append(result, p) + } + } + return result +} + +func parseDCGMMetrics(data []byte) (gpuUtil, memUtil *float64) { + parser := expfmt.TextParser{} + families, err := parser.TextToMetricFamilies(bytes.NewReader(data)) + if err != nil { + return nil, nil + } + + gpuUtil = avgMetricValue(families["DCGM_FI_DEV_GPU_UTIL"]) + memUtil = avgMetricValue(families["DCGM_FI_DEV_MEM_COPY_UTIL"]) + return gpuUtil, memUtil +} + +func avgMetricValue(family *dto.MetricFamily) *float64 { + if family == nil || len(family.Metric) == 0 { + return nil + } + sum := 0.0 + count := 0 + for _, m := range family.Metric { + if m.Gauge != nil && m.Gauge.Value != nil { + sum += *m.Gauge.Value + count++ + } + } + if count == 0 { + return nil + } + avg := sum / float64(count) + return &avg +} +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichDCGM|TestParseDCGM" -v` +Expected: PASS (all 6 tests) + +- [ ] **Step 6: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 7: Commit** + +```bash +git add internal/providers/k8s/metrics.go internal/providers/k8s/metrics_test.go go.mod go.sum +git commit -m "Add DCGM exporter scraping for K8s GPU metrics" +``` + +--- + +### Task 4: Prometheus Query Enrichment + +**Files:** +- Modify: `internal/providers/k8s/metrics.go` +- Modify: `internal/providers/k8s/metrics_test.go` + +This task adds the Prometheus query path — the third fallback. It supports both direct URL (`--prom-url`) and in-cluster service endpoint (`--prom-endpoint`), querying `avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"..."}[7d])`. + +- [ ] **Step 1: Write the failing tests** + +Add to `internal/providers/k8s/metrics_test.go`: + +```go +import ( + "net/http" + "net/http/httptest" + "strings" +) +``` + +Add these test functions: + +```go +func TestEnrichPrometheusMetrics_PopulatesFromDirectURL(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "65.5"]}, + {"metric": {"node": "i-node2"}, "value": [1700000000, "30.0"]} + ] + } + }` + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/query" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + query := r.URL.Query().Get("query") + if !strings.Contains(query, "DCGM_FI_DEV_GPU_UTIL") { + t.Errorf("unexpected query: %s", query) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(promResponse)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + {InstanceID: "i-node2", Source: models.SourceK8sNode, Name: "cluster/i-node2"}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 2 { + t.Errorf("expected 2 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 65.5 { + t.Errorf("expected node1 GPU util 65.5, got %v", instances[0].AvgGPUUtilization) + } + if instances[1].AvgGPUUtilization == nil || *instances[1].AvgGPUUtilization != 30.0 { + t.Errorf("expected node2 GPU util 30.0, got %v", instances[1].AvgGPUUtilization) + } +} + +func TestEnrichPrometheusMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 80.0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"success","data":{"resultType":"vector","result":[]}}`)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_NoOptions(t *testing.T) { + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, PrometheusOptions{}) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_InClusterEndpoint(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "50.0"]} + ] + } + }` + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{}, + proxyData: map[string][]byte{ + "monitoring/prometheus:9090/api/v1/query": []byte(promResponse), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + opts := PrometheusOptions{Endpoint: "monitoring/prometheus:9090"} + + enriched := EnrichPrometheusMetrics(context.Background(), client, instances, opts) + + if enriched != 1 { + t.Errorf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected 50.0, got %v", instances[0].AvgGPUUtilization) + } +} + +func TestParsePrometheusEndpoint(t *testing.T) { + tests := []struct { + input string + namespace string + service string + port string + wantErr bool + }{ + {"monitoring/prometheus:9090", "monitoring", "prometheus", "9090", false}, + {"kube-system/thanos-query:10902", "kube-system", "thanos-query", "10902", false}, + {"invalid", "", "", "", true}, + {"ns/svc", "", "", "", true}, + } + for _, tt := range tests { + ns, svc, port, err := parsePrometheusEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parsePrometheusEndpoint(%q): err=%v, wantErr=%v", tt.input, err, tt.wantErr) + continue + } + if ns != tt.namespace || svc != tt.service || port != tt.port { + t.Errorf("parsePrometheusEndpoint(%q) = (%q,%q,%q), want (%q,%q,%q)", + tt.input, ns, svc, port, tt.namespace, tt.service, tt.port) + } + } +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichPrometheus|TestParsePrometheus" -v` +Expected: FAIL — functions not defined + +- [ ] **Step 3: Implement Prometheus metrics enrichment** + +Add to `internal/providers/k8s/metrics.go` (additional imports at the top): + +```go +import ( + "encoding/json" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) +``` + +Add these types and functions: + +```go +// PrometheusOptions configures how to reach a Prometheus-compatible API. +type PrometheusOptions struct { + URL string + Endpoint string +} + +// EnrichPrometheusMetrics queries a Prometheus endpoint for GPU utilization metrics +// for K8s nodes that don't already have AvgGPUUtilization populated. +func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance, opts PrometheusOptions) int { + if opts.URL == "" && opts.Endpoint == "" { + return 0 + } + + type nodeRef struct { + index int + name string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + nodes = append(nodes, nodeRef{index: i, name: inst.InstanceID}) + } + if len(nodes) == 0 { + return 0 + } + + source := opts.URL + if source == "" { + source = opts.Endpoint + } + fmt.Fprintf(os.Stderr, " Querying Prometheus at %s...\n", source) + + nodeNames := make([]string, len(nodes)) + for i, n := range nodes { + nodeNames[i] = n.name + } + nodeRegex := strings.Join(nodeNames, "|") + + gpuResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"%s"}[7d])`, nodeRegex)) + memResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL{node=~"%s"}[7d])`, nodeRegex)) + + enriched := 0 + for _, node := range nodes { + if val, ok := gpuResults[node.name]; ok { + instances[node.index].AvgGPUUtilization = &val + if memVal, ok := memResults[node.name]; ok { + instances[node.index].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " Prometheus: got GPU metrics for %d of %d remaining nodes\n", enriched, len(nodes)) + return enriched +} + +func queryPrometheus(ctx context.Context, client K8sClient, opts PrometheusOptions, query string) map[string]float64 { + var data []byte + var err error + + if opts.URL != "" { + data, err = queryPrometheusHTTP(ctx, opts.URL, query) + } else { + data, err = queryPrometheusProxy(ctx, client, opts.Endpoint, query) + } + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) + return nil + } + + return parsePrometheusResponse(data) +} + +func queryPrometheusHTTP(ctx context.Context, baseURL, query string) ([]byte, error) { + u := fmt.Sprintf("%s/api/v1/query?query=%s", strings.TrimRight(baseURL, "/"), url.QueryEscape(query)) + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return io.ReadAll(resp.Body) +} + +func queryPrometheusProxy(ctx context.Context, client K8sClient, endpoint, query string) ([]byte, error) { + ns, svc, port, err := parsePrometheusEndpoint(endpoint) + if err != nil { + return nil, err + } + path := fmt.Sprintf("/api/v1/query?query=%s", url.QueryEscape(query)) + return client.ProxyGet(ctx, ns, svc, port, path) +} + +func parsePrometheusEndpoint(endpoint string) (namespace, service, port string, err error) { + slashIdx := strings.Index(endpoint, "/") + if slashIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + namespace = endpoint[:slashIdx] + rest := endpoint[slashIdx+1:] + colonIdx := strings.LastIndex(rest, ":") + if colonIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + service = rest[:colonIdx] + port = rest[colonIdx+1:] + return namespace, service, port, nil +} + +func parsePrometheusResponse(data []byte) map[string]float64 { + var resp struct { + Status string `json:"status"` + Data struct { + ResultType string `json:"resultType"` + Result []struct { + Metric map[string]string `json:"metric"` + Value []json.RawMessage `json:"value"` + } `json:"result"` + } `json:"data"` + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil + } + if resp.Status != "success" { + return nil + } + + results := make(map[string]float64) + for _, r := range resp.Data.Result { + node := r.Metric["node"] + if node == "" || len(r.Value) < 2 { + continue + } + var valStr string + if err := json.Unmarshal(r.Value[1], &valStr); err != nil { + continue + } + val, err := strconv.ParseFloat(valStr, 64) + if err != nil { + continue + } + results[node] = val + } + return results +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichPrometheus|TestParsePrometheus" -v` +Expected: PASS (all 5 tests) + +- [ ] **Step 5: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 6: Commit** + +```bash +git add internal/providers/k8s/metrics.go internal/providers/k8s/metrics_test.go +git commit -m "Add Prometheus query enrichment for K8s GPU metrics" +``` + +--- + +### Task 5: K8s Low GPU Utilization Analysis Rule + +**Files:** +- Modify: `internal/analysis/rules.go` +- Modify: `internal/analysis/rules_test.go` + +- [ ] **Step 1: Write the failing tests** + +Add to `internal/analysis/rules_test.go`: + +```go +func TestRuleK8sLowGPUUtil_FlagsLowUtilization(t *testing.T) { + inst := models.GPUInstance{ + InstanceID: "i-node1", + Source: models.SourceK8sNode, + State: "ready", + InstanceType: "g5.xlarge", + GPUModel: "A10G", + GPUCount: 1, + GPUAllocated: 1, + MonthlyCost: 734, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 1 { + t.Fatalf("expected 1 signal, got %d", len(inst.WasteSignals)) + } + if inst.WasteSignals[0].Type != "low_utilization" { + t.Errorf("expected low_utilization, got %s", inst.WasteSignals[0].Type) + } + if inst.WasteSignals[0].Severity != models.SeverityCritical { + t.Errorf("expected critical, got %s", inst.WasteSignals[0].Severity) + } + if inst.WasteSignals[0].Confidence != 0.85 { + t.Errorf("expected confidence 0.85, got %f", inst.WasteSignals[0].Confidence) + } + if len(inst.Recommendations) != 1 { + t.Fatalf("expected 1 recommendation, got %d", len(inst.Recommendations)) + } + if inst.Recommendations[0].MonthlySavings != 734*0.8 { + t.Errorf("expected savings %.0f, got %f", 734*0.8, inst.Recommendations[0].MonthlySavings) + } +} + +func TestRuleK8sLowGPUUtil_SkipsNonK8s(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceEC2, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for EC2 instance") + } +} + +func TestRuleK8sLowGPUUtil_SkipsNoMetrics(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals when metrics unavailable") + } +} + +func TestRuleK8sLowGPUUtil_SkipsHighUtilization(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + AvgGPUUtilization: ptr(45.0), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for well-utilized GPU") + } +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `go test ./internal/analysis/ -run TestRuleK8sLowGPUUtil -v` +Expected: FAIL — `ruleK8sLowGPUUtil` not defined + +- [ ] **Step 3: Implement the rule** + +In `internal/analysis/rules.go`, add `ruleK8sLowGPUUtil` to the rules slice inside `analyzeInstance()` (line 23-31). The full slice should be: + +```go + rules := []func(*models.GPUInstance){ + ruleIdle, + ruleOversizedGPU, + rulePricingMismatch, + ruleStale, + ruleSageMakerLowUtil, + ruleSageMakerOversized, + ruleK8sUnallocatedGPU, + ruleSpotEligible, + ruleK8sLowGPUUtil, + } +``` + +Then add the rule function at the end of the file: + +```go +// Rule 9: K8s GPU node with low GPU utilization (requires DCGM/CW/Prometheus metrics). +func ruleK8sLowGPUUtil(inst *models.GPUInstance) { + if inst.Source != models.SourceK8sNode { + return + } + if inst.AvgGPUUtilization == nil { + return + } + if *inst.AvgGPUUtilization >= 10 { + return + } + + inst.WasteSignals = append(inst.WasteSignals, models.WasteSignal{ + Type: "low_utilization", + Severity: models.SeverityCritical, + Confidence: 0.85, + Evidence: fmt.Sprintf("K8s GPU node utilization averaging %.1f%%. GPUs are allocated but barely used.", *inst.AvgGPUUtilization), + }) + inst.Recommendations = append(inst.Recommendations, models.Recommendation{ + Action: models.ActionDownsize, + Description: fmt.Sprintf("GPU utilization averaging %.1f%%. Consider bin-packing more workloads, downsizing, or removing from the node pool.", *inst.AvgGPUUtilization), + CurrentMonthlyCost: inst.MonthlyCost, + RecommendedMonthlyCost: inst.MonthlyCost * 0.2, + MonthlySavings: inst.MonthlyCost * 0.8, + SavingsPercent: 80, + Risk: models.RiskMedium, + }) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `go test ./internal/analysis/ -run TestRuleK8sLowGPUUtil -v` +Expected: PASS (all 4 tests) + +- [ ] **Step 5: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 6: Commit** + +```bash +git add internal/analysis/rules.go internal/analysis/rules_test.go +git commit -m "Add ruleK8sLowGPUUtil for utilization-based K8s GPU waste detection" +``` + +--- + +### Task 6: Wire Everything into CLI and Scan Flow + +**Files:** +- Modify: `cmd/gpuaudit/main.go` +- Modify: `internal/providers/k8s/scanner.go` + +This task adds the `--prom-url` and `--prom-endpoint` CLI flags, passes them through to the K8s scan, wires CloudWatch Container Insights enrichment, and orchestrates the fallback chain in `main.go`. + +- [ ] **Step 1: Extend K8s ScanOptions** + +In `internal/providers/k8s/scanner.go`, change the `ScanOptions` struct (lines 20-23) from: + +```go +type ScanOptions struct { + Kubeconfig string + Context string +} +``` + +to: + +```go +type ScanOptions struct { + Kubeconfig string + Context string + PromURL string + PromEndpoint string +} +``` + +- [ ] **Step 2: Export BuildClient** + +Add to `internal/providers/k8s/scanner.go` after the existing `buildClient` function: + +```go +func BuildClientPublic(kubeconfigPath, contextName string) (K8sClient, string, error) { + return buildClient(kubeconfigPath, contextName) +} +``` + +- [ ] **Step 3: Add CLI flags** + +In `cmd/gpuaudit/main.go`, add the flag variables after `scanKubeContext` (around line 51): + +```go + scanPromURL string + scanPromEndpoint string +``` + +Add the flag registrations inside the first `init()` function, after the `--kube-context` flag (after line 73): + +```go + scanCmd.Flags().StringVar(&scanPromURL, "prom-url", "", "Prometheus URL for GPU metrics (e.g., https://prometheus.corp.example.com)") + scanCmd.Flags().StringVar(&scanPromEndpoint, "prom-endpoint", "", "In-cluster Prometheus service as namespace/service:port (e.g., monitoring/prometheus:9090)") +``` + +- [ ] **Step 4: Add flag validation and wiring in runScan** + +In `cmd/gpuaudit/main.go`, in the `runScan` function, add validation after `ctx := context.Background()` (line 84): + +```go + if scanPromURL != "" && scanPromEndpoint != "" { + return fmt.Errorf("--prom-url and --prom-endpoint are mutually exclusive") + } +``` + +Then modify the K8s scan section. Replace the block starting with `// Kubernetes API scan` (around lines 107-119) with: + +```go + // Kubernetes API scan + if !scanSkipK8s { + k8sOpts := k8sprovider.ScanOptions{ + Kubeconfig: scanKubeconfig, + Context: scanKubeContext, + PromURL: scanPromURL, + PromEndpoint: scanPromEndpoint, + } + k8sInstances, err := k8sprovider.Scan(ctx, k8sOpts) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Kubernetes scan failed: %v\n", err) + } else if len(k8sInstances) > 0 { + if !scanSkipMetrics { + enrichK8sGPUMetrics(ctx, k8sInstances, k8sOpts, opts) + } + analysis.AnalyzeAll(k8sInstances) + result.Instances = append(result.Instances, k8sInstances...) + result.Summary = awsprovider.BuildSummary(result.Instances) + } + } +``` + +- [ ] **Step 5: Add the enrichK8sGPUMetrics helper function** + +Add this function at the bottom of `cmd/gpuaudit/main.go`: + +```go +func enrichK8sGPUMetrics(ctx context.Context, instances []models.GPUInstance, k8sOpts k8sprovider.ScanOptions, awsOpts awsprovider.ScanOptions) { + // Source 1: CloudWatch Container Insights + if len(instances) > 0 && instances[0].ClusterName != "" { + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if awsOpts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(awsOpts.Profile)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err == nil { + region := instances[0].Region + if region == "" { + region = "us-east-1" + } + cfg.Region = region + cwClient := cloudwatch.NewFromConfig(cfg) + fmt.Fprintf(os.Stderr, " Enriching K8s GPU metrics via CloudWatch Container Insights...\n") + awsprovider.EnrichK8sGPUMetrics(ctx, cwClient, instances, instances[0].ClusterName, awsprovider.DefaultMetricWindow) + + enriched := 0 + for _, inst := range instances { + if inst.AvgGPUUtilization != nil { + enriched++ + } + } + fmt.Fprintf(os.Stderr, " CloudWatch: got GPU metrics for %d of %d nodes\n", enriched, len(instances)) + } + } + + // Count remaining + remaining := 0 + for _, inst := range instances { + if inst.AvgGPUUtilization == nil { + remaining++ + } + } + + // Source 2: DCGM exporter scrape + if remaining > 0 { + client, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + k8sprovider.EnrichDCGMMetrics(ctx, client, instances) + } + + remaining = 0 + for _, inst := range instances { + if inst.AvgGPUUtilization == nil { + remaining++ + } + } + } + + // Source 3: Prometheus query + if remaining > 0 && (k8sOpts.PromURL != "" || k8sOpts.PromEndpoint != "") { + var client k8sprovider.K8sClient + if k8sOpts.PromEndpoint != "" { + c, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + client = c + } + } + promOpts := k8sprovider.PrometheusOptions{ + URL: k8sOpts.PromURL, + Endpoint: k8sOpts.PromEndpoint, + } + k8sprovider.EnrichPrometheusMetrics(ctx, client, instances, promOpts) + } +} +``` + +You will need to add the `"github.com/aws/aws-sdk-go-v2/service/cloudwatch"` import to `main.go` if it's not already present. + +- [ ] **Step 6: Run build and full test suite** + +Run: `go build ./... && go test ./...` +Expected: Build succeeds, all tests pass + +- [ ] **Step 7: Commit** + +```bash +git add cmd/gpuaudit/main.go internal/providers/k8s/scanner.go +git commit -m "Wire K8s GPU metrics fallback chain into CLI scan flow" +``` diff --git a/go.mod b/go.mod index b86d582..9b28a73 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,18 @@ module github.com/gpuaudit/cli go 1.24.0 require ( - github.com/aws/aws-sdk-go-v2 v1.41.5 + github.com/aws/aws-sdk-go-v2 v1.41.6 github.com/aws/aws-sdk-go-v2/config v1.32.14 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 github.com/aws/aws-sdk-go-v2/service/costexplorer v1.63.6 github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 github.com/aws/aws-sdk-go-v2/service/eks v1.82.0 + github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2 github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 + github.com/prometheus/client_model v0.6.2 + github.com/prometheus/common v0.67.5 github.com/spf13/cobra v1.10.2 k8s.io/api v0.32.3 k8s.io/apimachinery v0.32.3 @@ -18,17 +22,16 @@ require ( ) require ( - github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect - github.com/aws/smithy-go v1.24.2 // indirect + github.com/aws/smithy-go v1.25.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -39,7 +42,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -52,13 +55,14 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/term v0.25.0 // indirect - golang.org/x/text v0.19.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/term v0.38.0 // indirect + golang.org/x/text v0.32.0 // indirect golang.org/x/time v0.7.0 // indirect - google.golang.org/protobuf v1.35.1 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index c4d6139..67088e7 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,15 @@ -github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= -github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2 v1.41.6 h1:1AX0AthnBQzMx1vbmir3Y4WsnJgiydmnJjiLu+LvXOg= +github.com/aws/aws-sdk-go-v2 v1.41.6/go.mod h1:dy0UzBIfwSeot4grGvY1AqFWN5zgziMmWGzysDnHFcQ= github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo= github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 h1:GmLa5Kw1ESqtFpXsx5MmC84QWa/ZrLZvlJGa2y+4kcQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22/go.mod h1:6sW9iWm9DK9YRpRGga/qzrzNLgKpT2cIxb7Vo2eNOp0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 h1:dY4kWZiSaXIzxnKlj17nHnBcXXBfac6UlsAx2qL6XrU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22/go.mod h1:KIpEUx0JuRZLO7U6cbV204cWAEco2iC3l061IxlwLtI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 h1:ud2A364lLBkhGAC7oYw/1xg9BF4acwJC+qdLykxy83o= @@ -24,6 +24,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhL github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2 h1:2TDersSNowBwSRTrnD0LxLilpr6Dr5coXwVsWO7f2rw= +github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2/go.mod h1:UMm4MKZDJMbuJZF5QOJBsVRMLeKiEXAgCXFpocWPDFo= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 h1:5jLvLVu20tlFgVOsX+ns4jNVzoUWP36AQc5sAvNJSMI= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0/go.mod h1:zsRrjJIfG9a9b3VRU+uPa3dX5fqgI+zKMXD4tbIlbdA= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= @@ -34,8 +36,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6f github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= -github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= -github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/aws/smithy-go v1.25.0 h1:Sz/XJ64rwuiKtB6j98nDIPyYrV1nVNJ4YU74gttcl5U= +github.com/aws/smithy-go v1.25.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -65,8 +67,8 @@ github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6 github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -107,6 +109,10 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -121,12 +127,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -137,38 +145,38 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= -golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= -golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index f91bcbe..93c139e 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -28,6 +28,8 @@ func analyzeInstance(inst *models.GPUInstance) { ruleSageMakerLowUtil, ruleSageMakerOversized, ruleK8sUnallocatedGPU, + ruleSpotEligible, + ruleK8sLowGPUUtil, } for _, rule := range rules { rule(inst) @@ -317,11 +319,11 @@ func ruleK8sUnallocatedGPU(inst *models.GPUInstance) { Type: "idle", Severity: models.SeverityCritical, Confidence: 0.9, - Evidence: fmt.Sprintf("GPU node has %d GPU(s) but no pods requesting GPUs for %.0f+ hours.", inst.GPUCount, inst.UptimeHours), + Evidence: fmt.Sprintf("GPU node has %d GPU(s) but no pods requesting GPUs. Node up for %d days.", inst.GPUCount, int(inst.UptimeHours/24)), }) inst.Recommendations = append(inst.Recommendations, models.Recommendation{ Action: models.ActionTerminate, - Description: fmt.Sprintf("No GPU pods scheduled on this node for %d days. Remove from node pool or scale down.", int(inst.UptimeHours/24)), + Description: fmt.Sprintf("Node up %d days with 0 GPU pods scheduled. Remove from node pool or scale down.", int(inst.UptimeHours/24)), CurrentMonthlyCost: inst.MonthlyCost, MonthlySavings: inst.MonthlyCost, SavingsPercent: 100, @@ -347,3 +349,79 @@ func ruleK8sUnallocatedGPU(inst *models.GPUInstance) { }) } } + +// Rule 8: On-demand instance eligible for Spot pricing. +func ruleSpotEligible(inst *models.GPUInstance) { + if inst.PricingModel != "on-demand" { + return + } + if inst.UptimeHours < 24 { + return + } + if inst.SpotHourlyCost == nil { + return + } + if inst.HourlyCost <= 0 { + return + } + + spotHourly := *inst.SpotHourlyCost + savingsPercent := ((inst.HourlyCost - spotHourly) / inst.HourlyCost) * 100 + if savingsPercent <= 0 { + return + } + + monthlySavings := (inst.HourlyCost - spotHourly) * 730 + spotMonthlyCost := spotHourly * 730 + + // Higher savings → higher confidence + confidence := 0.35 + (savingsPercent / 120) + if confidence > 0.95 { + confidence = 0.95 + } + + inst.WasteSignals = append(inst.WasteSignals, models.WasteSignal{ + Type: "spot_eligible", + Severity: models.SeverityInfo, + Confidence: confidence, + Evidence: fmt.Sprintf("Spot pricing available at $%.3f/hr vs $%.3f/hr on-demand (%.0f%% savings).", spotHourly, inst.HourlyCost, savingsPercent), + }) + inst.Recommendations = append(inst.Recommendations, models.Recommendation{ + Action: models.ActionChangePricing, + Description: fmt.Sprintf("Spot pricing available at $%.2f/hr (%.0f%% savings). Spot instances may be interrupted — suitable for fault-tolerant workloads.", spotHourly, savingsPercent), + CurrentMonthlyCost: inst.MonthlyCost, + RecommendedMonthlyCost: spotMonthlyCost, + MonthlySavings: monthlySavings, + SavingsPercent: savingsPercent, + Risk: models.RiskHigh, + }) +} + +// Rule 9: K8s GPU node with low GPU utilization (requires DCGM/CW/Prometheus metrics). +func ruleK8sLowGPUUtil(inst *models.GPUInstance) { + if inst.Source != models.SourceK8sNode { + return + } + if inst.AvgGPUUtilization == nil { + return + } + if *inst.AvgGPUUtilization >= 10 { + return + } + + inst.WasteSignals = append(inst.WasteSignals, models.WasteSignal{ + Type: "low_utilization", + Severity: models.SeverityCritical, + Confidence: 0.85, + Evidence: fmt.Sprintf("K8s GPU node utilization averaging %.1f%% over the past 7 days. GPUs are allocated but barely used.", *inst.AvgGPUUtilization), + }) + inst.Recommendations = append(inst.Recommendations, models.Recommendation{ + Action: models.ActionDownsize, + Description: fmt.Sprintf("GPU utilization averaging %.1f%% over the past 7 days. Consider bin-packing more workloads, downsizing, or removing from the node pool.", *inst.AvgGPUUtilization), + CurrentMonthlyCost: inst.MonthlyCost, + RecommendedMonthlyCost: inst.MonthlyCost * 0.2, + MonthlySavings: inst.MonthlyCost * 0.8, + SavingsPercent: 80, + Risk: models.RiskMedium, + }) +} diff --git a/internal/analysis/rules_test.go b/internal/analysis/rules_test.go index d8d264d..80bfa6d 100644 --- a/internal/analysis/rules_test.go +++ b/internal/analysis/rules_test.go @@ -259,3 +259,192 @@ func TestAnalyzeAll_ComputesSavings(t *testing.T) { t.Errorf("expected no signals for healthy instance, got %d", len(instances[1].WasteSignals)) } } + +func TestRuleSpotEligible_FlagsOnDemandWithSpotPrice(t *testing.T) { + spotPrice := 0.556 + inst := models.GPUInstance{ + InstanceID: "i-test", + Source: models.SourceEC2, + PricingModel: "on-demand", + UptimeHours: 48, + HourlyCost: 1.006, + MonthlyCost: 1.006 * 730, + SpotHourlyCost: &spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 1 { + t.Fatalf("expected 1 signal, got %d", len(inst.WasteSignals)) + } + if inst.WasteSignals[0].Type != "spot_eligible" { + t.Errorf("expected spot_eligible, got %s", inst.WasteSignals[0].Type) + } + if inst.WasteSignals[0].Severity != models.SeverityInfo { + t.Errorf("expected info severity, got %s", inst.WasteSignals[0].Severity) + } + if len(inst.Recommendations) != 1 { + t.Fatalf("expected 1 recommendation, got %d", len(inst.Recommendations)) + } + if inst.Recommendations[0].Action != models.ActionChangePricing { + t.Errorf("expected change_pricing, got %s", inst.Recommendations[0].Action) + } + expectedSavings := (1.006 - 0.556) * 730 + diff := inst.Recommendations[0].MonthlySavings - expectedSavings + if diff < -0.01 || diff > 0.01 { + t.Errorf("expected savings %.2f, got %.2f", expectedSavings, inst.Recommendations[0].MonthlySavings) + } +} + +func TestRuleSpotEligible_SkipsSpotInstances(t *testing.T) { + spotPrice := 0.556 + inst := models.GPUInstance{ + PricingModel: "spot", + UptimeHours: 48, + SpotHourlyCost: &spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for spot instance, got %d", len(inst.WasteSignals)) + } +} + +func TestRuleSpotEligible_SkipsRecentInstances(t *testing.T) { + spotPrice := 0.556 + inst := models.GPUInstance{ + PricingModel: "on-demand", + UptimeHours: 12, + SpotHourlyCost: &spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for recent instance, got %d", len(inst.WasteSignals)) + } +} + +func TestRuleSpotEligible_SkipsWhenNoSpotPrice(t *testing.T) { + inst := models.GPUInstance{ + PricingModel: "on-demand", + UptimeHours: 48, + SpotHourlyCost: nil, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals when spot price unavailable, got %d", len(inst.WasteSignals)) + } +} + +func TestRuleSpotEligible_ConfidenceScalesWithSavings(t *testing.T) { + tests := []struct { + name string + onDemand float64 + spotPrice float64 + minConfidence float64 + }{ + {"large_savings_60pct", 1.0, 0.4, 0.85}, + {"moderate_savings_40pct", 1.0, 0.6, 0.65}, + {"small_savings_20pct", 1.0, 0.8, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inst := models.GPUInstance{ + PricingModel: "on-demand", + UptimeHours: 48, + HourlyCost: tt.onDemand, + MonthlyCost: tt.onDemand * 730, + SpotHourlyCost: &tt.spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) == 0 { + t.Fatal("expected signal") + } + if inst.WasteSignals[0].Confidence < tt.minConfidence { + t.Errorf("expected confidence >= %.2f, got %.2f", tt.minConfidence, inst.WasteSignals[0].Confidence) + } + }) + } +} + +func TestRuleK8sLowGPUUtil_FlagsLowUtilization(t *testing.T) { + inst := models.GPUInstance{ + InstanceID: "i-node1", + Source: models.SourceK8sNode, + State: "ready", + InstanceType: "g5.xlarge", + GPUModel: "A10G", + GPUCount: 1, + GPUAllocated: 1, + MonthlyCost: 734, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 1 { + t.Fatalf("expected 1 signal, got %d", len(inst.WasteSignals)) + } + if inst.WasteSignals[0].Type != "low_utilization" { + t.Errorf("expected low_utilization, got %s", inst.WasteSignals[0].Type) + } + if inst.WasteSignals[0].Severity != models.SeverityCritical { + t.Errorf("expected critical, got %s", inst.WasteSignals[0].Severity) + } + if inst.WasteSignals[0].Confidence != 0.85 { + t.Errorf("expected confidence 0.85, got %f", inst.WasteSignals[0].Confidence) + } + if len(inst.Recommendations) != 1 { + t.Fatalf("expected 1 recommendation, got %d", len(inst.Recommendations)) + } + if inst.Recommendations[0].MonthlySavings != 734*0.8 { + t.Errorf("expected savings %.0f, got %f", 734*0.8, inst.Recommendations[0].MonthlySavings) + } +} + +func TestRuleK8sLowGPUUtil_SkipsNonK8s(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceEC2, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for EC2 instance") + } +} + +func TestRuleK8sLowGPUUtil_SkipsNoMetrics(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals when metrics unavailable") + } +} + +func TestRuleK8sLowGPUUtil_SkipsHighUtilization(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + AvgGPUUtilization: ptr(45.0), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for well-utilized GPU") + } +} diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 0000000..7d74430 --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,171 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package diff compares two scan results and reports what changed. +package diff + +import ( + "fmt" + + "github.com/gpuaudit/cli/internal/models" +) + +// DiffResult holds the comparison between two scan results. +type DiffResult struct { + OldTimestamp string `json:"old_timestamp"` + NewTimestamp string `json:"new_timestamp"` + Added []models.GPUInstance `json:"added,omitempty"` + Removed []models.GPUInstance `json:"removed,omitempty"` + Changed []InstanceDiff `json:"changed,omitempty"` + UnchangedCount int `json:"unchanged_count"` + CostSummary CostDelta `json:"cost_summary"` +} + +// InstanceDiff describes what changed for a single instance between scans. +type InstanceDiff struct { + InstanceID string `json:"instance_id"` + Old models.GPUInstance `json:"old"` + New models.GPUInstance `json:"new"` + CostDelta float64 `json:"cost_delta"` + Changes []string `json:"changes"` +} + +// CostDelta summarizes the financial impact of changes between scans. +type CostDelta struct { + OldTotalMonthlyCost float64 `json:"old_total_monthly_cost"` + NewTotalMonthlyCost float64 `json:"new_total_monthly_cost"` + CostChange float64 `json:"cost_change"` + OldTotalWaste float64 `json:"old_total_waste"` + NewTotalWaste float64 `json:"new_total_waste"` + WasteChange float64 `json:"waste_change"` + AddedCost float64 `json:"added_cost"` + RemovedSavings float64 `json:"removed_savings"` +} + +// Compare computes the diff between two scan results, matching instances by ID. +func Compare(old, new *models.ScanResult) *DiffResult { + oldMap := make(map[string]models.GPUInstance, len(old.Instances)) + for _, inst := range old.Instances { + oldMap[inst.InstanceID] = inst + } + + newMap := make(map[string]models.GPUInstance, len(new.Instances)) + for _, inst := range new.Instances { + newMap[inst.InstanceID] = inst + } + + result := &DiffResult{ + OldTimestamp: old.Timestamp.Format("2006-01-02 15:04 UTC"), + NewTimestamp: new.Timestamp.Format("2006-01-02 15:04 UTC"), + } + + // Find removed and changed + for id, oldInst := range oldMap { + newInst, exists := newMap[id] + if !exists { + result.Removed = append(result.Removed, oldInst) + continue + } + changes := diffInstance(oldInst, newInst) + if len(changes) > 0 { + result.Changed = append(result.Changed, InstanceDiff{ + InstanceID: id, + Old: oldInst, + New: newInst, + CostDelta: newInst.MonthlyCost - oldInst.MonthlyCost, + Changes: changes, + }) + } else { + result.UnchangedCount++ + } + } + + // Find added + for id, newInst := range newMap { + if _, exists := oldMap[id]; !exists { + result.Added = append(result.Added, newInst) + } + } + + // Cost summary + result.CostSummary = computeCostDelta(old, new, result) + + return result +} + +func diffInstance(old, new models.GPUInstance) []string { + var changes []string + + if old.InstanceType != new.InstanceType { + changes = append(changes, fmt.Sprintf("Instance type: %s → %s", old.InstanceType, new.InstanceType)) + } + if old.PricingModel != new.PricingModel { + changes = append(changes, fmt.Sprintf("Pricing: %s → %s", old.PricingModel, new.PricingModel)) + } + if old.MonthlyCost != new.MonthlyCost { + delta := new.MonthlyCost - old.MonthlyCost + changes = append(changes, fmt.Sprintf("Cost: $%.0f → $%.0f (%s/mo)", old.MonthlyCost, new.MonthlyCost, fmtDelta(delta))) + } + if old.State != new.State { + changes = append(changes, fmt.Sprintf("State: %s → %s", old.State, new.State)) + } + if old.GPUAllocated != new.GPUAllocated { + changes = append(changes, fmt.Sprintf("GPU allocated: %d → %d", old.GPUAllocated, new.GPUAllocated)) + } + if maxSeverityStr(old.WasteSignals) != maxSeverityStr(new.WasteSignals) { + oldSev := maxSeverityStr(old.WasteSignals) + newSev := maxSeverityStr(new.WasteSignals) + if oldSev == "" { + oldSev = "(none)" + } + if newSev == "" { + newSev = "(none)" + } + changes = append(changes, fmt.Sprintf("Severity: %s → %s", oldSev, newSev)) + } + + return changes +} + +func maxSeverityStr(signals []models.WasteSignal) string { + max := models.Severity("") + for _, s := range signals { + if s.Severity == models.SeverityCritical { + return string(models.SeverityCritical) + } + if s.Severity == models.SeverityWarning { + max = models.SeverityWarning + } + if s.Severity == models.SeverityInfo && max == "" { + max = models.SeverityInfo + } + } + return string(max) +} + +func fmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +func computeCostDelta(old, new *models.ScanResult, diff *DiffResult) CostDelta { + cd := CostDelta{ + OldTotalMonthlyCost: old.Summary.TotalMonthlyCost, + NewTotalMonthlyCost: new.Summary.TotalMonthlyCost, + CostChange: new.Summary.TotalMonthlyCost - old.Summary.TotalMonthlyCost, + OldTotalWaste: old.Summary.TotalEstimatedWaste, + NewTotalWaste: new.Summary.TotalEstimatedWaste, + WasteChange: new.Summary.TotalEstimatedWaste - old.Summary.TotalEstimatedWaste, + } + + for _, inst := range diff.Added { + cd.AddedCost += inst.MonthlyCost + } + for _, inst := range diff.Removed { + cd.RemovedSavings += inst.MonthlyCost + } + + return cd +} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go new file mode 100644 index 0000000..35d4f1f --- /dev/null +++ b/internal/diff/diff_test.go @@ -0,0 +1,219 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package diff + +import ( + "testing" + "time" + + "github.com/gpuaudit/cli/internal/models" +) + +func scanResult(instances ...models.GPUInstance) *models.ScanResult { + return &models.ScanResult{ + Timestamp: time.Date(2026, 4, 8, 12, 0, 0, 0, time.UTC), + Instances: instances, + Summary: models.ScanSummary{ + TotalInstances: len(instances), + TotalMonthlyCost: sumMonthlyCost(instances), + TotalEstimatedWaste: sumWaste(instances), + }, + } +} + +func sumMonthlyCost(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.MonthlyCost + } + return total +} + +func sumWaste(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.EstimatedSavings + } + return total +} + +func inst(id string, monthlyCost float64) models.GPUInstance { + return models.GPUInstance{ + InstanceID: id, + InstanceType: "g6e.16xlarge", + GPUModel: "L40S", + GPUCount: 1, + MonthlyCost: monthlyCost, + HourlyCost: monthlyCost / 730, + State: "ready", + Source: models.SourceK8sNode, + PricingModel: "on-demand", + } +} + +func TestCompare_AddedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 1 { + t.Fatalf("expected 1 added, got %d", len(result.Added)) + } + if result.Added[0].InstanceID != "i-bbb" { + t.Errorf("expected added instance i-bbb, got %s", result.Added[0].InstanceID) + } + if result.CostSummary.AddedCost != 3000 { + t.Errorf("expected added cost 3000, got %.0f", result.CostSummary.AddedCost) + } +} + +func TestCompare_RemovedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750)) + + result := Compare(old, new) + + if len(result.Removed) != 1 { + t.Fatalf("expected 1 removed, got %d", len(result.Removed)) + } + if result.Removed[0].InstanceID != "i-bbb" { + t.Errorf("expected removed instance i-bbb, got %s", result.Removed[0].InstanceID) + } + if result.CostSummary.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", result.CostSummary.RemovedSavings) + } +} + +func TestCompare_CostChange(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 4200)) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + if result.Changed[0].CostDelta != -2550 { + t.Errorf("expected cost delta -2550, got %.0f", result.Changed[0].CostDelta) + } + found := false + for _, c := range result.Changed[0].Changes { + if c == "Cost: $6750 → $4200 (-$2550/mo)" { + found = true + } + } + if !found { + t.Errorf("expected cost change string, got %v", result.Changed[0].Changes) + } +} + +func TestCompare_AllFieldChanges(t *testing.T) { + oldInst := inst("i-aaa", 6750) + oldInst.InstanceType = "g6e.16xlarge" + oldInst.PricingModel = "on-demand" + oldInst.State = "ready" + oldInst.GPUAllocated = 0 + oldInst.WasteSignals = []models.WasteSignal{{Severity: models.SeverityCritical}} + + newInst := inst("i-aaa", 4200) + newInst.InstanceType = "g6e.12xlarge" + newInst.PricingModel = "reserved" + newInst.State = "not-ready" + newInst.GPUAllocated = 2 + newInst.WasteSignals = nil + + old := scanResult(oldInst) + new := scanResult(newInst) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + + changes := result.Changed[0].Changes + expected := []string{ + "Instance type: g6e.16xlarge → g6e.12xlarge", + "Pricing: on-demand → reserved", + "Cost: $6750 → $4200 (-$2550/mo)", + "State: ready → not-ready", + "GPU allocated: 0 → 2", + "Severity: critical → (none)", + } + if len(changes) != len(expected) { + t.Fatalf("expected %d changes, got %d: %v", len(expected), len(changes), changes) + } + for i, exp := range expected { + if changes[i] != exp { + t.Errorf("change[%d]: expected %q, got %q", i, exp, changes[i]) + } + } +} + +func TestCompare_UnchangedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 0 { + t.Errorf("expected 0 added, got %d", len(result.Added)) + } + if len(result.Removed) != 0 { + t.Errorf("expected 0 removed, got %d", len(result.Removed)) + } + if len(result.Changed) != 0 { + t.Errorf("expected 0 changed, got %d", len(result.Changed)) + } + if result.UnchangedCount != 2 { + t.Errorf("expected 2 unchanged, got %d", result.UnchangedCount) + } +} + +func TestCompare_CostSummary(t *testing.T) { + oldA := inst("i-aaa", 6750) + oldA.EstimatedSavings = 6750 + oldB := inst("i-bbb", 3000) + + newA := inst("i-aaa", 6750) + newA.EstimatedSavings = 6750 + newC := inst("i-ccc", 2000) + + old := scanResult(oldA, oldB) + new := scanResult(newA, newC) + + result := Compare(old, new) + + cs := result.CostSummary + if cs.OldTotalMonthlyCost != 9750 { + t.Errorf("expected old total 9750, got %.0f", cs.OldTotalMonthlyCost) + } + if cs.NewTotalMonthlyCost != 8750 { + t.Errorf("expected new total 8750, got %.0f", cs.NewTotalMonthlyCost) + } + if cs.CostChange != -1000 { + t.Errorf("expected cost change -1000, got %.0f", cs.CostChange) + } + if cs.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", cs.RemovedSavings) + } + if cs.AddedCost != 2000 { + t.Errorf("expected added cost 2000, got %.0f", cs.AddedCost) + } +} + +func TestCompare_EmptyScans(t *testing.T) { + old := scanResult() + new := scanResult() + + result := Compare(old, new) + + if len(result.Added) != 0 || len(result.Removed) != 0 || len(result.Changed) != 0 { + t.Errorf("expected no changes for empty scans") + } + if result.UnchangedCount != 0 { + t.Errorf("expected 0 unchanged, got %d", result.UnchangedCount) + } +} diff --git a/internal/models/models.go b/internal/models/models.go index 0fd6557..153ecec 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -57,6 +57,9 @@ type GPUInstance struct { Name string `json:"name"` // from Name tag or endpoint name Tags map[string]string `json:"tags,omitempty"` + // Network (populated for EC2) + PrivateDnsName string `json:"private_dns_name,omitempty"` + // GPU hardware InstanceType string `json:"instance_type"` GPUModel string `json:"gpu_model"` @@ -66,6 +69,7 @@ type GPUInstance struct { // Kubernetes (populated for k8s-node source) ClusterName string `json:"cluster_name,omitempty"` + K8sNodeName string `json:"k8s_node_name,omitempty"` GPUAllocated int `json:"gpu_allocated,omitempty"` // State @@ -85,10 +89,11 @@ type GPUInstance struct { InvocationCount *int64 `json:"invocation_count,omitempty"` // Cost - PricingModel string `json:"pricing_model"` // on-demand, spot, reserved, savings-plan - HourlyCost float64 `json:"hourly_cost"` - MonthlyCost float64 `json:"monthly_cost"` - MTDCost *float64 `json:"mtd_cost,omitempty"` + PricingModel string `json:"pricing_model"` // on-demand, spot, reserved, savings-plan + HourlyCost float64 `json:"hourly_cost"` + MonthlyCost float64 `json:"monthly_cost"` + SpotHourlyCost *float64 `json:"spot_hourly_cost,omitempty"` + MTDCost *float64 `json:"mtd_cost,omitempty"` // Analysis results (populated by analysis engine) WasteSignals []WasteSignal `json:"waste_signals,omitempty"` @@ -98,7 +103,7 @@ type GPUInstance struct { // WasteSignal represents a detected waste indicator on a GPU instance. type WasteSignal struct { - Type string `json:"type"` // idle, low_utilization, oversized_gpu, pricing_mismatch, stale, low_invocations + Type string `json:"type"` // idle, low_utilization, oversized_gpu, pricing_mismatch, stale, low_invocations, spot_eligible Severity Severity `json:"severity"` Confidence float64 `json:"confidence"` // 0.0 - 1.0 Evidence string `json:"evidence"` @@ -117,12 +122,15 @@ type Recommendation struct { // ScanResult holds the complete output of a gpuaudit scan. type ScanResult struct { - Timestamp time.Time `json:"timestamp"` - AccountID string `json:"account_id"` - Regions []string `json:"regions"` - ScanDuration string `json:"scan_duration"` - Instances []GPUInstance `json:"instances"` - Summary ScanSummary `json:"summary"` + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` + Targets []string `json:"targets,omitempty"` + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` } // ScanSummary provides aggregate statistics for a scan. @@ -137,5 +145,39 @@ type ScanSummary struct { HealthyCount int `json:"healthy_count"` } +// TargetSummary provides per-target aggregate statistics. +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +// TargetErrorInfo describes a target that failed to scan. +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} + +// MaxSeverity returns the highest severity among the given waste signals. +func MaxSeverity(signals []WasteSignal) Severity { + max := Severity("") + for _, s := range signals { + if s.Severity == SeverityCritical { + return SeverityCritical + } + if s.Severity == SeverityWarning { + max = SeverityWarning + } + if s.Severity == SeverityInfo && max == "" { + max = SeverityInfo + } + } + return max +} + // Ptr is a convenience helper for creating pointer values in tests and literals. func Ptr[T any](v T) *T { return &v } diff --git a/internal/output/diff.go b/internal/output/diff.go new file mode 100644 index 0000000..db2f7c9 --- /dev/null +++ b/internal/output/diff.go @@ -0,0 +1,132 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package output + +import ( + "encoding/json" + "fmt" + "io" + "sort" + "strings" + + "github.com/gpuaudit/cli/internal/diff" + "github.com/gpuaudit/cli/internal/models" +) + +// FormatDiffTable writes a human-readable diff report. +func FormatDiffTable(w io.Writer, d *diff.DiffResult) { + fmt.Fprintf(w, "\n gpuaudit diff — %s → %s\n\n", d.OldTimestamp, d.NewTimestamp) + + cs := d.CostSummary + + oldCount := len(d.Removed) + len(d.Changed) + d.UnchangedCount + newCount := len(d.Added) + len(d.Changed) + d.UnchangedCount + + // Cost summary box + boxWidth := 58 // inner width between │ markers + boxLine := strings.Repeat("─", boxWidth) + fmt.Fprintf(w, " ┌%s┐\n", boxLine) + writeBoxLine(w, "Cost Delta", boxWidth) + fmt.Fprintf(w, " ├%s┤\n", boxLine) + writeBoxLine(w, fmt.Sprintf("Monthly spend: $%-9.0f → $%-9.0f (%s)", + cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, diffFmtDelta(cs.CostChange)), boxWidth) + writeBoxLine(w, fmt.Sprintf("Estimated waste: $%-9.0f → $%-9.0f (%s)", + cs.OldTotalWaste, cs.NewTotalWaste, diffFmtDelta(cs.WasteChange)), boxWidth) + writeBoxLine(w, fmt.Sprintf("Instances: %d → %d (-%d removed, +%d added)", + oldCount, newCount, len(d.Removed), len(d.Added)), boxWidth) + fmt.Fprintf(w, " └%s┘\n", boxLine) + + // Removed + if len(d.Removed) > 0 { + sortInstancesByCost(d.Removed) + fmt.Fprintf(w, "\n REMOVED — %d instance(s), -$%.0f/mo\n\n", len(d.Removed), cs.RemovedSavings) + printDiffInstanceTable(w, d.Removed) + } + + // Added + if len(d.Added) > 0 { + sortInstancesByCost(d.Added) + fmt.Fprintf(w, "\n ADDED — %d instance(s), +$%.0f/mo\n\n", len(d.Added), cs.AddedCost) + printDiffInstanceTable(w, d.Added) + } + + // Changed + if len(d.Changed) > 0 { + fmt.Fprintf(w, "\n CHANGED — %d instance(s)\n\n", len(d.Changed)) + fmt.Fprintf(w, " %-36s %s\n", "Instance", "Change") + fmt.Fprintf(w, " %s %s\n", strings.Repeat("─", 36), strings.Repeat("─", 50)) + for _, c := range d.Changed { + name := c.New.Name + if name == "" { + name = c.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + for i, change := range c.Changes { + if i == 0 { + fmt.Fprintf(w, " %-36s %s\n", name, change) + } else { + fmt.Fprintf(w, " %-36s %s\n", "", change) + } + } + } + fmt.Fprintln(w) + } + + // Unchanged + if d.UnchangedCount > 0 { + fmt.Fprintf(w, " UNCHANGED — %d instance(s)\n\n", d.UnchangedCount) + } +} + +func printDiffInstanceTable(w io.Writer, instances []models.GPUInstance) { + fmt.Fprintf(w, " %-36s %-26s %10s\n", "Instance", "Type", "Monthly") + fmt.Fprintf(w, " %s %s %s\n", + strings.Repeat("─", 36), strings.Repeat("─", 26), strings.Repeat("─", 10)) + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + fmt.Fprintf(w, " %-36s %-26s $%9.0f\n", name, typeDesc, inst.MonthlyCost) + } +} + +func sortInstancesByCost(instances []models.GPUInstance) { + sort.Slice(instances, func(i, j int) bool { + return instances[i].MonthlyCost > instances[j].MonthlyCost + }) +} + +func writeBoxLine(w io.Writer, content string, width int) { + // Pad content to fill the box width (with 2-char margin on each side) + inner := width - 4 // 2 spaces on each side + if len(content) > inner { + content = content[:inner] + } + fmt.Fprintf(w, " │ %-*s │\n", inner, content) +} + +func diffFmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +// FormatDiffJSON writes the diff result as pretty-printed JSON. +func FormatDiffJSON(w io.Writer, d *diff.DiffResult) error { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(d) +} diff --git a/internal/output/markdown.go b/internal/output/markdown.go index 13290bb..58995c5 100644 --- a/internal/output/markdown.go +++ b/internal/output/markdown.go @@ -31,14 +31,41 @@ func FormatMarkdown(w io.Writer, result *models.ScanResult) { fmt.Fprintf(w, "| Warning | %d |\n", s.WarningCount) fmt.Fprintf(w, "| Healthy | %d |\n\n", s.HealthyCount) + // Per-target breakdown + if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "## By Target\n\n") + fmt.Fprintf(w, "| Target | Instances | Spend/mo | Waste/mo | Waste |\n") + fmt.Fprintf(w, "|---|---|---|---|---|\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, "| %s | %d | $%.0f | $%.0f | %.0f%% |\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintln(w) + } + + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, "## Warnings\n\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, "- **%s** — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } + if s.TotalInstances == 0 { fmt.Fprintf(w, "No GPU instances found.\n") return } fmt.Fprintf(w, "## Findings\n\n") - fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") - fmt.Fprintf(w, "|---|---|---|---|---|---|\n") + multiTarget := len(result.TargetSummaries) > 1 + if multiTarget { + fmt.Fprintf(w, "| Instance | Target | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|---|\n") + } else { + fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|\n") + } for _, inst := range result.Instances { name := inst.Name @@ -61,8 +88,14 @@ func FormatMarkdown(w io.Writer, result *models.ScanResult) { savings = fmt.Sprintf("$%.0f/mo", inst.EstimatedSavings) } - fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", - name, inst.InstanceType, inst.GPUCount, inst.GPUModel, - inst.MonthlyCost, signal, savings, rec) + if multiTarget { + fmt.Fprintf(w, "| %s | %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.AccountID, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } else { + fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } } } diff --git a/internal/output/slack.go b/internal/output/slack.go index 530afe7..f8fc334 100644 --- a/internal/output/slack.go +++ b/internal/output/slack.go @@ -34,6 +34,27 @@ func FormatSlack(w io.Writer, result *models.ScanResult) error { blocks = append(blocks, map[string]any{"type": "divider"}) + // Per-target breakdown + if len(result.TargetSummaries) > 1 { + lines := []string{"*By Target*"} + for _, ts := range result.TargetSummaries { + lines = append(lines, fmt.Sprintf("• `%s` — %d instances, $%.0f/mo spend, $%.0f/mo waste (%.0f%%)", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + blocks = append(blocks, map[string]any{"type": "divider"}) + } + + // Target errors + if len(result.TargetErrors) > 0 { + lines := []string{":warning: *Target Warnings*"} + for _, te := range result.TargetErrors { + lines = append(lines, fmt.Sprintf("• `%s` — %s", te.Target, te.Error)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + } + // Critical findings critical, warning, _ := groupBySeverity(result.Instances) diff --git a/internal/output/table.go b/internal/output/table.go index 3f73232..6052fe8 100644 --- a/internal/output/table.go +++ b/internal/output/table.go @@ -34,6 +34,8 @@ func FormatTable(w io.Writer, result *models.ScanResult) { fmt.Fprintf(w, " │ Estimated monthly waste: $%-10.0f (%4.0f%%) │\n", s.TotalEstimatedWaste, s.WastePercent) fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n\n") + printTargetSummary(w, result) + if s.TotalInstances == 0 { fmt.Fprintf(w, " No GPU instances found.\n\n") return @@ -42,14 +44,16 @@ func FormatTable(w io.Writer, result *models.ScanResult) { // Group instances by severity critical, warning, healthy := groupBySeverity(result.Instances) + multiTarget := len(result.TargetSummaries) > 1 + if len(critical) > 0 { fmt.Fprintf(w, " CRITICAL — %d instance(s), $%.0f/mo potential savings\n\n", len(critical), sumSavings(critical)) - printInstanceTable(w, critical) + printInstanceTable(w, critical, multiTarget) } if len(warning) > 0 { fmt.Fprintf(w, " WARNING — %d instance(s), $%.0f/mo potential savings\n\n", len(warning), sumSavings(warning)) - printInstanceTable(w, warning) + printInstanceTable(w, warning, multiTarget) } if len(healthy) > 0 { @@ -57,17 +61,54 @@ func FormatTable(w io.Writer, result *models.ScanResult) { } } -func printInstanceTable(w io.Writer, instances []models.GPUInstance) { - // Header - fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", - "Instance", "Type", "Monthly", "Signal", "Recommendation") - fmt.Fprintf(w, " %s %s %s %s %s\n", - strings.Repeat("─", 36), - strings.Repeat("─", 26), - strings.Repeat("─", 10), - strings.Repeat("─", 16), - strings.Repeat("─", 50), - ) +func printTargetSummary(w io.Writer, result *models.ScanResult) { + if len(result.TargetSummaries) < 2 { + return + } + + fmt.Fprintf(w, " By Target\n") + fmt.Fprintf(w, " ┌──────────────┬───────────┬───────────┬───────────┬───────┐\n") + fmt.Fprintf(w, " │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │\n") + fmt.Fprintf(w, " ├──────────────┼───────────┼───────────┼───────────┼───────┤\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, " │ %-12s │ %9d │ $%8.0f │ $%8.0f │ %4.0f%% │\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintf(w, " └──────────────┴───────────┴───────────┴───────────┴───────┘\n\n") + + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, " Warnings\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, " ✗ %s — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } +} + +func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget bool) { + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s %10s %-16s %s\n", + "Instance", "Target", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 14), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } else { + fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", + "Instance", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } for _, inst := range instances { name := inst.Name @@ -87,6 +128,9 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance) { signal := "" if len(inst.WasteSignals) > 0 { signal = inst.WasteSignals[0].Type + if inst.AvgGPUUtilization != nil { + signal += fmt.Sprintf(" [GPU %.0f%%]", *inst.AvgGPUUtilization) + } } rec := "" @@ -94,16 +138,20 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance) { rec = inst.Recommendations[0].Description } - fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", - name, typeDesc, inst.MonthlyCost, signal, rec) + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s $%9.0f %-16s %s\n", + name, inst.AccountID, typeDesc, inst.MonthlyCost, signal, rec) + } else { + fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", + name, typeDesc, inst.MonthlyCost, signal, rec) + } } fmt.Fprintln(w) } func groupBySeverity(instances []models.GPUInstance) (critical, warning, healthy []models.GPUInstance) { for _, inst := range instances { - maxSev := maxSeverity(inst.WasteSignals) - switch maxSev { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: critical = append(critical, inst) case models.SeverityWarning: @@ -125,22 +173,6 @@ func groupBySeverity(instances []models.GPUInstance) (critical, warning, healthy return } -func maxSeverity(signals []models.WasteSignal) models.Severity { - max := models.Severity("") - for _, s := range signals { - if s.Severity == models.SeverityCritical { - return models.SeverityCritical - } - if s.Severity == models.SeverityWarning { - max = models.SeverityWarning - } - if s.Severity == models.SeverityInfo && max == "" { - max = models.SeverityInfo - } - } - return max -} - func sumSavings(instances []models.GPUInstance) float64 { total := 0.0 for _, inst := range instances { diff --git a/internal/prometheus/query.go b/internal/prometheus/query.go new file mode 100644 index 0000000..063cf09 --- /dev/null +++ b/internal/prometheus/query.go @@ -0,0 +1,74 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package prometheus + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) + +// QueryHTTP executes a PromQL instant query against a Prometheus-compatible HTTP API +// and returns a map from the given labelName to its metric value. +func QueryHTTP(ctx context.Context, baseURL, query, labelName string) (map[string]float64, error) { + u := fmt.Sprintf("%s/api/v1/query?query=%s", strings.TrimRight(baseURL, "/"), url.QueryEscape(query)) + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return ParseResponse(data, labelName) +} + +// ParseResponse extracts metric values from a Prometheus API JSON response, +// keyed by the given label name. +func ParseResponse(data []byte, labelName string) (map[string]float64, error) { + var resp struct { + Status string `json:"status"` + Data struct { + ResultType string `json:"resultType"` + Result []struct { + Metric map[string]string `json:"metric"` + Value []json.RawMessage `json:"value"` + } `json:"result"` + } `json:"data"` + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + if resp.Status != "success" { + return nil, fmt.Errorf("query returned status %q", resp.Status) + } + + results := make(map[string]float64) + for _, r := range resp.Data.Result { + key := r.Metric[labelName] + if key == "" || len(r.Value) < 2 { + continue + } + var valStr string + if err := json.Unmarshal(r.Value[1], &valStr); err != nil { + continue + } + val, err := strconv.ParseFloat(valStr, 64) + if err != nil { + continue + } + results[key] = val + } + return results, nil +} diff --git a/internal/prometheus/query_test.go b/internal/prometheus/query_test.go new file mode 100644 index 0000000..6849b94 --- /dev/null +++ b/internal/prometheus/query_test.go @@ -0,0 +1,98 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package prometheus + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestParseResponse_ExtractsByLabel(t *testing.T) { + data := []byte(`{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"Hostname": "ip-10-0-1-1"}, "value": [1700000000, "45.2"]}, + {"metric": {"Hostname": "ip-10-0-1-2"}, "value": [1700000000, "12.8"]} + ] + } + }`) + + results, err := ParseResponse(data, "Hostname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + if results["ip-10-0-1-1"] != 45.2 { + t.Errorf("expected 45.2, got %f", results["ip-10-0-1-1"]) + } + if results["ip-10-0-1-2"] != 12.8 { + t.Errorf("expected 12.8, got %f", results["ip-10-0-1-2"]) + } +} + +func TestParseResponse_SkipsMissingLabel(t *testing.T) { + data := []byte(`{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"other": "value"}, "value": [1700000000, "45.2"]} + ] + } + }`) + + results, err := ParseResponse(data, "Hostname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } +} + +func TestParseResponse_ErrorStatus(t *testing.T) { + data := []byte(`{"status": "error", "errorType": "bad_data", "error": "parse error"}`) + + _, err := ParseResponse(data, "node") + if err == nil { + t.Error("expected error for non-success status") + } +} + +func TestQueryHTTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/query" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + query := r.URL.Query().Get("query") + if query == "" { + t.Error("expected query parameter") + } + fmt.Fprintf(w, `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "host1"}, "value": [1700000000, "55.5"]} + ] + } + }`) + })) + defer srv.Close() + + results, err := QueryHTTP(context.Background(), srv.URL, "up", "node") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if results["host1"] != 55.5 { + t.Errorf("expected 55.5, got %f", results["host1"]) + } +} diff --git a/internal/providers/aws/cloudwatch.go b/internal/providers/aws/cloudwatch.go index 819261c..ab06d3e 100644 --- a/internal/providers/aws/cloudwatch.go +++ b/internal/providers/aws/cloudwatch.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -79,6 +80,69 @@ func EnrichSageMakerMetrics(ctx context.Context, client CloudWatchClient, instan return nil } +// EnrichK8sGPUMetrics populates GPU utilization metrics on K8s nodes using CloudWatch Container Insights. +func EnrichK8sGPUMetrics(ctx context.Context, client CloudWatchClient, instances []models.GPUInstance, clusterName string, window MetricWindow) { + type nodeRef struct { + index int + instanceID string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode { + continue + } + if inst.AvgGPUUtilization != nil { + continue + } + if !strings.HasPrefix(inst.InstanceID, "i-") { + continue + } + nodes = append(nodes, nodeRef{index: i, instanceID: inst.InstanceID}) + } + if len(nodes) == 0 { + return + } + + now := time.Now() + start := now.Add(-window.Duration) + + clusterDim := cwtypes.Dimension{ + Name: aws.String("ClusterName"), + Value: aws.String(clusterName), + } + + enriched := 0 + for _, node := range nodes { + instanceDim := cwtypes.Dimension{ + Name: aws.String("InstanceId"), + Value: aws.String(node.instanceID), + } + + safeID := strings.ReplaceAll(node.instanceID, "-", "_") + + queries := []cwtypes.MetricDataQuery{ + metricQuery2("gpu_util_"+safeID, "ContainerInsights", "node_gpu_utilization", "Average", window.Period, clusterDim, instanceDim), + metricQuery2("gpu_mem_"+safeID, "ContainerInsights", "node_gpu_memory_utilization", "Average", window.Period, clusterDim, instanceDim), + } + + results, err := fetchMetrics(ctx, client, queries, start, now) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Container Insights metrics unavailable: %v\n", err) + break + } + + if results["gpu_util_"+safeID] != nil { + instances[node.index].AvgGPUUtilization = results["gpu_util_"+safeID] + instances[node.index].AvgGPUMemUtilization = results["gpu_mem_"+safeID] + enriched++ + } + } + if enriched > 0 { + fmt.Fprintf(os.Stderr, " CloudWatch: got GPU metrics for %d of %d nodes\n", enriched, len(nodes)) + } +} + func getEC2Metrics(ctx context.Context, client CloudWatchClient, instanceID string, window MetricWindow) (map[string]*float64, error) { now := time.Now() start := now.Add(-window.Duration) diff --git a/internal/providers/aws/cloudwatch_test.go b/internal/providers/aws/cloudwatch_test.go new file mode 100644 index 0000000..6dd1d8f --- /dev/null +++ b/internal/providers/aws/cloudwatch_test.go @@ -0,0 +1,125 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + cwtypes "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" + + "github.com/gpuaudit/cli/internal/models" +) + +type mockCloudWatchClient struct { + output *cloudwatch.GetMetricDataOutput + err error +} + +func (m *mockCloudWatchClient) GetMetricData(ctx context.Context, params *cloudwatch.GetMetricDataInput, optFns ...func(*cloudwatch.Options)) (*cloudwatch.GetMetricDataOutput, error) { + if m.err != nil { + return nil, m.err + } + return m.output, nil +} + +func TestEnrichK8sGPUMetrics_PopulatesUtilization(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{ + MetricDataResults: []cwtypes.MetricDataResult{ + {Id: aws.String("gpu_util_i_abc123"), Values: []float64{45.0, 50.0, 55.0}}, + {Id: aws.String("gpu_mem_i_abc123"), Values: []float64{30.0, 35.0, 40.0}}, + }, + }, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "ml-cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected avg GPU util 50.0, got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 35.0 { + t.Errorf("expected avg GPU mem util 35.0, got %f", *instances[0].AvgGPUMemUtilization) + } +} + +func TestEnrichK8sGPUMetrics_SkipsNonK8sNodes(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-ec2", Source: models.SourceEC2}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for non-K8s instance") + } +} + +func TestEnrichK8sGPUMetrics_SkipsNodesWithoutInstanceID(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "node-hostname", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for node without EC2 instance ID") + } +} + +func TestEnrichK8sGPUMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + AvgGPUUtilization: &gpuUtil, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Errorf("expected existing value 75.0 to be preserved, got %f", *instances[0].AvgGPUUtilization) + } +} + +func TestEnrichK8sGPUMetrics_HandlesAPIError(t *testing.T) { + client := &mockCloudWatchClient{ + err: fmt.Errorf("access denied"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-abc123", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util after API error") + } +} diff --git a/internal/providers/aws/ec2.go b/internal/providers/aws/ec2.go index 0fa6738..cb82a46 100644 --- a/internal/providers/aws/ec2.go +++ b/internal/providers/aws/ec2.go @@ -97,13 +97,14 @@ func ec2InstanceToGPU(inst ec2types.Instance, accountID, region string) *models. // TODO: detect RI/SP coverage via Cost Explorer return &models.GPUInstance{ - InstanceID: aws.ToString(inst.InstanceId), - Source: models.SourceEC2, - AccountID: accountID, - Region: region, - Name: name, - Tags: tags, - InstanceType: instanceType, + InstanceID: aws.ToString(inst.InstanceId), + Source: models.SourceEC2, + AccountID: accountID, + Region: region, + Name: name, + Tags: tags, + PrivateDnsName: aws.ToString(inst.PrivateDnsName), + InstanceType: instanceType, GPUModel: spec.GPUModel, GPUCount: spec.GPUCount, GPUVRAMGiB: spec.GPUVRAMGiB, diff --git a/internal/providers/aws/multiaccount.go b/internal/providers/aws/multiaccount.go new file mode 100644 index 0000000..298c475 --- /dev/null +++ b/internal/providers/aws/multiaccount.go @@ -0,0 +1,161 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +// Target represents a resolved scan target with its credentials. +type Target struct { + AccountID string + Config aws.Config +} + +// TargetError records a target that failed credential resolution. +type TargetError struct { + AccountID string + Err error +} + +// STSClient is the subset of the STS API we need. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// OrgClient is the subset of the Organizations API we need. +type OrgClient interface { + ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) +} + +// ResolveTargets determines which accounts to scan and obtains credentials for each. +// +// Behaviour: +// - No --targets/--org: returns self only (uses baseCfg, no AssumeRole) +// - --targets + --role: AssumeRole for each, self included by default +// - --org + --role: ListAccounts, filter Active, AssumeRole for non-self accounts +// - --skip-self: exclude caller's account +// - Self account is never AssumeRole'd — uses original credentials +// - Failed AssumeRole calls are collected as TargetError, not fatal +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) (selfAccount string, targets []Target, targetErrors []TargetError) { + // Identify the caller's own account. + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return "", nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + } + selfAccount = aws.ToString(identity.Account) + + // Determine the list of account IDs to scan. + var accountIDs []string + + switch { + case opts.OrgScan: + activeAccounts, listErr := listActiveOrgAccounts(ctx, orgClient) + if listErr != nil { + return selfAccount, nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} + } + accountIDs = activeAccounts + case len(opts.Targets) > 0: + // Always include self unless it is already in the list or --skip-self is set. + seen := make(map[string]bool) + for _, id := range opts.Targets { + if !seen[id] { + accountIDs = append(accountIDs, id) + seen[id] = true + } + } + if !seen[selfAccount] && !opts.SkipSelf { + // Prepend self so it appears first. + accountIDs = append([]string{selfAccount}, accountIDs...) + } + default: + // No multi-target flags — scan self only. + return selfAccount, []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + } + + for _, acctID := range accountIDs { + if opts.SkipSelf && acctID == selfAccount { + continue + } + + if acctID == selfAccount { + // Self: use original credentials, no AssumeRole. + targets = append(targets, Target{AccountID: selfAccount, Config: baseCfg}) + continue + } + + // AssumeRole into the target account. + cfg, assumeErr := assumeRole(ctx, baseCfg, stsClient, acctID, opts.Role, opts.ExternalID) + if assumeErr != nil { + targetErrors = append(targetErrors, TargetError{AccountID: acctID, Err: assumeErr}) + continue + } + targets = append(targets, Target{AccountID: acctID, Config: cfg}) + } + + return selfAccount, targets, targetErrors +} + +// assumeRole assumes a role in the given account and returns an aws.Config +// with the temporary credentials. +func assumeRole(ctx context.Context, baseCfg aws.Config, stsClient STSClient, accountID, roleName, externalID string) (aws.Config, error) { + roleARN := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleName) + + input := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleARN), + RoleSessionName: aws.String("gpuaudit"), + } + if externalID != "" { + input.ExternalId = aws.String(externalID) + } + + result, err := stsClient.AssumeRole(ctx, input) + if err != nil { + return aws.Config{}, fmt.Errorf("AssumeRole %s: %w", roleARN, err) + } + + creds := result.Credentials + cfg := baseCfg.Copy() + cfg.Credentials = credentials.NewStaticCredentialsProvider( + aws.ToString(creds.AccessKeyId), + aws.ToString(creds.SecretAccessKey), + aws.ToString(creds.SessionToken), + ) + + return cfg, nil +} + +// listActiveOrgAccounts returns the account IDs of all active accounts in the organization. +func listActiveOrgAccounts(ctx context.Context, orgClient OrgClient) ([]string, error) { + var accountIDs []string + var nextToken *string + + for { + out, err := orgClient.ListAccounts(ctx, &organizations.ListAccountsInput{ + NextToken: nextToken, + }) + if err != nil { + return nil, err + } + for _, acct := range out.Accounts { + if acct.Status == orgtypes.AccountStatusActive { + accountIDs = append(accountIDs, aws.ToString(acct.Id)) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + return accountIDs, nil +} diff --git a/internal/providers/aws/multiaccount_test.go b/internal/providers/aws/multiaccount_test.go new file mode 100644 index 0000000..bc2ba11 --- /dev/null +++ b/internal/providers/aws/multiaccount_test.go @@ -0,0 +1,304 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) + +// --- Mock STS client --- + +type mockSTSClient struct { + callerAccount string + assumeResults map[string]*sts.AssumeRoleOutput // accountID -> output + assumeErrors map[string]error // accountID -> error +} + +func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return &sts.GetCallerIdentityOutput{ + Account: aws.String(m.callerAccount), + }, nil +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Extract account ID from the role ARN: arn:aws:iam:::role/ + arn := aws.ToString(params.RoleArn) + // Simple parse: find the account between the 4th and 5th colons + accountID := parseAccountFromARN(arn) + + if err, ok := m.assumeErrors[accountID]; ok { + return nil, err + } + if out, ok := m.assumeResults[accountID]; ok { + return out, nil + } + return nil, fmt.Errorf("no mock configured for account %s", accountID) +} + +func parseAccountFromARN(arn string) string { + // arn:aws:iam::123456789012:role/name + colons := 0 + start := 0 + for i, c := range arn { + if c == ':' { + colons++ + if colons == 4 { + start = i + 1 + } + if colons == 5 { + return arn[start:i] + } + } + } + return "" +} + +// --- Mock Org client --- + +type mockOrgClient struct { + accounts []orgtypes.Account + err error +} + +func (m *mockOrgClient) ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) { + if m.err != nil { + return nil, m.err + } + return &organizations.ListAccountsOutput{Accounts: m.accounts}, nil +} + +// Helper to build a successful AssumeRole result with dummy credentials. +func assumeRoleOK(accountID string) *sts.AssumeRoleOutput { + exp := time.Now().Add(1 * time.Hour) + return &sts.AssumeRoleOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AKID-" + accountID), + SecretAccessKey: aws.String("SECRET-" + accountID), + SessionToken: aws.String("TOKEN-" + accountID), + Expiration: &exp, + }, + } +} + +func TestResolveTargets_NoTargets_ReturnsSelfOnly(t *testing.T) { + stsClient := &mockSTSClient{callerAccount: "111111111111"} + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{} + + self, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + if self != "111111111111" { + t.Errorf("expected self account 111111111111, got %s", self) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (self), got %d", len(targets)) + } + if targets[0].AccountID != "111111111111" { + t.Errorf("expected account 111111111111, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_ExplicitTargets_ReturnsSelfPlusAssumed(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + "333333333333": assumeRoleOK("333333333333"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222", "333333333333"}, + Role: "AuditRole", + } + + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d", len(errs)) + } + // Self + 2 explicit targets = 3 + if len(targets) != 3 { + t.Fatalf("expected 3 targets, got %d", len(targets)) + } + + // Verify self is included + found := false + for _, tgt := range targets { + if tgt.AccountID == "111111111111" { + found = true + break + } + } + if !found { + t.Error("self account 111111111111 not found in targets") + } + + // Verify assumed targets + for _, acct := range []string{"222222222222", "333333333333"} { + found := false + for _, tgt := range targets { + if tgt.AccountID == acct { + found = true + break + } + } + if !found { + t.Errorf("account %s not found in targets", acct) + } + } +} + +func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222"}, + Role: "AuditRole", + SkipSelf: true, + } + + self, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d", len(errs)) + } + if self != "111111111111" { + t.Errorf("expected self account 111111111111, got %s", self) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (no self), got %d", len(targets)) + } + if targets[0].AccountID != "222222222222" { + t.Errorf("expected account 222222222222, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_PartialFailure(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + assumeErrors: map[string]error{ + "333333333333": fmt.Errorf("access denied"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222", "333333333333"}, + Role: "AuditRole", + } + + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + // Self + 222 succeeded, 333 failed + if len(targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(targets)) + } + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if errs[0].AccountID != "333333333333" { + t.Errorf("expected error for 333333333333, got %s", errs[0].AccountID) + } +} + +func TestResolveTargets_OrgDiscovery(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + "444444444444": assumeRoleOK("444444444444"), + }, + } + orgClient := &mockOrgClient{ + accounts: []orgtypes.Account{ + {Id: aws.String("111111111111"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("222222222222"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("333333333333"), Status: orgtypes.AccountStatusSuspended}, + {Id: aws.String("444444444444"), Status: orgtypes.AccountStatusActive}, + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + OrgScan: true, + Role: "AuditRole", + } + + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + // Active accounts: 111 (self), 222, 444. Suspended 333 is filtered. + if len(targets) != 3 { + t.Fatalf("expected 3 targets (self + 2 active non-self), got %d", len(targets)) + } + + // Verify suspended account is excluded + for _, tgt := range targets { + if tgt.AccountID == "333333333333" { + t.Error("suspended account 333333333333 should be excluded") + } + } + + // Verify self is included + found := false + for _, tgt := range targets { + if tgt.AccountID == "111111111111" { + found = true + break + } + } + if !found { + t.Error("self account 111111111111 not found in targets") + } +} + +func TestResolveTargets_SelfInExplicitTargets_NotAssumed(t *testing.T) { + // If the caller's own account appears in --targets, it should use baseCfg (no AssumeRole). + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + // No AssumeRole result for self — it should not be called + assumeErrors: map[string]error{ + "111111111111": fmt.Errorf("should not assume role for self"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "AuditRole", + } + + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + // Self (from targets list, no duplicate) + 222 + if len(targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(targets)) + } +} diff --git a/internal/providers/aws/prometheus.go b/internal/providers/aws/prometheus.go new file mode 100644 index 0000000..bef722f --- /dev/null +++ b/internal/providers/aws/prometheus.go @@ -0,0 +1,130 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/gpuaudit/cli/internal/models" + prom "github.com/gpuaudit/cli/internal/prometheus" +) + +// EnrichEC2PrometheusGPUMetrics queries a Prometheus endpoint for DCGM GPU metrics +// on EC2 instances that don't already have AvgGPUUtilization populated. +// It matches Prometheus results to EC2 instances via private DNS hostname. +func EnrichEC2PrometheusGPUMetrics(ctx context.Context, promURL string, instances []models.GPUInstance) int { + if promURL == "" { + return 0 + } + + type instRef struct { + index int + hostname string + ip string + } + var refs []instRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceEC2 || inst.State != "running" { + continue + } + if inst.AvgGPUUtilization != nil { + continue + } + if inst.PrivateDnsName == "" { + continue + } + hostname := strings.SplitN(inst.PrivateDnsName, ".", 2)[0] + ip := extractIPFromDNS(inst.PrivateDnsName) + refs = append(refs, instRef{index: i, hostname: hostname, ip: ip}) + } + if len(refs) == 0 { + return 0 + } + + // Build lookup maps: hostname → index, ip → index + hostnameToIdx := make(map[string]int, len(refs)) + ipToIdx := make(map[string]int, len(refs)) + for _, ref := range refs { + hostnameToIdx[ref.hostname] = ref.index + if ref.ip != "" { + ipToIdx[ref.ip] = ref.index + } + } + + fmt.Fprintf(os.Stderr, " Querying Prometheus at %s for EC2 GPU metrics...\n", promURL) + + // Query GPU utilization — get all DCGM metrics and match locally. + // DCGM exporter labels vary by setup: "Hostname" for host identity, + // "instance" for scrape target (ip:port). + gpuByHostname, err := prom.QueryHTTP(ctx, promURL, + `avg by (Hostname) (avg_over_time(DCGM_FI_DEV_GPU_UTIL[7d]))`, "Hostname") + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus EC2 GPU query failed: %v\n", err) + return 0 + } + + memByHostname, _ := prom.QueryHTTP(ctx, promURL, + `avg by (Hostname) (avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL[7d]))`, "Hostname") + + enriched := 0 + + // First pass: match by Hostname label (short hostname like "ip-10-22-249-234") + for _, ref := range refs { + if val, ok := gpuByHostname[ref.hostname]; ok { + instances[ref.index].AvgGPUUtilization = &val + if memVal, ok := memByHostname[ref.hostname]; ok { + instances[ref.index].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + + // Second pass: try matching by instance label (ip:port) for instances still missing metrics + instanceSeriesCount := 0 + if enriched < len(refs) { + gpuByInstance, err := prom.QueryHTTP(ctx, promURL, + `avg by (instance) (avg_over_time(DCGM_FI_DEV_GPU_UTIL[7d]))`, "instance") + if err == nil { + instanceSeriesCount = len(gpuByInstance) + memByInstance, _ := prom.QueryHTTP(ctx, promURL, + `avg by (instance) (avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL[7d]))`, "instance") + + for instanceLabel, val := range gpuByInstance { + ip := strings.SplitN(instanceLabel, ":", 2)[0] + idx, ok := ipToIdx[ip] + if !ok || instances[idx].AvgGPUUtilization != nil { + continue + } + v := val + instances[idx].AvgGPUUtilization = &v + if memVal, ok := memByInstance[instanceLabel]; ok { + instances[idx].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + } + + if enriched > 0 { + fmt.Fprintf(os.Stderr, " Prometheus: matched %d of %d EC2 instances\n", enriched, len(refs)) + } else { + fmt.Fprintf(os.Stderr, " Prometheus: matched 0 of %d EC2 instances (server returned %d Hostname series, %d instance series)\n", + len(refs), len(gpuByHostname), instanceSeriesCount) + } + return enriched +} + +// extractIPFromDNS extracts the IP address from an EC2 private DNS name. +// e.g., "ip-10-22-249-234.ec2.internal" → "10.22.249.234" +func extractIPFromDNS(dnsName string) string { + hostname := strings.SplitN(dnsName, ".", 2)[0] + if !strings.HasPrefix(hostname, "ip-") { + return "" + } + return strings.ReplaceAll(hostname[3:], "-", ".") +} diff --git a/internal/providers/aws/prometheus_test.go b/internal/providers/aws/prometheus_test.go new file mode 100644 index 0000000..0096349 --- /dev/null +++ b/internal/providers/aws/prometheus_test.go @@ -0,0 +1,174 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestExtractIPFromDNS(t *testing.T) { + tests := []struct { + dns string + wantIP string + }{ + {"ip-10-22-249-234.ec2.internal", "10.22.249.234"}, + {"ip-172-31-0-5.us-west-2.compute.internal", "172.31.0.5"}, + {"custom-hostname.ec2.internal", ""}, + {"", ""}, + } + for _, tt := range tests { + got := extractIPFromDNS(tt.dns) + if got != tt.wantIP { + t.Errorf("extractIPFromDNS(%q) = %q, want %q", tt.dns, got, tt.wantIP) + } + } +} + +func TestEnrichEC2PrometheusGPUMetrics_MatchesByHostname(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query().Get("query") + if strings.Contains(query, "GPU_UTIL") { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"Hostname": "ip-10-0-1-100"}, "value": [1700000000, "72.5"]} + ]} + }`) + } else { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"Hostname": "ip-10-0-1-100"}, "value": [1700000000, "45.0"]} + ]} + }`) + } + })) + defer srv.Close() + + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + }, + { + InstanceID: "i-def456", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-200.ec2.internal", + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + + if enriched != 1 { + t.Fatalf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 72.5 { + t.Errorf("expected GPU util 72.5, got %v", instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil || *instances[0].AvgGPUMemUtilization != 45.0 { + t.Errorf("expected GPU mem util 45.0, got %v", instances[0].AvgGPUMemUtilization) + } + if instances[1].AvgGPUUtilization != nil { + t.Error("expected no GPU util for unmatched instance") + } +} + +func TestEnrichEC2PrometheusGPUMetrics_FallsBackToInstanceLabel(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query().Get("query") + if strings.Contains(query, "Hostname") { + // No results by hostname + fmt.Fprintf(w, `{"status": "success", "data": {"resultType": "vector", "result": []}}`) + } else if strings.Contains(query, "instance") && strings.Contains(query, "GPU_UTIL") { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"instance": "10.0.1.100:9400"}, "value": [1700000000, "88.0"]} + ]} + }`) + } else { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"instance": "10.0.1.100:9400"}, "value": [1700000000, "60.0"]} + ]} + }`) + } + })) + defer srv.Close() + + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + + if enriched != 1 { + t.Fatalf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 88.0 { + t.Errorf("expected GPU util 88.0, got %v", instances[0].AvgGPUUtilization) + } +} + +func TestEnrichEC2PrometheusGPUMetrics_SkipsAlreadyEnriched(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("should not query Prometheus when all instances already have metrics") + fmt.Fprintf(w, `{"status": "success", "data": {"resultType": "vector", "result": []}}`) + })) + defer srv.Close() + + gpuUtil := 50.0 + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + AvgGPUUtilization: &gpuUtil, + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichEC2PrometheusGPUMetrics_SkipsNonEC2(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("should not query Prometheus for non-EC2 instances") + fmt.Fprintf(w, `{"status": "success", "data": {"resultType": "vector", "result": []}}`) + })) + defer srv.Close() + + instances := []models.GPUInstance{ + { + InstanceID: "node-1", + Source: models.SourceK8sNode, + State: "ready", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index d8d5921..908c4f5 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -16,6 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/costexplorer" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/eks" + "github.com/aws/aws-sdk-go-v2/service/organizations" "github.com/aws/aws-sdk-go-v2/service/sagemaker" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -34,6 +35,16 @@ type ScanOptions struct { SkipCosts bool ExcludeTags map[string]string MinUptimeDays int + + // Prometheus + PromURL string + + // Multi-target options + Targets []string + Role string + ExternalID string + OrgScan bool + SkipSelf bool } // DefaultScanOptions returns sensible defaults. @@ -43,7 +54,7 @@ func DefaultScanOptions() ScanOptions { } } -// Scan performs a full GPU audit of the AWS account. +// Scan performs a full GPU audit across one or more AWS accounts. func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { start := time.Now() @@ -58,13 +69,23 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { return nil, fmt.Errorf("loading AWS config: %w", err) } - // Get account ID + // Resolve targets (accounts to scan) stsClient := sts.NewFromConfig(cfg) - identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return nil, fmt.Errorf("getting caller identity: %w", err) + + var orgClient OrgClient + if opts.OrgScan { + orgClient = organizations.NewFromConfig(cfg) + } + + callerAccount, targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + + // Print target errors to stderr and check for fatal failure + for _, te := range targetErrors { + fmt.Fprintf(os.Stderr, " warning: failed to resolve target %s: %v\n", te.AccountID, te.Err) + } + if len(targets) == 0 { + return nil, fmt.Errorf("no scannable targets resolved (errors: %d)", len(targetErrors)) } - accountID := aws.ToString(identity.Account) // Determine regions to scan regions := opts.Regions @@ -75,46 +96,50 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { } } - fmt.Fprintf(os.Stderr," Scanning %d regions for GPU instances...\n", len(regions)) + if len(targets) > 1 { + fmt.Fprintf(os.Stderr, " Scanning %d accounts across %d regions for GPU instances...\n", len(targets), len(regions)) + } else { + fmt.Fprintf(os.Stderr, " Scanning %d regions for GPU instances...\n", len(regions)) + } - // Scan all regions concurrently - type regionResult struct { - region string + // Scan all targets in parallel + type targetResult struct { instances []models.GPUInstance - err error + regions []string } - results := make(chan regionResult, len(regions)) + targetResults := make(chan targetResult, len(targets)) var wg sync.WaitGroup - for _, region := range regions { + for _, t := range targets { wg.Add(1) - go func(r string) { + go func(target Target) { defer wg.Done() - instances, err := scanRegion(ctx, cfg, accountID, r, opts) - results <- regionResult{region: r, instances: instances, err: err} - }(region) + instances, scannedRegions := scanTarget(ctx, target, regions, opts) + targetResults <- targetResult{instances: instances, regions: scannedRegions} + }(t) } go func() { wg.Wait() - close(results) + close(targetResults) }() var allInstances []models.GPUInstance - var scannedRegions []string + regionSet := make(map[string]bool) - for res := range results { - if res.err != nil { - fmt.Fprintf(os.Stderr," warning: error scanning %s: %v\n", res.region, res.err) - continue - } - if len(res.instances) > 0 { - allInstances = append(allInstances, res.instances...) - scannedRegions = append(scannedRegions, res.region) + for res := range targetResults { + allInstances = append(allInstances, res.instances...) + for _, r := range res.regions { + regionSet[r] = true } } + var scannedRegions []string + for r := range regionSet { + scannedRegions = append(scannedRegions, r) + } + // Filter by excluded tags if len(opts.ExcludeTags) > 0 { filtered := allInstances[:0] @@ -132,14 +157,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { } } - // Enrich with Cost Explorer data (account-level, not per-region) - if !opts.SkipCosts && len(allInstances) > 0 { - ceClient := costexplorer.NewFromConfig(cfg) - if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { - fmt.Fprintf(os.Stderr," warning: could not enrich cost data: %v\n", err) - } - } - // Run analysis analysis.AnalyzeAll(allInstances) @@ -160,14 +177,87 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { // Build summary summary := BuildSummary(allInstances) - return &models.ScanResult{ + result := &models.ScanResult{ Timestamp: start, - AccountID: accountID, + AccountID: callerAccount, Regions: scannedRegions, ScanDuration: time.Since(start).Round(time.Millisecond).String(), Instances: allInstances, Summary: summary, - }, nil + } + + // Populate multi-target metadata when multiple targets are involved + isMultiTarget := len(targets) > 1 || len(targetErrors) > 0 + if isMultiTarget { + for _, t := range targets { + result.Targets = append(result.Targets, t.AccountID) + } + result.TargetSummaries = BuildTargetSummaries(allInstances) + for _, te := range targetErrors { + result.TargetErrors = append(result.TargetErrors, models.TargetErrorInfo{ + Target: te.AccountID, + Error: te.Err.Error(), + }) + } + } + + return result, nil +} + +// scanTarget scans all regions for a single target account, including +// Cost Explorer enrichment (which is account-scoped). +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string) { + type regionResult struct { + region string + instances []models.GPUInstance + err error + } + + results := make(chan regionResult, len(regions)) + var wg sync.WaitGroup + + for _, region := range regions { + wg.Add(1) + go func(r string) { + defer wg.Done() + instances, err := scanRegion(ctx, target.Config, target.AccountID, r, opts) + results <- regionResult{region: r, instances: instances, err: err} + }(region) + } + + go func() { + wg.Wait() + close(results) + }() + + var allInstances []models.GPUInstance + var scannedRegions []string + + for res := range results { + if res.err != nil { + fmt.Fprintf(os.Stderr, " warning: error scanning %s in account %s: %v\n", res.region, target.AccountID, res.err) + continue + } + if len(res.instances) > 0 { + allInstances = append(allInstances, res.instances...) + scannedRegions = append(scannedRegions, res.region) + } + } + + // Enrich with Cost Explorer data (account-scoped) + if !opts.SkipCosts && len(allInstances) > 0 { + ceClient := costexplorer.NewFromConfig(target.Config) + if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { + fmt.Fprintf(os.Stderr, " warning: could not enrich cost data for account %s: %v\n", target.AccountID, err) + } + } + + // Enrich EC2 GPU metrics from Prometheus (for instances missing GPU utilization) + if !opts.SkipMetrics && opts.PromURL != "" && len(allInstances) > 0 { + EnrichEC2PrometheusGPUMetrics(ctx, opts.PromURL, allInstances) + } + + return allInstances, scannedRegions } func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, opts ScanOptions) ([]models.GPUInstance, error) { @@ -188,6 +278,7 @@ func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, o if err := EnrichEC2Metrics(ctx, cwClient, ec2Instances, opts.MetricWindow); err != nil { fmt.Fprintf(os.Stderr, " warning: could not enrich EC2 metrics in %s: %v\n", region, err) } + EnrichSpotPrices(ctx, ec2Client, ec2Instances) } allInstances = append(allInstances, ec2Instances...) } @@ -231,46 +322,6 @@ func getGPURegions(ctx context.Context, cfg aws.Config) ([]string, error) { }, nil } -// BuildSummary computes aggregate statistics for a set of GPU instances. -func BuildSummary(instances []models.GPUInstance) models.ScanSummary { - s := models.ScanSummary{ - TotalInstances: len(instances), - } - - for _, inst := range instances { - s.TotalMonthlyCost += inst.MonthlyCost - s.TotalEstimatedWaste += inst.EstimatedSavings - - maxSeverity := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSeverity = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { - maxSeverity = models.SeverityWarning - } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { - maxSeverity = models.SeverityInfo - } - } - - switch maxSeverity { - case models.SeverityCritical: - s.CriticalCount++ - case models.SeverityWarning: - s.WarningCount++ - case models.SeverityInfo: - s.InfoCount++ - default: - s.HealthyCount++ - } - } - - if s.TotalMonthlyCost > 0 { - s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 - } - - return s -} - func matchesExcludeTags(instanceTags map[string]string, excludes map[string]string) bool { for k, v := range excludes { if instanceTags[k] == v { diff --git a/internal/providers/aws/spot.go b/internal/providers/aws/spot.go new file mode 100644 index 0000000..d8ddcd6 --- /dev/null +++ b/internal/providers/aws/spot.go @@ -0,0 +1,89 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "os" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + "github.com/gpuaudit/cli/internal/models" +) + +// SpotPriceClient is the subset of the EC2 API needed for spot price lookups. +type SpotPriceClient interface { + DescribeSpotPriceHistory(ctx context.Context, params *ec2.DescribeSpotPriceHistoryInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error) +} + +// EnrichSpotPrices fetches current spot prices for EC2 GPU instances and +// populates SpotHourlyCost on each instance where spot is available. +func EnrichSpotPrices(ctx context.Context, client SpotPriceClient, instances []models.GPUInstance) { + // Collect unique EC2 instance types. + typeSet := make(map[string]bool) + for _, inst := range instances { + if inst.Source == models.SourceEC2 { + typeSet[inst.InstanceType] = true + } + } + if len(typeSet) == 0 { + return + } + + instanceTypes := make([]ec2types.InstanceType, 0, len(typeSet)) + for t := range typeSet { + instanceTypes = append(instanceTypes, ec2types.InstanceType(t)) + } + + input := &ec2.DescribeSpotPriceHistoryInput{ + InstanceTypes: instanceTypes, + ProductDescriptions: []string{"Linux/UNIX"}, + StartTime: aws.Time(time.Now().Add(-1 * time.Hour)), + } + + out, err := client.DescribeSpotPriceHistory(ctx, input) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: could not fetch spot prices: %v\n", err) + return + } + + // Take the most recent price per instance type. The API returns entries + // per (type, AZ) sorted newest-first. We collapse across AZs — spot prices + // within a region are typically within a few percent. A 1-hour window with + // a handful of GPU types fits well within a single API page (1000 entries). + latestPrice := make(map[string]float64) + for _, sp := range out.SpotPriceHistory { + itype := string(sp.InstanceType) + if _, seen := latestPrice[itype]; seen { + continue + } + price, err := strconv.ParseFloat(aws.ToString(sp.SpotPrice), 64) + if err != nil { + continue + } + latestPrice[itype] = price + } + + // Populate SpotHourlyCost on matching instances and correct cost for + // instances already running as spot. + for i := range instances { + if instances[i].Source != models.SourceEC2 { + continue + } + price, ok := latestPrice[instances[i].InstanceType] + if !ok { + continue + } + instances[i].SpotHourlyCost = &price + if instances[i].PricingModel == "spot" { + instances[i].HourlyCost = price + instances[i].MonthlyCost = price * 730 + } + } +} diff --git a/internal/providers/aws/spot_test.go b/internal/providers/aws/spot_test.go new file mode 100644 index 0000000..55c62f9 --- /dev/null +++ b/internal/providers/aws/spot_test.go @@ -0,0 +1,153 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + "github.com/gpuaudit/cli/internal/models" +) + +type mockSpotPriceClient struct { + prices []ec2types.SpotPrice + err error +} + +func (m *mockSpotPriceClient) DescribeSpotPriceHistory(ctx context.Context, params *ec2.DescribeSpotPriceHistoryInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error) { + if m.err != nil { + return nil, m.err + } + return &ec2.DescribeSpotPriceHistoryOutput{ + SpotPriceHistory: m.prices, + }, nil +} + +func TestEnrichSpotPrices_PopulatesSpotCost(t *testing.T) { + client := &mockSpotPriceClient{ + prices: []ec2types.SpotPrice{ + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.556"), + Timestamp: aws.Time(time.Now()), + }, + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.500"), + Timestamp: aws.Time(time.Now().Add(-1 * time.Hour)), + }, + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-1", InstanceType: "g5.xlarge", Source: models.SourceEC2}, + {InstanceID: "i-2", InstanceType: "g5.2xlarge", Source: models.SourceEC2}, + } + + EnrichSpotPrices(context.Background(), client, instances) + + if instances[0].SpotHourlyCost == nil { + t.Fatal("expected spot price for g5.xlarge") + } + if *instances[0].SpotHourlyCost != 0.556 { + t.Errorf("expected 0.556, got %f", *instances[0].SpotHourlyCost) + } + if instances[1].SpotHourlyCost != nil { + t.Error("expected nil spot price for g5.2xlarge (not in API response)") + } +} + +func TestEnrichSpotPrices_SkipsNonEC2(t *testing.T) { + client := &mockSpotPriceClient{ + prices: []ec2types.SpotPrice{ + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.556"), + Timestamp: aws.Time(time.Now()), + }, + }, + } + instances := []models.GPUInstance{ + {InstanceID: "ep-1", InstanceType: "ml.g5.xlarge", Source: models.SourceSageMakerEndpoint}, + } + + EnrichSpotPrices(context.Background(), client, instances) + + if instances[0].SpotHourlyCost != nil { + t.Error("expected nil spot price for SageMaker instance") + } +} + +func TestEnrichSpotPrices_HandlesAPIError(t *testing.T) { + client := &mockSpotPriceClient{ + err: fmt.Errorf("access denied"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-1", InstanceType: "g5.xlarge", Source: models.SourceEC2}, + } + + EnrichSpotPrices(context.Background(), client, instances) + + if instances[0].SpotHourlyCost != nil { + t.Error("expected nil spot price after API error") + } +} + +func TestEnrichSpotPrices_EmptyInstances(t *testing.T) { + client := &mockSpotPriceClient{} + EnrichSpotPrices(context.Background(), client, nil) + EnrichSpotPrices(context.Background(), client, []models.GPUInstance{}) +} + +func TestEnrichSpotPrices_CorrectsCostForSpotInstances(t *testing.T) { + client := &mockSpotPriceClient{ + prices: []ec2types.SpotPrice{ + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.556"), + Timestamp: aws.Time(time.Now()), + }, + }, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-spot", + InstanceType: "g5.xlarge", + Source: models.SourceEC2, + PricingModel: "spot", + HourlyCost: 1.006, // on-demand price (wrong for spot) + MonthlyCost: 1.006 * 730, + }, + { + InstanceID: "i-ondemand", + InstanceType: "g5.xlarge", + Source: models.SourceEC2, + PricingModel: "on-demand", + HourlyCost: 1.006, + MonthlyCost: 1.006 * 730, + }, + } + + EnrichSpotPrices(context.Background(), client, instances) + + // Spot instance should have corrected cost + if instances[0].HourlyCost != 0.556 { + t.Errorf("spot instance hourly cost: expected 0.556, got %f", instances[0].HourlyCost) + } + expectedMonthlyCost := 0.556 * 730 + const epsilon = 0.0001 + if instances[0].MonthlyCost < expectedMonthlyCost-epsilon || instances[0].MonthlyCost > expectedMonthlyCost+epsilon { + t.Errorf("spot instance monthly cost: expected %f, got %f", expectedMonthlyCost, instances[0].MonthlyCost) + } + + // On-demand instance should keep original cost + if instances[1].HourlyCost != 1.006 { + t.Errorf("on-demand instance hourly cost should be unchanged, got %f", instances[1].HourlyCost) + } +} diff --git a/internal/providers/aws/summary.go b/internal/providers/aws/summary.go new file mode 100644 index 0000000..5c6f715 --- /dev/null +++ b/internal/providers/aws/summary.go @@ -0,0 +1,80 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "sort" + + "github.com/gpuaudit/cli/internal/models" +) + +// BuildSummary computes aggregate statistics for a set of GPU instances. +func BuildSummary(instances []models.GPUInstance) models.ScanSummary { + s := models.ScanSummary{ + TotalInstances: len(instances), + } + + for _, inst := range instances { + s.TotalMonthlyCost += inst.MonthlyCost + s.TotalEstimatedWaste += inst.EstimatedSavings + + switch models.MaxSeverity(inst.WasteSignals) { + case models.SeverityCritical: + s.CriticalCount++ + case models.SeverityWarning: + s.WarningCount++ + case models.SeverityInfo: + s.InfoCount++ + default: + s.HealthyCount++ + } + } + + if s.TotalMonthlyCost > 0 { + s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 + } + + return s +} + +// BuildTargetSummaries computes per-target breakdowns from a flat instance list. +func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary { + if len(instances) == 0 { + return nil + } + + byTarget := make(map[string][]models.GPUInstance) + for _, inst := range instances { + byTarget[inst.AccountID] = append(byTarget[inst.AccountID], inst) + } + + summaries := make([]models.TargetSummary, 0, len(byTarget)) + for target, insts := range byTarget { + ts := models.TargetSummary{ + Target: target, + TotalInstances: len(insts), + } + for _, inst := range insts { + ts.TotalMonthlyCost += inst.MonthlyCost + ts.TotalEstimatedWaste += inst.EstimatedSavings + + switch models.MaxSeverity(inst.WasteSignals) { + case models.SeverityCritical: + ts.CriticalCount++ + case models.SeverityWarning: + ts.WarningCount++ + } + } + if ts.TotalMonthlyCost > 0 { + ts.WastePercent = (ts.TotalEstimatedWaste / ts.TotalMonthlyCost) * 100 + } + summaries = append(summaries, ts) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].TotalMonthlyCost > summaries[j].TotalMonthlyCost + }) + + return summaries +} diff --git a/internal/providers/aws/summary_test.go b/internal/providers/aws/summary_test.go new file mode 100644 index 0000000..24702ec --- /dev/null +++ b/internal/providers/aws/summary_test.go @@ -0,0 +1,92 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestBuildTargetSummaries_MultipleAccounts(t *testing.T) { + instances := []models.GPUInstance{ + { + AccountID: "111111111111", + MonthlyCost: 1000, + EstimatedSavings: 500, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityCritical}}, + }, + { + AccountID: "111111111111", + MonthlyCost: 2000, + EstimatedSavings: 0, + }, + { + AccountID: "222222222222", + MonthlyCost: 3000, + EstimatedSavings: 1000, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityWarning}}, + }, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 2 { + t.Fatalf("expected 2 target summaries, got %d", len(summaries)) + } + + var s1, s2 *models.TargetSummary + for i := range summaries { + switch summaries[i].Target { + case "111111111111": + s1 = &summaries[i] + case "222222222222": + s2 = &summaries[i] + } + } + + if s1 == nil || s2 == nil { + t.Fatal("missing target summaries") + } + + if s1.TotalInstances != 2 { + t.Errorf("acct1: expected 2 instances, got %d", s1.TotalInstances) + } + if s1.TotalMonthlyCost != 3000 { + t.Errorf("acct1: expected $3000 cost, got $%.0f", s1.TotalMonthlyCost) + } + if s1.TotalEstimatedWaste != 500 { + t.Errorf("acct1: expected $500 waste, got $%.0f", s1.TotalEstimatedWaste) + } + if s1.CriticalCount != 1 { + t.Errorf("acct1: expected 1 critical, got %d", s1.CriticalCount) + } + + if s2.TotalInstances != 1 { + t.Errorf("acct2: expected 1 instance, got %d", s2.TotalInstances) + } + if s2.WarningCount != 1 { + t.Errorf("acct2: expected 1 warning, got %d", s2.WarningCount) + } +} + +func TestBuildTargetSummaries_SingleAccount(t *testing.T) { + instances := []models.GPUInstance{ + {AccountID: "111111111111", MonthlyCost: 1000}, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } +} + +func TestBuildTargetSummaries_Empty(t *testing.T) { + summaries := BuildTargetSummaries(nil) + + if len(summaries) != 0 { + t.Fatalf("expected 0 summaries for nil input, got %d", len(summaries)) + } +} diff --git a/internal/providers/k8s/discover.go b/internal/providers/k8s/discover.go index 6df9ef0..e3316c0 100644 --- a/internal/providers/k8s/discover.go +++ b/internal/providers/k8s/discover.go @@ -24,6 +24,7 @@ const gpuResourceName corev1.ResourceName = "nvidia.com/gpu" type K8sClient interface { ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) ListPods(ctx context.Context, namespace string, opts metav1.ListOptions) (*corev1.PodList, error) + ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) } // DiscoverGPUNodes finds Kubernetes nodes with GPU capacity and reports their allocation. @@ -163,6 +164,7 @@ func nodeToGPUInstance(node corev1.Node, gpuPods []corev1.Pod, clusterName strin Name: fmt.Sprintf("%s/%s", clusterName, hostname), Tags: tags, ClusterName: clusterName, + K8sNodeName: node.Name, GPUAllocated: gpuAllocated, InstanceType: instanceType, GPUModel: gpuModel, diff --git a/internal/providers/k8s/discover_test.go b/internal/providers/k8s/discover_test.go index 9d0cff1..016c9df 100644 --- a/internal/providers/k8s/discover_test.go +++ b/internal/providers/k8s/discover_test.go @@ -17,8 +17,10 @@ import ( ) type mockK8sClient struct { - nodes *corev1.NodeList - pods *corev1.PodList + nodes *corev1.NodeList + pods *corev1.PodList + proxyData map[string][]byte + proxyErr error } func (m *mockK8sClient) ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) { @@ -29,6 +31,17 @@ func (m *mockK8sClient) ListPods(ctx context.Context, namespace string, opts met return m.pods, nil } +func (m *mockK8sClient) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + if m.proxyErr != nil { + return nil, m.proxyErr + } + key := fmt.Sprintf("%s/%s:%s%s", namespace, podName, port, path) + if data, ok := m.proxyData[key]; ok { + return data, nil + } + return nil, fmt.Errorf("no mock data for %s", key) +} + func gpuNode(name, instanceType string, gpuCount int, ready bool, created time.Time) corev1.Node { readyStatus := corev1.ConditionFalse if ready { diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go new file mode 100644 index 0000000..02bb259 --- /dev/null +++ b/internal/providers/k8s/metrics.go @@ -0,0 +1,264 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "bytes" + "context" + "fmt" + "net/url" + "os" + "strings" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + "github.com/prometheus/common/model" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" + prom "github.com/gpuaudit/cli/internal/prometheus" +) + +// EnrichDCGMMetrics discovers dcgm-exporter pods and scrapes GPU metrics for K8s nodes +// that don't already have AvgGPUUtilization populated. Returns the number of nodes enriched. +func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance) int { + needsMetrics := make(map[string]int) + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + key := inst.K8sNodeName + if key == "" { + key = inst.InstanceID + } + needsMetrics[key] = i + } + if len(needsMetrics) == 0 { + return 0 + } + + dcgmPods, err := findDCGMPods(ctx, client) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: could not list DCGM exporter pods: %v\n", err) + return 0 + } + if len(dcgmPods) == 0 { + fmt.Fprintf(os.Stderr, " DCGM exporter not detected, skipping\n") + return 0 + } + + fmt.Fprintf(os.Stderr, " Probing DCGM exporter on GPU nodes...\n") + + enriched := 0 + scrapeErrors := 0 + for _, pod := range dcgmPods { + idx, ok := needsMetrics[pod.Spec.NodeName] + if !ok { + continue + } + + data, err := client.ProxyGet(ctx, pod.Namespace, pod.Name, "9400", "/metrics") + if err != nil { + scrapeErrors++ + if scrapeErrors == 1 { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failed: %v\n", err) + } + if scrapeErrors >= 3 { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failing consistently, skipping remaining nodes\n") + break + } + continue + } + + gpuUtil, memUtil := parseDCGMMetrics(data) + if gpuUtil != nil { + instances[idx].AvgGPUUtilization = gpuUtil + instances[idx].AvgGPUMemUtilization = memUtil + enriched++ + scrapeErrors = 0 + } + } + + if enriched > 0 { + fmt.Fprintf(os.Stderr, " DCGM: got GPU metrics for %d of %d remaining nodes\n", enriched, len(needsMetrics)) + } + return enriched +} + +func findDCGMPods(ctx context.Context, client K8sClient) ([]corev1.Pod, error) { + podList, err := client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app.kubernetes.io/name=dcgm-exporter", + }) + if err != nil { + return nil, err + } + if len(podList.Items) > 0 { + return runningPods(podList.Items), nil + } + + podList, err = client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app=nvidia-dcgm-exporter", + }) + if err != nil { + return nil, err + } + return runningPods(podList.Items), nil +} + +func runningPods(pods []corev1.Pod) []corev1.Pod { + var result []corev1.Pod + for _, p := range pods { + if p.Status.Phase == corev1.PodRunning { + result = append(result, p) + } + } + return result +} + +func parseDCGMMetrics(data []byte) (gpuUtil, memUtil *float64) { + parser := expfmt.NewTextParser(model.LegacyValidation) + families, err := parser.TextToMetricFamilies(bytes.NewReader(data)) + if err != nil { + return nil, nil + } + + gpuUtil = avgMetricValue(families["DCGM_FI_DEV_GPU_UTIL"]) + memUtil = avgMetricValue(families["DCGM_FI_DEV_MEM_COPY_UTIL"]) + return gpuUtil, memUtil +} + +func avgMetricValue(family *dto.MetricFamily) *float64 { + if family == nil || len(family.Metric) == 0 { + return nil + } + sum := 0.0 + count := 0 + for _, m := range family.Metric { + if m.Gauge != nil && m.Gauge.Value != nil { + sum += *m.Gauge.Value + count++ + } + } + if count == 0 { + return nil + } + avg := sum / float64(count) + return &avg +} + +// PrometheusOptions configures how to reach a Prometheus-compatible API. +type PrometheusOptions struct { + URL string + Endpoint string +} + +// EnrichPrometheusMetrics queries a Prometheus endpoint for GPU utilization metrics +// for K8s nodes that don't already have AvgGPUUtilization populated. +func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance, opts PrometheusOptions) int { + if opts.URL == "" && opts.Endpoint == "" { + return 0 + } + + type nodeRef struct { + index int + name string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + name := inst.K8sNodeName + if name == "" { + name = inst.InstanceID + } + nodes = append(nodes, nodeRef{index: i, name: name}) + } + if len(nodes) == 0 { + return 0 + } + + source := opts.URL + if source == "" { + source = opts.Endpoint + } + fmt.Fprintf(os.Stderr, " Querying Prometheus at %s...\n", source) + + nodeNames := make([]string, len(nodes)) + for i, n := range nodes { + nodeNames[i] = n.name + } + nodeRegex := strings.Join(nodeNames, "|") + + gpuResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"%s"}[7d])`, nodeRegex), "node") + memResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL{node=~"%s"}[7d])`, nodeRegex), "node") + + enriched := 0 + for _, node := range nodes { + if val, ok := gpuResults[node.name]; ok { + instances[node.index].AvgGPUUtilization = &val + if memVal, ok := memResults[node.name]; ok { + instances[node.index].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " Prometheus: got GPU metrics for %d of %d remaining nodes\n", enriched, len(nodes)) + return enriched +} + +func queryPrometheus(ctx context.Context, client K8sClient, opts PrometheusOptions, query, labelName string) map[string]float64 { + if opts.URL != "" { + results, err := prom.QueryHTTP(ctx, opts.URL, query, labelName) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) + return nil + } + return results + } + + data, err := queryPrometheusProxy(ctx, client, opts.Endpoint, query) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) + return nil + } + results, err := prom.ParseResponse(data, labelName) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus response parse failed: %v\n", err) + return nil + } + return results +} + +func queryPrometheusProxy(ctx context.Context, client K8sClient, endpoint, query string) ([]byte, error) { + ns, svc, port, err := parsePrometheusEndpoint(endpoint) + if err != nil { + return nil, err + } + path := fmt.Sprintf("/api/v1/query?query=%s", url.QueryEscape(query)) + return client.ProxyGet(ctx, ns, svc, port, path) +} + +func parsePrometheusEndpoint(endpoint string) (namespace, service, port string, err error) { + slashIdx := strings.Index(endpoint, "/") + if slashIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + namespace = endpoint[:slashIdx] + rest := endpoint[slashIdx+1:] + colonIdx := strings.LastIndex(rest, ":") + if colonIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + service = rest[:colonIdx] + port = rest[colonIdx+1:] + return namespace, service, port, nil +} + diff --git a/internal/providers/k8s/metrics_test.go b/internal/providers/k8s/metrics_test.go new file mode 100644 index 0000000..4d7e851 --- /dev/null +++ b/internal/providers/k8s/metrics_test.go @@ -0,0 +1,314 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +func dcgmPod(name, namespace, nodeName string) corev1.Pod { + return corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/name": "dcgm-exporter", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + } +} + +const sampleDCGMMetrics = `# HELP DCGM_FI_DEV_GPU_UTIL GPU utilization. +# TYPE DCGM_FI_DEV_GPU_UTIL gauge +DCGM_FI_DEV_GPU_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 42.0 +DCGM_FI_DEV_GPU_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 38.0 +# HELP DCGM_FI_DEV_MEM_COPY_UTIL GPU memory utilization. +# TYPE DCGM_FI_DEV_MEM_COPY_UTIL gauge +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 55.0 +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 60.0 +` + +func TestEnrichDCGMMetrics_PopulatesUtilization(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "ip-10-22-1-100.ec2.internal"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-abc123", K8sNodeName: "ip-10-22-1-100.ec2.internal", Source: models.SourceK8sNode, Name: "cluster/ip-10-22-1-100"}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 40.0 { + t.Errorf("expected avg GPU util 40.0 (average of 42 and 38), got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 57.5 { + t.Errorf("expected avg GPU mem util 57.5 (average of 55 and 60), got %f", *instances[0].AvgGPUMemUtilization) + } + if enriched != 1 { + t.Errorf("expected 1 enriched node, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-abc123", K8sNodeName: "node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Error("should not overwrite existing utilization") + } + if enriched != 0 { + t.Errorf("expected 0 enriched nodes, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_NoDCGMPods(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{Items: []corev1.Pod{}}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil when no DCGM pods") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_HandlesScrapeError(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "node1"), + }, + }, + proxyErr: fmt.Errorf("connection refused"), + } + instances := []models.GPUInstance{ + {InstanceID: "node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil after scrape error") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestParseDCGMMetrics(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte(sampleDCGMMetrics)) + + if gpuUtil == nil { + t.Fatal("expected gpu util") + } + if *gpuUtil != 40.0 { + t.Errorf("expected 40.0, got %f", *gpuUtil) + } + if memUtil == nil { + t.Fatal("expected mem util") + } + if *memUtil != 57.5 { + t.Errorf("expected 57.5, got %f", *memUtil) + } +} + +func TestParseDCGMMetrics_EmptyInput(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte("")) + if gpuUtil != nil || memUtil != nil { + t.Error("expected nil for empty input") + } +} + +func TestEnrichPrometheusMetrics_PopulatesFromDirectURL(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "65.5"]}, + {"metric": {"node": "i-node2"}, "value": [1700000000, "30.0"]} + ] + } + }` + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/query" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + query := r.URL.Query().Get("query") + if !strings.Contains(query, "DCGM_FI_DEV_GPU_UTIL") && !strings.Contains(query, "DCGM_FI_DEV_MEM_COPY_UTIL") { + t.Errorf("unexpected query: %s", query) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(promResponse)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + {InstanceID: "i-node2", Source: models.SourceK8sNode, Name: "cluster/i-node2"}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 2 { + t.Errorf("expected 2 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 65.5 { + t.Errorf("expected node1 GPU util 65.5, got %v", instances[0].AvgGPUUtilization) + } + if instances[1].AvgGPUUtilization == nil || *instances[1].AvgGPUUtilization != 30.0 { + t.Errorf("expected node2 GPU util 30.0, got %v", instances[1].AvgGPUUtilization) + } +} + +func TestEnrichPrometheusMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 80.0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"success","data":{"resultType":"vector","result":[]}}`)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_NoOptions(t *testing.T) { + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, PrometheusOptions{}) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_InClusterEndpoint(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "50.0"]} + ] + } + }` + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + opts := PrometheusOptions{Endpoint: "monitoring/prometheus:9090"} + + // Use a custom client that returns promResponse for any ProxyGet to monitoring/prometheus + customClient := &promMockClient{response: []byte(promResponse)} + + enriched := EnrichPrometheusMetrics(context.Background(), customClient, instances, opts) + + if enriched != 1 { + t.Errorf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected 50.0, got %v", instances[0].AvgGPUUtilization) + } +} + +// promMockClient is a specialized mock that always returns a fixed response for ProxyGet. +type promMockClient struct { + mockK8sClient + response []byte +} + +func (m *promMockClient) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + return m.response, nil +} + +func TestParsePrometheusEndpoint(t *testing.T) { + tests := []struct { + input string + namespace string + service string + port string + wantErr bool + }{ + {"monitoring/prometheus:9090", "monitoring", "prometheus", "9090", false}, + {"kube-system/thanos-query:10902", "kube-system", "thanos-query", "10902", false}, + {"invalid", "", "", "", true}, + {"ns/svc", "", "", "", true}, + } + for _, tt := range tests { + ns, svc, port, err := parsePrometheusEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parsePrometheusEndpoint(%q): err=%v, wantErr=%v", tt.input, err, tt.wantErr) + continue + } + if ns != tt.namespace || svc != tt.service || port != tt.port { + t.Errorf("parsePrometheusEndpoint(%q) = (%q,%q,%q), want (%q,%q,%q)", + tt.input, ns, svc, port, tt.namespace, tt.service, tt.port) + } + } +} diff --git a/internal/providers/k8s/scanner.go b/internal/providers/k8s/scanner.go index 67634f3..c35ef88 100644 --- a/internal/providers/k8s/scanner.go +++ b/internal/providers/k8s/scanner.go @@ -19,8 +19,10 @@ import ( // ScanOptions controls Kubernetes GPU scanning. type ScanOptions struct { - Kubeconfig string - Context string + Kubeconfig string + Context string + PromURL string + PromEndpoint string } // Scan discovers GPU nodes in Kubernetes clusters accessible via kubeconfig. @@ -47,6 +49,11 @@ func Scan(ctx context.Context, opts ScanOptions) ([]models.GPUInstance, error) { return instances, nil } +// BuildClientPublic builds a K8s client and returns the cluster name. +func BuildClientPublic(kubeconfigPath, contextName string) (K8sClient, string, error) { + return buildClient(kubeconfigPath, contextName) +} + func buildClient(kubeconfigPath, contextName string) (K8sClient, string, error) { loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() if kubeconfigPath != "" { @@ -100,6 +107,10 @@ func (w *k8sClientWrapper) ListPods(ctx context.Context, namespace string, opts return w.clientset.CoreV1().Pods(namespace).List(ctx, opts) } +func (w *k8sClientWrapper) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + return w.clientset.CoreV1().Pods(namespace).ProxyGet("http", podName, port, path, nil).DoRaw(ctx) +} + func defaultKubeconfig() string { home, err := os.UserHomeDir() if err != nil {