Skip to content

Commit

Permalink
Added new mode (-p)
Browse files Browse the repository at this point in the history
Package mode:  When invoked in package mode, counterfeiter
will generate an interface and shim implementation from a
package in your GOPATH.  Counterfeiter finds the public
methods in the package <source-path> and adds those method
signatures to the generated interface <interface-name>.

Signed-off-by: Luke Woydziak <luke.woydziak@emc.com>
  • Loading branch information
Julian Hjortshoj authored and tjarratt committed Oct 14, 2016
1 parent d245d78 commit c2f4a41
Show file tree
Hide file tree
Showing 15 changed files with 1,258 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -21,6 +21,7 @@ _testmain.go

*.exe
*.test
*.iml

arguments/argumentsfakes/
fixtures/aliased_package/aliased_packagefakes/
Expand Down
10 changes: 9 additions & 1 deletion arguments/flags.go
@@ -1,6 +1,8 @@
package arguments

import "flag"
import (
"flag"
)

var (
fakeNameFlag = flag.String(
Expand All @@ -14,4 +16,10 @@ var (
"",
"The file or directory to which the generated fake will be written",
)

packageFlag = flag.Bool(
"p",
false,
"whether or not to generate a package shim",
)
)
61 changes: 58 additions & 3 deletions arguments/parser.go
@@ -1,8 +1,10 @@
package arguments

import (
"path"
"path/filepath"
"regexp"
"runtime"
"strings"
"unicode"

Expand Down Expand Up @@ -30,6 +32,14 @@ func NewArgumentParser(
}

func (argParser *argumentParser) ParseArguments(args ...string) ParsedArguments {
if *packageFlag {
return argParser.parsePackageArgs(args...)
} else {
return argParser.parseInterfaceArgs(args...)
}
}

func (argParser *argumentParser) parseInterfaceArgs(args ...string) ParsedArguments {
var interfaceName string
var outputPathFlagValue string
var rootDestinationDir string
Expand Down Expand Up @@ -59,9 +69,10 @@ func (argParser *argumentParser) ParseArguments(args ...string) ParsedArguments
packageName := restrictToValidPackageName(filepath.Base(filepath.Dir(outputPath)))

return ParsedArguments{
SourcePackageDir: sourcePackageDir,
ImportPath: importPath,
OutputPath: outputPath,
GenerateInterfaceAndShimFromPackageDirectory: false,
SourcePackageDir: sourcePackageDir,
OutputPath: outputPath,
ImportPath: importPath,

InterfaceName: interfaceName,
DestinationPackageName: packageName,
Expand All @@ -71,6 +82,30 @@ func (argParser *argumentParser) ParseArguments(args ...string) ParsedArguments
}
}

func (argParser *argumentParser) parsePackageArgs(args ...string) ParsedArguments {
dir := argParser.getPackageDir(args[0])

packageName := path.Base(dir) + "shim"

var outputPath string
if *outputPathFlag != "" {
// TODO: sensible checking of dirs and symlinks
outputPath = *outputPathFlag
} else {
outputPath = path.Join(argParser.currentWorkingDir(), packageName)
}

return ParsedArguments{
GenerateInterfaceAndShimFromPackageDirectory: true,
SourcePackageDir: dir,
OutputPath: outputPath,

DestinationPackageName: packageName,

PrintToStdOut: any(args, "-"),
}
}

type argumentParser struct {
ui terminal.UI
failHandler FailHandler
Expand All @@ -80,6 +115,8 @@ type argumentParser struct {
}

type ParsedArguments struct {
GenerateInterfaceAndShimFromPackageDirectory bool

SourcePackageDir string // abs path to the dir containing the interface to fake
ImportPath string // import path to the package containing the interface to fake
OutputPath string // path to write the fake file to
Expand Down Expand Up @@ -129,6 +166,24 @@ func packageNameForPath(pathToPackage string) string {
return packageName + "fakes"
}

func (argParser *argumentParser) getPackageDir(arg string) string {
if filepath.IsAbs(arg) {
return arg
}

pathToCheck := path.Join(runtime.GOROOT(), "src", arg)

stat, err := argParser.fileStatReader(pathToCheck)
if err != nil {
argParser.failHandler("No such file or directory '%s'", arg)
}
if !stat.IsDir() {
argParser.failHandler("No such file or directory '%s'", arg) // TODO: for now?
}

return pathToCheck
}

func (argParser *argumentParser) getSourceDir(arg string) string {
if !filepath.IsAbs(arg) {
arg = filepath.Join(argParser.currentWorkingDir(), arg)
Expand Down
28 changes: 26 additions & 2 deletions arguments/parser_test.go
@@ -1,14 +1,15 @@
package arguments_test
package arguments

import (
"errors"
"os"
"path"
"path/filepath"
"runtime"
"time"

"github.com/maxbrunsfeld/counterfeiter/terminal/terminalfakes"

. "github.com/maxbrunsfeld/counterfeiter/arguments"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
Expand Down Expand Up @@ -39,6 +40,7 @@ var _ = Describe("parsing arguments", func() {
})

BeforeEach(func() {
*packageFlag = false
failWasCalled = false
fail = func(_ string, _ ...interface{}) { failWasCalled = true }
cwd = func() string {
Expand All @@ -55,6 +57,28 @@ var _ = Describe("parsing arguments", func() {
}
})

Describe("when the -p flag is provided", func() {
BeforeEach(func() {
args = []string{"os"}
*packageFlag = true
})

It("doesn't parse extraneous arguments", func() {
Expect(parsedArgs.InterfaceName).To(Equal(""))
Expect(parsedArgs.FakeImplName).To(Equal(""))
})

Context("given a stdlib package", func() {
It("sets arguments as expected", func() {
Expect(parsedArgs.SourcePackageDir).To(Equal(path.Join(runtime.GOROOT(), "src/os")))
Expect(parsedArgs.OutputPath).To(Equal(path.Join(cwd(), "osshim")))
Expect(parsedArgs.DestinationPackageName).To(Equal("osshim"))
})
})

Context("given a relative path to a path to a package", func() {})
})

Describe("when a single argument is provided", func() {
BeforeEach(func() {
args = []string{"someonesinterfaces.AnInterface"}
Expand Down
43 changes: 43 additions & 0 deletions fixtures/packagegen/apackage/apackage.go
@@ -0,0 +1,43 @@
package ostest

import (
"fmt"
"os"
"time"
)

func FindProcess(pid int) (*os.Process, error) {
return os.FindProcess(pid)
}

func Hostname() (name string, err error) {
return os.Hostname()
}

func Expand(s string, mapping func(string) string) string {
return os.Expand(s, mapping)
}

func Clearenv() {
os.Clearenv()
}

func Environ() []string {
return os.Environ()
}

func Chtimes(name string, atime time.Time, mtime time.Time) error {
return os.Chtimes(name, atime, mtime)
}

func MkdirAll(path string, perm os.FileMode) error {
return os.MkdirAll(path, perm)
}

func Exit(code int) {
os.Exit(code)
}

func Fictional(lol ...string) {
fmt.Printf("%#v", lol)
}
19 changes: 19 additions & 0 deletions fixtures/packagegen/package_gen.go
@@ -0,0 +1,19 @@
// This file was generated by counterfeiter
package osshim

import (
"os"
"time"
)

type Os interface {
FindProcess(pid int) (*os.Process, error)
Hostname() (name string, err error)
Expand(s string, mapping func(string) string) string
Clearenv()
Environ() []string
Chtimes(name string, atime time.Time, mtime time.Time) error
MkdirAll(path string, perm os.FileMode) error
Exit(code int)
Fictional(lol ...string)
}

0 comments on commit c2f4a41

Please sign in to comment.