Skip to content
Merged

Dev #44

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 405 additions & 0 deletions database/cleanup.go

Large diffs are not rendered by default.

125 changes: 75 additions & 50 deletions database/transaction.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package database

import (
"context"
"database/sql"
"fmt"
"log"
"time"
)

// TxFn represents a function that uses a transaction
Expand Down Expand Up @@ -42,58 +44,81 @@ func (db *DB) WithTransaction(fn TxFn) error {
return nil
}

// QueryRow executes a query that returns a single row and scans the result into the provided destination
func (db *DB) QueryRowSafe(query string, dest interface{}, args ...interface{}) error {
row := db.QueryRow(query, args...)
if err := row.Scan(dest); err != nil {
if err == sql.ErrNoRows {
return ErrNotFound
}
return fmt.Errorf("scan failed: %w", err)
}
return nil
}

// ExecSafe executes a statement and returns the result summary
func (db *DB) ExecSafe(query string, args ...interface{}) (sql.Result, error) {
result, err := db.Exec(query, args...)
if err != nil {
return nil, fmt.Errorf("exec failed: %w", err)
}
return result, nil
}

// CustomError types for database operations
var (
ErrNotFound = fmt.Errorf("record not found")
ErrDuplicate = fmt.Errorf("duplicate record")
ErrConstraint = fmt.Errorf("constraint violation")
)

// ExecTx executes a statement within a transaction and returns the result
func ExecTx(tx *sql.Tx, query string, args ...interface{}) (sql.Result, error) {
result, err := tx.Exec(query, args...)
if err != nil {
return nil, fmt.Errorf("exec in transaction failed: %w", err)
}
return result, nil
}

// GetRowsAffected is a helper to get rows affected from a result
func GetRowsAffected(result sql.Result) (int64, error) {
affected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
// WithTimeoutTransaction wraps a function with a transaction that has a timeout
func (db *DB) WithTimeoutTransaction(ctx context.Context, timeout time.Duration, fn TxFn) error {
// Create a context with timeout
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// Create a done channel to signal completion
done := make(chan error, 1)

// Run the transaction in a goroutine
go func() {
done <- db.WithTransaction(fn)
}()

// Wait for either context timeout or transaction completion
select {
case <-ctx.Done():
// Context timed out
return fmt.Errorf("transaction timed out after %v: %w", timeout, ctx.Err())
case err := <-done:
// Transaction completed
return err
}
return affected, nil
}

// GetLastInsertID is a helper to get last insert ID from a result
func GetLastInsertID(result sql.Result) (int64, error) {
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("failed to get last insert ID: %w", err)
}
return id, nil
// BatchTransaction executes multiple operations in a single transaction
// All operations must succeed or the transaction is rolled back
func (db *DB) BatchTransaction(operations []TxFn) error {
return db.WithTransaction(func(tx *sql.Tx) error {
for i, op := range operations {
if err := op(tx); err != nil {
return fmt.Errorf("operation %d failed: %w", i, err)
}
}
return nil
})
}

// UpdateInTransaction updates a record in a transaction
func (db *DB) UpdateInTransaction(table string, id string, updates map[string]interface{}) error {
return db.WithTransaction(func(tx *sql.Tx) error {
// Build the update statement
query := fmt.Sprintf("UPDATE %s SET ", table)
var params []interface{}

i := 0
for field, value := range updates {
if i > 0 {
query += ", "
}
query += field + " = ?"
params = append(params, value)
i++
}

// Add the WHERE clause and updated_at
query += ", updated_at = ? WHERE id = ?"
params = append(params, time.Now(), id)

// Execute the update
result, err := tx.Exec(query, params...)
if err != nil {
return fmt.Errorf("update failed: %w", err)
}

// Check if any rows were affected
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}

if rowsAffected == 0 {
return fmt.Errorf("no rows affected, record with ID %s not found", id)
}

return nil
})
}
215 changes: 113 additions & 102 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,108 +82,119 @@ func DiscoverTraefikAPI() (string, error) {
}

func main() {
log.Println("Starting Middleware Manager...")

var debug bool
flag.BoolVar(&debug, "debug", false, "Enable debug mode")
flag.Parse()

cfg := loadConfiguration(debug)

if os.Getenv("TRAEFIK_API_URL") == "" {
if discoveredURL, err := DiscoverTraefikAPI(); err == nil && discoveredURL != "" {
log.Printf("Auto-discovered Traefik API URL: %s", discoveredURL)
cfg.TraefikAPIURL = discoveredURL
}
}

db, err := database.InitDB(cfg.DBPath)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()

configDir := cfg.ConfigDir
if err := config.EnsureConfigDirectory(configDir); err != nil {
log.Printf("Warning: Failed to create config directory: %v", err)
}

if err := config.SaveTemplateFile(configDir); err != nil {
log.Printf("Warning: Failed to save default middleware templates: %v", err)
}

if err := config.LoadDefaultTemplates(db); err != nil {
log.Printf("Warning: Failed to load default middleware templates: %v", err)
}

if err := config.SaveTemplateServicesFile(configDir); err != nil {
log.Printf("Warning: Failed to save default service templates: %v", err)
}

if err := config.LoadDefaultServiceTemplates(db); err != nil {
log.Printf("Warning: Failed to load default service templates: %v", err)
}

configManager, err := services.NewConfigManager(filepath.Join(configDir, "config.json"))
if err != nil {
log.Fatalf("Failed to initialize config manager: %v", err)
}

configManager.EnsureDefaultDataSources(cfg.PangolinAPIURL, cfg.TraefikAPIURL)

stopChan := make(chan struct{})

resourceWatcher, err := services.NewResourceWatcher(db, configManager)
if err != nil {
log.Fatalf("Failed to create resource watcher: %v", err)
}
go resourceWatcher.Start(cfg.CheckInterval)

configGenerator := services.NewConfigGenerator(db, cfg.TraefikConfDir, configManager)
go configGenerator.Start(cfg.GenerateInterval)

serverConfig := api.ServerConfig{
Port: cfg.Port,
UIPath: cfg.UIPath,
Debug: cfg.Debug,
AllowCORS: cfg.AllowCORS,
CORSOrigin: cfg.CORSOrigin,
}

server := api.NewServer(db.DB, serverConfig, configManager, cfg.TraefikStaticConfigPath, cfg.PluginsJSONURL)
go func() {
if err := server.Start(); err != nil {
log.Printf("Server error: %v", err)
close(stopChan)
}
}()

signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)

