Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal/color/color.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const (
TextTypeError
// TextTypeWarning is for warning messages
TextTypeWarning
// TextTypeXml indicates the content is XML
TextTypeXml
)

var typeMap map[TextType]chroma.TokenType = map[TextType]chroma.TokenType{
Expand Down Expand Up @@ -85,6 +87,10 @@ func (c *chromaColorizer) Write(w io.Writer, s string, scheme string, t TextType
if err = quick.Highlight(w, s, "transact-sql", "terminal16m", scheme); err != nil {
_, err = w.Write([]byte(s))
}
case TextTypeXml:
if err = quick.Highlight(w, s, "xml", "terminal16m", scheme); err != nil {
_, err = w.Write([]byte(s))
}
default:
tokens := chroma.Literator(chroma.Token{
Type: typeMap[t], Value: s})
Expand Down
5 changes: 5 additions & 0 deletions internal/color/color_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func TestWrite(t *testing.T) {
args: args{s: "warn", t: TextTypeWarning},
wantW: "\x1b[3mwarn\x1b[0m",
},
{
name: "XML",
args: args{s: "<node>value</node>", t: TextTypeXml},
wantW: "\x1b[1m\x1b[38;2;0;128;0m<node>\x1b[0mvalue\x1b[1m\x1b[38;2;0;128;0m</node>\x1b[0m",
},
}

for _, tt := range tests {
Expand Down
29 changes: 28 additions & 1 deletion pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ func newCommands() Commands {
action: onerrorCommand,
name: "ONERROR",
},
"XML": {
regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`),
action: xmlCommand,
name: "XML",
},
}
}

Expand Down Expand Up @@ -368,10 +373,16 @@ func listCommand(s *Sqlcmd, args []string, line uint) (err error) {
}
output := s.GetOutput()
if cmd == "color" {
sample := "select 'literal' as literal, 100 as number from [sys].[tables]"
clr := color.TextTypeTSql
if s.Format.IsXmlMode() {
sample = `<node att="attValue"/><node>value</node>`
clr = color.TextTypeXml
}
// ignoring errors since it's not critical output
for _, style := range s.colorizer.Styles() {
_, _ = output.Write([]byte(style + ": "))
_ = s.colorizer.Write(output, "select 'literal' as literal, 100 as number from [sys].[tables]", style, color.TextTypeTSql)
_ = s.colorizer.Write(output, sample, style, clr)
_, _ = output.Write([]byte(SqlcmdEol))
}
return
Expand Down Expand Up @@ -507,6 +518,22 @@ func onerrorCommand(s *Sqlcmd, args []string, line uint) error {
return nil
}

func xmlCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) != 1 || args[0] == "" {
return InvalidCommandError("XML", line)
}
params := strings.TrimSpace(args[0])
// "OFF" and "ON" are documented as the allowed values.
// ODBC sqlcmd treats any value other than "ON" the same as "OFF".
// So we will too.
if strings.EqualFold(params, "on") {
s.Format.XmlMode(true)
} else {
s.Format.XmlMode(false)
}
return nil
}

func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
var b *strings.Builder
end := len(arg)
Expand Down
2 changes: 2 additions & 0 deletions pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func TestCommandParsing(t *testing.T) {
{`:!!notepad`, "EXEC", []string{"notepad"}},
{` !! dir c:\`, "EXEC", []string{` dir c:\`}},
{`!!dir c:\`, "EXEC", []string{`dir c:\`}},
{`:XML ON `, "XML", []string{`ON `}},
}

for _, test := range commands {
Expand Down Expand Up @@ -187,6 +188,7 @@ func TestListCommandUsesColorizer(t *testing.T) {
func TestListColorPrintsStyleSamples(t *testing.T) {
vars := InitializeVariables(false)
s := New(nil, "", vars)
s.Format = NewSQLCmdDefaultFormatter(false)
// force colorizer on
s.colorizer = color.New(true)
buf := &memoryBuffer{buf: new(bytes.Buffer)}
Expand Down
50 changes: 37 additions & 13 deletions pkg/sqlcmd/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type Formatter interface {
AddMessage(string)
// AddError is called for each error encountered during batch execution
AddError(err error)
// XmlMode enables or disables XML rendering mode
XmlMode(enable bool)
// IsXmlMode returns whether XML mode is enabled
IsXmlMode() bool
}

// ControlCharacterBehavior specifies the text handling required for control characters in the output
Expand Down Expand Up @@ -77,6 +81,7 @@ type sqlCmdFormatterType struct {
format string
maxColNameLen int
colorizer color.Colorizer
xml bool
}

// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter
Expand Down Expand Up @@ -119,7 +124,7 @@ func (f *sqlCmdFormatterType) writeOut(s string, t color.TextType) {
}
}

// Stores the settings to use for processing the current batch
// BeginBatch stores the settings to use for processing the current batch
// TODO: add a third io.Writer for messages when we add -r support
func (f *sqlCmdFormatterType) BeginBatch(_ string, vars *Variables, out io.Writer, err io.Writer) {
f.out = out
Expand All @@ -138,17 +143,19 @@ func (f *sqlCmdFormatterType) EndBatch() {
func (f *sqlCmdFormatterType) BeginResultSet(cols []*sql.ColumnType) {
f.rowcount = 0
f.columnDetails, f.maxColNameLen = calcColumnDetails(cols, f.vars.MaxFixedColumnWidth(), f.vars.MaxVarColumnWidth())
if f.vars.RowsBetweenHeaders() > -1 && f.format == "horizontal" {
if f.vars.RowsBetweenHeaders() > -1 && f.format == "horizontal" && !f.xml {
f.printColumnHeadings()
}
}

// Writes a blank line to the designated output writer
// EndResultSet writes a blank line to the designated output writer
func (f *sqlCmdFormatterType) EndResultSet() {
f.writeOut(SqlcmdEol, color.TextTypeNormal)
if !f.xml {
f.writeOut(SqlcmdEol, color.TextTypeNormal)
}
}

// Writes the current row to the designated output writer
// AddRow writes the current row to the designated output writer
func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
retval := ""
values, err := f.scanRow(row)
Expand All @@ -157,7 +164,9 @@ func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
return retval
}
retval = values[0]
if f.format == "horizontal" {
if f.xml {
f.printColumnValue(retval, 0)
} else if f.format == "horizontal" {
// values are the full values, look at the displaywidth of each column and truncate accordingly
for i, v := range values {
if i > 0 {
Expand All @@ -176,7 +185,6 @@ func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
}
f.writeOut(SqlcmdEol, color.TextTypeNormal)
return retval

}

func (f *sqlCmdFormatterType) addVerticalRow(values []string) {
Expand All @@ -193,12 +201,14 @@ func (f *sqlCmdFormatterType) addVerticalRow(values []string) {
}
}

// Writes a non-error message to the designated message writer
// AddMessage writes a non-error message to the designated message writer
func (f *sqlCmdFormatterType) AddMessage(msg string) {
f.mustWriteOut(msg+SqlcmdEol, color.TextTypeWarning)
if !f.xml {
f.mustWriteOut(msg+SqlcmdEol, color.TextTypeWarning)
}
}

// Writes an error to the designated err Writer
// AddError writes an error to the designated err Writer
func (f *sqlCmdFormatterType) AddError(err error) {
print := true
b := new(strings.Builder)
Expand All @@ -217,6 +227,16 @@ func (f *sqlCmdFormatterType) AddError(err error) {
}
}

// XmlMode enables or disables XML mode
func (f *sqlCmdFormatterType) XmlMode(enable bool) {
f.xml = enable
}

// IsXmlMode returns whether XML mode is enabled
func (f *sqlCmdFormatterType) IsXmlMode() bool {
return f.xml
}

// Prints column headings based on columnDetail, variables, and command line arguments
func (f *sqlCmdFormatterType) printColumnHeadings() {
names := new(strings.Builder)
Expand Down Expand Up @@ -535,7 +555,7 @@ func (f *sqlCmdFormatterType) printColumnValue(val string, col int) {

s.WriteString(val)
r := []rune(val)
if f.format == "horizontal" {
if !f.xml && f.format == "horizontal" {
if !f.removeTrailingSpaces {
if f.vars.MaxVarColumnWidth() != 0 || !isLargeVariableType(&c.col) {
padding := c.displayWidth - min64(c.displayWidth, int64(len(r)))
Expand All @@ -551,11 +571,15 @@ func (f *sqlCmdFormatterType) printColumnValue(val string, col int) {

r = []rune(s.String())
}
if c.displayWidth > 0 && int64(len(r)) > c.displayWidth {
if !f.xml && (c.displayWidth > 0 && int64(len(r)) > c.displayWidth) {
s.Reset()
s.WriteString(string(r[:c.displayWidth]))
}
f.writeOut(s.String(), color.TextTypeCell)
clr := color.TextTypeCell
if f.xml {
clr = color.TextTypeXml
}
f.writeOut(s.String(), clr)
}

func (f *sqlCmdFormatterType) mustWriteOut(s string, t color.TextType) {
Expand Down
9 changes: 9 additions & 0 deletions pkg/sqlcmd/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,12 @@ func TestFormatterColorizer(t *testing.T) {
assert.NoError(t, err, "runSqlCmd returned error")
assert.Equal(t, "\x1b[38;2;0;128;0mname\x1b[0m"+SqlcmdEol+SqlcmdEol+"\x1b[3m(1 row affected)"+SqlcmdEol+"\x1b[0m", buf.buf.String())
}

func TestFormatterXmlMode(t *testing.T) {
s, buf := setupSqlCmdWithMemoryOutput(t)
defer buf.Close()
s.Format.XmlMode(true)
err := runSqlCmd(t, s, []string{"select name from sys.databases where name='master' for xml auto ", "GO"})
assert.NoError(t, err, "runSqlCmd returned error")
assert.Equal(t, `<sys.databases name="master"/>`+SqlcmdEol, buf.buf.String())
}