Skip to content

Commit

Permalink
[DRAFT] add protection from recursion
Browse files Browse the repository at this point in the history
Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>
  • Loading branch information
denis-tingaikin committed Feb 25, 2024
1 parent 3afde47 commit 54e8061
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 40 deletions.
46 changes: 23 additions & 23 deletions analyzer.go
Expand Up @@ -17,6 +17,7 @@
package goheader

import (
"context"
"fmt"
"go/ast"
"os"
Expand Down Expand Up @@ -48,34 +49,29 @@ func (t *Target) ModTime() (time.Time, error) {
}

type Analyzer struct {
values map[string]Value
values context.Context
template string
}

func (a *Analyzer) processPerTargetValues(target *Target) error {
a.values["mod-year"] = a.values["year"]
a.values["mod-year-range"] = a.values["year-range"]
func (a *Analyzer) processPerTargetValues(target *Target) context.Context {
var ctx = a.values
ctx = context.WithValue(ctx, "mod-year", ctx.Value("year"))
ctx = context.WithValue(ctx, "mod-year-range", ctx.Value("year-range"))

if t, err := target.ModTime(); err == nil {
a.values["mod-year"] = &ConstValue{RawValue: fmt.Sprint(t.Year())}
a.values["mod-year-range"] = &RegexpValue{RawValue: `((20\d\d\-{{mod-year}})|({{mod-year}}))`}
ctx = context.WithValue(ctx, "mod-year", &ConstValue{RawValue: fmt.Sprint(t.Year())})
ctx = context.WithValue(ctx, "mod-year-range", &RegexpValue{RawValue: `((20\d\d\-{{mod-year}})|({{mod-year}}))`})
}

for _, v := range a.values {
if err := v.Calculate(a.values); err != nil {
return err
}
}
return nil
return ctx
}

func (a *Analyzer) Analyze(target *Target) (i Issue) {
if a.template == "" {
return NewIssue("Missed template for check")
}

if err := a.processPerTargetValues(target); err != nil {
return &issue{msg: err.Error()}
}
var values = a.processPerTargetValues(target)

file := target.File
var header string
Expand All @@ -94,7 +90,7 @@ func (a *Analyzer) Analyze(target *Target) (i Issue) {
if i == nil {
return
}
fix, ok := a.generateFix(i, file, header)
fix, ok := a.generateFix(values, i, file, header)
if !ok {
return
}
Expand All @@ -111,10 +107,14 @@ func (a *Analyzer) Analyze(target *Target) (i Issue) {
templateCh := t.Peek()
if templateCh == '{' {
name := a.readField(t)
if a.values[name] == nil {
if values.Value(name) == nil {
return NewIssue(fmt.Sprintf("Template has unknown value: %v", name))
}
if i := a.values[name].Read(s); i != nil {
var v = values.Value(name).(Value)
if err := v.Calculate(values); err != nil {
return &issue{location: t.location, msg: err.Error()}
}
if i := v.Read(s); i != nil {
return i
}
continue
Expand Down Expand Up @@ -158,27 +158,27 @@ func (a *Analyzer) readField(reader *Reader) string {
}

func New(options ...Option) *Analyzer {
a := &Analyzer{values: make(map[string]Value)}
a := &Analyzer{values: context.Background()}
for _, o := range options {
o.apply(a)
}
return a
}

func (a *Analyzer) generateFix(i Issue, file *ast.File, header string) (Fix, bool) {
func (a *Analyzer) generateFix(values context.Context, i Issue, file *ast.File, header string) (Fix, bool) {
var expect string
t := NewReader(a.template)
for !t.Done() {
ch := t.Peek()
if ch == '{' {
f := a.values[a.readField(t)]
f := values.Value(a.readField(t))
if f == nil {
return Fix{}, false
}
if f.Calculate(a.values) != nil {
if f.(Value).Calculate(values) != nil {
return Fix{}, false
}
expect += f.Get()
expect += f.(Value).Get()
continue
}

Expand Down
60 changes: 60 additions & 0 deletions analyzer_test.go
Expand Up @@ -48,6 +48,66 @@ func header(header string) *goheader.Target {
Path: os.TempDir(),
}
}

func TestAnalyzer_Analyze6(t *testing.T) {
a := goheader.New(
goheader.WithTemplate("A {{ some-value }} B"),
goheader.WithValues(map[string]goheader.Value{
"SOME-VALUE": &goheader.ConstValue{
RawValue: "{{ some-value }}",
},
}),
)
var issue = a.Analyze(header("A {{ SOME-VALUE }} B"))
require.NotNil(t, issue)
require.Contains(t, issue.Message(), "recursion detected")
issue = a.Analyze(header("A {{ SOME-VALUE }} C"))
require.NotNil(t, issue)
require.Contains(t, issue.Message(), "recursion detected")
}

func TestAnalyzer_Analyze7(t *testing.T) {
a := goheader.New(
goheader.WithTemplate("A {{ some-value1 }} B"),
goheader.WithValues(map[string]goheader.Value{
"SOME-VALUE1": &goheader.ConstValue{
RawValue: "{{ some-value1 }}",
},
"SOME-VALUE2": &goheader.ConstValue{
RawValue: "{{ some-value2 }}",
},
}),
)
var issue = a.Analyze(header("A {{ SOME-VALUE }} B"))
require.NotNil(t, issue)
require.Contains(t, issue.Message(), "recursion detected")
issue = a.Analyze(header("A {{ SOME-VALUE }} C"))
require.NotNil(t, issue)
require.Contains(t, issue.Message(), "recursion detected")
}
func TestAnalyzer_Analyze8(t *testing.T) {
a := goheader.New(
goheader.WithTemplate("A {{ some-value3 }} B"),
goheader.WithValues(map[string]goheader.Value{
"SOME-VALUE1": &goheader.ConstValue{
RawValue: "{{ some-value2 }}",
},
"SOME-VALUE2": &goheader.ConstValue{
RawValue: "{{ some-value3 }}",
},
"SOME-VALUE3": &goheader.ConstValue{
RawValue: "{{ some-value1 }}",
},
}),
)
var issue = a.Analyze(header("A {{ SOME-VALUE }} B"))
require.NotNil(t, issue)
require.Contains(t, issue.Message(), "recursion detected")
issue = a.Analyze(header("A {{ SOME-VALUE }} C"))
require.NotNil(t, issue)
require.Contains(t, issue.Message(), "recursion detected")
}

func TestAnalyzer_YearRangeValue_ShouldWorkWithComplexVariables(t *testing.T) {
var conf goheader.Configuration
var vals, err = conf.GetValues()
Expand Down
10 changes: 6 additions & 4 deletions option.go
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022 Denis Tingaikin
// Copyright (c) 2020-2024 Denis Tingaikin
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -16,7 +16,10 @@

package goheader

import "strings"
import (
"context"
"strings"
)

type Option interface {
apply(*Analyzer)
Expand All @@ -30,9 +33,8 @@ func (f applyAnalyzerOptionFunc) apply(a *Analyzer) {

func WithValues(values map[string]Value) Option {
return applyAnalyzerOptionFunc(func(a *Analyzer) {
a.values = make(map[string]Value)
for k, v := range values {
a.values[strings.ToLower(k)] = v
a.values = context.WithValue(a.values, strings.ToLower(k), v)
}
})
}
Expand Down
38 changes: 25 additions & 13 deletions value.go
Expand Up @@ -17,14 +17,17 @@
package goheader

import (
"context"
"errors"
"fmt"
"regexp"
"strings"
)

const maxRecursionLevel = 15

type Calculable interface {
Calculate(map[string]Value) error
Calculate(context.Context) error
Get() string
Raw() string
}
Expand All @@ -34,23 +37,32 @@ type Value interface {
Read(*Reader) Issue
}

func calculateValue(calculable Calculable, values map[string]Value) (string, error) {
sb := strings.Builder{}
r := calculable.Raw()
var endIndex int
var startIndex int
func calculateValue(ctx context.Context, calculable Calculable) (string, error) {
var (
sb = strings.Builder{}
r = calculable.Raw()
startIndex, endIndex int
)
var level int
if v := ctx.Value("level"); v != nil {
level = v.(int)
}
if level > maxRecursionLevel {
return "", errors.New("recursion detected")
}
ctx = context.WithValue(ctx, "level", level+1)
for startIndex = strings.Index(r, "{{"); startIndex >= 0; startIndex = strings.Index(r, "{{") {
_, _ = sb.WriteString(r[:startIndex])
endIndex = strings.Index(r, "}}")
if endIndex < 0 {
return "", errors.New("missed value ending")
}
subVal := strings.ToLower(strings.TrimSpace(r[startIndex+2 : endIndex]))
if val := values[subVal]; val != nil {
if err := val.Calculate(values); err != nil {
if val := ctx.Value(subVal); val != nil {
if err := val.(Value).Calculate(ctx); err != nil {
return "", err
}
sb.WriteString(val.Get())
sb.WriteString(val.(Value).Get())
} else {
return "", fmt.Errorf("unknown value name %v", subVal)
}
Expand All @@ -65,8 +77,8 @@ type ConstValue struct {
RawValue, Value string
}

func (c *ConstValue) Calculate(values map[string]Value) error {
v, err := calculateValue(c, values)
func (c *ConstValue) Calculate(ctx context.Context) error {
v, err := calculateValue(ctx, c)
if err != nil {
return err
}
Expand Down Expand Up @@ -109,8 +121,8 @@ type RegexpValue struct {
RawValue, Value string
}

func (r *RegexpValue) Calculate(values map[string]Value) error {
v, err := calculateValue(r, values)
func (r *RegexpValue) Calculate(ctx context.Context) error {
v, err := calculateValue(ctx, r)
if err != nil {
return err
}
Expand Down

0 comments on commit 54e8061

Please sign in to comment.