serviceWatcher, err := services.NewServiceWatcher(db, configManager)
if err != nil {
log.Printf("Warning: Failed to create service watcher: %v", err)
serviceWatcher = nil
} else {
go serviceWatcher.Start(cfg.ServiceInterval)
}

select {
case <-signalChan:
log.Println("Received shutdown signal")
case <-stopChan:
log.Println("Received stop signal from server")
}

log.Println("Shutting down...")
resourceWatcher.Stop()
if serviceWatcher != nil {
serviceWatcher.Stop()
}
configGenerator.Stop()
server.Stop()
log.Println("Middleware Manager stopped")
log.Println("Starting Middleware Manager...")

var debug bool
flag.BoolVar(&debug, "debug", false, "Enable debug mode")
flag.Parse()

cfg := loadConfiguration(debug)

if os.Getenv("TRAEFIK_API_URL") == "" {
if discoveredURL, err := DiscoverTraefikAPI(); err == nil && discoveredURL != "" {
log.Printf("Auto-discovered Traefik API URL: %s", discoveredURL)
cfg.TraefikAPIURL = discoveredURL
}
}

db, err := database.InitDB(cfg.DBPath)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()

configDir := cfg.ConfigDir
if err := config.EnsureConfigDirectory(configDir); err != nil {
log.Printf("Warning: Failed to create config directory: %v", err)
}

if err := config.SaveTemplateFile(configDir); err != nil {
log.Printf("Warning: Failed to save default middleware templates: %v", err)
}

if err := config.LoadDefaultTemplates(db); err != nil {
log.Printf("Warning: Failed to load default middleware templates: %v", err)
}

if err := config.SaveTemplateServicesFile(configDir); err != nil {
log.Printf("Warning: Failed to save default service templates: %v", err)
}

if err := config.LoadDefaultServiceTemplates(db); err != nil {
log.Printf("Warning: Failed to load default service templates: %v", err)
}

// Run comprehensive database cleanup on startup
log.Println("Performing full database cleanup...")
cleanupOpts := database.DefaultCleanupOptions()
cleanupOpts.LogLevel = 2 // More verbose logging during startup

if err := db.PerformFullCleanup(cleanupOpts); err != nil {
log.Printf("Warning: Database cleanup encountered issues: %v", err)
} else {
log.Println("Database cleanup completed successfully")
}

configManager, err := services.NewConfigManager(filepath.Join(configDir, "config.json"))
if err != nil {
log.Fatalf("Failed to initialize config manager: %v", err)
}

configManager.EnsureDefaultDataSources(cfg.PangolinAPIURL, cfg.TraefikAPIURL)

stopChan := make(chan struct{})

resourceWatcher, err := services.NewResourceWatcher(db, configManager)
if err != nil {
log.Fatalf("Failed to create resource watcher: %v", err)
}
go resourceWatcher.Start(cfg.CheckInterval)

configGenerator := services.NewConfigGenerator(db, cfg.TraefikConfDir, configManager)
go configGenerator.Start(cfg.GenerateInterval)

serverConfig := api.ServerConfig{
Port: cfg.Port,
UIPath: cfg.UIPath,
Debug: cfg.Debug,
AllowCORS: cfg.AllowCORS,
CORSOrigin: cfg.CORSOrigin,
}

server := api.NewServer(db.DB, serverConfig, configManager, cfg.TraefikStaticConfigPath, cfg.PluginsJSONURL)
go func() {
if err := server.Start(); err != nil {
log.Printf("Server error: %v", err)
close(stopChan)
}
}()

signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)

serviceWatcher, err := services.NewServiceWatcher(db, configManager)
if err != nil {
log.Printf("Warning: Failed to create service watcher: %v", err)
serviceWatcher = nil
} else {
go serviceWatcher.Start(cfg.ServiceInterval)
}

select {
case <-signalChan:
log.Println("Received shutdown signal")
case <-stopChan:
log.Println("Received stop signal from server")
}

log.Println("Shutting down...")
resourceWatcher.Stop()
if serviceWatcher != nil {
serviceWatcher.Stop()
}
configGenerator.Stop()
server.Stop()
log.Println("Middleware Manager stopped")
}

func loadConfiguration(debug bool) Configuration {
Expand Down
Loading