Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions github/git_trees.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ func (t TreeEntry) String() string {
return Stringify(t)
}

// treeEntryWithFileDelete is used internally to delete a file whose
// Content and SHA fields are empty. It does this by removing the "omitempty"
// tag modifier on the SHA field which causes the GitHub API to receive
// {"sha":null} and thereby delete the file.
type treeEntryWithFileDelete struct {
SHA *string `json:"sha"`
Path *string `json:"path,omitempty"`
Mode *string `json:"mode,omitempty"`
Type *string `json:"type,omitempty"`
Size *int `json:"size,omitempty"`
Content *string `json:"content,omitempty"`
URL *string `json:"url,omitempty"`
}

func (t *TreeEntry) MarshalJSON() ([]byte, error) {
if t.SHA == nil && t.Content == nil {
return json.Marshal(struct {
Expand Down Expand Up @@ -102,8 +116,8 @@ func (s *GitService) GetTree(ctx context.Context, owner string, repo string, sha

// createTree represents the body of a CreateTree request.
type createTree struct {
BaseTree string `json:"base_tree,omitempty"`
Entries []TreeEntry `json:"tree"`
BaseTree string `json:"base_tree,omitempty"`
Entries []interface{} `json:"tree"`
}

// CreateTree creates a new tree in a repository. If both a tree and a nested
Expand All @@ -114,9 +128,24 @@ type createTree struct {
func (s *GitService) CreateTree(ctx context.Context, owner string, repo string, baseTree string, entries []TreeEntry) (*Tree, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/git/trees", owner, repo)

newEntries := make([]interface{}, 0, len(entries))
for _, entry := range entries {
if entry.Content == nil && entry.SHA == nil {
newEntries = append(newEntries, treeEntryWithFileDelete{
Path: entry.Path,
Mode: entry.Mode,
Type: entry.Type,
Size: entry.Size,
URL: entry.URL,
})
continue
}
newEntries = append(newEntries, entry)
}

body := &createTree{
BaseTree: baseTree,
Entries: entries,
Entries: newEntries,
}
req, err := s.client.NewRequest("POST", u, body)
if err != nil {
Expand Down
48 changes: 23 additions & 25 deletions github/git_trees_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
package github

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"testing"
Expand Down Expand Up @@ -68,17 +69,16 @@ func TestGitService_CreateTree(t *testing.T) {
}

mux.HandleFunc("/repos/o/r/git/trees", func(w http.ResponseWriter, r *http.Request) {
v := new(createTree)
json.NewDecoder(r.Body).Decode(v)
got, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}

testMethod(t, r, "POST")

want := &createTree{
BaseTree: "b",
Entries: input,
}
if !reflect.DeepEqual(v, want) {
t.Errorf("Git.CreateTree request body: %+v, want %+v", v, want)
want := []byte(`{"base_tree":"b","tree":[{"sha":"7c258a9869f33c1e1e1f74fbb32f07c86cb5a75b","path":"file.rb","mode":"100644","type":"blob"}]}` + "\n")
if !bytes.Equal(got, want) {
t.Errorf("Git.CreateTree request body: %s, want %s", got, want)
}

fmt.Fprint(w, `{
Expand Down Expand Up @@ -132,17 +132,16 @@ func TestGitService_CreateTree_Content(t *testing.T) {
}

mux.HandleFunc("/repos/o/r/git/trees", func(w http.ResponseWriter, r *http.Request) {
v := new(createTree)
json.NewDecoder(r.Body).Decode(v)
got, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}

testMethod(t, r, "POST")

want := &createTree{
BaseTree: "b",
Entries: input,
}
if !reflect.DeepEqual(v, want) {
t.Errorf("Git.CreateTree request body: %+v, want %+v", v, want)
want := []byte(`{"base_tree":"b","tree":[{"path":"content.md","mode":"100644","content":"file content"}]}` + "\n")
if !bytes.Equal(got, want) {
t.Errorf("Git.CreateTree request body: %s, want %s", got, want)
}

fmt.Fprint(w, `{
Expand Down Expand Up @@ -198,17 +197,16 @@ func TestGitService_CreateTree_Delete(t *testing.T) {
}

mux.HandleFunc("/repos/o/r/git/trees", func(w http.ResponseWriter, r *http.Request) {
v := new(createTree)
json.NewDecoder(r.Body).Decode(v)
got, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}

testMethod(t, r, "POST")

want := &createTree{
BaseTree: "b",
Entries: input,
}
if !reflect.DeepEqual(v, want) {
t.Errorf("Git.CreateTree request body: %+v, want %+v", v, want)
want := []byte(`{"base_tree":"b","tree":[{"sha":null,"path":"content.md","mode":"100644"}]}` + "\n")
if !bytes.Equal(got, want) {
t.Errorf("Git.CreateTree request body: %s, want %s", got, want)
}

fmt.Fprint(w, `{
Expand Down