Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Refactorings to make testing easier
  • Loading branch information
pwittrock committed Jun 22, 2018
1 parent a8688aa commit 3ffb868
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 132 deletions.
2 changes: 2 additions & 0 deletions PROJECT
@@ -0,0 +1,2 @@
domain: k8s.io
repo: sigs.k8s.io/controller-tools
10 changes: 8 additions & 2 deletions cmd/controller-scaffold/cmd/api.go
Expand Up @@ -27,7 +27,6 @@ import (
"sigs.k8s.io/controller-tools/pkg/scaffold"
"sigs.k8s.io/controller-tools/pkg/scaffold/controller"
"sigs.k8s.io/controller-tools/pkg/scaffold/input"
"sigs.k8s.io/controller-tools/pkg/scaffold/project"
"sigs.k8s.io/controller-tools/pkg/scaffold/resource"
)

Expand Down Expand Up @@ -64,7 +63,7 @@ After the scaffold is written, api will run make on the project.
make run
`,
Run: func(cmd *cobra.Command, args []string) {
project.DieIfNoProject()
DieIfNoProject()

fmt.Println("Create Resource under pkg/apis [y/n]?")
re := yesno()
Expand Down Expand Up @@ -127,3 +126,10 @@ func init() {
rootCmd.AddCommand(APICmd)
r = resource.ForFlags(APICmd.Flags())
}

// DieIfNoProject checks to make sure the command is run from a directory containing a project file.
func DieIfNoProject() {
if _, err := os.Stat("PROJECT"); os.IsNotExist(err) {
log.Fatalf("Command must be run from a diretory containing %s", "PROJECT")
}
}
27 changes: 24 additions & 3 deletions cmd/controller-scaffold/cmd/project.go
Expand Up @@ -19,10 +19,10 @@ import (
"log"
"os"
"os/exec"

"strings"

"github.com/spf13/cobra"
flag "github.com/spf13/pflag"
"sigs.k8s.io/controller-tools/pkg/scaffold"
"sigs.k8s.io/controller-tools/pkg/scaffold/input"
"sigs.k8s.io/controller-tools/pkg/scaffold/manager"
Expand Down Expand Up @@ -102,9 +102,30 @@ controller-scaffold project --domain k8s.io --license apache2 --owner "The Kuber
func init() {
rootCmd.AddCommand(ProjectCmd)

prj = project.ForFlags(ProjectCmd.Flags())
bp = project.BoilerplateForFlags(ProjectCmd.Flags())
prj = ProjectForFlags(ProjectCmd.Flags())
bp = BoilerplateForFlags(ProjectCmd.Flags())
gopkg = &project.GopkgToml{}
mrg = &manager.Cmd{}
dkr = &manager.Dockerfile{}
}

// ProjectForFlags registers flags for Project fields and returns the Project
func ProjectForFlags(f *flag.FlagSet) *project.Project {
p := &project.Project{}
f.StringVar(&p.Domain, "domain", "k8s.io", "domain for groups")
f.StringVar(&p.Version, "project-version", "2", "project version")
f.StringVar(&p.Repo, "repo", "", "name of the github repo. "+
"defaults to the go package of the current working directory.")
return p
}

// BoilerplateForFlags registers flags for Boilerplate fields and returns the Boilerplate
func BoilerplateForFlags(f *flag.FlagSet) *project.Boilerplate {
b := &project.Boilerplate{}
f.StringVar(&b.Path, "path", "", "domain for groups")
f.StringVar(&b.License, "license", "apache2",
"license to use to boilerplate. Maybe one of apache2,none")
f.StringVar(&b.Owner, "owner", "",
"Owner to add to the copyright")
return b
}
32 changes: 27 additions & 5 deletions pkg/scaffold/input/input.go
Expand Up @@ -65,7 +65,9 @@ type Domain interface {

// SetDomain sets the domain
func (i *Input) SetDomain(d string) {
i.Domain = d
if i.Domain == "" {
i.Domain = d
}
}

// Repo allows a repo to be set on an object
Expand All @@ -76,7 +78,9 @@ type Repo interface {

// SetRepo sets the repo
func (i *Input) SetRepo(r string) {
i.Repo = r
if i.Repo == "" {
i.Repo = r
}
}

// Boilerplate allows boilerplate text to be set on an object
Expand All @@ -87,7 +91,9 @@ type Boilerplate interface {

// SetBoilerplate sets the boilerplate text
func (i *Input) SetBoilerplate(b string) {
i.Boilerplate = b
if i.Boilerplate == "" {
i.Boilerplate = b
}
}

// BoilerplatePath allows boilerplate file path to be set on an object
Expand All @@ -98,7 +104,9 @@ type BoilerplatePath interface {

// SetBoilerplatePath sets the boilerplate file path
func (i *Input) SetBoilerplatePath(bp string) {
i.BoilerplatePath = bp
if i.BoilerplatePath == "" {
i.BoilerplatePath = bp
}
}

// Version allows the project version to be set on an object
Expand All @@ -109,7 +117,9 @@ type Version interface {

// SetVersion sets the project version
func (i *Input) SetVersion(v string) {
i.Version = v
if i.Version == "" {
i.Version = v
}
}

// File is a scaffoldable file
Expand All @@ -132,3 +142,15 @@ type Options struct {
// Path is the path to the project
ProjectPath string
}

// ProjectFile is deserialized into a PROJECT file
type ProjectFile struct {
// Version is the project version - defaults to "2"
Version string `yaml:"version,omitempty"`

// Domain is the domain associated with the project and used for API groups
Domain string `yaml:"domain,omitempty"`

// Repo is the go package name of the project root
Repo string `yaml:"repo,omitempty"`
}
27 changes: 1 addition & 26 deletions pkg/scaffold/project/boilerplate.go
Expand Up @@ -18,11 +18,9 @@ package project

import (
"fmt"
"io/ioutil"
"path/filepath"
"time"

flag "github.com/spf13/pflag"
"sigs.k8s.io/controller-tools/pkg/scaffold/input"
)

Expand Down Expand Up @@ -84,28 +82,5 @@ limitations under the License.
*/`

var none = `/*
{{ if .Owner }}Copyright {{ .Year }} {{ .Owner }} {{ end }}.
{{ if .Owner }}Copyright {{ .Year }} {{ .Owner }}{{ end }}.
*/`

// BoilerplateForFlags registers flags for Boilerplate fields and returns the Boilerplate
func BoilerplateForFlags(f *flag.FlagSet) *Boilerplate {
b := &Boilerplate{}
f.StringVar(&b.Path, "path", "", "domain for groups")
f.StringVar(&b.License, "license", "apache2",
"license to use to boilerplate. Maybe one of apache2,none")
f.StringVar(&b.Owner, "owner", "",
"Owner to add to the copyright")
return b
}

// GetBoilerplate reads the boilerplate file
func GetBoilerplate(path string) (string, error) {
b, err := ioutil.ReadFile(path)
return string(b), err
}

// BoilerplatePath returns the default path to the boilerplate file
func BoilerplatePath() string {
i, _ := (&Boilerplate{}).GetInput()
return i.Path
}
20 changes: 9 additions & 11 deletions pkg/scaffold/project/gopkg.go
Expand Up @@ -35,7 +35,7 @@ type GopkgToml struct {
// ManagedHeader is the header to write after the user owned pieces and before the managed parts of the Gopkg.toml
ManagedHeader string

// DefaultUserContent is the default content to use for the user owned pieces
// DefaultGopkgUserContent is the default content to use for the user owned pieces
DefaultUserContent string

// UserContent is the content to use for the user owned pieces
Expand Down Expand Up @@ -65,12 +65,12 @@ func (g *GopkgToml) GetInput() (input.Input, error) {
g.Path = "Gopkg.toml"
}
if g.ManagedHeader == "" {
g.ManagedHeader = defaultHeader
g.ManagedHeader = DefaultGopkgHeader
}

// Set the user content to be used if the Gopkg.toml doesn't exist
if g.DefaultUserContent == "" {
g.DefaultUserContent = defaultUserContent
g.DefaultUserContent = DefaultGopkgUserContent
}

// Set the user owned content from the last Gopkg.toml file - e.g. everything before the header
Expand All @@ -81,16 +81,12 @@ func (g *GopkgToml) GetInput() (input.Input, error) {
return input.Input{}, err
}

g.Input.IfExistsAction = input.Overwrite
g.TemplateBody = depTemplate
return g.Input, nil
}

func (g *GopkgToml) getUserContent(b []byte) (string, error) {
if len(b) == 0 {
// Use the default user lines
return g.DefaultUserContent, nil
}

// Keep the users lines
scanner := bufio.NewScanner(bytes.NewReader(b))
userLines := []string{}
Expand All @@ -111,11 +107,13 @@ func (g *GopkgToml) getUserContent(b []byte) (string, error) {
return strings.Join(userLines, "\n"), nil
}

const defaultHeader = "# STANZAS BELOW ARE GENERATED AND MAY BE WRITTEN - DO NOT MODIFY BELOW THIS LINE."
// DefaultGopkgHeader is the default header used to separate user managed lines and controller-manager managed lines
const DefaultGopkgHeader = "# STANZAS BELOW ARE GENERATED AND MAY BE WRITTEN - DO NOT MODIFY BELOW THIS LINE."

const defaultUserContent = `required = [
// DefaultGopkgUserContent is the default user managed lines to provide.
const DefaultGopkgUserContent = `required = [
"github.com/emicklei/go-restful",
"github.com/onsi/ginkgo", # for test framework
"github.com/onsi/ginkgo", # for test framework
"github.com/onsi/gomega", # for test matchers
"k8s.io/client-go/plugin/pkg/client/auth/gcp", # for development against gcp
"k8s.io/code-generator/cmd/deepcopy-gen", # for go generate
Expand Down
83 changes: 19 additions & 64 deletions pkg/scaffold/project/project.go
Expand Up @@ -19,13 +19,10 @@ package project
import (
"fmt"
"go/build"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"

flag "github.com/spf13/pflag"
"gopkg.in/yaml.v2"
"sigs.k8s.io/controller-tools/pkg/scaffold/input"
)
Expand All @@ -35,16 +32,9 @@ var _ input.File = &Project{}
// Project scaffolds the PROJECT file with project metadata
type Project struct {
// Path is the output file location - defaults to PROJECT
Path string `yaml:",omitempty"`
Path string

// Version is the project version - defaults to "2"
Version string `yaml:"version,omitempty"`

// Domain is the domain associated with the project and used for API groups
Domain string `yaml:"domain,omitempty"`

// Repo is the go package name of the project root
Repo string `yaml:"repo,omitempty"`
input.ProjectFile
}

// GetInput implements input.File
Expand All @@ -53,87 +43,52 @@ func (c *Project) GetInput() (input.Input, error) {
c.Path = "PROJECT"
}
if c.Repo == "" {
r, err := c.defaultRepo()
r, err := c.repoFromGopathAndWd(os.Getenv("GOPATH"), os.Getwd)
if err != nil {
return input.Input{}, err
}
c.Repo = r
}

out, err := yaml.Marshal(c)
out, err := yaml.Marshal(c.ProjectFile)
if err != nil {
return input.Input{}, err
}

return input.Input{
Path: c.Path,
TemplateBody: string(out),
Repo: c.Repo,
Version: c.Version,
Domain: c.Domain,
}, nil
}

func (Project) defaultRepo() (string, error) {
func (Project) repoFromGopathAndWd(gopath string, getwd func() (string, error)) (string, error) {
// Assume the working dir is the root of the repo
wd, err := os.Getwd()
wd, err := getwd()
if err != nil {
log.Fatal(err)
return "", err
}

// Strip the GOPATH from the working dir to get the go package of the repo
gopath := os.Getenv("GOPATH")
if len(gopath) == 0 {
gopath = build.Default.GOPATH
}
goSrc := filepath.Join(gopath, "src")

// Make sure the GOPATH is set and the working dir is under the GOPATH
if !strings.HasPrefix(filepath.Dir(wd), goSrc) {
return "", fmt.Errorf("kubebuilder must be run from the project root under $GOPATH/src/<package>. "+
"\nCurrent GOPATH=%s. \nCurrent directory=%s", gopath, wd)
}

// Prune the base path from the go package for the repo
repo := strings.Replace(wd, fmt.Sprintf("%s%s", goSrc, string(filepath.Separator)), "", 1)

// Make sure the prune did what it was supposed to
if strings.Contains(repo, goSrc) {
return "", fmt.Errorf("could not parse go package for repo: %s", repo)
}
return repo, err
}

// GetProject reads the project file and deserializes it into a Project
func GetProject(path string) (Project, error) {
in, err := ioutil.ReadFile(path)
if err != nil {
return Project{}, err
}
p := Project{}
err = yaml.Unmarshal(in, &p)
if err != nil {
return Project{}, err
return "", fmt.Errorf("working directory must be a project directory under "+
"$GOPATH/src/<project-package>\n- GOPATH=%s\n- WD=%s", gopath, wd)
}
return p, nil
}

// ForFlags registers flags for Project fields and returns the Project
func ForFlags(f *flag.FlagSet) *Project {
p := &Project{}
f.StringVar(&p.Domain, "domain", "k8s.io", "domain for groups")
f.StringVar(&p.Version, "project-version", "2", "project version")
f.StringVar(&p.Repo, "repo", "", "name of the github repo. "+
"defaults to the go package of the current working directory.")
return p
}

// Path returns the default location for the PROJECT file
func Path() string {
i, _ := (&Project{}).GetInput()
return i.Path
}

// DieIfNoProject checks to make sure the command is run from a directory containing a project file.
func DieIfNoProject() {
if _, err := os.Stat(Path()); os.IsNotExist(err) {
log.Fatalf("Command must be run from a diretory containing %s", Path())
// Figure out the repo name by removing $GOPATH/src from the working directory - e.g.
// '$GOPATH/src/kubernetes-sigs/controller-tools' becomes 'kubernetes-sigs/controller-tools'
repo := ""
for wd != goSrc {
repo = filepath.Join(filepath.Base(wd), repo)
wd = filepath.Dir(wd)
}
return repo, nil
}

0 comments on commit 3ffb868

Please sign in to comment.