diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..032c1bd1 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +* text=auto + diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index e18f537f..a28ef0ed 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -1,20 +1,20 @@ -name: golangci-lint -on: - push: - branches: - - main - pull_request: -jobs: - golangci-pr: - name: lint-pr-changes - runs-on: ubuntu-latest - steps: - - uses: actions/setup-go@v3 - with: - go-version: 1.18 - - uses: actions/checkout@v3 - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - version: latest - only-new-issues: true +name: golangci-lint +on: + push: + branches: + - main + pull_request: +jobs: + golangci-pr: + name: lint-pr-changes + runs-on: ubuntu-latest + steps: + - uses: actions/setup-go@v3 + with: + go-version: 1.18 + - uses: actions/checkout@v3 + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest + only-new-issues: true diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 776b310b..589234cb 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -1,53 +1,53 @@ -variables: - # AZURE_CLIENT_SECRET and SQLPASSWORD must be defined as secret variables in the pipeline. - # AZURE_TENANT_ID and AZURE_CLIENT_ID are not expected to be secret variables, just regular variables - AZURECLIENTSECRET: $(AZURE_CLIENT_SECRET) - PASSWORD: $(SQLPASSWORD) -pool: - vmImage: 'ubuntu-latest' - -steps: - - template: include-install-go-tools.yml - - - task: Docker@2 - displayName: 'Run SQL 2017 docker image' - inputs: - command: run - arguments: '-m 2GB -e ACCEPT_EULA=1 -d --name sql2017 -p:1433:1433 -e SA_PASSWORD=$(PASSWORD) mcr.microsoft.com/mssql/server:2017-latest' - - - template: include-runtests-linux.yml - parameters: - RunName: 'SQL2017' - SQLCMDUSER: sa - SQLPASSWORD: $(PASSWORD) - - - template: include-runtests-linux.yml - parameters: - RunName: 'SQLDB' - # AZURESERVER must be defined as a variable in the pipeline - SQLCMDSERVER: $(AZURESERVER) - AZURECLIENTSECRET: $(AZURECLIENTSECRET) - - - task: Palmmedia.reportgenerator.reportgenerator-build-release-task.reportgenerator@4 - displayName: Merge coverage data - inputs: - reports: '**/*.coverage.xml"' # REQUIRED # The coverage reports that should be parsed (separated by semicolon). Globbing is supported. - targetdir: 'coverage' # REQUIRED # The directory where the generated report should be saved. - reporttypes: 'HtmlInline_AzurePipelines;Cobertura' # The output formats and scope (separated by semicolon) Values: Badges, Clover, Cobertura, CsvSummary, Html, HtmlChart, HtmlInline, HtmlInline_AzurePipelines, HtmlInline_AzurePipelines_Dark, HtmlSummary, JsonSummary, Latex, LatexSummary, lcov, MarkdownSummary, MHtml, PngChart, SonarQube, TeamCitySummary, TextSummary, Xml, XmlSummary - sourcedirs: '$(Build.SourcesDirectory)' # Optional directories which contain the corresponding source code (separated by semicolon). The source directories are used if coverage report contains classes without path information. - verbosity: 'Info' # The verbosity level of the log messages. Values: Verbose, Info, Warning, Error, Off - tag: '$(build.buildnumber)_#$(build.buildid)_$(Build.SourceBranchName)' # Optional tag or build version. - - task: PublishCodeCoverageResults@1 - inputs: - codeCoverageTool: Cobertura - pathToSources: '$(Build.SourcesDirectory)' - summaryFileLocation: $(Build.SourcesDirectory)/coverage/*.xml - reportDirectory: $(Build.SourcesDirectory)/coverage - failIfCoverageEmpty: true - condition: always() - continueOnError: true - env: - disable.coverage.autogenerate: 'true' - - - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 - displayName: ‘Component Detection’ +variables: + # AZURE_CLIENT_SECRET and SQLPASSWORD must be defined as secret variables in the pipeline. + # AZURE_TENANT_ID and AZURE_CLIENT_ID are not expected to be secret variables, just regular variables + AZURECLIENTSECRET: $(AZURE_CLIENT_SECRET) + PASSWORD: $(SQLPASSWORD) +pool: + vmImage: 'ubuntu-latest' + +steps: + - template: include-install-go-tools.yml + + - task: Docker@2 + displayName: 'Run SQL 2017 docker image' + inputs: + command: run + arguments: '-m 2GB -e ACCEPT_EULA=1 -d --name sql2017 -p:1433:1433 -e SA_PASSWORD=$(PASSWORD) mcr.microsoft.com/mssql/server:2017-latest' + + - template: include-runtests-linux.yml + parameters: + RunName: 'SQL2017' + SQLCMDUSER: sa + SQLPASSWORD: $(PASSWORD) + + - template: include-runtests-linux.yml + parameters: + RunName: 'SQLDB' + # AZURESERVER must be defined as a variable in the pipeline + SQLCMDSERVER: $(AZURESERVER) + AZURECLIENTSECRET: $(AZURECLIENTSECRET) + + - task: Palmmedia.reportgenerator.reportgenerator-build-release-task.reportgenerator@4 + displayName: Merge coverage data + inputs: + reports: '**/*.coverage.xml"' # REQUIRED # The coverage reports that should be parsed (separated by semicolon). Globbing is supported. + targetdir: 'coverage' # REQUIRED # The directory where the generated report should be saved. + reporttypes: 'HtmlInline_AzurePipelines;Cobertura' # The output formats and scope (separated by semicolon) Values: Badges, Clover, Cobertura, CsvSummary, Html, HtmlChart, HtmlInline, HtmlInline_AzurePipelines, HtmlInline_AzurePipelines_Dark, HtmlSummary, JsonSummary, Latex, LatexSummary, lcov, MarkdownSummary, MHtml, PngChart, SonarQube, TeamCitySummary, TextSummary, Xml, XmlSummary + sourcedirs: '$(Build.SourcesDirectory)' # Optional directories which contain the corresponding source code (separated by semicolon). The source directories are used if coverage report contains classes without path information. + verbosity: 'Info' # The verbosity level of the log messages. Values: Verbose, Info, Warning, Error, Off + tag: '$(build.buildnumber)_#$(build.buildid)_$(Build.SourceBranchName)' # Optional tag or build version. + - task: PublishCodeCoverageResults@1 + inputs: + codeCoverageTool: Cobertura + pathToSources: '$(Build.SourcesDirectory)' + summaryFileLocation: $(Build.SourcesDirectory)/coverage/*.xml + reportDirectory: $(Build.SourcesDirectory)/coverage + failIfCoverageEmpty: true + condition: always() + continueOnError: true + env: + disable.coverage.autogenerate: 'true' + + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: ‘Component Detection’ diff --git a/.pipelines/include-install-go-tools.yml b/.pipelines/include-install-go-tools.yml index 53aeb1f2..3e33cf47 100644 --- a/.pipelines/include-install-go-tools.yml +++ b/.pipelines/include-install-go-tools.yml @@ -1,36 +1,36 @@ -steps: - - task: GoTool@0 - inputs: - version: '1.18' - - task: Go@0 - displayName: 'Go: get dependencies' - inputs: - command: 'get' - arguments: '-d' - workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' - - - - task: Go@0 - displayName: 'Go: install gotest.tools/gotestsum' - inputs: - command: 'custom' - customCommand: 'install' - arguments: 'gotest.tools/gotestsum@latest' - workingDirectory: '$(System.DefaultWorkingDirectory)' - - - task: Go@0 - displayName: 'Go: install github.com/axw/gocov/gocov' - inputs: - command: 'custom' - customCommand: 'install' - arguments: 'github.com/axw/gocov/gocov@latest' - workingDirectory: '$(System.DefaultWorkingDirectory)' - - - task: Go@0 - displayName: 'Go: install github.com/axw/gocov/gocov' - inputs: - command: 'custom' - customCommand: 'install' - arguments: 'github.com/AlekSi/gocov-xml@latest' - workingDirectory: '$(System.DefaultWorkingDirectory)' +steps: + - task: GoTool@0 + inputs: + version: '1.18' + - task: Go@0 + displayName: 'Go: get dependencies' + inputs: + command: 'get' + arguments: '-d' + workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' + + + - task: Go@0 + displayName: 'Go: install gotest.tools/gotestsum' + inputs: + command: 'custom' + customCommand: 'install' + arguments: 'gotest.tools/gotestsum@latest' + workingDirectory: '$(System.DefaultWorkingDirectory)' + + - task: Go@0 + displayName: 'Go: install github.com/axw/gocov/gocov' + inputs: + command: 'custom' + customCommand: 'install' + arguments: 'github.com/axw/gocov/gocov@latest' + workingDirectory: '$(System.DefaultWorkingDirectory)' + + - task: Go@0 + displayName: 'Go: install github.com/axw/gocov/gocov' + inputs: + command: 'custom' + customCommand: 'install' + arguments: 'github.com/AlekSi/gocov-xml@latest' + workingDirectory: '$(System.DefaultWorkingDirectory)' \ No newline at end of file diff --git a/.pipelines/include-runtests-linux.yml b/.pipelines/include-runtests-linux.yml index 2735a10e..e1aa0d42 100644 --- a/.pipelines/include-runtests-linux.yml +++ b/.pipelines/include-runtests-linux.yml @@ -1,46 +1,46 @@ -parameters: -- name: RunName - type: string -- name: SQLCMDUSER - type: string - default: '' -- name: SQLPASSWORD - type: string - default: '' -- name: AZURECLIENTSECRET - type: string - default: '' -- name: SQLCMDSERVER - type: string - default: . -- name: SQLCMDDBNAME - type: string - default: '' -steps: - - script: | - ~/go/bin/gotestsum --junitfile "${{ parameters.RunName }}.testresults.xml" -- ./... -coverprofile="${{ parameters.RunName }}.coverage.txt" -covermode count - ~/go/bin/gocov convert "${{ parameters.RunName }}.coverage.txt" > "${{ parameters.RunName }}.coverage.json" - ~/go/bin/gocov-xml < "${{ parameters.RunName }}.coverage.json" > ${{ parameters.RunName }}.coverage.xml - mkdir -p coverage - workingDirectory: '$(Build.SourcesDirectory)' - displayName: 'run tests' - env: - SQLPASSWORD: ${{ parameters.SQLPASSWORD }} - SQLCMDUSER: ${{ parameters.SQLCMDUSER }} - SQLCMDPASSWORD: ${{ parameters.SQLPASSWORD }} - AZURE_TENANT_ID: $(AZURE_TENANT_ID) - AZURE_CLIENT_ID: $(AZURE_CLIENT_ID) - AZURE_CLIENT_SECRET: ${{ parameters.AZURECLIENTSECRET }} - SQLCMDSERVER: ${{ parameters.SQLCMDSERVER }} - SQLCMDDBNAME: ${{ parameters.SQLCMDDBNAME }} - continueOnError: true - - - task: PublishTestResults@2 - displayName: "Publish junit-style results" - inputs: - testResultsFiles: '${{ parameters.RunName }}.testresults.xml' - testResultsFormat: JUnit - searchFolder: '$(Build.SourcesDirectory)' - testRunTitle: '${{ parameters.RunName }} - $(Build.SourceBranchName)' - failTaskOnFailedTests: true - condition: always() +parameters: +- name: RunName + type: string +- name: SQLCMDUSER + type: string + default: '' +- name: SQLPASSWORD + type: string + default: '' +- name: AZURECLIENTSECRET + type: string + default: '' +- name: SQLCMDSERVER + type: string + default: . +- name: SQLCMDDBNAME + type: string + default: '' +steps: + - script: | + ~/go/bin/gotestsum --junitfile "${{ parameters.RunName }}.testresults.xml" -- ./... -coverprofile="${{ parameters.RunName }}.coverage.txt" -covermode count + ~/go/bin/gocov convert "${{ parameters.RunName }}.coverage.txt" > "${{ parameters.RunName }}.coverage.json" + ~/go/bin/gocov-xml < "${{ parameters.RunName }}.coverage.json" > ${{ parameters.RunName }}.coverage.xml + mkdir -p coverage + workingDirectory: '$(Build.SourcesDirectory)' + displayName: 'run tests' + env: + SQLPASSWORD: ${{ parameters.SQLPASSWORD }} + SQLCMDUSER: ${{ parameters.SQLCMDUSER }} + SQLCMDPASSWORD: ${{ parameters.SQLPASSWORD }} + AZURE_TENANT_ID: $(AZURE_TENANT_ID) + AZURE_CLIENT_ID: $(AZURE_CLIENT_ID) + AZURE_CLIENT_SECRET: ${{ parameters.AZURECLIENTSECRET }} + SQLCMDSERVER: ${{ parameters.SQLCMDSERVER }} + SQLCMDDBNAME: ${{ parameters.SQLCMDDBNAME }} + continueOnError: true + + - task: PublishTestResults@2 + displayName: "Publish junit-style results" + inputs: + testResultsFiles: '${{ parameters.RunName }}.testresults.xml' + testResultsFormat: JUnit + searchFolder: '$(Build.SourcesDirectory)' + testRunTitle: '${{ parameters.RunName }} - $(Build.SourceBranchName)' + failTaskOnFailedTests: true + condition: always() diff --git a/SUPPORT.md b/SUPPORT.md index dc72f0e5..8b05616f 100644 --- a/SUPPORT.md +++ b/SUPPORT.md @@ -1,25 +1,25 @@ -# TODO: The maintainer of this repo has not yet edited this file - -**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? - -- **No CSS support:** Fill out this template with information about how to file issues and get help. -- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). -- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. - -*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* - -# Support - -## How to file issues and get help - -This project uses GitHub Issues to track bugs and feature requests. Please search the existing -issues before filing new issues to avoid duplicates. For new issues, file your bug or -feature request as a new Issue. - -For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE -FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER -CHANNEL. WHERE WILL YOU HELP PEOPLE?**. - -## Microsoft Support Policy - -Support for this **PROJECT or PRODUCT** is limited to the resources listed above. +# TODO: The maintainer of this repo has not yet edited this file + +**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? + +- **No CSS support:** Fill out this template with information about how to file issues and get help. +- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). +- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. + +*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* + +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE +FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER +CHANNEL. WHERE WILL YOU HELP PEOPLE?**. + +## Microsoft Support Policy + +Support for this **PROJECT or PRODUCT** is limited to the resources listed above. diff --git a/build/azure-pipelines/build-common.yml b/build/azure-pipelines/build-common.yml index 62e91668..3ad9659b 100644 --- a/build/azure-pipelines/build-common.yml +++ b/build/azure-pipelines/build-common.yml @@ -1,70 +1,70 @@ -parameters: -- name: OS - type: string - default: -- name: Arch - type: string - default: -- name: ArtifactName - type: string -- name: VersionTag - type: string - default: $(Build.BuildNumber) - -steps: -- task: GoTool@0 - inputs: - version: '1.18' - goBin: $(Build.SourcesDirectory) - -- task: Go@0 - displayName: 'Go install go-winres' - inputs: - command: 'custom' - customCommand: 'install' - arguments: 'github.com/tc-hib/go-winres@latest' - workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' - env: - GOBIN: $(Build.SourcesDirectory) - -- task: CmdLine@2 - displayName: 'generate version resource' - inputs: - script: $(Build.SourcesDirectory)/go-winres make --file-version git-tag --product-version git-tag - workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' - -- task: Go@0 - displayName: 'Go: get dependencies' - inputs: - command: 'get' - arguments: '-d' - workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' - env: - GOOS: ${{ parameters.OS }} - GOARCH: ${{ parameters.Arch }} - GOBIN: $(Build.SourcesDirectory) - -- task: Go@0 - displayName: 'Go: build sqlcmd' - inputs: - command: 'build' - arguments: '-o $(Build.BinariesDirectory) -ldflags="-X main.version=${{ parameters.VersionTag }}"' - workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' - env: - GOOS: ${{ parameters.OS }} - GOARCH: ${{ parameters.Arch }} - GOBIN: $(Build.SourcesDirectory) - CGO_ENABLED: 0 # Enables Docker image based off 'scratch' - -- task: CopyFiles@2 - inputs: - TargetFolder: '$(Build.ArtifactStagingDirectory)' - SourceFolder: '$(Build.BinariesDirectory)' - Contents: '**' - -- task: PublishPipelineArtifact@1 - displayName: 'Publish binary' - inputs: - targetPath: $(Build.ArtifactStagingDirectory) - artifactName: 'Sqlcmd${{ parameters.ArtifactName }}' - +parameters: +- name: OS + type: string + default: +- name: Arch + type: string + default: +- name: ArtifactName + type: string +- name: VersionTag + type: string + default: $(Build.BuildNumber) + +steps: +- task: GoTool@0 + inputs: + version: '1.18' + goBin: $(Build.SourcesDirectory) + +- task: Go@0 + displayName: 'Go install go-winres' + inputs: + command: 'custom' + customCommand: 'install' + arguments: 'github.com/tc-hib/go-winres@latest' + workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' + env: + GOBIN: $(Build.SourcesDirectory) + +- task: CmdLine@2 + displayName: 'generate version resource' + inputs: + script: $(Build.SourcesDirectory)/go-winres make --file-version git-tag --product-version git-tag + workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' + +- task: Go@0 + displayName: 'Go: get dependencies' + inputs: + command: 'get' + arguments: '-d' + workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' + env: + GOOS: ${{ parameters.OS }} + GOARCH: ${{ parameters.Arch }} + GOBIN: $(Build.SourcesDirectory) + +- task: Go@0 + displayName: 'Go: build sqlcmd' + inputs: + command: 'build' + arguments: '-o $(Build.BinariesDirectory) -ldflags="-X main.version=${{ parameters.VersionTag }}"' + workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' + env: + GOOS: ${{ parameters.OS }} + GOARCH: ${{ parameters.Arch }} + GOBIN: $(Build.SourcesDirectory) + CGO_ENABLED: 0 # Enables Docker image based off 'scratch' + +- task: CopyFiles@2 + inputs: + TargetFolder: '$(Build.ArtifactStagingDirectory)' + SourceFolder: '$(Build.BinariesDirectory)' + Contents: '**' + +- task: PublishPipelineArtifact@1 + displayName: 'Publish binary' + inputs: + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: 'Sqlcmd${{ parameters.ArtifactName }}' + diff --git a/build/azure-pipelines/build-product.yml b/build/azure-pipelines/build-product.yml index c028c4ae..f9a6a75a 100644 --- a/build/azure-pipelines/build-product.yml +++ b/build/azure-pipelines/build-product.yml @@ -1,197 +1,197 @@ -trigger: - tags: - include: - - v* - -pr: none - -parameters: - - name: PushToGithub - default: true - type: boolean - displayName: Push packages to github - -stages: - - stage: Compile - displayName: Compile sqlcmd on all supported platforms - jobs: - - job: Sqlcmd - strategy: - matrix: - linux: - imageName: 'ubuntu-latest' - artifact: LinuxAmd64 - os: - arch: - mac: - imageName: 'macOS-latest' - artifact: DarwinAmd64 - os: - arch: - windows: - imageName: 'windows-latest' - artifact: WindowsAmd64 - os: - arch: - linuxArm: - imageName: 'ubuntu-latest' - artifact: LinuxArm64 - os: - arch: arm64 - windowsArm: - imageName: 'windows-latest' - artifact: WindowsArm - os: - arch: arm - linuxs390x: - imageName: 'ubuntu-latest' - artifact: LinuxS390x - os: - arch: s390x - pool: - vmImage: $(imageName) - steps: - - template: build-tag.yml - - script: | - echo $(getVersion.VERSION_TAG) - - template: build-common.yml - parameters: - OS: $(os) - Arch: $(arch) - ArtifactName: $(artifact) - VersionTag: $(getVersion.VERSION_TAG) - - - stage: CreatePackages - displayName: Create packages to publish - jobs: - - job: Sign_and_pack - pool: - vmImage: 'windows-latest' - steps: - - template: build-tag.yml - - task: DownloadPipelineArtifact@2 - inputs: - buildType: 'current' - targetPath: '$(Pipeline.Workspace)' - - task: EsrpCodeSigning@1 - displayName: Sign Windows binary - inputs: - ConnectedServiceName: 'Code Signing' - FolderPath: '$(Pipeline.Workspace)' - Pattern: 'sqlcmd.exe' - signConfigType: 'inlineSignParams' - SessionTimeout: '600' - MaxConcurrency: '5' - MaxRetryAttempts: '5' - inlineOperation: | - [ - { - "keyCode": "CP-230012", - "operationSetCode": "SigntoolSign", - "parameters": [ - { - "parameterName": "OpusName", - "parameterValue": "go-sqlcmd" - }, - { - "parameterName": "OpusInfo", - "parameterValue": "https://github.com/microsoft/go-sqlcmd" - }, - { - "parameterName": "PageHash", - "parameterValue": "/NPH" - }, - { - "parameterName": "FileDigest", - "parameterValue": "/fd sha256" - }, - { - "parameterName": "TimeStamp", - "parameterValue": "/tr \"http://rfc3161.gtm.corp.microsoft.com/TSS/HttpTspServer\" /td sha256" - } - ], - "toolName": "signtool.exe", - "toolVersion": "6.2.9304.0" - }, - { - "keyCode": "CP-230012", - "operationSetCode": "SigntoolVerify", - "parameters": [ - { - "parameterName": "VerifyAll", - "parameterValue": "/all" - } - ], - "toolName": "signtool.exe", - "toolVersion": "6.2.9304.0" - } - ] - - task: ArchiveFiles@2 - displayName: Zip Windows amd64 binary - inputs: - rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdWindowsAmd64\Sqlcmd.exe' - includeRootFolder: false - archiveType: 'zip' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-windows-x64.zip' - - - task: ArchiveFiles@2 - displayName: Zip Windows arm binary - inputs: - rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdWindowsArm\Sqlcmd.exe' - includeRootFolder: false - archiveType: 'zip' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-windows-arm.zip' - - - task: ArchiveFiles@2 - displayName: Tar Linux amd64 binary - inputs: - rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdLinuxAmd64' - includeRootFolder: false - archiveType: 'tar' - tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-linux-x64.tar.bz2' - - - task: ArchiveFiles@2 - displayName: Tar Darwin binary - inputs: - rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdDarwinAmd64' - includeRootFolder: false - archiveType: 'tar' - tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-darwin-x64.tar.bz2' - - - task: ArchiveFiles@2 - displayName: Tar Linux arm64 binary - inputs: - rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdLinuxArm64' - includeRootFolder: false - archiveType: 'tar' - tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-linux-arm64.tar.bz2' - - - task: ArchiveFiles@2 - displayName: Tar Linux s390x binary - inputs: - rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdLinuxS390x' - includeRootFolder: false - archiveType: 'tar' - tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-linux-s390x.tar.bz2' - - - task: PublishPipelineArtifact@1 - displayName: 'Publish release archives' - inputs: - targetPath: $(Build.ArtifactStagingDirectory) - artifactName: SqlcmdRelease - - - task: GitHubRelease@1 - condition: eq('${{ parameters.PushToGithub}}', 'true') - inputs: - gitHubConnection: 'gosqlcmd_github' - repositoryName: '$(Build.Repository.Name)' - action: 'create' - target: '$(Build.SourceVersion)' - tagSource: 'userSpecifiedTag' - tag: '$(getVersion.VERSION_TAG)' - changeLogCompareToRelease: 'lastFullRelease' - changeLogType: 'commitBased' +trigger: + tags: + include: + - v* + +pr: none + +parameters: + - name: PushToGithub + default: true + type: boolean + displayName: Push packages to github + +stages: + - stage: Compile + displayName: Compile sqlcmd on all supported platforms + jobs: + - job: Sqlcmd + strategy: + matrix: + linux: + imageName: 'ubuntu-latest' + artifact: LinuxAmd64 + os: + arch: + mac: + imageName: 'macOS-latest' + artifact: DarwinAmd64 + os: + arch: + windows: + imageName: 'windows-latest' + artifact: WindowsAmd64 + os: + arch: + linuxArm: + imageName: 'ubuntu-latest' + artifact: LinuxArm64 + os: + arch: arm64 + windowsArm: + imageName: 'windows-latest' + artifact: WindowsArm + os: + arch: arm + linuxs390x: + imageName: 'ubuntu-latest' + artifact: LinuxS390x + os: + arch: s390x + pool: + vmImage: $(imageName) + steps: + - template: build-tag.yml + - script: | + echo $(getVersion.VERSION_TAG) + - template: build-common.yml + parameters: + OS: $(os) + Arch: $(arch) + ArtifactName: $(artifact) + VersionTag: $(getVersion.VERSION_TAG) + + - stage: CreatePackages + displayName: Create packages to publish + jobs: + - job: Sign_and_pack + pool: + vmImage: 'windows-latest' + steps: + - template: build-tag.yml + - task: DownloadPipelineArtifact@2 + inputs: + buildType: 'current' + targetPath: '$(Pipeline.Workspace)' + - task: EsrpCodeSigning@1 + displayName: Sign Windows binary + inputs: + ConnectedServiceName: 'Code Signing' + FolderPath: '$(Pipeline.Workspace)' + Pattern: 'sqlcmd.exe' + signConfigType: 'inlineSignParams' + SessionTimeout: '600' + MaxConcurrency: '5' + MaxRetryAttempts: '5' + inlineOperation: | + [ + { + "keyCode": "CP-230012", + "operationSetCode": "SigntoolSign", + "parameters": [ + { + "parameterName": "OpusName", + "parameterValue": "go-sqlcmd" + }, + { + "parameterName": "OpusInfo", + "parameterValue": "https://github.com/microsoft/go-sqlcmd" + }, + { + "parameterName": "PageHash", + "parameterValue": "/NPH" + }, + { + "parameterName": "FileDigest", + "parameterValue": "/fd sha256" + }, + { + "parameterName": "TimeStamp", + "parameterValue": "/tr \"http://rfc3161.gtm.corp.microsoft.com/TSS/HttpTspServer\" /td sha256" + } + ], + "toolName": "signtool.exe", + "toolVersion": "6.2.9304.0" + }, + { + "keyCode": "CP-230012", + "operationSetCode": "SigntoolVerify", + "parameters": [ + { + "parameterName": "VerifyAll", + "parameterValue": "/all" + } + ], + "toolName": "signtool.exe", + "toolVersion": "6.2.9304.0" + } + ] + - task: ArchiveFiles@2 + displayName: Zip Windows amd64 binary + inputs: + rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdWindowsAmd64\Sqlcmd.exe' + includeRootFolder: false + archiveType: 'zip' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-windows-x64.zip' + + - task: ArchiveFiles@2 + displayName: Zip Windows arm binary + inputs: + rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdWindowsArm\Sqlcmd.exe' + includeRootFolder: false + archiveType: 'zip' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-windows-arm.zip' + + - task: ArchiveFiles@2 + displayName: Tar Linux amd64 binary + inputs: + rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdLinuxAmd64' + includeRootFolder: false + archiveType: 'tar' + tarCompression: 'bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-linux-x64.tar.bz2' + + - task: ArchiveFiles@2 + displayName: Tar Darwin binary + inputs: + rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdDarwinAmd64' + includeRootFolder: false + archiveType: 'tar' + tarCompression: 'bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-darwin-x64.tar.bz2' + + - task: ArchiveFiles@2 + displayName: Tar Linux arm64 binary + inputs: + rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdLinuxArm64' + includeRootFolder: false + archiveType: 'tar' + tarCompression: 'bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-linux-arm64.tar.bz2' + + - task: ArchiveFiles@2 + displayName: Tar Linux s390x binary + inputs: + rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdLinuxS390x' + includeRootFolder: false + archiveType: 'tar' + tarCompression: 'bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(getVersion.VERSION_TAG)-linux-s390x.tar.bz2' + + - task: PublishPipelineArtifact@1 + displayName: 'Publish release archives' + inputs: + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: SqlcmdRelease + + - task: GitHubRelease@1 + condition: eq('${{ parameters.PushToGithub}}', 'true') + inputs: + gitHubConnection: 'gosqlcmd_github' + repositoryName: '$(Build.Repository.Name)' + action: 'create' + target: '$(Build.SourceVersion)' + tagSource: 'userSpecifiedTag' + tag: '$(getVersion.VERSION_TAG)' + changeLogCompareToRelease: 'lastFullRelease' + changeLogType: 'commitBased' diff --git a/cmd/sqlcmd/main_test.go b/cmd/sqlcmd/main_test.go index a6aa2a38..6a42ecfe 100644 --- a/cmd/sqlcmd/main_test.go +++ b/cmd/sqlcmd/main_test.go @@ -4,6 +4,7 @@ package main import ( "os" + "path/filepath" "runtime" "strings" "testing" @@ -189,24 +190,24 @@ func TestUnicodeOutput(t *testing.T) { func TestUnicodeInput(t *testing.T) { testfiles := []string{ - `testdata/selectutf8.txt`, - `testdata/selectutf8_bom.txt`, - `testdata/selectunicode_BE.txt`, - `testdata/selectunicode_LE.txt`, + filepath.Join(`testdata`, `selectutf8.txt`), + filepath.Join(`testdata`, `selectutf8_bom.txt`), + filepath.Join(`testdata`, `selectunicode_BE.txt`), + filepath.Join(`testdata`, `selectunicode_LE.txt`), } for _, test := range testfiles { for _, unicodeOutput := range []bool{true, false} { var outfile string if unicodeOutput { - outfile = `testdata/unicodeout_linux.txt` + outfile = filepath.Join(`testdata`, `unicodeout_linux.txt`) if runtime.GOOS == "windows" { - outfile = `testdata/unicodeout.txt` + outfile = filepath.Join(`testdata`, `unicodeout.txt`) } } else { outfile = `testdata/utf8out_linux.txt` if runtime.GOOS == "windows" { - outfile = `testdata/utf8out.txt` + outfile = filepath.Join(`testdata`, `utf8out.txt`) } } o, err := os.CreateTemp("", "sqlcmdmain") @@ -226,10 +227,12 @@ func TestUnicodeInput(t *testing.T) { assert.NoError(t, err, "run") assert.Equal(t, 0, exitCode, "exitCode") bytes, err := os.ReadFile(o.Name()) + s := strings.ReplaceAll(string(bytes), sqlcmd.SqlcmdEol, "\n") // Normalize Eols for cross plat if assert.NoError(t, err, "os.ReadFile") { expectedBytes, err := os.ReadFile(outfile) + expectedS := strings.ReplaceAll(string(expectedBytes), sqlcmd.SqlcmdEol, "\n") // Normalize Eols for cross plat if assert.NoErrorf(t, err, "Unable to open %s", outfile) { - assert.Equalf(t, expectedBytes, bytes, "input file: <%s> output bytes should match <%s>", test, outfile) + assert.Equalf(t, expectedS, s, "input file: <%s> output bytes should match <%s>", test, outfile) } } } @@ -263,8 +266,8 @@ func TestQueryAndExit(t *testing.T) { } // Test to verify fix for issue: https://github.com/microsoft/go-sqlcmd/issues/98 -// 1. Verify when -b is passed in (ExitOnError), we don't always get an error (even when input is good) -// 2, Verify when the input is actually bad, we do get an error +// 1. Verify when -b is passed in (ExitOnError), we don't always get an error (even when input is good) +// 2, Verify when the input is actually bad, we do get an error func TestExitOnError(t *testing.T) { args = newArguments() args.InputFile = []string{"testdata/select100.sql"} @@ -320,7 +323,7 @@ func TestAzureAuth(t *testing.T) { func TestMissingInputFile(t *testing.T) { args = newArguments() - args.InputFile = []string{"testdata/missingFile.sql"} + args.InputFile = []string{filepath.Join("testdata", "missingFile.sql")} if canTestAzureAuth() { args.UseAad = true diff --git a/cmd/sqlcmd/testdata/select100.sql b/cmd/sqlcmd/testdata/select100.sql index 1b87fa39..718c071f 100644 --- a/cmd/sqlcmd/testdata/select100.sql +++ b/cmd/sqlcmd/testdata/select100.sql @@ -1 +1 @@ -select 100 +select 100 diff --git a/pkg/sqlcmd/azure_auth.go b/pkg/sqlcmd/azure_auth.go index 5d924390..f554a181 100644 --- a/pkg/sqlcmd/azure_auth.go +++ b/pkg/sqlcmd/azure_auth.go @@ -1,55 +1,55 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "database/sql/driver" - "fmt" - "net/url" - "os" - - "github.com/microsoft/go-mssqldb/azuread" -) - -const ( - NotSpecified = "NotSpecified" - SqlPassword = "SqlPassword" - sqlClientId = "a94f9c62-97fe-4d19-b06d-472bed8d2bcf" -) - -func getSqlClientId() string { - if clientId := os.Getenv("SQLCMDCLIENTID"); clientId != "" { - return clientId - } - return sqlClientId -} - -func GetTokenBasedConnection(connstr string, authenticationMethod string) (driver.Connector, error) { - - connectionUrl, err := url.Parse(connstr) - if err != nil { - return nil, err - } - - query := connectionUrl.Query() - query.Set("fedauth", authenticationMethod) - query.Set("applicationclientid", getSqlClientId()) - switch authenticationMethod { - case azuread.ActiveDirectoryServicePrincipal, azuread.ActiveDirectoryApplication: - query.Set("clientcertpath", os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH")) - case azuread.ActiveDirectoryInteractive: - loginTimeout := query.Get("connection timeout") - loginTimeoutSeconds := 0 - if loginTimeout != "" { - _, _ = fmt.Sscanf(loginTimeout, "%d", &loginTimeoutSeconds) - } - // AAD interactive needs minutes at minimum - if loginTimeoutSeconds > 0 && loginTimeoutSeconds < 120 { - query.Set("connection timeout", "120") - } - } - - connectionUrl.RawQuery = query.Encode() - return azuread.NewConnector(connectionUrl.String()) -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "database/sql/driver" + "fmt" + "net/url" + "os" + + "github.com/microsoft/go-mssqldb/azuread" +) + +const ( + NotSpecified = "NotSpecified" + SqlPassword = "SqlPassword" + sqlClientId = "a94f9c62-97fe-4d19-b06d-472bed8d2bcf" +) + +func getSqlClientId() string { + if clientId := os.Getenv("SQLCMDCLIENTID"); clientId != "" { + return clientId + } + return sqlClientId +} + +func GetTokenBasedConnection(connstr string, authenticationMethod string) (driver.Connector, error) { + + connectionUrl, err := url.Parse(connstr) + if err != nil { + return nil, err + } + + query := connectionUrl.Query() + query.Set("fedauth", authenticationMethod) + query.Set("applicationclientid", getSqlClientId()) + switch authenticationMethod { + case azuread.ActiveDirectoryServicePrincipal, azuread.ActiveDirectoryApplication: + query.Set("clientcertpath", os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH")) + case azuread.ActiveDirectoryInteractive: + loginTimeout := query.Get("connection timeout") + loginTimeoutSeconds := 0 + if loginTimeout != "" { + _, _ = fmt.Sscanf(loginTimeout, "%d", &loginTimeoutSeconds) + } + // AAD interactive needs minutes at minimum + if loginTimeoutSeconds > 0 && loginTimeoutSeconds < 120 { + query.Set("connection timeout", "120") + } + } + + connectionUrl.RawQuery = query.Encode() + return azuread.NewConnector(connectionUrl.String()) +} diff --git a/pkg/sqlcmd/batch.go b/pkg/sqlcmd/batch.go index 9afaf812..7b8082e5 100644 --- a/pkg/sqlcmd/batch.go +++ b/pkg/sqlcmd/batch.go @@ -1,263 +1,263 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -const minCapIncrease = 512 - -// lineend is the slice to use when appending a line. -var lineend = []rune(SqlcmdEol) - -// Batch provides the query text to run -type Batch struct { - // read provides the next chunk of runes - read batchScan - // Buffer is the current batch text - Buffer []rune - // Length is the length of the statement - Length int - // raw is the unprocessed runes - raw []rune - // rawlen is the number of unprocessed runes - rawlen int - // quote indicates currently processing a quoted string - quote rune - // comment is the state of multi-line comment processing - comment bool - // batchline is the 1-based index of the next line. - // Used for the prompt in interactive mode - batchline int - // linecount is the total number of batch lines processed in the session - linecount uint - // varmap tracks the location of expandable variables for the entire batch - varmap map[int]string - // linevarmap tracks the location of expandable variables on the current line - linevarmap map[int]string - // cmd is the set of Commands available - cmd Commands -} - -type batchScan func() (string, error) - -// NewBatch creates a Batch which converts runes provided by reader into SQL batches -func NewBatch(reader batchScan, cmd Commands) *Batch { - b := &Batch{ - read: reader, - cmd: cmd, - } - b.Reset(nil) - return b -} - -// String returns the current SQL batch text -func (b *Batch) String() string { - return string(b.Buffer) -} - -// Reset clears the current batch text and replaces it with new runes -func (b *Batch) Reset(r []rune) { - b.Buffer, b.Length = nil, 0 - b.quote = 0 - b.comment = false - b.batchline = 1 - if r != nil { - b.raw, b.rawlen = r, len(r) - } else { - b.rawlen = 0 - } - b.varmap = make(map[int]string) -} - -// Next processes the next chunk of input and sets the Batch state accordingly. -// If the input contains a command to run, Next returns the Command and its -// parameters. -// Upon exit from Next, the caller can use the State method to determine if -// it represents a runnable SQL batch text. -func (b *Batch) Next() (*Command, []string, error) { - b.linevarmap = nil - var err error - var i int - if b.rawlen == 0 { - s, err := b.read() - if err != nil { - return nil, nil, err - } - b.raw = []rune(s) - b.rawlen = len(b.raw) - } - - var command *Command - var args []string - var ok bool - var scannedCommand bool - b.linecount++ -parse: - for ; i < b.rawlen; i++ { - c, next := b.raw[i], grab(b.raw, i+1, b.rawlen) - switch { - // we're in a quoted string - case b.quote != 0: - i, ok, err = b.readString(b.raw, i, b.rawlen, b.quote, b.linecount) - if err != nil { - break parse - } - if ok { - b.quote = 0 - } - // inside a multiline comment - case b.comment: - i, ok = readMultilineComment(b.raw, i, b.rawlen) - b.comment = !ok - // start of a string - case c == '\'' || c == '"': - b.quote = c - // inline sql comment, skip to end of line - case c == '-' && next == '-': - i = b.rawlen - // start a multi-line comment - case c == '/' && next == '*': - b.comment = true - i++ - // continue processing quoted string or multiline comment - case b.quote != 0 || b.comment: - - // Handle variable references - case c == '$' && next == '(': - vi, ok := readVariableReference(b.raw, i+2, b.rawlen) - if ok { - b.addVariableLocation(i, string(b.raw[i+2:vi])) - i = vi - - } else { - err = syntaxError(b.linecount) - break parse - } - // Commands have to be alone on the line - case !scannedCommand && b.cmd != nil: - var cend int - scannedCommand = true - command, args, cend = readCommand(b.cmd, b.raw, i, b.rawlen) - if command != nil { - // remove the command from raw - b.raw = append(b.raw[:i], b.raw[cend:]...) - break parse - } - } - } - if err == nil { - i = min(i, b.rawlen) - empty := isEmptyLine(b.raw, 0, i) - appendLine := true - if !b.comment && command != nil && empty { - appendLine = false - } - if appendLine { - // any variables on the line need to be added to the global map - inc := 0 - if b.Length > 0 { - inc = len(lineend) - } - if b.linevarmap != nil { - for v := range b.linevarmap { - b.varmap[v+b.Length+inc] = b.linevarmap[v] - } - } - // log.Printf(">> appending: `%s`", string(r[st:i])) - b.append(b.raw[:i], lineend) - b.batchline++ - } - b.raw = b.raw[i:] - b.rawlen = len(b.raw) - } else { - b.Reset(nil) - } - return command, args, err -} - -// append appends r to b.Buffer separated by sep when b.Buffer is not already empty. -// -// Dynamically grows b.Buf as necessary to accommodate r and the separator. -// Specifically, when b.Buf is not empty, b.Buf will grow by increments of -// MinCapIncrease. -// -// After a call to append, b.Len will be len(b.Buf)+len(sep)+len(r). Call Reset -// to reset the Buf. -func (b *Batch) append(r, sep []rune) { - rlen := len(r) - // initial - if b.Buffer == nil { - b.Buffer, b.Length = r, rlen - return - } - blen, seplen := b.Length, len(sep) - tlen := blen + rlen + seplen - // grow - if bcap := cap(b.Buffer); tlen > bcap { - n := tlen + 2*rlen - n += minCapIncrease - (n % minCapIncrease) - z := make([]rune, blen, n) - copy(z, b.Buffer) - b.Buffer = z - } - b.Buffer = b.Buffer[:tlen] - copy(b.Buffer[blen:], sep) - copy(b.Buffer[blen+seplen:], r) - b.Length = tlen -} - -// State returns a string representing the state of statement parsing. -// * Is in the middle of a multi-line comment -// - Has a non-empty batch ready to run -// = Is empty -// ' " Is in the middle of a multi-line quoted string -func (b *Batch) State() string { - switch { - case b.quote != 0: - return string(b.quote) - case b.comment: - return "*" - case b.Length != 0: - return "-" - } - return "=" -} - -// readString seeks to the end of a string returning the position and whether -// or not the string's end was found. -// -// If the string's terminator was not found, then the result will be the passed -// end. -// An error is returned if the string contains a malformed variable reference -func (b *Batch) readString(r []rune, i, end int, quote rune, line uint) (int, bool, error) { - var prev, c, next rune - for ; i < end; i++ { - c, next = r[i], grab(r, i+1, end) - switch { - case c == '$' && next == '(': - vl, ok := readVariableReference(r, i+2, end) - if ok { - b.addVariableLocation(i, string(r[i+2:vl])) - i = vl - - } else { - return i, false, syntaxError(line) - } - case quote == '\'' && c == '\'' && next == '\'': - i++ - continue - case quote == '\'' && c == '\'' && prev != '\'', - quote == '"' && c == '"': - return i, true, nil - } - prev = c - } - return end, false, nil -} - -// addVariableLocation is called for each variable on the current line -func (b *Batch) addVariableLocation(i int, v string) { - if b.linevarmap == nil { - b.linevarmap = make(map[int]string) - } - b.linevarmap[i] = v -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +const minCapIncrease = 512 + +// lineend is the slice to use when appending a line. +var lineend = []rune(SqlcmdEol) + +// Batch provides the query text to run +type Batch struct { + // read provides the next chunk of runes + read batchScan + // Buffer is the current batch text + Buffer []rune + // Length is the length of the statement + Length int + // raw is the unprocessed runes + raw []rune + // rawlen is the number of unprocessed runes + rawlen int + // quote indicates currently processing a quoted string + quote rune + // comment is the state of multi-line comment processing + comment bool + // batchline is the 1-based index of the next line. + // Used for the prompt in interactive mode + batchline int + // linecount is the total number of batch lines processed in the session + linecount uint + // varmap tracks the location of expandable variables for the entire batch + varmap map[int]string + // linevarmap tracks the location of expandable variables on the current line + linevarmap map[int]string + // cmd is the set of Commands available + cmd Commands +} + +type batchScan func() (string, error) + +// NewBatch creates a Batch which converts runes provided by reader into SQL batches +func NewBatch(reader batchScan, cmd Commands) *Batch { + b := &Batch{ + read: reader, + cmd: cmd, + } + b.Reset(nil) + return b +} + +// String returns the current SQL batch text +func (b *Batch) String() string { + return string(b.Buffer) +} + +// Reset clears the current batch text and replaces it with new runes +func (b *Batch) Reset(r []rune) { + b.Buffer, b.Length = nil, 0 + b.quote = 0 + b.comment = false + b.batchline = 1 + if r != nil { + b.raw, b.rawlen = r, len(r) + } else { + b.rawlen = 0 + } + b.varmap = make(map[int]string) +} + +// Next processes the next chunk of input and sets the Batch state accordingly. +// If the input contains a command to run, Next returns the Command and its +// parameters. +// Upon exit from Next, the caller can use the State method to determine if +// it represents a runnable SQL batch text. +func (b *Batch) Next() (*Command, []string, error) { + b.linevarmap = nil + var err error + var i int + if b.rawlen == 0 { + s, err := b.read() + if err != nil { + return nil, nil, err + } + b.raw = []rune(s) + b.rawlen = len(b.raw) + } + + var command *Command + var args []string + var ok bool + var scannedCommand bool + b.linecount++ +parse: + for ; i < b.rawlen; i++ { + c, next := b.raw[i], grab(b.raw, i+1, b.rawlen) + switch { + // we're in a quoted string + case b.quote != 0: + i, ok, err = b.readString(b.raw, i, b.rawlen, b.quote, b.linecount) + if err != nil { + break parse + } + if ok { + b.quote = 0 + } + // inside a multiline comment + case b.comment: + i, ok = readMultilineComment(b.raw, i, b.rawlen) + b.comment = !ok + // start of a string + case c == '\'' || c == '"': + b.quote = c + // inline sql comment, skip to end of line + case c == '-' && next == '-': + i = b.rawlen + // start a multi-line comment + case c == '/' && next == '*': + b.comment = true + i++ + // continue processing quoted string or multiline comment + case b.quote != 0 || b.comment: + + // Handle variable references + case c == '$' && next == '(': + vi, ok := readVariableReference(b.raw, i+2, b.rawlen) + if ok { + b.addVariableLocation(i, string(b.raw[i+2:vi])) + i = vi + + } else { + err = syntaxError(b.linecount) + break parse + } + // Commands have to be alone on the line + case !scannedCommand && b.cmd != nil: + var cend int + scannedCommand = true + command, args, cend = readCommand(b.cmd, b.raw, i, b.rawlen) + if command != nil { + // remove the command from raw + b.raw = append(b.raw[:i], b.raw[cend:]...) + break parse + } + } + } + if err == nil { + i = min(i, b.rawlen) + empty := isEmptyLine(b.raw, 0, i) + appendLine := true + if !b.comment && command != nil && empty { + appendLine = false + } + if appendLine { + // any variables on the line need to be added to the global map + inc := 0 + if b.Length > 0 { + inc = len(lineend) + } + if b.linevarmap != nil { + for v := range b.linevarmap { + b.varmap[v+b.Length+inc] = b.linevarmap[v] + } + } + // log.Printf(">> appending: `%s`", string(r[st:i])) + b.append(b.raw[:i], lineend) + b.batchline++ + } + b.raw = b.raw[i:] + b.rawlen = len(b.raw) + } else { + b.Reset(nil) + } + return command, args, err +} + +// append appends r to b.Buffer separated by sep when b.Buffer is not already empty. +// +// Dynamically grows b.Buf as necessary to accommodate r and the separator. +// Specifically, when b.Buf is not empty, b.Buf will grow by increments of +// MinCapIncrease. +// +// After a call to append, b.Len will be len(b.Buf)+len(sep)+len(r). Call Reset +// to reset the Buf. +func (b *Batch) append(r, sep []rune) { + rlen := len(r) + // initial + if b.Buffer == nil { + b.Buffer, b.Length = r, rlen + return + } + blen, seplen := b.Length, len(sep) + tlen := blen + rlen + seplen + // grow + if bcap := cap(b.Buffer); tlen > bcap { + n := tlen + 2*rlen + n += minCapIncrease - (n % minCapIncrease) + z := make([]rune, blen, n) + copy(z, b.Buffer) + b.Buffer = z + } + b.Buffer = b.Buffer[:tlen] + copy(b.Buffer[blen:], sep) + copy(b.Buffer[blen+seplen:], r) + b.Length = tlen +} + +// State returns a string representing the state of statement parsing. +// * Is in the middle of a multi-line comment +// - Has a non-empty batch ready to run +// = Is empty +// ' " Is in the middle of a multi-line quoted string +func (b *Batch) State() string { + switch { + case b.quote != 0: + return string(b.quote) + case b.comment: + return "*" + case b.Length != 0: + return "-" + } + return "=" +} + +// readString seeks to the end of a string returning the position and whether +// or not the string's end was found. +// +// If the string's terminator was not found, then the result will be the passed +// end. +// An error is returned if the string contains a malformed variable reference +func (b *Batch) readString(r []rune, i, end int, quote rune, line uint) (int, bool, error) { + var prev, c, next rune + for ; i < end; i++ { + c, next = r[i], grab(r, i+1, end) + switch { + case c == '$' && next == '(': + vl, ok := readVariableReference(r, i+2, end) + if ok { + b.addVariableLocation(i, string(r[i+2:vl])) + i = vl + + } else { + return i, false, syntaxError(line) + } + case quote == '\'' && c == '\'' && next == '\'': + i++ + continue + case quote == '\'' && c == '\'' && prev != '\'', + quote == '"' && c == '"': + return i, true, nil + } + prev = c + } + return end, false, nil +} + +// addVariableLocation is called for each variable on the current line +func (b *Batch) addVariableLocation(i int, v string) { + if b.linevarmap == nil { + b.linevarmap = make(map[int]string) + } + b.linevarmap[i] = v +} diff --git a/pkg/sqlcmd/batch_test.go b/pkg/sqlcmd/batch_test.go index 3128ec77..a323c1fb 100644 --- a/pkg/sqlcmd/batch_test.go +++ b/pkg/sqlcmd/batch_test.go @@ -1,223 +1,223 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "io" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBatchNext(t *testing.T) { - tests := []struct { - s string - stmts []string - cmds []string - state string - }{ - {"", nil, nil, "="}, - {"select 1", []string{"select 1"}, nil, "-"}, - {"select $(x)\nquit", []string{"select $(x)"}, []string{"QUIT"}, "-"}, - {"select '$ (X' \nquite", []string{"select '$ (X' " + SqlcmdEol + "quite"}, nil, "-"}, - {":list\n:reset\n", nil, []string{"LIST", "RESET"}, "="}, - {"select 1\n:list\nselect 2", []string{"select 1" + SqlcmdEol + "select 2"}, []string{"LIST"}, "-"}, - {"select '1\n", []string{"select '1" + SqlcmdEol + ""}, nil, "'"}, - {"select 1 /* comment\nGO", []string{"select 1 /* comment" + SqlcmdEol + "GO"}, nil, "*"}, - {"select '1\n00' \n/* comm\nent*/\nGO 4", []string{"select '1" + SqlcmdEol + "00' " + SqlcmdEol + "/* comm" + SqlcmdEol + "ent*/"}, []string{"GO"}, "-"}, - {"$(x) $(y) 100\nquit", []string{"$(x) $(y) 100"}, []string{"QUIT"}, "-"}, - {"select 1\n:list", []string{"select 1"}, []string{"LIST"}, "-"}, - {"select 1\n:reset", []string{"select 1"}, []string{"RESET"}, "-"}, - {"select 1\n:exit()", []string{"select 1"}, []string{"EXIT"}, "-"}, - {"select 1\n:exit (select 10)", []string{"select 1"}, []string{"EXIT"}, "-"}, - {"select 1\n:exit", []string{"select 1"}, []string{"EXIT"}, "-"}, - } - for _, test := range tests { - b := NewBatch(sp(test.s, "\n"), newCommands()) - var stmts, cmds []string - loop: - for { - cmd, _, err := b.Next() - switch { - case err == io.EOF: - // if we get EOF before a command we will try to run - // whatever is in the buffer - if s := b.String(); s != "" { - stmts = append(stmts, s) - } - break loop - case err != nil: - t.Fatalf("test %s did not expect error, got: %v", test.s, err) - } - if cmd != nil { - cmds = append(cmds, cmd.name) - } - } - assert.Equal(t, test.stmts, stmts, "Statements for %s", test.s) - assert.Equal(t, test.state, b.State(), "State for %s", test.s) - assert.Equal(t, test.cmds, cmds, "Commands for %s", test.s) - b.Reset(nil) - assert.Zero(t, b.Length, "Length after Reset") - assert.Zero(t, len(b.Buffer), "len(Buffer) after Reset") - assert.Zero(t, b.quote, "quote after Reset") - assert.False(t, b.comment, "comment after Reset") - assert.Equal(t, "=", b.State(), "State() after Reset") - } -} - -func sp(a, sep string) func() (string, error) { - s := strings.Split(a, sep) - return func() (string, error) { - if len(s) > 0 { - z := s[0] - s = s[1:] - return z, nil - } - return "", io.EOF - } -} - -func TestBatchNextErrOnInvalidVariable(t *testing.T) { - tests := []string{ - "select $(x", - "$((x", - "alter $( x)", - } - for _, test := range tests { - b := NewBatch(sp(test, "\n"), newCommands()) - cmd, _, err := b.Next() - assert.Nil(t, cmd, "cmd for "+test) - assert.Equal(t, uint(1), b.linecount, "linecount should increment on a variable syntax error") - assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 1", "expected err for %s", test) - } -} - -func TestReadString(t *testing.T) { - tests := []struct { - // input string - s string - // index to start inside s - i int - // expected return string - exp string - // expected return bool - ok bool - }{ - {`'`, 0, ``, false}, - {` '`, 1, ``, false}, - {`'str' `, 0, `'str'`, true}, - {` 'str' `, 1, `'str'`, true}, - {`"str"`, 0, `"str"`, true}, - {`'str''str'`, 0, `'str''str'`, true}, - {` 'str''str' `, 1, `'str''str'`, true}, - {` "str''str" `, 1, `"str''str"`, true}, - // escaped \" aren't allowed in strings, so the second " would be next - // double quoted string - {`"str\""`, 0, `"str\"`, true}, - {` "str\"" `, 1, `"str\"`, true}, - {`'str\'`, 0, `'str\'`, true}, - {`''''`, 0, `''''`, true}, - {` '''' `, 1, `''''`, true}, - {`''''''`, 0, `''''''`, true}, - {` '''''' `, 1, `''''''`, true}, - {`'''`, 0, ``, false}, - {` ''' `, 1, ``, false}, - {`'''''`, 0, ``, false}, - {` ''''' `, 1, ``, false}, - {`"st'r"`, 0, `"st'r"`, true}, - {` "st'r" `, 1, `"st'r"`, true}, - {`"st''r"`, 0, `"st''r"`, true}, - {` "st''r" `, 1, `"st''r"`, true}, - {`'$(v)'`, 0, `'$(v)'`, true}, - {`'var $(var1) var2 $(var2)'`, 0, `'var $(var1) var2 $(var2)'`, true}, - {`'var $(var1) $`, 0, `'var $(var1) $`, false}, - } - b := NewBatch(nil, newCommands()) - - for _, test := range tests { - r := []rune(test.s) - c, end := rune(strings.TrimSpace(test.s)[0]), len(r) - if c != '\'' && c != '"' { - t.Fatalf("test %+v incorrect!", test) - } - pos, ok, err := b.readString(r, test.i+1, end, c, uint(0)) - assert.NoErrorf(t, err, "should be no error for %s", test) - assert.Equal(t, test.ok, ok, "test %+v ok", test) - if !ok { - continue - } - assert.Equal(t, c, r[pos], "test %+v last character") - v := string(r[test.i : pos+1]) - assert.Equal(t, test.exp, v, "test %+v returned string", test) - } -} - -func TestReadStringMalformedVariable(t *testing.T) { - tests := []string{ - "'select $(x'", - "' $((x'", - "'alter $( x)", - } - b := NewBatch(nil, newCommands()) - for _, test := range tests { - r := []rune(test) - _, ok, err := b.readString(r, 1, len(test), '\'', 10) - assert.Falsef(t, ok, "ok for %s", test) - assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 10", "expected err for %s", test) - } -} - -func TestReadStringVarmap(t *testing.T) { - type mapTest struct { - s string - m map[int]string - } - tests := []mapTest{ - {`'var $(var1) var2 $(var2)'`, map[int]string{5: "var1", 18: "var2"}}, - {`'var $(va_1) var2 $(va-2)'`, map[int]string{5: "va_1", 18: "va-2"}}, - } - for _, test := range tests { - b := NewBatch(nil, newCommands()) - b.linevarmap = make(map[int]string) - i, ok, err := b.readString([]rune(test.s), 1, len(test.s), '\'', 0) - assert.Truef(t, ok, "ok returned by readString for %s", test.s) - assert.NoErrorf(t, err, "readString for %s", test.s) - assert.Equal(t, len(test.s)-1, i, "index returned by readString for %s", test.s) - assert.Equalf(t, test.m, b.linevarmap, "linevarmap after readString %s", test.s) - } -} - -func TestBatchNextVarMap(t *testing.T) { - type mapTest struct { - s string - m map[int]string - } - tests := []mapTest{ - {"'var $(var1)\nvar2 $(var2)\n'", map[int]string{5: "var1", 17 + len(SqlcmdEol): "var2"}}, - {"$(var1) select $(var2)\nselect 100\nselect '$(var3)'", map[int]string{ - 0: "var1", - 15: "var2", - 40 + 2*len(SqlcmdEol): "var3"}, - }, - } -loop: - for _, test := range tests { - var err error - b := NewBatch(sp(test.s, "\n"), newCommands()) - for { - _, _, err = b.Next() - if err == io.EOF { - assert.Equalf(t, test.m, b.varmap, "varmap after Next %s. Batch:%s", test.s, escapeeol(b.String())) - break loop - } else { - assert.NoErrorf(t, err, "Should have no error from Next") - } - } - } -} - -func escapeeol(s string) string { - return strings.Replace(strings.Replace(s, "\n", `\n`, -1), "\r", `\r`, -1) -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBatchNext(t *testing.T) { + tests := []struct { + s string + stmts []string + cmds []string + state string + }{ + {"", nil, nil, "="}, + {"select 1", []string{"select 1"}, nil, "-"}, + {"select $(x)\nquit", []string{"select $(x)"}, []string{"QUIT"}, "-"}, + {"select '$ (X' \nquite", []string{"select '$ (X' " + SqlcmdEol + "quite"}, nil, "-"}, + {":list\n:reset\n", nil, []string{"LIST", "RESET"}, "="}, + {"select 1\n:list\nselect 2", []string{"select 1" + SqlcmdEol + "select 2"}, []string{"LIST"}, "-"}, + {"select '1\n", []string{"select '1" + SqlcmdEol + ""}, nil, "'"}, + {"select 1 /* comment\nGO", []string{"select 1 /* comment" + SqlcmdEol + "GO"}, nil, "*"}, + {"select '1\n00' \n/* comm\nent*/\nGO 4", []string{"select '1" + SqlcmdEol + "00' " + SqlcmdEol + "/* comm" + SqlcmdEol + "ent*/"}, []string{"GO"}, "-"}, + {"$(x) $(y) 100\nquit", []string{"$(x) $(y) 100"}, []string{"QUIT"}, "-"}, + {"select 1\n:list", []string{"select 1"}, []string{"LIST"}, "-"}, + {"select 1\n:reset", []string{"select 1"}, []string{"RESET"}, "-"}, + {"select 1\n:exit()", []string{"select 1"}, []string{"EXIT"}, "-"}, + {"select 1\n:exit (select 10)", []string{"select 1"}, []string{"EXIT"}, "-"}, + {"select 1\n:exit", []string{"select 1"}, []string{"EXIT"}, "-"}, + } + for _, test := range tests { + b := NewBatch(sp(test.s, "\n"), newCommands()) + var stmts, cmds []string + loop: + for { + cmd, _, err := b.Next() + switch { + case err == io.EOF: + // if we get EOF before a command we will try to run + // whatever is in the buffer + if s := b.String(); s != "" { + stmts = append(stmts, s) + } + break loop + case err != nil: + t.Fatalf("test %s did not expect error, got: %v", test.s, err) + } + if cmd != nil { + cmds = append(cmds, cmd.name) + } + } + assert.Equal(t, test.stmts, stmts, "Statements for %s", test.s) + assert.Equal(t, test.state, b.State(), "State for %s", test.s) + assert.Equal(t, test.cmds, cmds, "Commands for %s", test.s) + b.Reset(nil) + assert.Zero(t, b.Length, "Length after Reset") + assert.Zero(t, len(b.Buffer), "len(Buffer) after Reset") + assert.Zero(t, b.quote, "quote after Reset") + assert.False(t, b.comment, "comment after Reset") + assert.Equal(t, "=", b.State(), "State() after Reset") + } +} + +func sp(a, sep string) func() (string, error) { + s := strings.Split(a, sep) + return func() (string, error) { + if len(s) > 0 { + z := s[0] + s = s[1:] + return z, nil + } + return "", io.EOF + } +} + +func TestBatchNextErrOnInvalidVariable(t *testing.T) { + tests := []string{ + "select $(x", + "$((x", + "alter $( x)", + } + for _, test := range tests { + b := NewBatch(sp(test, "\n"), newCommands()) + cmd, _, err := b.Next() + assert.Nil(t, cmd, "cmd for "+test) + assert.Equal(t, uint(1), b.linecount, "linecount should increment on a variable syntax error") + assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 1", "expected err for %s", test) + } +} + +func TestReadString(t *testing.T) { + tests := []struct { + // input string + s string + // index to start inside s + i int + // expected return string + exp string + // expected return bool + ok bool + }{ + {`'`, 0, ``, false}, + {` '`, 1, ``, false}, + {`'str' `, 0, `'str'`, true}, + {` 'str' `, 1, `'str'`, true}, + {`"str"`, 0, `"str"`, true}, + {`'str''str'`, 0, `'str''str'`, true}, + {` 'str''str' `, 1, `'str''str'`, true}, + {` "str''str" `, 1, `"str''str"`, true}, + // escaped \" aren't allowed in strings, so the second " would be next + // double quoted string + {`"str\""`, 0, `"str\"`, true}, + {` "str\"" `, 1, `"str\"`, true}, + {`'str\'`, 0, `'str\'`, true}, + {`''''`, 0, `''''`, true}, + {` '''' `, 1, `''''`, true}, + {`''''''`, 0, `''''''`, true}, + {` '''''' `, 1, `''''''`, true}, + {`'''`, 0, ``, false}, + {` ''' `, 1, ``, false}, + {`'''''`, 0, ``, false}, + {` ''''' `, 1, ``, false}, + {`"st'r"`, 0, `"st'r"`, true}, + {` "st'r" `, 1, `"st'r"`, true}, + {`"st''r"`, 0, `"st''r"`, true}, + {` "st''r" `, 1, `"st''r"`, true}, + {`'$(v)'`, 0, `'$(v)'`, true}, + {`'var $(var1) var2 $(var2)'`, 0, `'var $(var1) var2 $(var2)'`, true}, + {`'var $(var1) $`, 0, `'var $(var1) $`, false}, + } + b := NewBatch(nil, newCommands()) + + for _, test := range tests { + r := []rune(test.s) + c, end := rune(strings.TrimSpace(test.s)[0]), len(r) + if c != '\'' && c != '"' { + t.Fatalf("test %+v incorrect!", test) + } + pos, ok, err := b.readString(r, test.i+1, end, c, uint(0)) + assert.NoErrorf(t, err, "should be no error for %s", test) + assert.Equal(t, test.ok, ok, "test %+v ok", test) + if !ok { + continue + } + assert.Equal(t, c, r[pos], "test %+v last character") + v := string(r[test.i : pos+1]) + assert.Equal(t, test.exp, v, "test %+v returned string", test) + } +} + +func TestReadStringMalformedVariable(t *testing.T) { + tests := []string{ + "'select $(x'", + "' $((x'", + "'alter $( x)", + } + b := NewBatch(nil, newCommands()) + for _, test := range tests { + r := []rune(test) + _, ok, err := b.readString(r, 1, len(test), '\'', 10) + assert.Falsef(t, ok, "ok for %s", test) + assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 10", "expected err for %s", test) + } +} + +func TestReadStringVarmap(t *testing.T) { + type mapTest struct { + s string + m map[int]string + } + tests := []mapTest{ + {`'var $(var1) var2 $(var2)'`, map[int]string{5: "var1", 18: "var2"}}, + {`'var $(va_1) var2 $(va-2)'`, map[int]string{5: "va_1", 18: "va-2"}}, + } + for _, test := range tests { + b := NewBatch(nil, newCommands()) + b.linevarmap = make(map[int]string) + i, ok, err := b.readString([]rune(test.s), 1, len(test.s), '\'', 0) + assert.Truef(t, ok, "ok returned by readString for %s", test.s) + assert.NoErrorf(t, err, "readString for %s", test.s) + assert.Equal(t, len(test.s)-1, i, "index returned by readString for %s", test.s) + assert.Equalf(t, test.m, b.linevarmap, "linevarmap after readString %s", test.s) + } +} + +func TestBatchNextVarMap(t *testing.T) { + type mapTest struct { + s string + m map[int]string + } + tests := []mapTest{ + {"'var $(var1)\nvar2 $(var2)\n'", map[int]string{5: "var1", 17 + len(SqlcmdEol): "var2"}}, + {"$(var1) select $(var2)\nselect 100\nselect '$(var3)'", map[int]string{ + 0: "var1", + 15: "var2", + 40 + 2*len(SqlcmdEol): "var3"}, + }, + } +loop: + for _, test := range tests { + var err error + b := NewBatch(sp(test.s, "\n"), newCommands()) + for { + _, _, err = b.Next() + if err == io.EOF { + assert.Equalf(t, test.m, b.varmap, "varmap after Next %s. Batch:%s", test.s, escapeeol(b.String())) + break loop + } else { + assert.NoErrorf(t, err, "Should have no error from Next") + } + } + } +} + +func escapeeol(s string) string { + return strings.Replace(strings.Replace(s, "\n", `\n`, -1), "\r", `\r`, -1) +} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 2a4c1b5b..9d2b3926 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -1,507 +1,507 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "fmt" - "os" - "regexp" - "sort" - "strconv" - "strings" - - "github.com/alecthomas/kong" - "golang.org/x/text/encoding/unicode" - "golang.org/x/text/transform" -) - -// Command defines a sqlcmd action which can be intermixed with the SQL batch -// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands -type Command struct { - // regex must include at least one group if it has parameters - // Will be matched using FindStringSubmatch - regex *regexp.Regexp - // The function that implements the command. Third parameter is the line number - action func(*Sqlcmd, []string, uint) error - // Name of the command - name string - // whether the command is a system command - isSystem bool -} - -// Commands is the set of sqlcmd command implementations -type Commands map[string]*Command - -func newCommands() Commands { - // Commands is the set of Command implementations - return map[string]*Command{ - "EXIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT(?:[ \t]*(\(?.*\)?$)|$)`), - action: exitCommand, - name: "EXIT", - }, - "QUIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), - action: quitCommand, - name: "QUIT", - }, - "GO": { - regex: regexp.MustCompile(batchTerminatorRegex("GO")), - action: goCommand, - name: "GO", - }, - "OUT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), - action: outCommand, - name: "OUT", - }, - "ERROR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), - action: errorCommand, - name: "ERROR", - }, "READFILE": { - regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), - action: readFileCommand, - name: "READFILE", - }, - "SETVAR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), - action: setVarCommand, - name: "SETVAR", - }, - "LISTVAR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), - action: listVarCommand, - name: "LISTVAR", - }, - "RESET": { - regex: regexp.MustCompile(`(?im)^[ \t]*:RESET(?:[ \t]+(.*$)|$)`), - action: resetCommand, - name: "RESET", - }, - "LIST": { - regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), - action: listCommand, - name: "LIST", - }, - "CONNECT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), - action: connectCommand, - name: "CONNECT", - }, - "EXEC": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), - action: execCommand, - name: "EXEC", - isSystem: true, - }, - "EDIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), - action: editCommand, - name: "EDIT", - isSystem: true, - }, - } -} - -// DisableSysCommands disables the ED and :!! commands. -// When exitOnCall is true, running those commands will exit the process. -func (c Commands) DisableSysCommands(exitOnCall bool) { - f := warnDisabled - if exitOnCall { - f = errorDisabled - } - for _, cmd := range c { - if cmd.isSystem { - cmd.action = f - } - } -} - -func (c Commands) matchCommand(line string) (*Command, []string) { - for _, cmd := range c { - matchedCommand := cmd.regex.FindStringSubmatch(line) - if matchedCommand != nil { - return cmd, matchedCommand[1:] - } - } - return nil, nil -} - -func warnDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - return nil -} - -func errorDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - s.Exitcode = 1 - return ErrExitRequested -} - -func batchTerminatorRegex(terminator string) string { - return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) -} - -// SetBatchTerminator attempts to set the batch terminator to the given value -// Returns an error if the new value is not usable in the regex -func (c Commands) SetBatchTerminator(terminator string) error { - cmd := c["GO"] - regex, err := regexp.Compile(batchTerminatorRegex(terminator)) - if err != nil { - return err - } - cmd.regex = regex - return nil -} - -// exitCommand has 3 modes. -// With no (), it just exits without running any query -// With () it runs whatever batch is in the buffer then exits -// With any text between () it runs the text as a query then exits -func exitCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return ErrExitRequested - } - params := strings.TrimSpace(args[0]) - if params == "" { - return ErrExitRequested - } - if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { - return InvalidCommandError("EXIT", line) - } - // First we run the current batch - query := s.batch.String() - if query != "" { - query = s.getRunnableQuery(query) - if exitCode, err := s.runQuery(query); err != nil { - s.Exitcode = exitCode - return ErrExitRequested - } - } - query = strings.TrimSpace(params[1 : len(params)-1]) - s.batch.Reset([]rune(query)) - _, _, err := s.batch.Next() - if err != nil { - return err - } - query = s.batch.String() - if s.batch.String() != "" { - query = s.getRunnableQuery(query) - s.Exitcode, _ = s.runQuery(query) - } - return ErrExitRequested -} - -// quitCommand immediately exits the program without running any more batches -func quitCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("QUIT", line) - } - return ErrExitRequested -} - -// goCommand runs the current batch the number of times specified -func goCommand(s *Sqlcmd, args []string, line uint) error { - // default to 1 execution - n := 1 - var err error - if len(args) > 0 { - cnt := strings.TrimSpace(args[0]) - if cnt != "" { - if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { - return err - } - _, err = fmt.Sscanf(cnt, "%d", &n) - } - } - if err != nil || n < 1 { - return InvalidCommandError("GO", line) - } - query := s.batch.String() - if query == "" { - return nil - } - query = s.getRunnableQuery(query) - for i := 0; i < n; i++ { - if retcode, err := s.runQuery(query); err != nil { - s.Exitcode = retcode - return err - } - } - s.batch.Reset(nil) - return nil -} - -// outCommand changes the output writer to use a file -func outCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) - } - switch { - case strings.EqualFold(args[0], "stdout"): - s.SetOutput(os.Stdout) - case strings.EqualFold(args[0], "stderr"): - s.SetOutput(os.Stderr) - default: - o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - if s.UnicodeOutputFile { - // ODBC sqlcmd doesn't write a BOM but we will. - // Maybe the endian-ness should be configurable. - win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) - encoder := transform.NewWriter(o, win16le.NewEncoder()) - s.SetOutput(encoder) - } else { - s.SetOutput(o) - } - } - return nil -} - -// errorCommand changes the error writer to use a file -func errorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) - } - switch { - case strings.EqualFold(args[0], "stderr"): - s.SetError(os.Stderr) - case strings.EqualFold(args[0], "stdout"): - s.SetError(os.Stdout) - default: - o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - s.SetError(o) - } - return nil -} - -func readFileCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 { - return InvalidCommandError(":R", line) - } - fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) - return s.IncludeFile(fileName, false) -} - -// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables -func setVarCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 || args[0] == "" { - return InvalidCommandError(":SETVAR", line) - } - - varname := args[0] - val := "" - // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value - // in some very unexpected cases. This version will require the space. - sp := strings.IndexRune(args[0], ' ') - if sp > -1 { - val = strings.TrimSpace(varname[sp:]) - varname = varname[:sp] - } - if err := s.vars.Setvar(varname, val); err != nil { - switch e := err.(type) { - case *VariableError: - return e - default: - return InvalidCommandError(":SETVAR", line) - } - } - return nil -} - -// listVarCommand prints the set of Sqlcmd scripting variables. -// Builtin values are printed first, followed by user-set values in sorted order. -func listVarCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("LISTVAR", line) - } - - vars := s.vars.All() - keys := make([]string, 0, len(vars)) - for k := range vars { - if !contains(builtinVariables, k) { - keys = append(keys, k) - } - } - sort.Strings(keys) - keys = append(builtinVariables, keys...) - for _, k := range keys { - fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) - } - return nil -} - -// resetCommand resets the statement cache -func resetCommand(s *Sqlcmd, args []string, line uint) error { - if s.batch != nil { - s.batch.Reset(nil) - } - - return nil -} - -// listCommand displays statements currently in the statement cache -func listCommand(s *Sqlcmd, args []string, line uint) error { - if s.batch != nil && s.batch.String() != "" { - fmt.Fprintf(s.GetOutput(), `%s%s`, []byte(s.batch.String()), SqlcmdEol) - } - - return nil -} - -type connectData struct { - Server string `arg:""` - Database string `short:"D"` - Username string `short:"U"` - Password string `short:"P"` - LoginTimeout string `short:"l"` - AuthenticationMethod string `short:"G"` -} - -func connectCommand(s *Sqlcmd, args []string, line uint) error { - - if len(args) == 0 { - return InvalidCommandError("CONNECT", line) - } - cmdLine := strings.TrimSpace(args[0]) - if cmdLine == "" { - return InvalidCommandError("CONNECT", line) - } - arguments := &connectData{} - parser, err := kong.New(arguments) - if err != nil { - return InvalidCommandError("CONNECT", line) - } - - // Fields removes extra whitespace. - // Note :connect doesn't support passwords with spaces - if _, err = parser.Parse(strings.Fields(cmdLine)); err != nil { - return InvalidCommandError("CONNECT", line) - } - - connect := *s.Connect - connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false) - connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false) - connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false) - timeout, _ := resolveArgumentVariables(s, []rune(arguments.LoginTimeout), false) - if timeout != "" { - if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { - if timeoutSeconds < 0 { - return InvalidCommandError("CONNECT", line) - } - connect.LoginTimeoutSeconds = int(timeoutSeconds) - } - } - connect.AuthenticationMethod = arguments.AuthenticationMethod - // If no user name is provided we switch to integrated auth - _ = s.ConnectDb(&connect, s.lineIo == nil) - // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option - return nil -} - -func execCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("EXEC", line) - } - cmdLine := strings.TrimSpace(args[0]) - if cmdLine == "" { - return InvalidCommandError("EXEC", line) - } - if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { - return err - } else { - cmd := sysCommand(cmdLine) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - _ = cmd.Run() - } - return nil -} - -func editCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("ED", line) - } - file, err := os.CreateTemp("", "sq*.sql") - if err != nil { - return err - } - fileName := file.Name() - defer os.Remove(fileName) - text := s.batch.String() - if s.batch.State() == "-" { - text = fmt.Sprintf("%s%s", text, SqlcmdEol) - } - _, err = file.WriteString(text) - if err != nil { - return err - } - file.Close() - cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - err = cmd.Run() - if err != nil { - return err - } - wasEcho := s.echoFileLines - s.echoFileLines = true - s.batch.Reset(nil) - _ = s.IncludeFile(fileName, false) - s.echoFileLines = wasEcho - return nil -} - -func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { - var b *strings.Builder - end := len(arg) - for i := 0; i < end; { - c, next := arg[i], grab(arg, i+1, end) - switch { - case c == '$' && next == '(': - vl, ok := readVariableReference(arg, i+2, end) - if ok { - varName := string(arg[i+2 : vl]) - val, ok := s.resolveVariable(varName) - if ok { - if b == nil { - b = new(strings.Builder) - b.Grow(len(arg)) - b.WriteString(string(arg[0:i])) - } - b.WriteString(val) - } else { - if failOnUnresolved { - return "", UndefinedVariable(varName) - } - s.WriteError(s.GetError(), UndefinedVariable(varName)) - if b != nil { - b.WriteString(string(arg[i : vl+1])) - } - } - i += ((vl - i) + 1) - } else { - if b != nil { - b.WriteString("$(") - } - i += 2 - } - default: - if b != nil { - b.WriteRune(c) - } - i++ - } - } - if b == nil { - return string(arg), nil - } - return b.String(), nil -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/alecthomas/kong" + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +// Command defines a sqlcmd action which can be intermixed with the SQL batch +// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands +type Command struct { + // regex must include at least one group if it has parameters + // Will be matched using FindStringSubmatch + regex *regexp.Regexp + // The function that implements the command. Third parameter is the line number + action func(*Sqlcmd, []string, uint) error + // Name of the command + name string + // whether the command is a system command + isSystem bool +} + +// Commands is the set of sqlcmd command implementations +type Commands map[string]*Command + +func newCommands() Commands { + // Commands is the set of Command implementations + return map[string]*Command{ + "EXIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT(?:[ \t]*(\(?.*\)?$)|$)`), + action: exitCommand, + name: "EXIT", + }, + "QUIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), + action: quitCommand, + name: "QUIT", + }, + "GO": { + regex: regexp.MustCompile(batchTerminatorRegex("GO")), + action: goCommand, + name: "GO", + }, + "OUT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), + action: outCommand, + name: "OUT", + }, + "ERROR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), + action: errorCommand, + name: "ERROR", + }, "READFILE": { + regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), + action: readFileCommand, + name: "READFILE", + }, + "SETVAR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), + action: setVarCommand, + name: "SETVAR", + }, + "LISTVAR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), + action: listVarCommand, + name: "LISTVAR", + }, + "RESET": { + regex: regexp.MustCompile(`(?im)^[ \t]*:RESET(?:[ \t]+(.*$)|$)`), + action: resetCommand, + name: "RESET", + }, + "LIST": { + regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), + action: listCommand, + name: "LIST", + }, + "CONNECT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), + action: connectCommand, + name: "CONNECT", + }, + "EXEC": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), + action: execCommand, + name: "EXEC", + isSystem: true, + }, + "EDIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), + action: editCommand, + name: "EDIT", + isSystem: true, + }, + } +} + +// DisableSysCommands disables the ED and :!! commands. +// When exitOnCall is true, running those commands will exit the process. +func (c Commands) DisableSysCommands(exitOnCall bool) { + f := warnDisabled + if exitOnCall { + f = errorDisabled + } + for _, cmd := range c { + if cmd.isSystem { + cmd.action = f + } + } +} + +func (c Commands) matchCommand(line string) (*Command, []string) { + for _, cmd := range c { + matchedCommand := cmd.regex.FindStringSubmatch(line) + if matchedCommand != nil { + return cmd, matchedCommand[1:] + } + } + return nil, nil +} + +func warnDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + return nil +} + +func errorDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + s.Exitcode = 1 + return ErrExitRequested +} + +func batchTerminatorRegex(terminator string) string { + return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) +} + +// SetBatchTerminator attempts to set the batch terminator to the given value +// Returns an error if the new value is not usable in the regex +func (c Commands) SetBatchTerminator(terminator string) error { + cmd := c["GO"] + regex, err := regexp.Compile(batchTerminatorRegex(terminator)) + if err != nil { + return err + } + cmd.regex = regex + return nil +} + +// exitCommand has 3 modes. +// With no (), it just exits without running any query +// With () it runs whatever batch is in the buffer then exits +// With any text between () it runs the text as a query then exits +func exitCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return ErrExitRequested + } + params := strings.TrimSpace(args[0]) + if params == "" { + return ErrExitRequested + } + if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { + return InvalidCommandError("EXIT", line) + } + // First we run the current batch + query := s.batch.String() + if query != "" { + query = s.getRunnableQuery(query) + if exitCode, err := s.runQuery(query); err != nil { + s.Exitcode = exitCode + return ErrExitRequested + } + } + query = strings.TrimSpace(params[1 : len(params)-1]) + s.batch.Reset([]rune(query)) + _, _, err := s.batch.Next() + if err != nil { + return err + } + query = s.batch.String() + if s.batch.String() != "" { + query = s.getRunnableQuery(query) + s.Exitcode, _ = s.runQuery(query) + } + return ErrExitRequested +} + +// quitCommand immediately exits the program without running any more batches +func quitCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("QUIT", line) + } + return ErrExitRequested +} + +// goCommand runs the current batch the number of times specified +func goCommand(s *Sqlcmd, args []string, line uint) error { + // default to 1 execution + n := 1 + var err error + if len(args) > 0 { + cnt := strings.TrimSpace(args[0]) + if cnt != "" { + if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { + return err + } + _, err = fmt.Sscanf(cnt, "%d", &n) + } + } + if err != nil || n < 1 { + return InvalidCommandError("GO", line) + } + query := s.batch.String() + if query == "" { + return nil + } + query = s.getRunnableQuery(query) + for i := 0; i < n; i++ { + if retcode, err := s.runQuery(query); err != nil { + s.Exitcode = retcode + return err + } + } + s.batch.Reset(nil) + return nil +} + +// outCommand changes the output writer to use a file +func outCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("OUT", line) + } + switch { + case strings.EqualFold(args[0], "stdout"): + s.SetOutput(os.Stdout) + case strings.EqualFold(args[0], "stderr"): + s.SetOutput(os.Stderr) + default: + o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + if s.UnicodeOutputFile { + // ODBC sqlcmd doesn't write a BOM but we will. + // Maybe the endian-ness should be configurable. + win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) + encoder := transform.NewWriter(o, win16le.NewEncoder()) + s.SetOutput(encoder) + } else { + s.SetOutput(o) + } + } + return nil +} + +// errorCommand changes the error writer to use a file +func errorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("OUT", line) + } + switch { + case strings.EqualFold(args[0], "stderr"): + s.SetError(os.Stderr) + case strings.EqualFold(args[0], "stdout"): + s.SetError(os.Stdout) + default: + o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + s.SetError(o) + } + return nil +} + +func readFileCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 { + return InvalidCommandError(":R", line) + } + fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) + return s.IncludeFile(fileName, false) +} + +// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables +func setVarCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 || args[0] == "" { + return InvalidCommandError(":SETVAR", line) + } + + varname := args[0] + val := "" + // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value + // in some very unexpected cases. This version will require the space. + sp := strings.IndexRune(args[0], ' ') + if sp > -1 { + val = strings.TrimSpace(varname[sp:]) + varname = varname[:sp] + } + if err := s.vars.Setvar(varname, val); err != nil { + switch e := err.(type) { + case *VariableError: + return e + default: + return InvalidCommandError(":SETVAR", line) + } + } + return nil +} + +// listVarCommand prints the set of Sqlcmd scripting variables. +// Builtin values are printed first, followed by user-set values in sorted order. +func listVarCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("LISTVAR", line) + } + + vars := s.vars.All() + keys := make([]string, 0, len(vars)) + for k := range vars { + if !contains(builtinVariables, k) { + keys = append(keys, k) + } + } + sort.Strings(keys) + keys = append(builtinVariables, keys...) + for _, k := range keys { + fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) + } + return nil +} + +// resetCommand resets the statement cache +func resetCommand(s *Sqlcmd, args []string, line uint) error { + if s.batch != nil { + s.batch.Reset(nil) + } + + return nil +} + +// listCommand displays statements currently in the statement cache +func listCommand(s *Sqlcmd, args []string, line uint) error { + if s.batch != nil && s.batch.String() != "" { + fmt.Fprintf(s.GetOutput(), `%s%s`, []byte(s.batch.String()), SqlcmdEol) + } + + return nil +} + +type connectData struct { + Server string `arg:""` + Database string `short:"D"` + Username string `short:"U"` + Password string `short:"P"` + LoginTimeout string `short:"l"` + AuthenticationMethod string `short:"G"` +} + +func connectCommand(s *Sqlcmd, args []string, line uint) error { + + if len(args) == 0 { + return InvalidCommandError("CONNECT", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { + return InvalidCommandError("CONNECT", line) + } + arguments := &connectData{} + parser, err := kong.New(arguments) + if err != nil { + return InvalidCommandError("CONNECT", line) + } + + // Fields removes extra whitespace. + // Note :connect doesn't support passwords with spaces + if _, err = parser.Parse(strings.Fields(cmdLine)); err != nil { + return InvalidCommandError("CONNECT", line) + } + + connect := *s.Connect + connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false) + connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false) + connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false) + timeout, _ := resolveArgumentVariables(s, []rune(arguments.LoginTimeout), false) + if timeout != "" { + if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { + if timeoutSeconds < 0 { + return InvalidCommandError("CONNECT", line) + } + connect.LoginTimeoutSeconds = int(timeoutSeconds) + } + } + connect.AuthenticationMethod = arguments.AuthenticationMethod + // If no user name is provided we switch to integrated auth + _ = s.ConnectDb(&connect, s.lineIo == nil) + // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option + return nil +} + +func execCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("EXEC", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { + return InvalidCommandError("EXEC", line) + } + if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { + return err + } else { + cmd := sysCommand(cmdLine) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + _ = cmd.Run() + } + return nil +} + +func editCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("ED", line) + } + file, err := os.CreateTemp("", "sq*.sql") + if err != nil { + return err + } + fileName := file.Name() + defer os.Remove(fileName) + text := s.batch.String() + if s.batch.State() == "-" { + text = fmt.Sprintf("%s%s", text, SqlcmdEol) + } + _, err = file.WriteString(text) + if err != nil { + return err + } + file.Close() + cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + err = cmd.Run() + if err != nil { + return err + } + wasEcho := s.echoFileLines + s.echoFileLines = true + s.batch.Reset(nil) + _ = s.IncludeFile(fileName, false) + s.echoFileLines = wasEcho + return nil +} + +func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { + var b *strings.Builder + end := len(arg) + for i := 0; i < end; { + c, next := arg[i], grab(arg, i+1, end) + switch { + case c == '$' && next == '(': + vl, ok := readVariableReference(arg, i+2, end) + if ok { + varName := string(arg[i+2 : vl]) + val, ok := s.resolveVariable(varName) + if ok { + if b == nil { + b = new(strings.Builder) + b.Grow(len(arg)) + b.WriteString(string(arg[0:i])) + } + b.WriteString(val) + } else { + if failOnUnresolved { + return "", UndefinedVariable(varName) + } + s.WriteError(s.GetError(), UndefinedVariable(varName)) + if b != nil { + b.WriteString(string(arg[i : vl+1])) + } + } + i += ((vl - i) + 1) + } else { + if b != nil { + b.WriteString("$(") + } + i += 2 + } + default: + if b != nil { + b.WriteRune(c) + } + i++ + } + } + if b == nil { + return string(arg), nil + } + return b.String(), nil +} diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 584867ba..c5aa4feb 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -1,302 +1,302 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "bytes" - "fmt" - "os" - "strings" - "testing" - - "github.com/microsoft/go-mssqldb/azuread" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestQuitCommand(t *testing.T) { - s := &Sqlcmd{} - err := quitCommand(s, nil, 1) - require.ErrorIs(t, err, ErrExitRequested) - err = quitCommand(s, []string{"extra parameters"}, 2) - require.Error(t, err, "Quit should error out with extra parameters") - assert.NotErrorIs(t, err, ErrExitRequested, "Error with extra arguments") -} - -func TestCommandParsing(t *testing.T) { - type commandTest struct { - line string - cmd string - args []string - } - c := newCommands() - commands := []commandTest{ - {"quite", "", nil}, - {"quit", "QUIT", []string{""}}, - {":QUIT\n", "QUIT", []string{""}}, - {" QUIT \n", "QUIT", []string{""}}, - {"quit extra\n", "QUIT", []string{"extra"}}, - {`:Out c:\folder\file`, "OUT", []string{`c:\folder\file`}}, - {` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}}, - {`:Setvar A1 "some value" `, "SETVAR", []string{`A1 "some value" `}}, - {` :Listvar`, "LISTVAR", []string{""}}, - {`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}}, - {`:EXIT ( )`, "EXIT", []string{"( )"}}, - {`EXIT `, "EXIT", []string{""}}, - {`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}}, - {`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}}, - {`:!! notepad`, "EXEC", []string{" notepad"}}, - {`:!!notepad`, "EXEC", []string{"notepad"}}, - {` !! dir c:\`, "EXEC", []string{` dir c:\`}}, - {`!!dir c:\`, "EXEC", []string{`dir c:\`}}, - } - - for _, test := range commands { - cmd, args := c.matchCommand(test.line) - if test.cmd != "" { - if assert.NotNil(t, cmd, "No command found for `%s`", test.line) { - assert.Equal(t, test.cmd, cmd.name, "Incorrect command for `%s`", test.line) - assert.Equal(t, test.args, args, "Incorrect arguments for `%s`", test.line) - } - } else { - assert.Nil(t, cmd, "Unexpected match for %s", test.line) - } - } -} - -func TestCustomBatchSeparator(t *testing.T) { - c := newCommands() - err := c.SetBatchTerminator("me!") - if assert.NoError(t, err, "SetBatchTerminator should succeed") { - cmd, args := c.matchCommand(" me! 5 \n") - if assert.NotNil(t, cmd, "matchCommand didn't find GO for custom batch separator") { - assert.Equal(t, "GO", cmd.name, "command name") - assert.Equal(t, "5", strings.TrimSpace(args[0]), "go argument") - } - } -} - -func TestVarCommands(t *testing.T) { - vars := InitializeVariables(false) - s := New(nil, "", vars) - buf := &memoryBuffer{buf: new(bytes.Buffer)} - s.SetOutput(buf) - err := setVarCommand(s, []string{"ABC 100"}, 1) - assert.NoError(t, err, "setVarCommand ABC 100") - err = setVarCommand(s, []string{"XYZ 200"}, 2) - assert.NoError(t, err, "setVarCommand XYZ 200") - err = listVarCommand(s, []string{""}, 3) - assert.NoError(t, err, "listVarCommand") - s.SetOutput(nil) - varmap := s.vars.All() - o := buf.buf.String() - t.Logf("Listvar output:\n'%s'", o) - output := strings.Split(o, SqlcmdEol) - for i, v := range builtinVariables { - line := strings.Split(output[i], " = ") - assert.Equalf(t, v, line[0], "unexpected variable printed at index %d", i) - val := strings.Trim(line[1], `"`) - assert.Equalf(t, varmap[v], val, "Unexpected value for variable %s", v) - } - assert.Equalf(t, `ABC = "100"`, output[len(output)-3], "Penultimate non-empty line should be ABC") - assert.Equalf(t, `XYZ = "200"`, output[len(output)-2], "Last non-empty line should be XYZ") - assert.Equalf(t, "", output[len(output)-1], "Last line should be empty") - -} - -// memoryBuffer has both Write and Close methods for use as io.WriteCloser -type memoryBuffer struct { - buf *bytes.Buffer -} - -func (b *memoryBuffer) Write(p []byte) (n int, err error) { - return b.buf.Write(p) -} - -func (b *memoryBuffer) Close() error { - return nil -} - -func TestResetCommand(t *testing.T) { - var err error - - // setup a test sqlcmd - vars := InitializeVariables(false) - s := New(nil, "", vars) - buf := &memoryBuffer{buf: new(bytes.Buffer)} - s.SetOutput(buf) - - // insert a test batch - s.batch.Reset([]rune("select 1")) - _, _, err = s.batch.Next() - assert.NoError(t, err, "Inserting test batch") - assert.Equal(t, s.batch.batchline, int(2), "Batch line updated after test batch insert") - - // execute reset command and validate results - err = resetCommand(s, nil, 1) - assert.Equal(t, s.batch.batchline, int(1), "Batch line not reset properly") - assert.NoError(t, err, "Executing :reset command") -} - -func TestListCommand(t *testing.T) { - var err error - - // setup a test sqlcmd - vars := InitializeVariables(false) - s := New(nil, "", vars) - buf := &memoryBuffer{buf: new(bytes.Buffer)} - s.SetOutput(buf) - - // insert test batch - s.batch.Reset([]rune("select 1")) - _, _, err = s.batch.Next() - assert.NoError(t, err, "Inserting test batch") - - // execute list command and verify results - err = listCommand(s, nil, 1) - assert.NoError(t, err, "Executing :list command") - s.SetOutput(nil) - o := buf.buf.String() - assert.Equal(t, o, "select 1"+SqlcmdEol, ":list output not equal to batch") -} - -func TestConnectCommand(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - prompted := false - s.lineIo = &testConsole{ - OnPasswordPrompt: func(prompt string) ([]byte, error) { - prompted = true - return []byte{}, nil - }, - } - err := connectCommand(s, []string{"someserver -U someuser"}, 1) - assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure") - assert.True(t, prompted, "connectCommand with user name and no password should prompt for password") - assert.NotEqual(t, "someserver", s.Connect.ServerName, "On connection failure, sqlCmd.Connect does not copy inputs") - - err = connectCommand(s, []string{}, 2) - assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error") - c := newConnect(t) - - authenticationMethod := "" - password := "" - username := "" - if canTestAzureAuth() { - authenticationMethod = "-G " + azuread.ActiveDirectoryDefault - } - if c.Password != "" { - password = "-P " + c.Password - } - if c.UserName != "" { - username = "-U " + c.UserName - } - s.vars.Set("servername", c.ServerName) - s.vars.Set("to", "111") - buf.buf.Reset() - err = connectCommand(s, []string{fmt.Sprintf("$(servername) %s %s %s -l $(to)", username, password, authenticationMethod)}, 3) - if assert.NoError(t, err, "connectCommand with valid parameters should not return an error") { - // not using assert to avoid printing passwords in the log - assert.NotContains(t, buf.buf.String(), "$(servername)", "ConnectDB should have succeeded") - if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password || s.Connect.LoginTimeoutSeconds != 111 { - t.Fatalf("After connect, sqlCmd.Connect is not updated %+v", s.Connect) - } - } -} - -func TestErrorCommand(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - file, err := os.CreateTemp("", "sqlcmderr") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(file.Name()) - fileName := file.Name() - _ = file.Close() - err = errorCommand(s, []string{""}, 1) - assert.EqualError(t, err, InvalidCommandError("OUT", 1).Error(), "errorCommand with empty file name") - err = errorCommand(s, []string{fileName}, 1) - assert.NoError(t, err, "errorCommand") - // Only some error kinds go to the error output - err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"}) - assert.NoError(t, err, "runSqlCmd") - s.SetError(nil) - errText, err := os.ReadFile(file.Name()) - if assert.NoError(t, err, "ReadFile") { - assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents") - } -} - -func TestResolveArgumentVariables(t *testing.T) { - type argTest struct { - arg string - val string - err string - } - - args := []argTest{ - {"$(var1)", "var1val", ""}, - {"$(var1", "$(var1", ""}, - {`C:\folder\$(var1)\$(var2)\$(var1)\file.sql`, `C:\folder\var1val\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."}, - {`C:\folder\$(var1\$(var2)\$(var1)\file.sql`, `C:\folder\$(var1\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."}, - } - vars := InitializeVariables(false) - s := New(nil, "", vars) - s.vars.Set("var1", "var1val") - buf := &memoryBuffer{buf: new(bytes.Buffer)} - defer buf.Close() - s.SetError(buf) - for _, test := range args { - actual, _ := resolveArgumentVariables(s, []rune(test.arg), false) - assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg) - assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg) - buf.buf.Reset() - } - actual, err := resolveArgumentVariables(s, []rune("$(var1)$(var2)"), true) - if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") { - assert.Empty(t, actual, "fail on unresolved variable") - } -} - -func TestExecCommand(t *testing.T) { - vars := InitializeVariables(false) - s := New(nil, "", vars) - s.vars.Set("var1", "hello") - buf := &memoryBuffer{buf: new(bytes.Buffer)} - defer buf.Close() - s.SetOutput(buf) - err := execCommand(s, []string{`echo $(var1)`}, 1) - if assert.NoError(t, err, "execCommand with valid arguments") { - assert.Equal(t, buf.buf.String(), "hello"+SqlcmdEol, "echo output should be in sqlcmd output") - } -} - -func TestDisableSysCommandBlocksExec(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.Cmd.DisableSysCommands(false) - c := []string{"set nocount on", ":!! echo hello", "select 100", "go"} - err := runSqlCmd(t, s, c) - if assert.NoError(t, err, ":!! with warning should not raise error") { - assert.Contains(t, buf.buf.String(), ErrCommandsDisabled.Error()+SqlcmdEol+"100"+SqlcmdEol) - assert.Equal(t, 0, s.Exitcode, "ExitCode after warning") - } - buf.buf.Reset() - s.Cmd.DisableSysCommands(true) - err = runSqlCmd(t, s, c) - if assert.NoError(t, err, ":!! with error should not return error") { - assert.Contains(t, buf.buf.String(), ErrCommandsDisabled.Error()+SqlcmdEol) - assert.NotContains(t, buf.buf.String(), "100", "query should not run when syscommand disabled") - assert.Equal(t, 1, s.Exitcode, "ExitCode after error") - } -} - -func TestEditCommand(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.vars.Set(SQLCMDEDITOR, "echo select 5000> ") - c := []string{"set nocount on", "go", "select 100", ":ed", "go"} - err := runSqlCmd(t, s, c) - if assert.NoError(t, err, ":ed should not raise error") { - assert.Equal(t, "1> select 5000"+SqlcmdEol+"5000"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output from query after :ed command") - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "bytes" + "fmt" + "os" + "strings" + "testing" + + "github.com/microsoft/go-mssqldb/azuread" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQuitCommand(t *testing.T) { + s := &Sqlcmd{} + err := quitCommand(s, nil, 1) + require.ErrorIs(t, err, ErrExitRequested) + err = quitCommand(s, []string{"extra parameters"}, 2) + require.Error(t, err, "Quit should error out with extra parameters") + assert.NotErrorIs(t, err, ErrExitRequested, "Error with extra arguments") +} + +func TestCommandParsing(t *testing.T) { + type commandTest struct { + line string + cmd string + args []string + } + c := newCommands() + commands := []commandTest{ + {"quite", "", nil}, + {"quit", "QUIT", []string{""}}, + {":QUIT\n", "QUIT", []string{""}}, + {" QUIT \n", "QUIT", []string{""}}, + {"quit extra\n", "QUIT", []string{"extra"}}, + {`:Out c:\folder\file`, "OUT", []string{`c:\folder\file`}}, + {` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}}, + {`:Setvar A1 "some value" `, "SETVAR", []string{`A1 "some value" `}}, + {` :Listvar`, "LISTVAR", []string{""}}, + {`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}}, + {`:EXIT ( )`, "EXIT", []string{"( )"}}, + {`EXIT `, "EXIT", []string{""}}, + {`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}}, + {`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}}, + {`:!! notepad`, "EXEC", []string{" notepad"}}, + {`:!!notepad`, "EXEC", []string{"notepad"}}, + {` !! dir c:\`, "EXEC", []string{` dir c:\`}}, + {`!!dir c:\`, "EXEC", []string{`dir c:\`}}, + } + + for _, test := range commands { + cmd, args := c.matchCommand(test.line) + if test.cmd != "" { + if assert.NotNil(t, cmd, "No command found for `%s`", test.line) { + assert.Equal(t, test.cmd, cmd.name, "Incorrect command for `%s`", test.line) + assert.Equal(t, test.args, args, "Incorrect arguments for `%s`", test.line) + } + } else { + assert.Nil(t, cmd, "Unexpected match for %s", test.line) + } + } +} + +func TestCustomBatchSeparator(t *testing.T) { + c := newCommands() + err := c.SetBatchTerminator("me!") + if assert.NoError(t, err, "SetBatchTerminator should succeed") { + cmd, args := c.matchCommand(" me! 5 \n") + if assert.NotNil(t, cmd, "matchCommand didn't find GO for custom batch separator") { + assert.Equal(t, "GO", cmd.name, "command name") + assert.Equal(t, "5", strings.TrimSpace(args[0]), "go argument") + } + } +} + +func TestVarCommands(t *testing.T) { + vars := InitializeVariables(false) + s := New(nil, "", vars) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + err := setVarCommand(s, []string{"ABC 100"}, 1) + assert.NoError(t, err, "setVarCommand ABC 100") + err = setVarCommand(s, []string{"XYZ 200"}, 2) + assert.NoError(t, err, "setVarCommand XYZ 200") + err = listVarCommand(s, []string{""}, 3) + assert.NoError(t, err, "listVarCommand") + s.SetOutput(nil) + varmap := s.vars.All() + o := buf.buf.String() + t.Logf("Listvar output:\n'%s'", o) + output := strings.Split(o, SqlcmdEol) + for i, v := range builtinVariables { + line := strings.Split(output[i], " = ") + assert.Equalf(t, v, line[0], "unexpected variable printed at index %d", i) + val := strings.Trim(line[1], `"`) + assert.Equalf(t, varmap[v], val, "Unexpected value for variable %s", v) + } + assert.Equalf(t, `ABC = "100"`, output[len(output)-3], "Penultimate non-empty line should be ABC") + assert.Equalf(t, `XYZ = "200"`, output[len(output)-2], "Last non-empty line should be XYZ") + assert.Equalf(t, "", output[len(output)-1], "Last line should be empty") + +} + +// memoryBuffer has both Write and Close methods for use as io.WriteCloser +type memoryBuffer struct { + buf *bytes.Buffer +} + +func (b *memoryBuffer) Write(p []byte) (n int, err error) { + return b.buf.Write(p) +} + +func (b *memoryBuffer) Close() error { + return nil +} + +func TestResetCommand(t *testing.T) { + var err error + + // setup a test sqlcmd + vars := InitializeVariables(false) + s := New(nil, "", vars) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + + // insert a test batch + s.batch.Reset([]rune("select 1")) + _, _, err = s.batch.Next() + assert.NoError(t, err, "Inserting test batch") + assert.Equal(t, s.batch.batchline, int(2), "Batch line updated after test batch insert") + + // execute reset command and validate results + err = resetCommand(s, nil, 1) + assert.Equal(t, s.batch.batchline, int(1), "Batch line not reset properly") + assert.NoError(t, err, "Executing :reset command") +} + +func TestListCommand(t *testing.T) { + var err error + + // setup a test sqlcmd + vars := InitializeVariables(false) + s := New(nil, "", vars) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + + // insert test batch + s.batch.Reset([]rune("select 1")) + _, _, err = s.batch.Next() + assert.NoError(t, err, "Inserting test batch") + + // execute list command and verify results + err = listCommand(s, nil, 1) + assert.NoError(t, err, "Executing :list command") + s.SetOutput(nil) + o := buf.buf.String() + assert.Equal(t, o, "select 1"+SqlcmdEol, ":list output not equal to batch") +} + +func TestConnectCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + prompted := false + s.lineIo = &testConsole{ + OnPasswordPrompt: func(prompt string) ([]byte, error) { + prompted = true + return []byte{}, nil + }, + } + err := connectCommand(s, []string{"someserver -U someuser"}, 1) + assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure") + assert.True(t, prompted, "connectCommand with user name and no password should prompt for password") + assert.NotEqual(t, "someserver", s.Connect.ServerName, "On connection failure, sqlCmd.Connect does not copy inputs") + + err = connectCommand(s, []string{}, 2) + assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error") + c := newConnect(t) + + authenticationMethod := "" + password := "" + username := "" + if canTestAzureAuth() { + authenticationMethod = "-G " + azuread.ActiveDirectoryDefault + } + if c.Password != "" { + password = "-P " + c.Password + } + if c.UserName != "" { + username = "-U " + c.UserName + } + s.vars.Set("servername", c.ServerName) + s.vars.Set("to", "111") + buf.buf.Reset() + err = connectCommand(s, []string{fmt.Sprintf("$(servername) %s %s %s -l $(to)", username, password, authenticationMethod)}, 3) + if assert.NoError(t, err, "connectCommand with valid parameters should not return an error") { + // not using assert to avoid printing passwords in the log + assert.NotContains(t, buf.buf.String(), "$(servername)", "ConnectDB should have succeeded") + if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password || s.Connect.LoginTimeoutSeconds != 111 { + t.Fatalf("After connect, sqlCmd.Connect is not updated %+v", s.Connect) + } + } +} + +func TestErrorCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + file, err := os.CreateTemp("", "sqlcmderr") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + fileName := file.Name() + _ = file.Close() + err = errorCommand(s, []string{""}, 1) + assert.EqualError(t, err, InvalidCommandError("OUT", 1).Error(), "errorCommand with empty file name") + err = errorCommand(s, []string{fileName}, 1) + assert.NoError(t, err, "errorCommand") + // Only some error kinds go to the error output + err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"}) + assert.NoError(t, err, "runSqlCmd") + s.SetError(nil) + errText, err := os.ReadFile(file.Name()) + if assert.NoError(t, err, "ReadFile") { + assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents") + } +} + +func TestResolveArgumentVariables(t *testing.T) { + type argTest struct { + arg string + val string + err string + } + + args := []argTest{ + {"$(var1)", "var1val", ""}, + {"$(var1", "$(var1", ""}, + {`C:\folder\$(var1)\$(var2)\$(var1)\file.sql`, `C:\folder\var1val\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."}, + {`C:\folder\$(var1\$(var2)\$(var1)\file.sql`, `C:\folder\$(var1\$(var2)\var1val\file.sql`, "Sqlcmd: Error: 'var2' scripting variable not defined."}, + } + vars := InitializeVariables(false) + s := New(nil, "", vars) + s.vars.Set("var1", "var1val") + buf := &memoryBuffer{buf: new(bytes.Buffer)} + defer buf.Close() + s.SetError(buf) + for _, test := range args { + actual, _ := resolveArgumentVariables(s, []rune(test.arg), false) + assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg) + assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg) + buf.buf.Reset() + } + actual, err := resolveArgumentVariables(s, []rune("$(var1)$(var2)"), true) + if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") { + assert.Empty(t, actual, "fail on unresolved variable") + } +} + +func TestExecCommand(t *testing.T) { + vars := InitializeVariables(false) + s := New(nil, "", vars) + s.vars.Set("var1", "hello") + buf := &memoryBuffer{buf: new(bytes.Buffer)} + defer buf.Close() + s.SetOutput(buf) + err := execCommand(s, []string{`echo $(var1)`}, 1) + if assert.NoError(t, err, "execCommand with valid arguments") { + assert.Equal(t, buf.buf.String(), "hello"+SqlcmdEol, "echo output should be in sqlcmd output") + } +} + +func TestDisableSysCommandBlocksExec(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.Cmd.DisableSysCommands(false) + c := []string{"set nocount on", ":!! echo hello", "select 100", "go"} + err := runSqlCmd(t, s, c) + if assert.NoError(t, err, ":!! with warning should not raise error") { + assert.Contains(t, buf.buf.String(), ErrCommandsDisabled.Error()+SqlcmdEol+"100"+SqlcmdEol) + assert.Equal(t, 0, s.Exitcode, "ExitCode after warning") + } + buf.buf.Reset() + s.Cmd.DisableSysCommands(true) + err = runSqlCmd(t, s, c) + if assert.NoError(t, err, ":!! with error should not return error") { + assert.Contains(t, buf.buf.String(), ErrCommandsDisabled.Error()+SqlcmdEol) + assert.NotContains(t, buf.buf.String(), "100", "query should not run when syscommand disabled") + assert.Equal(t, 1, s.Exitcode, "ExitCode after error") + } +} + +func TestEditCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.vars.Set(SQLCMDEDITOR, "echo select 5000> ") + c := []string{"set nocount on", "go", "select 100", ":ed", "go"} + err := runSqlCmd(t, s, c) + if assert.NoError(t, err, ":ed should not raise error") { + assert.Equal(t, "1> select 5000"+SqlcmdEol+"5000"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output from query after :ed command") + } +} diff --git a/pkg/sqlcmd/errors.go b/pkg/sqlcmd/errors.go index f88ca051..f6bc1c01 100644 --- a/pkg/sqlcmd/errors.go +++ b/pkg/sqlcmd/errors.go @@ -1,156 +1,156 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "errors" - "fmt" - "strings" -) - -// ErrorPrefix is the prefix for all sqlcmd-generated errors -const ErrorPrefix = "Sqlcmd: Error: " - -// WarningPrefix is the prefix for all sqlcmd-generated warnings -const WarningPrefix = "Sqlcmd: Warning: " - -// Common Sqlcmd error messages -const ErrCmdDisabled = "ED and !! commands, startup script, and environment variables are disabled" - -type SqlcmdError interface { - error - IsSqlcmdErr() bool -} - -type CommonSqlcmdErr struct { - message string -} - -func (e *CommonSqlcmdErr) Error() string { - return e.message -} - -func (e *CommonSqlcmdErr) IsSqlcmdErr() bool { - return true -} - -// ArgumentError is related to command line switch validation not handled by kong -type ArgumentError struct { - Parameter string - Rule string -} - -func (e *ArgumentError) Error() string { - return ErrorPrefix + e.Rule -} - -func (e *ArgumentError) IsSqlcmdErr() bool { - return true -} - -// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format -var InvalidServerName = ArgumentError{ - Parameter: "server", - Rule: "server must be of the form [tcp]:server[[/instance]|[,port]]", -} - -// VariableError is an error about scripting variables -type VariableError struct { - Variable string - MessageFormat string -} - -func (e *VariableError) Error() string { - return ErrorPrefix + fmt.Sprintf(e.MessageFormat, e.Variable) -} - -func (e *VariableError) IsSqlcmdErr() bool { - return true -} - -// ReadOnlyVariable indicates the user tried to set a value to a read-only variable -func ReadOnlyVariable(variable string) *VariableError { - return &VariableError{ - Variable: variable, - MessageFormat: "The scripting variable: '%s' is read-only", - } -} - -// UndefinedVariable indicates the user tried to reference an undefined variable -func UndefinedVariable(variable string) *VariableError { - return &VariableError{ - Variable: variable, - MessageFormat: "'%s' scripting variable not defined.", - } -} - -// InvalidVariableValue indicates the variable was set to an invalid value -func InvalidVariableValue(variable string, value string) *VariableError { - return &VariableError{ - Variable: variable, - MessageFormat: "The environment variable: '%s' has invalid value: '" + strings.ReplaceAll(value, `%`, `%%`) + "'.", - } -} - -// CommandError indicates syntax errors for specific sqlcmd commands -type CommandError struct { - Command string - LineNumber uint -} - -func (e *CommandError) Error() string { - return ErrorPrefix + fmt.Sprintf("Syntax error at line %d near command '%s'.", e.LineNumber, e.Command) -} - -func (e *CommandError) IsSqlcmdErr() bool { - return true -} - -// InvalidCommandError creates a SQLCmdCommandError -func InvalidCommandError(command string, lineNumber uint) *CommandError { - return &CommandError{ - Command: command, - LineNumber: lineNumber, - } -} - -type FileError struct { - err error - path string -} - -func (e *FileError) Error() string { - return e.err.Error() -} - -func (e *FileError) IsSqlcmdErr() bool { - return true -} - -// InvalidFileError indicates a file could not be opened -func InvalidFileError(err error, filepath string) error { - return &FileError{ - err: errors.New(ErrorPrefix + " Error occurred while opening or operating on file " + filepath + " (Reason: " + err.Error() + ")."), - path: filepath, - } -} - -type SyntaxError struct { - err error -} - -func (e *SyntaxError) Error() string { - return e.err.Error() -} - -func (e *SyntaxError) IsSqlcmdErr() bool { - return true -} - -// SyntaxError indicates a malformed sqlcmd statement -func syntaxError(lineNumber uint) SqlcmdError { - return &SyntaxError{ - err: fmt.Errorf("%sSyntax error at line %d", ErrorPrefix, lineNumber), - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "errors" + "fmt" + "strings" +) + +// ErrorPrefix is the prefix for all sqlcmd-generated errors +const ErrorPrefix = "Sqlcmd: Error: " + +// WarningPrefix is the prefix for all sqlcmd-generated warnings +const WarningPrefix = "Sqlcmd: Warning: " + +// Common Sqlcmd error messages +const ErrCmdDisabled = "ED and !! commands, startup script, and environment variables are disabled" + +type SqlcmdError interface { + error + IsSqlcmdErr() bool +} + +type CommonSqlcmdErr struct { + message string +} + +func (e *CommonSqlcmdErr) Error() string { + return e.message +} + +func (e *CommonSqlcmdErr) IsSqlcmdErr() bool { + return true +} + +// ArgumentError is related to command line switch validation not handled by kong +type ArgumentError struct { + Parameter string + Rule string +} + +func (e *ArgumentError) Error() string { + return ErrorPrefix + e.Rule +} + +func (e *ArgumentError) IsSqlcmdErr() bool { + return true +} + +// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format +var InvalidServerName = ArgumentError{ + Parameter: "server", + Rule: "server must be of the form [tcp]:server[[/instance]|[,port]]", +} + +// VariableError is an error about scripting variables +type VariableError struct { + Variable string + MessageFormat string +} + +func (e *VariableError) Error() string { + return ErrorPrefix + fmt.Sprintf(e.MessageFormat, e.Variable) +} + +func (e *VariableError) IsSqlcmdErr() bool { + return true +} + +// ReadOnlyVariable indicates the user tried to set a value to a read-only variable +func ReadOnlyVariable(variable string) *VariableError { + return &VariableError{ + Variable: variable, + MessageFormat: "The scripting variable: '%s' is read-only", + } +} + +// UndefinedVariable indicates the user tried to reference an undefined variable +func UndefinedVariable(variable string) *VariableError { + return &VariableError{ + Variable: variable, + MessageFormat: "'%s' scripting variable not defined.", + } +} + +// InvalidVariableValue indicates the variable was set to an invalid value +func InvalidVariableValue(variable string, value string) *VariableError { + return &VariableError{ + Variable: variable, + MessageFormat: "The environment variable: '%s' has invalid value: '" + strings.ReplaceAll(value, `%`, `%%`) + "'.", + } +} + +// CommandError indicates syntax errors for specific sqlcmd commands +type CommandError struct { + Command string + LineNumber uint +} + +func (e *CommandError) Error() string { + return ErrorPrefix + fmt.Sprintf("Syntax error at line %d near command '%s'.", e.LineNumber, e.Command) +} + +func (e *CommandError) IsSqlcmdErr() bool { + return true +} + +// InvalidCommandError creates a SQLCmdCommandError +func InvalidCommandError(command string, lineNumber uint) *CommandError { + return &CommandError{ + Command: command, + LineNumber: lineNumber, + } +} + +type FileError struct { + err error + path string +} + +func (e *FileError) Error() string { + return e.err.Error() +} + +func (e *FileError) IsSqlcmdErr() bool { + return true +} + +// InvalidFileError indicates a file could not be opened +func InvalidFileError(err error, filepath string) error { + return &FileError{ + err: errors.New(ErrorPrefix + " Error occurred while opening or operating on file " + filepath + " (Reason: " + err.Error() + ")."), + path: filepath, + } +} + +type SyntaxError struct { + err error +} + +func (e *SyntaxError) Error() string { + return e.err.Error() +} + +func (e *SyntaxError) IsSqlcmdErr() bool { + return true +} + +// SyntaxError indicates a malformed sqlcmd statement +func syntaxError(lineNumber uint) SqlcmdError { + return &SyntaxError{ + err: fmt.Errorf("%sSyntax error at line %d", ErrorPrefix, lineNumber), + } +} diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 9de7430a..b44e07e9 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -1,663 +1,663 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "database/sql" - "fmt" - "io" - "strings" - "time" - - "github.com/google/uuid" - mssql "github.com/microsoft/go-mssqldb" -) - -const ( - defaultMaxDisplayWidth = 1024 * 1024 - maxPadWidth = 8000 -) - -// Formatter defines methods to process query output -type Formatter interface { - // BeginBatch is called before the query runs - BeginBatch(query string, vars *Variables, out io.Writer, err io.Writer) - // EndBatch is the last function called during batch execution and signals the end of the batch - EndBatch() - // BeginResultSet is called when a new result set is encountered - BeginResultSet([]*sql.ColumnType) - // EndResultSet is called after all rows in a result set have been processed - EndResultSet() - // AddRow is called for each row in a result set. It returns the value of the first column - AddRow(*sql.Rows) string - // AddMessage is called for every information message returned by the server during the batch - AddMessage(string) - // AddError is called for each error encountered during batch execution - AddError(err error) -} - -// ControlCharacterBehavior specifies the text handling required for control characters in the output -type ControlCharacterBehavior int - -const ( - // ControlIgnore preserves control characters in the output - ControlIgnore ControlCharacterBehavior = iota - // ControlReplace replaces control characters with spaces, 1 space per character - ControlReplace - // ControlRemove removes control characters from the output - ControlRemove - // ControlReplaceConsecutive replaces multiple consecutive control characters with a single space - ControlReplaceConsecutive -) - -type columnDetail struct { - displayWidth int64 - leftJustify bool - zeroesAfterDecimal bool - col sql.ColumnType - precision int - scale int -} - -// The default formatter based on the native sqlcmd style -// It supports both horizontal (default) and vertical layout for results. -// Both vertical and horizontal layouts respect column widths set by SQLCMD variables. -type sqlCmdFormatterType struct { - out io.Writer - err io.Writer - vars *Variables - colsep string - removeTrailingSpaces bool - ccb ControlCharacterBehavior - columnDetails []columnDetail - rowcount int - writepos int64 - format string - maxColNameLen int -} - -// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter -func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool) Formatter { - return &sqlCmdFormatterType{ - removeTrailingSpaces: removeTrailingSpaces, - format: "horizontal", - } -} - -// Adds the given string to the current line, wrapping it based on the screen width setting -func (f *sqlCmdFormatterType) writeOut(s string) { - w := f.vars.ScreenWidth() - if w == 0 { - f.mustWriteOut(s) - return - } - - r := []rune(s) - for i := 0; true; { - if i == len(r) { - f.mustWriteOut(string(r)) - return - } else if f.writepos == w { - f.mustWriteOut(string(r[:i])) - f.mustWriteOut(SqlcmdEol) - r = []rune(string(r[i:])) - f.writepos = 0 - i = 0 - } else { - c := r[i] - if c != '\r' && c != '\n' { - f.writepos++ - } else { - f.writepos = 0 - } - i++ - } - } -} - -// Stores the settings to use for processing the current batch -// TODO: add a third io.Writer for messages when we add -r support -func (f *sqlCmdFormatterType) BeginBatch(_ string, vars *Variables, out io.Writer, err io.Writer) { - f.out = out - f.err = err - f.vars = vars - f.colsep = vars.ColumnSeparator() - f.format = vars.Format() -} - -func (f *sqlCmdFormatterType) EndBatch() { -} - -// Calculate the widths for each column and print the column names -// Since sql.ColumnType only provides sizes for variable length types we will -// base our numbers for most types on https://docs.microsoft.com/sql/odbc/reference/appendixes/column-size -func (f *sqlCmdFormatterType) BeginResultSet(cols []*sql.ColumnType) { - f.rowcount = 0 - f.columnDetails, f.maxColNameLen = calcColumnDetails(cols, f.vars.MaxFixedColumnWidth(), f.vars.MaxVarColumnWidth()) - if f.vars.RowsBetweenHeaders() > -1 && f.format == "horizontal" { - f.printColumnHeadings() - } -} - -// Writes a blank line to the designated output writer -func (f *sqlCmdFormatterType) EndResultSet() { - f.writeOut(SqlcmdEol) -} - -// Writes the current row to the designated output writer -func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string { - retval := "" - values, err := f.scanRow(row) - if err != nil { - f.mustWriteErr(err.Error()) - return retval - } - retval = values[0] - if f.format == "horizontal" { - // values are the full values, look at the displaywidth of each column and truncate accordingly - for i, v := range values { - if i > 0 { - f.writeOut(f.vars.ColumnSeparator()) - } - f.printColumnValue(v, i) - } - f.rowcount++ - gap := f.vars.RowsBetweenHeaders() - if gap > 0 && (int64(f.rowcount)%gap == 0) { - f.writeOut(SqlcmdEol) - f.printColumnHeadings() - } - } else { - f.addVerticalRow(values) - } - f.writeOut(SqlcmdEol) - return retval - -} - -func (f *sqlCmdFormatterType) addVerticalRow(values []string) { - for i, v := range values { - if f.vars.RowsBetweenHeaders() > -1 { - builder := new(strings.Builder) - name := f.columnDetails[i].col.Name() - builder.WriteString(name) - builder = padRight(builder, int64(f.maxColNameLen-len(name)+1), " ") - f.writeOut(builder.String()) - } - f.printColumnValue(v, i) - f.writeOut(SqlcmdEol) - } -} - -// Writes a non-error message to the designated message writer -func (f *sqlCmdFormatterType) AddMessage(msg string) { - f.mustWriteOut(msg + SqlcmdEol) -} - -// Writes an error to the designated err Writer -func (f *sqlCmdFormatterType) AddError(err error) { - print := true - b := new(strings.Builder) - msg := err.Error() - switch e := (err).(type) { - case mssql.Error: - if print = f.vars.ErrorLevel() <= 0 || e.Class >= uint8(f.vars.ErrorLevel()); print { - b.WriteString(fmt.Sprintf("Msg %d, Level %d, State %d, Server %s, Line %d%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol)) - msg = strings.TrimPrefix(msg, "mssql: ") - } - } - if print { - b.WriteString(msg) - b.WriteString(SqlcmdEol) - f.mustWriteErr(fitToScreen(b, f.vars.ScreenWidth()).String()) - } -} - -// Prints column headings based on columnDetail, variables, and command line arguments -func (f *sqlCmdFormatterType) printColumnHeadings() { - names := new(strings.Builder) - sep := new(strings.Builder) - - var leftPad, rightPad int64 - for i, c := range f.columnDetails { - rightPad = 0 - nameLen := int64(len([]rune(c.col.Name()))) - if f.removeTrailingSpaces { - if nameLen == 0 { - // special case for unnamed columns when using -W - // print a single - - rightPad = 1 - sep = padRight(sep, 1, "-") - } else { - sep = padRight(sep, nameLen, "-") - } - } else { - length := min64(c.displayWidth, maxPadWidth) - if nameLen < length { - rightPad = length - nameLen - } - sep = padRight(sep, length, "-") - } - names = padRight(names, leftPad, " ") - names.WriteString(c.col.Name()[:min64(nameLen, c.displayWidth)]) - names = padRight(names, rightPad, " ") - if i != len(f.columnDetails)-1 { - names.WriteString(f.colsep) - sep.WriteString(f.colsep) - } - } - names.WriteString(SqlcmdEol) - sep.WriteString(SqlcmdEol) - names = fitToScreen(names, f.vars.ScreenWidth()) - sep = fitToScreen(sep, f.vars.ScreenWidth()) - f.mustWriteOut(names.String()) - f.mustWriteOut(sep.String()) -} - -// Wraps the input string every width characters when width > 0 -// When width == 0 returns the input Builder -// When width > 0 returns a new Builder containing the wrapped string -func fitToScreen(s *strings.Builder, width int64) *strings.Builder { - str := s.String() - runes := []rune(str) - if width == 0 || int64(len(runes)) < width { - return s - } - - line := new(strings.Builder) - line.Grow(len(str)) - var c int64 - for i, r := range runes { - if c == width { - // We have printed a line's worth - // if the next character is not part of a carriage return write our Eol - if (SqlcmdEol == "\r\n" && (i == len(runes)-1 || (i < len(runes)-1 && string(runes[i:i+2]) != SqlcmdEol))) || (SqlcmdEol == "\n" && r != '\n') { - line.WriteString(SqlcmdEol) - c = 0 - } - } - line.WriteRune(r) - if r == '\n' { - c = 0 - // we are assuming \r is a non-printed character - // The likelihood of a \r not being followed by \n is low - } else if r == '\r' && SqlcmdEol == "\r\n" { - c = 0 - } else { - c++ - } - } - return line -} - -// Given the array of driver-provided columnType values and the sqlcmd size limits, -// Return an array of columnDetail objects describing the output format for each column. -// Return the length of the longest column name. -func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) ([]columnDetail, int) { - columnDetails := make([]columnDetail, len(cols)) - maxNameLen := 0 - for i, c := range cols { - length, _ := c.Length() - nameLen := int64(len([]rune(c.Name()))) - if nameLen > int64(maxNameLen) { - maxNameLen = int(nameLen) - } - columnDetails[i].col = *c - columnDetails[i].leftJustify = true - columnDetails[i].zeroesAfterDecimal = false - p, s, ok := c.DecimalSize() - if ok { - columnDetails[i].precision = int(p) - columnDetails[i].scale = int(s) - } - if length == 0 { - columnDetails[i].displayWidth = defaultMaxDisplayWidth - } else { - columnDetails[i].displayWidth = length - } - typeName := c.DatabaseTypeName() - - switch typeName { - // Types with 0 size from sql.ColumnType - case "BIT": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(1, nameLen) - case "TINYINT": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(3, nameLen) - case "SMALLINT": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(6, nameLen) - case "INT": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(11, nameLen) - case "BIGINT": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(21, nameLen) - case "REAL", "SMALLMONEY": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(14, nameLen) - columnDetails[i].zeroesAfterDecimal = true - case "FLOAT", "MONEY": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(24, nameLen) - columnDetails[i].zeroesAfterDecimal = true - case "DECIMAL": - columnDetails[i].leftJustify = false - d, _, ok := c.DecimalSize() - // maybe panic on !ok? - if !ok { - d = 24 - } - columnDetails[i].displayWidth = max64(d+2, nameLen) - columnDetails[i].zeroesAfterDecimal = true - case "DATE": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(16, nameLen) - case "DATETIME": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(23, nameLen) - case "SMALLDATETIME": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(19, nameLen) - columnDetails[i].zeroesAfterDecimal = true - case "DATETIME2": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(38, nameLen) - columnDetails[i].zeroesAfterDecimal = true - case "TIME": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(16, nameLen) - case "DATETIMEOFFSET": - columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(45, nameLen) - case "UNIQUEIDENTIFIER": - columnDetails[i].displayWidth = max64(36, nameLen) - // Types that can be fixed or variable - case "VARCHAR": - if length > 8000 { - columnDetails[i].displayWidth = variable - } else { - if fixed > 0 { - length = min64(fixed, length) - } - columnDetails[i].displayWidth = max64(length, nameLen) - } - case "NVARCHAR": - if length > 4000 { - columnDetails[i].displayWidth = variable - } else { - if fixed > 0 { - length = min64(fixed, length) - } - columnDetails[i].displayWidth = max64(length, nameLen) - } - case "VARBINARY": - if length <= 8000 { - if fixed > 0 { - length = min64(fixed, length) - } - columnDetails[i].displayWidth = max64(length, nameLen) - } else { - columnDetails[i].displayWidth = variable - } - case "SQL_VARIANT": - if fixed > 0 { - columnDetails[i].displayWidth = min64(fixed, 8000) - } else { - columnDetails[i].displayWidth = 8000 - } - // Fixed length types - case "CHAR", "NCHAR": - if fixed > 0 { - length = min64(fixed, length) - } - columnDetails[i].displayWidth = max64(length, nameLen) - // Variable length types - // TODO: Fix BINARY once we have a driver with fix for https://github.com/denisenkom/go-mssqldb/issues/685 - case "XML", "TEXT", "NTEXT", "IMAGE", "BINARY": - columnDetails[i].displayWidth = variable - default: - columnDetails[i].displayWidth = length - } - // When max var length is 0 we don't print column headers and print every value with unlimited width - if variable == 0 { - columnDetails[i].displayWidth = 0 - } - } - return columnDetails, maxNameLen -} - -// scanRow fetches the next row and converts each value to the appropriate string representation -func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { - r := make([]interface{}, len(f.columnDetails)) - for i := range r { - r[i] = new(interface{}) - } - if err := rows.Scan(r...); err != nil { - return nil, err - } - row := make([]string, len(f.columnDetails)) - for n, z := range r { - j := z.(*interface{}) - if *j == nil { - row[n] = "NULL" - } else { - switch x := (*j).(type) { - case []byte: - if isBinaryDataType(&f.columnDetails[n].col) { - row[n] = decodeBinary(x) - } else if f.columnDetails[n].col.DatabaseTypeName() == "UNIQUEIDENTIFIER" { - // Unscramble the guid - // see https://github.com/denisenkom/go-mssqldb/issues/56 - x[0], x[1], x[2], x[3] = x[3], x[2], x[1], x[0] - x[4], x[5] = x[5], x[4] - x[6], x[7] = x[7], x[6] - if guid, err := uuid.FromBytes(x); err == nil { - row[n] = guid.String() - } else { - // this should never happen - row[n] = uuid.New().String() - } - } else { - row[n] = string(x) - } - case string: - row[n] = x - case time.Time: - // Go lacks any way to get the user's preferred time format or even the system default - switch f.columnDetails[n].col.DatabaseTypeName() { - case "DATE": - row[n] = x.Format("2006-01-02") - case "DATETIME": - row[n] = x.Format(dateTimeFormatString(3, false)) - case "DATETIME2": - row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, false)) - case "SMALLDATETIME": - row[n] = x.Format(dateTimeFormatString(0, false)) - case "DATETIMEOFFSET": - row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, true)) - case "TIME": - format := "15:04:05" - if f.columnDetails[n].scale > 0 { - format = fmt.Sprintf("%s.%0*d", format, f.columnDetails[n].scale, 0) - } - row[n] = x.Format(format) - default: - row[n] = x.Format(time.RFC3339) - } - case fmt.Stringer: - row[n] = x.String() - // not sure why go-mssql reports bit as bool - case bool: - if x { - row[n] = "1" - } else { - row[n] = "0" - } - default: - var err error - if row[n], err = fmt.Sprintf("%v", x), nil; err != nil { - return nil, err - } - } - } - } - return row, nil -} - -func dateTimeFormatString(scale int, addOffset bool) string { - format := `2006-01-02 15:04:05` - if scale > 0 { - format = fmt.Sprintf("%s.%0*d", format, scale, 0) - } - if addOffset { - format += " -07:00" - } - return format -} - -// Prints the final version of a cell based on formatting variables and command line parameters -func (f *sqlCmdFormatterType) printColumnValue(val string, col int) { - c := f.columnDetails[col] - s := new(strings.Builder) - if isNeedingControlCharacterTreatment(&c.col) { - val = applyControlCharacterBehavior(val, f.ccb) - } - - if isNeedingHexPrefix(&c.col) { - val = "0x" + val - } - - s.WriteString(val) - r := []rune(val) - if f.format == "horizontal" { - if !f.removeTrailingSpaces { - if f.vars.MaxVarColumnWidth() != 0 || !isLargeVariableType(&c.col) { - padding := c.displayWidth - min64(c.displayWidth, int64(len(r))) - if padding > 0 { - if c.leftJustify { - s = padRight(s, padding, " ") - } else { - s = padLeft(s, padding, " ") - } - } - } - } - - r = []rune(s.String()) - } - if c.displayWidth > 0 && int64(len(r)) > c.displayWidth { - s.Reset() - s.WriteString(string(r[:c.displayWidth])) - } - f.writeOut(s.String()) -} - -func (f *sqlCmdFormatterType) mustWriteOut(s string) { - _, err := f.out.Write([]byte(s)) - if err != nil { - panic(err) - } -} - -func (f *sqlCmdFormatterType) mustWriteErr(s string) { - _, err := f.err.Write([]byte(s)) - if err != nil { - panic(err) - } -} - -func isLargeVariableType(col *sql.ColumnType) bool { - l, _ := col.Length() - switch col.DatabaseTypeName() { - - case "VARCHAR", "VARBINARY": - return l > 8000 - case "NVARCHAR": - return l > 4000 - case "XML", "TEXT", "NTEXT", "IMAGE": - return true - } - return false -} - -func isNeedingControlCharacterTreatment(col *sql.ColumnType) bool { - switch col.DatabaseTypeName() { - case "CHAR", "VARCHAR", "TEXT", "NTEXT", "NCHAR", "NVARCHAR", "XML": - return true - } - return false -} -func isBinaryDataType(col *sql.ColumnType) bool { - switch col.DatabaseTypeName() { - case "BINARY", "VARBINARY": - return true - } - return false -} - -func isNeedingHexPrefix(col *sql.ColumnType) bool { - return isBinaryDataType(col) // || col.DatabaseTypeName() == "UDT" -} - -func isControlChar(r rune) bool { - c := int(r) - return c == 0x7f || (c >= 0 && c <= 0x1f) -} - -func applyControlCharacterBehavior(val string, ccb ControlCharacterBehavior) string { - if ccb == ControlIgnore { - return val - } - b := new(strings.Builder) - r := []rune(val) - if ccb == ControlReplace { - for _, l := range r { - if isControlChar(l) { - b.WriteRune(' ') - } else { - b.WriteRune(l) - } - } - } else { - for i := 0; i < len(r); { - if !isControlChar(r[i]) { - b.WriteRune(r[i]) - i++ - } else { - for ; i < len(r) && isControlChar(r[i]); i++ { - } - if ccb == ControlReplaceConsecutive { - b.WriteRune(' ') - } - } - } - } - return b.String() -} - -// Per https://docs.microsoft.com/sql/odbc/reference/appendixes/sql-to-c-binary -var hexDigits = []rune{'A', 'B', 'C', 'D', 'E', 'F'} - -func decodeBinary(b []byte) string { - - s := new(strings.Builder) - s.Grow(len(b) * 2) - for _, ch := range b { - b1 := ch >> 4 - b2 := ch & 0x0f - if b1 >= 10 { - s.WriteRune(hexDigits[b1-10]) - } else { - s.WriteRune(rune('0' + b1)) - } - if b2 >= 10 { - s.WriteRune(hexDigits[b2-10]) - } else { - s.WriteRune(rune('0' + b2)) - } - } - return s.String() -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "database/sql" + "fmt" + "io" + "strings" + "time" + + "github.com/google/uuid" + mssql "github.com/microsoft/go-mssqldb" +) + +const ( + defaultMaxDisplayWidth = 1024 * 1024 + maxPadWidth = 8000 +) + +// Formatter defines methods to process query output +type Formatter interface { + // BeginBatch is called before the query runs + BeginBatch(query string, vars *Variables, out io.Writer, err io.Writer) + // EndBatch is the last function called during batch execution and signals the end of the batch + EndBatch() + // BeginResultSet is called when a new result set is encountered + BeginResultSet([]*sql.ColumnType) + // EndResultSet is called after all rows in a result set have been processed + EndResultSet() + // AddRow is called for each row in a result set. It returns the value of the first column + AddRow(*sql.Rows) string + // AddMessage is called for every information message returned by the server during the batch + AddMessage(string) + // AddError is called for each error encountered during batch execution + AddError(err error) +} + +// ControlCharacterBehavior specifies the text handling required for control characters in the output +type ControlCharacterBehavior int + +const ( + // ControlIgnore preserves control characters in the output + ControlIgnore ControlCharacterBehavior = iota + // ControlReplace replaces control characters with spaces, 1 space per character + ControlReplace + // ControlRemove removes control characters from the output + ControlRemove + // ControlReplaceConsecutive replaces multiple consecutive control characters with a single space + ControlReplaceConsecutive +) + +type columnDetail struct { + displayWidth int64 + leftJustify bool + zeroesAfterDecimal bool + col sql.ColumnType + precision int + scale int +} + +// The default formatter based on the native sqlcmd style +// It supports both horizontal (default) and vertical layout for results. +// Both vertical and horizontal layouts respect column widths set by SQLCMD variables. +type sqlCmdFormatterType struct { + out io.Writer + err io.Writer + vars *Variables + colsep string + removeTrailingSpaces bool + ccb ControlCharacterBehavior + columnDetails []columnDetail + rowcount int + writepos int64 + format string + maxColNameLen int +} + +// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter +func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool) Formatter { + return &sqlCmdFormatterType{ + removeTrailingSpaces: removeTrailingSpaces, + format: "horizontal", + } +} + +// Adds the given string to the current line, wrapping it based on the screen width setting +func (f *sqlCmdFormatterType) writeOut(s string) { + w := f.vars.ScreenWidth() + if w == 0 { + f.mustWriteOut(s) + return + } + + r := []rune(s) + for i := 0; true; { + if i == len(r) { + f.mustWriteOut(string(r)) + return + } else if f.writepos == w { + f.mustWriteOut(string(r[:i])) + f.mustWriteOut(SqlcmdEol) + r = []rune(string(r[i:])) + f.writepos = 0 + i = 0 + } else { + c := r[i] + if c != '\r' && c != '\n' { + f.writepos++ + } else { + f.writepos = 0 + } + i++ + } + } +} + +// Stores the settings to use for processing the current batch +// TODO: add a third io.Writer for messages when we add -r support +func (f *sqlCmdFormatterType) BeginBatch(_ string, vars *Variables, out io.Writer, err io.Writer) { + f.out = out + f.err = err + f.vars = vars + f.colsep = vars.ColumnSeparator() + f.format = vars.Format() +} + +func (f *sqlCmdFormatterType) EndBatch() { +} + +// Calculate the widths for each column and print the column names +// Since sql.ColumnType only provides sizes for variable length types we will +// base our numbers for most types on https://docs.microsoft.com/sql/odbc/reference/appendixes/column-size +func (f *sqlCmdFormatterType) BeginResultSet(cols []*sql.ColumnType) { + f.rowcount = 0 + f.columnDetails, f.maxColNameLen = calcColumnDetails(cols, f.vars.MaxFixedColumnWidth(), f.vars.MaxVarColumnWidth()) + if f.vars.RowsBetweenHeaders() > -1 && f.format == "horizontal" { + f.printColumnHeadings() + } +} + +// Writes a blank line to the designated output writer +func (f *sqlCmdFormatterType) EndResultSet() { + f.writeOut(SqlcmdEol) +} + +// Writes the current row to the designated output writer +func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string { + retval := "" + values, err := f.scanRow(row) + if err != nil { + f.mustWriteErr(err.Error()) + return retval + } + retval = values[0] + if f.format == "horizontal" { + // values are the full values, look at the displaywidth of each column and truncate accordingly + for i, v := range values { + if i > 0 { + f.writeOut(f.vars.ColumnSeparator()) + } + f.printColumnValue(v, i) + } + f.rowcount++ + gap := f.vars.RowsBetweenHeaders() + if gap > 0 && (int64(f.rowcount)%gap == 0) { + f.writeOut(SqlcmdEol) + f.printColumnHeadings() + } + } else { + f.addVerticalRow(values) + } + f.writeOut(SqlcmdEol) + return retval + +} + +func (f *sqlCmdFormatterType) addVerticalRow(values []string) { + for i, v := range values { + if f.vars.RowsBetweenHeaders() > -1 { + builder := new(strings.Builder) + name := f.columnDetails[i].col.Name() + builder.WriteString(name) + builder = padRight(builder, int64(f.maxColNameLen-len(name)+1), " ") + f.writeOut(builder.String()) + } + f.printColumnValue(v, i) + f.writeOut(SqlcmdEol) + } +} + +// Writes a non-error message to the designated message writer +func (f *sqlCmdFormatterType) AddMessage(msg string) { + f.mustWriteOut(msg + SqlcmdEol) +} + +// Writes an error to the designated err Writer +func (f *sqlCmdFormatterType) AddError(err error) { + print := true + b := new(strings.Builder) + msg := err.Error() + switch e := (err).(type) { + case mssql.Error: + if print = f.vars.ErrorLevel() <= 0 || e.Class >= uint8(f.vars.ErrorLevel()); print { + b.WriteString(fmt.Sprintf("Msg %d, Level %d, State %d, Server %s, Line %d%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol)) + msg = strings.TrimPrefix(msg, "mssql: ") + } + } + if print { + b.WriteString(msg) + b.WriteString(SqlcmdEol) + f.mustWriteErr(fitToScreen(b, f.vars.ScreenWidth()).String()) + } +} + +// Prints column headings based on columnDetail, variables, and command line arguments +func (f *sqlCmdFormatterType) printColumnHeadings() { + names := new(strings.Builder) + sep := new(strings.Builder) + + var leftPad, rightPad int64 + for i, c := range f.columnDetails { + rightPad = 0 + nameLen := int64(len([]rune(c.col.Name()))) + if f.removeTrailingSpaces { + if nameLen == 0 { + // special case for unnamed columns when using -W + // print a single - + rightPad = 1 + sep = padRight(sep, 1, "-") + } else { + sep = padRight(sep, nameLen, "-") + } + } else { + length := min64(c.displayWidth, maxPadWidth) + if nameLen < length { + rightPad = length - nameLen + } + sep = padRight(sep, length, "-") + } + names = padRight(names, leftPad, " ") + names.WriteString(c.col.Name()[:min64(nameLen, c.displayWidth)]) + names = padRight(names, rightPad, " ") + if i != len(f.columnDetails)-1 { + names.WriteString(f.colsep) + sep.WriteString(f.colsep) + } + } + names.WriteString(SqlcmdEol) + sep.WriteString(SqlcmdEol) + names = fitToScreen(names, f.vars.ScreenWidth()) + sep = fitToScreen(sep, f.vars.ScreenWidth()) + f.mustWriteOut(names.String()) + f.mustWriteOut(sep.String()) +} + +// Wraps the input string every width characters when width > 0 +// When width == 0 returns the input Builder +// When width > 0 returns a new Builder containing the wrapped string +func fitToScreen(s *strings.Builder, width int64) *strings.Builder { + str := s.String() + runes := []rune(str) + if width == 0 || int64(len(runes)) < width { + return s + } + + line := new(strings.Builder) + line.Grow(len(str)) + var c int64 + for i, r := range runes { + if c == width { + // We have printed a line's worth + // if the next character is not part of a carriage return write our Eol + if (SqlcmdEol == "\r\n" && (i == len(runes)-1 || (i < len(runes)-1 && string(runes[i:i+2]) != SqlcmdEol))) || (SqlcmdEol == "\n" && r != '\n') { + line.WriteString(SqlcmdEol) + c = 0 + } + } + line.WriteRune(r) + if r == '\n' { + c = 0 + // we are assuming \r is a non-printed character + // The likelihood of a \r not being followed by \n is low + } else if r == '\r' && SqlcmdEol == "\r\n" { + c = 0 + } else { + c++ + } + } + return line +} + +// Given the array of driver-provided columnType values and the sqlcmd size limits, +// Return an array of columnDetail objects describing the output format for each column. +// Return the length of the longest column name. +func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) ([]columnDetail, int) { + columnDetails := make([]columnDetail, len(cols)) + maxNameLen := 0 + for i, c := range cols { + length, _ := c.Length() + nameLen := int64(len([]rune(c.Name()))) + if nameLen > int64(maxNameLen) { + maxNameLen = int(nameLen) + } + columnDetails[i].col = *c + columnDetails[i].leftJustify = true + columnDetails[i].zeroesAfterDecimal = false + p, s, ok := c.DecimalSize() + if ok { + columnDetails[i].precision = int(p) + columnDetails[i].scale = int(s) + } + if length == 0 { + columnDetails[i].displayWidth = defaultMaxDisplayWidth + } else { + columnDetails[i].displayWidth = length + } + typeName := c.DatabaseTypeName() + + switch typeName { + // Types with 0 size from sql.ColumnType + case "BIT": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(1, nameLen) + case "TINYINT": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(3, nameLen) + case "SMALLINT": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(6, nameLen) + case "INT": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(11, nameLen) + case "BIGINT": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(21, nameLen) + case "REAL", "SMALLMONEY": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(14, nameLen) + columnDetails[i].zeroesAfterDecimal = true + case "FLOAT", "MONEY": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(24, nameLen) + columnDetails[i].zeroesAfterDecimal = true + case "DECIMAL": + columnDetails[i].leftJustify = false + d, _, ok := c.DecimalSize() + // maybe panic on !ok? + if !ok { + d = 24 + } + columnDetails[i].displayWidth = max64(d+2, nameLen) + columnDetails[i].zeroesAfterDecimal = true + case "DATE": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(16, nameLen) + case "DATETIME": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(23, nameLen) + case "SMALLDATETIME": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(19, nameLen) + columnDetails[i].zeroesAfterDecimal = true + case "DATETIME2": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(38, nameLen) + columnDetails[i].zeroesAfterDecimal = true + case "TIME": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(16, nameLen) + case "DATETIMEOFFSET": + columnDetails[i].leftJustify = false + columnDetails[i].displayWidth = max64(45, nameLen) + case "UNIQUEIDENTIFIER": + columnDetails[i].displayWidth = max64(36, nameLen) + // Types that can be fixed or variable + case "VARCHAR": + if length > 8000 { + columnDetails[i].displayWidth = variable + } else { + if fixed > 0 { + length = min64(fixed, length) + } + columnDetails[i].displayWidth = max64(length, nameLen) + } + case "NVARCHAR": + if length > 4000 { + columnDetails[i].displayWidth = variable + } else { + if fixed > 0 { + length = min64(fixed, length) + } + columnDetails[i].displayWidth = max64(length, nameLen) + } + case "VARBINARY": + if length <= 8000 { + if fixed > 0 { + length = min64(fixed, length) + } + columnDetails[i].displayWidth = max64(length, nameLen) + } else { + columnDetails[i].displayWidth = variable + } + case "SQL_VARIANT": + if fixed > 0 { + columnDetails[i].displayWidth = min64(fixed, 8000) + } else { + columnDetails[i].displayWidth = 8000 + } + // Fixed length types + case "CHAR", "NCHAR": + if fixed > 0 { + length = min64(fixed, length) + } + columnDetails[i].displayWidth = max64(length, nameLen) + // Variable length types + // TODO: Fix BINARY once we have a driver with fix for https://github.com/denisenkom/go-mssqldb/issues/685 + case "XML", "TEXT", "NTEXT", "IMAGE", "BINARY": + columnDetails[i].displayWidth = variable + default: + columnDetails[i].displayWidth = length + } + // When max var length is 0 we don't print column headers and print every value with unlimited width + if variable == 0 { + columnDetails[i].displayWidth = 0 + } + } + return columnDetails, maxNameLen +} + +// scanRow fetches the next row and converts each value to the appropriate string representation +func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { + r := make([]interface{}, len(f.columnDetails)) + for i := range r { + r[i] = new(interface{}) + } + if err := rows.Scan(r...); err != nil { + return nil, err + } + row := make([]string, len(f.columnDetails)) + for n, z := range r { + j := z.(*interface{}) + if *j == nil { + row[n] = "NULL" + } else { + switch x := (*j).(type) { + case []byte: + if isBinaryDataType(&f.columnDetails[n].col) { + row[n] = decodeBinary(x) + } else if f.columnDetails[n].col.DatabaseTypeName() == "UNIQUEIDENTIFIER" { + // Unscramble the guid + // see https://github.com/denisenkom/go-mssqldb/issues/56 + x[0], x[1], x[2], x[3] = x[3], x[2], x[1], x[0] + x[4], x[5] = x[5], x[4] + x[6], x[7] = x[7], x[6] + if guid, err := uuid.FromBytes(x); err == nil { + row[n] = guid.String() + } else { + // this should never happen + row[n] = uuid.New().String() + } + } else { + row[n] = string(x) + } + case string: + row[n] = x + case time.Time: + // Go lacks any way to get the user's preferred time format or even the system default + switch f.columnDetails[n].col.DatabaseTypeName() { + case "DATE": + row[n] = x.Format("2006-01-02") + case "DATETIME": + row[n] = x.Format(dateTimeFormatString(3, false)) + case "DATETIME2": + row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, false)) + case "SMALLDATETIME": + row[n] = x.Format(dateTimeFormatString(0, false)) + case "DATETIMEOFFSET": + row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, true)) + case "TIME": + format := "15:04:05" + if f.columnDetails[n].scale > 0 { + format = fmt.Sprintf("%s.%0*d", format, f.columnDetails[n].scale, 0) + } + row[n] = x.Format(format) + default: + row[n] = x.Format(time.RFC3339) + } + case fmt.Stringer: + row[n] = x.String() + // not sure why go-mssql reports bit as bool + case bool: + if x { + row[n] = "1" + } else { + row[n] = "0" + } + default: + var err error + if row[n], err = fmt.Sprintf("%v", x), nil; err != nil { + return nil, err + } + } + } + } + return row, nil +} + +func dateTimeFormatString(scale int, addOffset bool) string { + format := `2006-01-02 15:04:05` + if scale > 0 { + format = fmt.Sprintf("%s.%0*d", format, scale, 0) + } + if addOffset { + format += " -07:00" + } + return format +} + +// Prints the final version of a cell based on formatting variables and command line parameters +func (f *sqlCmdFormatterType) printColumnValue(val string, col int) { + c := f.columnDetails[col] + s := new(strings.Builder) + if isNeedingControlCharacterTreatment(&c.col) { + val = applyControlCharacterBehavior(val, f.ccb) + } + + if isNeedingHexPrefix(&c.col) { + val = "0x" + val + } + + s.WriteString(val) + r := []rune(val) + if f.format == "horizontal" { + if !f.removeTrailingSpaces { + if f.vars.MaxVarColumnWidth() != 0 || !isLargeVariableType(&c.col) { + padding := c.displayWidth - min64(c.displayWidth, int64(len(r))) + if padding > 0 { + if c.leftJustify { + s = padRight(s, padding, " ") + } else { + s = padLeft(s, padding, " ") + } + } + } + } + + r = []rune(s.String()) + } + if c.displayWidth > 0 && int64(len(r)) > c.displayWidth { + s.Reset() + s.WriteString(string(r[:c.displayWidth])) + } + f.writeOut(s.String()) +} + +func (f *sqlCmdFormatterType) mustWriteOut(s string) { + _, err := f.out.Write([]byte(s)) + if err != nil { + panic(err) + } +} + +func (f *sqlCmdFormatterType) mustWriteErr(s string) { + _, err := f.err.Write([]byte(s)) + if err != nil { + panic(err) + } +} + +func isLargeVariableType(col *sql.ColumnType) bool { + l, _ := col.Length() + switch col.DatabaseTypeName() { + + case "VARCHAR", "VARBINARY": + return l > 8000 + case "NVARCHAR": + return l > 4000 + case "XML", "TEXT", "NTEXT", "IMAGE": + return true + } + return false +} + +func isNeedingControlCharacterTreatment(col *sql.ColumnType) bool { + switch col.DatabaseTypeName() { + case "CHAR", "VARCHAR", "TEXT", "NTEXT", "NCHAR", "NVARCHAR", "XML": + return true + } + return false +} +func isBinaryDataType(col *sql.ColumnType) bool { + switch col.DatabaseTypeName() { + case "BINARY", "VARBINARY": + return true + } + return false +} + +func isNeedingHexPrefix(col *sql.ColumnType) bool { + return isBinaryDataType(col) // || col.DatabaseTypeName() == "UDT" +} + +func isControlChar(r rune) bool { + c := int(r) + return c == 0x7f || (c >= 0 && c <= 0x1f) +} + +func applyControlCharacterBehavior(val string, ccb ControlCharacterBehavior) string { + if ccb == ControlIgnore { + return val + } + b := new(strings.Builder) + r := []rune(val) + if ccb == ControlReplace { + for _, l := range r { + if isControlChar(l) { + b.WriteRune(' ') + } else { + b.WriteRune(l) + } + } + } else { + for i := 0; i < len(r); { + if !isControlChar(r[i]) { + b.WriteRune(r[i]) + i++ + } else { + for ; i < len(r) && isControlChar(r[i]); i++ { + } + if ccb == ControlReplaceConsecutive { + b.WriteRune(' ') + } + } + } + } + return b.String() +} + +// Per https://docs.microsoft.com/sql/odbc/reference/appendixes/sql-to-c-binary +var hexDigits = []rune{'A', 'B', 'C', 'D', 'E', 'F'} + +func decodeBinary(b []byte) string { + + s := new(strings.Builder) + s.Grow(len(b) * 2) + for _, ch := range b { + b1 := ch >> 4 + b2 := ch & 0x0f + if b1 >= 10 { + s.WriteRune(hexDigits[b1-10]) + } else { + s.WriteRune(rune('0' + b1)) + } + if b2 >= 10 { + s.WriteRune(hexDigits[b2-10]) + } else { + s.WriteRune(rune('0' + b2)) + } + } + return s.String() +} diff --git a/pkg/sqlcmd/format_darwin.go b/pkg/sqlcmd/format_darwin.go index fc16bf8c..f9bc89cf 100644 --- a/pkg/sqlcmd/format_darwin.go +++ b/pkg/sqlcmd/format_darwin.go @@ -1,7 +1,7 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -// SqlcmdEol is the end-of-line marker for sqlcmd output -const SqlcmdEol = "\n" +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +// SqlcmdEol is the end-of-line marker for sqlcmd output +const SqlcmdEol = "\n" diff --git a/pkg/sqlcmd/format_linux.go b/pkg/sqlcmd/format_linux.go index fc16bf8c..f9bc89cf 100644 --- a/pkg/sqlcmd/format_linux.go +++ b/pkg/sqlcmd/format_linux.go @@ -1,7 +1,7 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -// SqlcmdEol is the end-of-line marker for sqlcmd output -const SqlcmdEol = "\n" +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +// SqlcmdEol is the end-of-line marker for sqlcmd output +const SqlcmdEol = "\n" diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index 39224500..48387ebc 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -1,140 +1,140 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "context" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFitToScreen(t *testing.T) { - type fitTest struct { - width int64 - raw string - fit string - } - - tests := []fitTest{ - {0, "this is a string", "this is a string"}, - {9, "12345678", "12345678"}, - {9, "123456789", "123456789"}, - {9, "123456789A", "123456789" + SqlcmdEol + "A"}, - {9, "123456789" + SqlcmdEol, "123456789" + SqlcmdEol}, - {9, "12345678" + SqlcmdEol + "9A", "12345678" + SqlcmdEol + "9A"}, - {9, "123456789\rA", "123456789" + SqlcmdEol + "\rA"}, - } - - for _, test := range tests { - - line := new(strings.Builder) - line.WriteString(test.raw) - t.Log(test.raw) - f := fitToScreen(line, test.width).String() - assert.Equal(t, test.fit, f, "Mismatched fit for raw string: '%s'", test.raw) - } -} - -func TestCalcColumnDetails(t *testing.T) { - type colTest struct { - fixed int64 - variable int64 - query string - details []columnDetail - max int - } - - tests := []colTest{ - {8, 8, - "select 100 as '123456789ABC', getdate() as '987654321', 'string' as col1", - []columnDetail{ - {leftJustify: false, displayWidth: 12}, - {leftJustify: false, displayWidth: 23}, - {leftJustify: true, displayWidth: 6}, - }, - 12, - }, - } - - db, err := ConnectDb(t) - if assert.NoError(t, err, "ConnectDB failed") { - defer db.Close() - for x, test := range tests { - rows, err := db.QueryContext(context.Background(), test.query) - if assert.NoError(t, err, "Query failed: %s", test.query) { - defer rows.Close() - cols, err := rows.ColumnTypes() - if assert.NoError(t, err, "ColumnTypes failed:%s", test.query) { - actual, max := calcColumnDetails(cols, test.fixed, test.variable) - for i, a := range actual { - if test.details[i].displayWidth != a.displayWidth || - test.details[i].leftJustify != a.leftJustify || - test.details[i].zeroesAfterDecimal != a.zeroesAfterDecimal { - assert.Failf(t, "", "[%d] Incorrect test details for column [%s] in query '%s':%+v", x, cols[i].Name(), test.query, a) - } - assert.Equal(t, test.max, max, "[%d] Max column name length incorrect", x) - } - } - } - } - } -} - -func TestControlCharacterBehavior(t *testing.T) { - type ccbTest struct { - raw string - replaced string - removed string - consecutivereplaced string - } - - tests := []ccbTest{ - {"no control", "no control", "no control", "no control"}, - {string(rune(1)) + "tabs\t\treturns\r\n\r\n", " tabs returns ", "tabsreturns", " tabs returns "}, - } - - for _, test := range tests { - s := applyControlCharacterBehavior(test.raw, ControlReplace) - assert.Equalf(t, test.replaced, s, "Incorrect Replaced for '%s'", test.raw) - s = applyControlCharacterBehavior(test.raw, ControlRemove) - assert.Equalf(t, test.removed, s, "Incorrect Remove for '%s'", test.raw) - s = applyControlCharacterBehavior(test.raw, ControlReplaceConsecutive) - assert.Equalf(t, test.consecutivereplaced, s, "Incorrect ReplaceConsecutive for '%s'", test.raw) - } -} - -func TestDecodeBinary(t *testing.T) { - type decodeTest struct { - b []byte - s string - } - - tests := []decodeTest{ - {[]byte("123456ABCDEF"), "313233343536414243444546"}, - {[]byte{0x12, 0x34, 0x56}, "123456"}, - } - for _, test := range tests { - a := decodeBinary(test.b) - assert.Equalf(t, test.s, a, "Incorrect decoded binary string for %v", test.b) - } -} - -func BenchmarkDecodeBinary(b *testing.B) { - b.ReportAllocs() - bytes := make([]byte, 10000) - for i := 0; i < 10000; i++ { - bytes[i] = byte(i % 0xff) - } - b.ResetTimer() - - for i := 0; i < b.N; i++ { - s := decodeBinary(bytes) - if len(s) != 20000 { - b.Fatalf("Incorrect length of returned string. Should be 20k, was %d", len(s)) - } - } - -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFitToScreen(t *testing.T) { + type fitTest struct { + width int64 + raw string + fit string + } + + tests := []fitTest{ + {0, "this is a string", "this is a string"}, + {9, "12345678", "12345678"}, + {9, "123456789", "123456789"}, + {9, "123456789A", "123456789" + SqlcmdEol + "A"}, + {9, "123456789" + SqlcmdEol, "123456789" + SqlcmdEol}, + {9, "12345678" + SqlcmdEol + "9A", "12345678" + SqlcmdEol + "9A"}, + {9, "123456789\rA", "123456789" + SqlcmdEol + "\rA"}, + } + + for _, test := range tests { + + line := new(strings.Builder) + line.WriteString(test.raw) + t.Log(test.raw) + f := fitToScreen(line, test.width).String() + assert.Equal(t, test.fit, f, "Mismatched fit for raw string: '%s'", test.raw) + } +} + +func TestCalcColumnDetails(t *testing.T) { + type colTest struct { + fixed int64 + variable int64 + query string + details []columnDetail + max int + } + + tests := []colTest{ + {8, 8, + "select 100 as '123456789ABC', getdate() as '987654321', 'string' as col1", + []columnDetail{ + {leftJustify: false, displayWidth: 12}, + {leftJustify: false, displayWidth: 23}, + {leftJustify: true, displayWidth: 6}, + }, + 12, + }, + } + + db, err := ConnectDb(t) + if assert.NoError(t, err, "ConnectDB failed") { + defer db.Close() + for x, test := range tests { + rows, err := db.QueryContext(context.Background(), test.query) + if assert.NoError(t, err, "Query failed: %s", test.query) { + defer rows.Close() + cols, err := rows.ColumnTypes() + if assert.NoError(t, err, "ColumnTypes failed:%s", test.query) { + actual, max := calcColumnDetails(cols, test.fixed, test.variable) + for i, a := range actual { + if test.details[i].displayWidth != a.displayWidth || + test.details[i].leftJustify != a.leftJustify || + test.details[i].zeroesAfterDecimal != a.zeroesAfterDecimal { + assert.Failf(t, "", "[%d] Incorrect test details for column [%s] in query '%s':%+v", x, cols[i].Name(), test.query, a) + } + assert.Equal(t, test.max, max, "[%d] Max column name length incorrect", x) + } + } + } + } + } +} + +func TestControlCharacterBehavior(t *testing.T) { + type ccbTest struct { + raw string + replaced string + removed string + consecutivereplaced string + } + + tests := []ccbTest{ + {"no control", "no control", "no control", "no control"}, + {string(rune(1)) + "tabs\t\treturns\r\n\r\n", " tabs returns ", "tabsreturns", " tabs returns "}, + } + + for _, test := range tests { + s := applyControlCharacterBehavior(test.raw, ControlReplace) + assert.Equalf(t, test.replaced, s, "Incorrect Replaced for '%s'", test.raw) + s = applyControlCharacterBehavior(test.raw, ControlRemove) + assert.Equalf(t, test.removed, s, "Incorrect Remove for '%s'", test.raw) + s = applyControlCharacterBehavior(test.raw, ControlReplaceConsecutive) + assert.Equalf(t, test.consecutivereplaced, s, "Incorrect ReplaceConsecutive for '%s'", test.raw) + } +} + +func TestDecodeBinary(t *testing.T) { + type decodeTest struct { + b []byte + s string + } + + tests := []decodeTest{ + {[]byte("123456ABCDEF"), "313233343536414243444546"}, + {[]byte{0x12, 0x34, 0x56}, "123456"}, + } + for _, test := range tests { + a := decodeBinary(test.b) + assert.Equalf(t, test.s, a, "Incorrect decoded binary string for %v", test.b) + } +} + +func BenchmarkDecodeBinary(b *testing.B) { + b.ReportAllocs() + bytes := make([]byte, 10000) + for i := 0; i < 10000; i++ { + bytes[i] = byte(i % 0xff) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s := decodeBinary(bytes) + if len(s) != 20000 { + b.Fatalf("Incorrect length of returned string. Should be 20k, was %d", len(s)) + } + } + +} diff --git a/pkg/sqlcmd/format_windows.go b/pkg/sqlcmd/format_windows.go index 9d7b6f32..7d22ae36 100644 --- a/pkg/sqlcmd/format_windows.go +++ b/pkg/sqlcmd/format_windows.go @@ -1,7 +1,7 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -// SqlcmdEol is the end-of-line marker for sqlcmd output -const SqlcmdEol = "\r\n" +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +// SqlcmdEol is the end-of-line marker for sqlcmd output +const SqlcmdEol = "\r\n" diff --git a/pkg/sqlcmd/parse.go b/pkg/sqlcmd/parse.go index e9192c78..f2549c7a 100644 --- a/pkg/sqlcmd/parse.go +++ b/pkg/sqlcmd/parse.go @@ -1,101 +1,101 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "strings" - "unicode" -) - -// grab grabs i from r, or returns 0 if i >= end. -func grab(r []rune, i, end int) rune { - if i < end { - return r[i] - } - return 0 -} - -// findNonSpace finds first non space rune in r, returning end if not found. -func findNonSpace(r []rune, i, end int) (int, bool) { - for ; i < end; i++ { - if !isSpaceOrControl(r[i]) { - return i, true - } - } - return i, false -} - -// isEmptyLine returns true when r is empty or composed of only whitespace. -func isEmptyLine(r []rune, i, end int) bool { - _, ok := findNonSpace(r, i, end) - return !ok -} - -// readMultilineComment finds the end of a multiline comment (ie, '*/'). -func readMultilineComment(r []rune, i, end int) (int, bool) { - i++ - for ; i < end; i++ { - if r[i-1] == '*' && r[i] == '/' { - return i, true - } - } - return end, false -} - -// readCommand reads to the next control character to find -// a command in the string. Command regexes constrain matches -// to the beginning of the string, and all commands consume -// an entire line. -func readCommand(c Commands, r []rune, i, end int) (*Command, []string, int) { - for ; i < end; i++ { - next := grab(r, i, end) - if next == 0 || unicode.IsControl(next) { - break - } - } - cmd, args := c.matchCommand(string(r[:i])) - return cmd, args, i -} - -// readVariableReference returns the index of the end of the variable reference or false if it's not a valid identifier -func readVariableReference(r []rune, i int, end int) (int, bool) { - for ; i < end; i++ { - if r[i] == ')' { - return i, true - } - if (r[i] >= 'a' && r[i] <= 'z') || (r[i] >= 'A' && r[i] <= 'Z') || (r[i] >= '0' && r[i] <= '9') || strings.ContainsRune(validVariableRunes, r[i]) { - continue - } - break - } - return 0, false -} - -func max64(a, b int64) int64 { - if a > b { - return a - } - return b -} - -// min returns the minimum of a, b. -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func min64(a, b int64) int64 { - if a < b { - return a - } - return b -} - -// isSpaceOrControl is a special test for either a space or a control (ie, \b) -// characters. -func isSpaceOrControl(r rune) bool { - return unicode.IsSpace(r) || unicode.IsControl(r) -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "strings" + "unicode" +) + +// grab grabs i from r, or returns 0 if i >= end. +func grab(r []rune, i, end int) rune { + if i < end { + return r[i] + } + return 0 +} + +// findNonSpace finds first non space rune in r, returning end if not found. +func findNonSpace(r []rune, i, end int) (int, bool) { + for ; i < end; i++ { + if !isSpaceOrControl(r[i]) { + return i, true + } + } + return i, false +} + +// isEmptyLine returns true when r is empty or composed of only whitespace. +func isEmptyLine(r []rune, i, end int) bool { + _, ok := findNonSpace(r, i, end) + return !ok +} + +// readMultilineComment finds the end of a multiline comment (ie, '*/'). +func readMultilineComment(r []rune, i, end int) (int, bool) { + i++ + for ; i < end; i++ { + if r[i-1] == '*' && r[i] == '/' { + return i, true + } + } + return end, false +} + +// readCommand reads to the next control character to find +// a command in the string. Command regexes constrain matches +// to the beginning of the string, and all commands consume +// an entire line. +func readCommand(c Commands, r []rune, i, end int) (*Command, []string, int) { + for ; i < end; i++ { + next := grab(r, i, end) + if next == 0 || unicode.IsControl(next) { + break + } + } + cmd, args := c.matchCommand(string(r[:i])) + return cmd, args, i +} + +// readVariableReference returns the index of the end of the variable reference or false if it's not a valid identifier +func readVariableReference(r []rune, i int, end int) (int, bool) { + for ; i < end; i++ { + if r[i] == ')' { + return i, true + } + if (r[i] >= 'a' && r[i] <= 'z') || (r[i] >= 'A' && r[i] <= 'Z') || (r[i] >= '0' && r[i] <= '9') || strings.ContainsRune(validVariableRunes, r[i]) { + continue + } + break + } + return 0, false +} + +func max64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +// min returns the minimum of a, b. +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +// isSpaceOrControl is a special test for either a space or a control (ie, \b) +// characters. +func isSpaceOrControl(r rune) bool { + return unicode.IsSpace(r) || unicode.IsControl(r) +} diff --git a/pkg/sqlcmd/parse_test.go b/pkg/sqlcmd/parse_test.go index 143d5f85..9b809c52 100644 --- a/pkg/sqlcmd/parse_test.go +++ b/pkg/sqlcmd/parse_test.go @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index e2e1f60b..81b66633 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -1,592 +1,592 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "bytes" - "database/sql" - "fmt" - "io" - "os" - "os/user" - "strings" - "testing" - - "github.com/microsoft/go-mssqldb/azuread" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" -) - -const oneRowAffected = "(1 row affected)" - -func TestConnectionStringFromSqlCmd(t *testing.T) { - type connectionStringTest struct { - settings *ConnectSettings - connectionString string - } - - pwd := uuid.New().String() - - commands := []connectionStringTest{ - - {&ConnectSettings{}, "sqlserver://."}, - { - &ConnectSettings{TrustServerCertificate: true, WorkstationName: "mystation", Database: "somedatabase"}, - "sqlserver://.?database=somedatabase&trustservercertificate=true&workstation+id=mystation", - }, - { - &ConnectSettings{WorkstationName: "mystation", Encrypt: "false", Database: "somedatabase"}, - "sqlserver://.?database=somedatabase&encrypt=false&workstation+id=mystation", - }, - { - &ConnectSettings{TrustServerCertificate: true, Password: pwd, ServerName: `someserver\instance`, Database: "somedatabase", UserName: "someuser"}, - fmt.Sprintf("sqlserver://someuser:%s@someserver/instance?database=somedatabase&trustservercertificate=true", pwd), - }, - { - &ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true, Password: pwd, ServerName: `tcp:someserver,1045`, UserName: "someuser"}, - "sqlserver://someserver:1045?trustservercertificate=true", - }, - { - &ConnectSettings{ServerName: `tcp:someserver,1045`}, - "sqlserver://someserver:1045", - }, - { - &ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd}, - fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd), - }, - } - - for i, test := range commands { - - connectionString, err := test.settings.ConnectionString() - if assert.NoError(t, err, "Unexpected error from [%d] %+v", i, test.settings) { - assert.Equal(t, test.connectionString, connectionString, "Wrong connection string from [%d]: %+v", i, test.settings) - } - } -} - -/* The following tests require a working SQL instance and rely on SqlCmd environment variables -to manage the initial connection string. The default connection when no environment variables are -set will be to localhost using Windows auth. - -*/ -func TestSqlCmdConnectDb(t *testing.T) { - v := InitializeVariables(true) - s := &Sqlcmd{vars: v} - s.Connect = newConnect(t) - err := s.ConnectDb(nil, false) - if assert.NoError(t, err, "ConnectDb should succeed") { - sqlcmduser := os.Getenv(SQLCMDUSER) - if sqlcmduser == "" { - u, _ := user.Current() - sqlcmduser = u.Username - } - assert.Equal(t, sqlcmduser, s.vars.SQLCmdUser(), "SQLCMDUSER variable should match connected user") - } -} - -func ConnectDb(t testing.TB) (*sql.Conn, error) { - v := InitializeVariables(true) - s := &Sqlcmd{vars: v} - s.Connect = newConnect(t) - err := s.ConnectDb(nil, false) - return s.db, err -} - -func TestSqlCmdQueryAndExit(t *testing.T) { - s, file := setupSqlcmdWithFileOutput(t) - defer os.Remove(file.Name()) - s.Query = "select $(X" - err := s.Run(true, false) - if assert.NoError(t, err, "s.Run(once = true)") { - s.SetOutput(nil) - bytes, err := os.ReadFile(file.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1"+SqlcmdEol, string(bytes), "Incorrect output from Run") - } - } -} - -// Simulate :r command -func TestIncludeFileNoExecutions(t *testing.T) { - s, file := setupSqlcmdWithFileOutput(t) - defer os.Remove(file.Name()) - dataPath := "testdata" + string(os.PathSeparator) - err := s.IncludeFile(dataPath+"singlebatchnogo.sql", false) - s.SetOutput(nil) - if assert.NoError(t, err, "IncludeFile singlebatchnogo.sql false") { - assert.Equal(t, "-", s.batch.State(), "s.batch.State() after IncludeFile singlebatchnogo.sql false") - assert.Equal(t, "select 100 as num"+SqlcmdEol+"select 'string' as title", s.batch.String(), "s.batch.String() after IncludeFile singlebatchnogo.sql false") - bytes, err := os.ReadFile(file.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "", string(bytes), "Incorrect output from Run") - } - file, err = os.CreateTemp("", "sqlcmdout") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(file.Name()) - s.SetOutput(file) - // The second file has a go so it will execute all statements before it - err = s.IncludeFile(dataPath+"twobatchnoendinggo.sql", false) - if assert.NoError(t, err, "IncludeFile twobatchnoendinggo.sql false") { - assert.Equal(t, "-", s.batch.State(), "s.batch.State() after IncludeFile twobatchnoendinggo.sql false") - assert.Equal(t, "select 'string' as title", s.batch.String(), "s.batch.String() after IncludeFile twobatchnoendinggo.sql false") - s.SetOutput(nil) - bytes, err := os.ReadFile(file.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run") - } - } - } -} - -// Simulate -i command line usage -func TestIncludeFileProcessAll(t *testing.T) { - s, file := setupSqlcmdWithFileOutput(t) - defer os.Remove(file.Name()) - dataPath := "testdata" + string(os.PathSeparator) - err := s.IncludeFile(dataPath+"twobatchwithgo.sql", true) - s.SetOutput(nil) - if assert.NoError(t, err, "IncludeFile twobatchwithgo.sql true") { - assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile twobatchwithgo.sql true") - assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile twobatchwithgo.sql true") - bytes, err := os.ReadFile(file.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run") - } - file, err = os.CreateTemp("", "sqlcmdout") - defer os.Remove(file.Name()) - assert.NoError(t, err, "os.CreateTemp") - s.SetOutput(file) - err = s.IncludeFile(dataPath+"twobatchnoendinggo.sql", true) - if assert.NoError(t, err, "IncludeFile twobatchnoendinggo.sql true") { - assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile twobatchnoendinggo.sql true") - assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile twobatchnoendinggo.sql true") - bytes, err := os.ReadFile(file.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run") - } - } - } -} - -func TestIncludeFileWithVariables(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - dataPath := "testdata" + string(os.PathSeparator) - err := s.IncludeFile(dataPath+"variablesnogo.sql", true) - if assert.NoError(t, err, "IncludeFile variablesnogo.sql true") { - assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile variablesnogo.sql true") - assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile variablesnogo.sql true") - s.SetOutput(nil) - o := buf.buf.String() - assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol, o) - } -} - -func TestGetRunnableQuery(t *testing.T) { - v := InitializeVariables(false) - v.Set("var1", "v1") - v.Set("var2", "variable2") - - type test struct { - raw string - q string - } - tests := []test{ - {"$(var1)", "v1"}, - {"$ (var2)", "$ (var2)"}, - {"select '$(VAR1) $(VAR2)' as c", "select 'v1 variable2' as c"}, - {" $(VAR1) ' $(VAR2) ' as $(VAR1)", " v1 ' variable2 ' as v1"}, - } - s := New(nil, "", v) - for _, test := range tests { - s.batch.Reset([]rune(test.raw)) - _, _, _ = s.batch.Next() - s.Connect.DisableVariableSubstitution = false - t.Log(test.raw) - r := s.getRunnableQuery(test.raw) - assert.Equalf(t, test.q, r, `runnableQuery for "%s"`, test.raw) - s.Connect.DisableVariableSubstitution = true - r = s.getRunnableQuery(test.raw) - assert.Equalf(t, test.raw, r, `runnableQuery without variable subs for "%s"`, test.raw) - } -} - -func TestExitInitialQuery(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - _ = s.vars.Setvar("var1", "1200") - s.Query = "EXIT(SELECT '$(var1)', 2100)" - err := s.Run(true, false) - if assert.NoError(t, err, "s.Run(once = true)") { - s.SetOutput(nil) - o := buf.buf.String() - assert.Equal(t, "1200 2100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, o, "Output") - assert.Equal(t, 1200, s.Exitcode, "ExitCode") - } - -} - -func TestExitCodeSetOnError(t *testing.T) { - s, _ := setupSqlCmdWithMemoryOutput(t) - s.Connect.ErrorSeverityLevel = 12 - retcode, err := s.runQuery("RAISERROR (N'Testing!' , 11, 1)") - assert.NoError(t, err, "!ExitOnError 11") - assert.Equal(t, -101, retcode, "Raiserror below ErrorSeverityLevel") - retcode, err = s.runQuery("RAISERROR (N'Testing!' , 14, 1)") - assert.NoError(t, err, "!ExitOnError 14") - assert.Equal(t, 14, retcode, "Raiserror above ErrorSeverityLevel") - s.Connect.ExitOnError = true - retcode, err = s.runQuery("RAISERROR (N'Testing!' , 11, 1)") - assert.NoError(t, err, "ExitOnError and Raiserror below ErrorSeverityLevel") - assert.Equal(t, -101, retcode, "Raiserror below ErrorSeverityLevel") - retcode, err = s.runQuery("RAISERROR (N'Testing!' , 14, 1)") - assert.ErrorIs(t, err, ErrExitRequested, "ExitOnError and Raiserror above ErrorSeverityLevel") - assert.Equal(t, 14, retcode, "ExitOnError and Raiserror above ErrorSeverityLevel") - s.Connect.ErrorSeverityLevel = 0 - retcode, err = s.runQuery("RAISERROR (N'Testing!' , 11, 1)") - assert.ErrorIs(t, err, ErrExitRequested, "ExitOnError and ErrorSeverityLevel = 0, Raiserror above 10") - assert.Equal(t, 1, retcode, "ExitOnError and ErrorSeverityLevel = 0, Raiserror above 10") - retcode, err = s.runQuery("RAISERROR (N'Testing!' , 5, 1)") - assert.NoError(t, err, "ExitOnError and ErrorSeverityLevel = 0, Raiserror below 10") - assert.Equal(t, -101, retcode, "ExitOnError and ErrorSeverityLevel = 0, Raiserror below 10") - retcode, err = s.runQuery("RAISERROR (15001, 10, 127)") - assert.ErrorIs(t, err, ErrExitRequested, "RAISERROR with state 127") - assert.Equal(t, 15001, retcode, "RAISERROR (15001, 10, 127)") -} - -func TestSqlCmdExitOnError(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.Connect.ExitOnError = true - err := runSqlCmd(t, s, []string{"select 1", "GO", ":setvar", "select 2", "GO"}) - o := buf.buf.String() - assert.EqualError(t, err, "Sqlcmd: Error: Syntax error at line 3 near command ':SETVAR'.", "Run should return an error") - assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"Sqlcmd: Error: Syntax error at line 3 near command ':SETVAR'."+SqlcmdEol, o, "Only first select should run") - assert.Equal(t, 1, s.Exitcode, "s.ExitCode for a syntax error") - - s, buf = setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.Connect.ExitOnError = true - s.Connect.ErrorSeverityLevel = 15 - s.vars.Set(SQLCMDERRORLEVEL, "14") - err = runSqlCmd(t, s, []string{"raiserror(N'13', 13, 1)", "GO", "raiserror(N'14', 14, 1)", "GO", "raiserror(N'15', 15, 1)", "GO", "SELECT 'nope'", "GO"}) - o = buf.buf.String() - assert.NotContains(t, o, "Level 13", "Level 13 should be filtered from the output") - assert.NotContains(t, o, "nope", "Last select should not be run") - assert.Contains(t, o, "Level 14", "Level 14 should be in the output") - assert.Contains(t, o, "Level 15", "Level 15 should be in the output") - assert.Equal(t, 15, s.Exitcode, "s.ExitCode for a syntax error") - assert.NoError(t, err, "Run should not return an error for a SQL error") -} - -func TestSqlCmdSetErrorLevel(t *testing.T) { - s, _ := setupSqlCmdWithMemoryOutput(t) - s.Connect.ErrorSeverityLevel = 15 - err := runSqlCmd(t, s, []string{"select bad as bad", "GO", "select 1", "GO"}) - assert.NoError(t, err, "runSqlCmd should have no error") - assert.Equal(t, 16, s.Exitcode, "Select error should be the exit code") -} - -type testConsole struct { - PromptText string - OnPasswordPrompt func(prompt string) ([]byte, error) - OnReadLine func() (string, error) -} - -func (tc *testConsole) Readline() (string, error) { - return tc.OnReadLine() -} - -func (tc *testConsole) ReadPassword(prompt string) ([]byte, error) { - return tc.OnPasswordPrompt(prompt) -} - -func (tc *testConsole) SetPrompt(s string) { - tc.PromptText = s -} - -func (tc *testConsole) Close() { - -} - -func TestPromptForPasswordNegative(t *testing.T) { - prompted := false - console := &testConsole{ - OnPasswordPrompt: func(prompt string) ([]byte, error) { - assert.Equal(t, "Password:", prompt, "Incorrect password prompt") - prompted = true - return []byte{}, nil - }, - OnReadLine: func() (string, error) { - assert.Fail(t, "ReadLine should not be called") - return "", nil - }, - } - v := InitializeVariables(true) - s := New(console, "", v) - s.Connect.UserName = "someuser" - err := s.ConnectDb(nil, false) - assert.True(t, prompted, "Password prompt not shown for SQL auth") - assert.Error(t, err, "ConnectDb") - prompted = false - s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword - err = s.ConnectDb(nil, false) - assert.True(t, prompted, "Password prompt not shown for AD Password auth") - assert.Error(t, err, "ConnectDb") - prompted = false -} - -func TestPromptForPasswordPositive(t *testing.T) { - prompted := false - c := newConnect(t) - if c.Password == "" { - // See if azure variables are set for activedirectoryserviceprincipal - c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID") - c.Password = os.Getenv("AZURE_CLIENT_SECRET") - c.AuthenticationMethod = azuread.ActiveDirectoryServicePrincipal - if c.Password == "" { - t.Skip("No password available") - } - } - password := c.Password - c.Password = "" - console := &testConsole{ - OnPasswordPrompt: func(prompt string) ([]byte, error) { - assert.Equal(t, "Password:", prompt, "Incorrect password prompt") - prompted = true - return []byte(password), nil - }, - OnReadLine: func() (string, error) { - assert.Fail(t, "ReadLine should not be called") - return "", nil - }, - } - v := InitializeVariables(true) - s := New(console, "", v) - // attempt without password prompt - err := s.ConnectDb(c, true) - assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password") - assert.Error(t, err, "ConnectDb with nopw==true and no password provided") - err = s.ConnectDb(c, false) - assert.True(t, prompted, "ConnectDb with !nopw should prompt for password") - assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt") - if s.Connect.Password != password { - t.Fatal(t, err, "Password not stored in the connection") - } -} - -func TestVerticalLayoutNoColumns(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.vars.Set(SQLCMDFORMAT, "vert") - _, err := s.runQuery("SELECT 100 as 'column1', 2000 as 'col2', 300") - assert.NoError(t, err, "runQuery failed") - assert.Equal(t, - "100"+SqlcmdEol+"2000"+SqlcmdEol+"300"+SqlcmdEol+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, - buf.buf.String(), "Query without column headers") -} - -func TestSelectGuidColumn(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - _, err := s.runQuery("select convert(uniqueidentifier, N'3ddba21e-ff0f-4d24-90b4-f355864d7865')") - assert.NoError(t, err, "runQuery failed") - assert.Equal(t, "3ddba21e-ff0f-4d24-90b4-f355864d7865"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "select a uniqueidentifier should work") -} - -func TestSelectNullGuidColumn(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - _, err := s.runQuery("select convert(uniqueidentifier,null)") - assert.NoError(t, err, "runQuery failed") - assert.Equal(t, "NULL"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "select a null uniqueidentifier should work") -} - -func TestVerticalLayoutWithColumns(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.vars.Set(SQLCMDFORMAT, "vert") - s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") - _, err := s.runQuery("SELECT 100 as 'column1', 2000 as 'col2', 300") - assert.NoError(t, err, "runQuery failed") - assert.Equal(t, - "column1 100"+SqlcmdEol+"col2 2000"+SqlcmdEol+" 300"+SqlcmdEol+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, - buf.buf.String(), "Query without column headers") - -} - -func TestSqlCmdDefersToPrintError(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - s.PrintError = func(msg string, severity uint8) bool { - return severity > 10 - } - err := runSqlCmd(t, s, []string{"PRINT 'this has severity 10'", "RAISERROR (N'Testing!' , 11, 1)", "GO"}) - if assert.NoError(t, err, "runSqlCmd failed") { - assert.Equal(t, "this has severity 10"+SqlcmdEol, buf.buf.String(), "Errors should be filtered by s.PrintError") - } -} - -func TestSqlCmdMaintainsConnectionBetweenBatches(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - err := runSqlCmd(t, s, []string{"CREATE TABLE #tmp1 (col1 int)", "insert into #tmp1 values (1)", "GO", "select * from #tmp1", "drop table #tmp1", "GO"}) - if assert.NoError(t, err, "runSqlCmd failed") { - assert.Equal(t, oneRowAffected+SqlcmdEol+"1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "Sqlcmd uses the same connection for all queries") - } -} - -func TestDateTimeFormats(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - defer buf.Close() - err := s.IncludeFile(`testdata/selectdates.sql`, true) - if assert.NoError(t, err, "selectdates.sql") { - assert.Equal(t, - `2022-03-05 14:01:02.000 2021-01-02 11:06:02.2000 2021-05-05 00:00:00.000000 +00:00 2019-01-11 13:00:00 14:01:02.0000000 2011-02-03`+SqlcmdEol+SqlcmdEol, - buf.buf.String(), - "Unexpected date format output") - - } -} - -func TestQueryServerPropertyReturnsColumnName(t *testing.T) { - s, buf := setupSqlCmdWithMemoryOutput(t) - s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "100") - defer buf.Close() - err := runSqlCmd(t, s, []string{"select SERVERPROPERTY('EngineEdition') AS DatabaseEngineEdition", "GO"}) - if assert.NoError(t, err, "select should succeed") { - assert.Contains(t, buf.buf.String(), "DatabaseEngineEdition", "Column name missing from output") - } -} - -func TestSqlCmdOutputAndError(t *testing.T) { - s, outfile, errfile := setupSqlcmdWithFileErrorOutput(t) - defer os.Remove(outfile.Name()) - defer os.Remove(errfile.Name()) - s.Query = "select $(X" - err := s.Run(true, false) - if assert.NoError(t, err, "s.Run(once = true)") { - bytes, err := os.ReadFile(errfile.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1"+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution") - } - } - s.Query = "select '1'" - err = s.Run(true, false) - if assert.NoError(t, err, "s.Run(once = true)") { - bytes, err := os.ReadFile(outfile.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for query execution") - } - } - - s, outfile, errfile = setupSqlcmdWithFileErrorOutput(t) - defer os.Remove(outfile.Name()) - defer os.Remove(errfile.Name()) - dataPath := "testdata" + string(os.PathSeparator) - err = s.IncludeFile(dataPath+"testerrorredirection.sql", false) - if assert.NoError(t, err, "IncludeFile testerrorredirection.sql false") { - bytes, err := os.ReadFile(outfile.Name()) - if assert.NoError(t, err, "os.ReadFile outfile") { - assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for sql file execution in outfile") - } - bytes, err = os.ReadFile(errfile.Name()) - if assert.NoError(t, err, "os.ReadFile errfile") { - assert.Equal(t, "Sqlcmd: Error: Syntax error at line 3"+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile") - } - } -} - -// runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input -func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error { - t.Helper() - i := 0 - s.batch.read = func() (string, error) { - if i < len(lines) { - index := i - i++ - return lines[index], nil - } - return "", io.EOF - } - return s.Run(false, false) -} - -func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { - t.Helper() - v := InitializeVariables(true) - v.Set(SQLCMDMAXVARTYPEWIDTH, "0") - s := New(nil, "", v) - s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true) - buf := &memoryBuffer{buf: new(bytes.Buffer)} - s.SetOutput(buf) - err := s.ConnectDb(nil, true) - assert.NoError(t, err, "s.ConnectDB") - return s, buf -} - -func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { - t.Helper() - v := InitializeVariables(true) - v.Set(SQLCMDMAXVARTYPEWIDTH, "0") - s := New(nil, "", v) - s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true) - file, err := os.CreateTemp("", "sqlcmdout") - assert.NoError(t, err, "os.CreateTemp") - s.SetOutput(file) - err = s.ConnectDb(nil, true) - if err != nil { - os.Remove(file.Name()) - } - assert.NoError(t, err, "s.ConnectDB") - return s, file -} - -func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) { - t.Helper() - v := InitializeVariables(true) - v.Set(SQLCMDMAXVARTYPEWIDTH, "0") - s := New(nil, "", v) - s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true) - outfile, err := os.CreateTemp("", "sqlcmdout") - assert.NoError(t, err, "os.CreateTemp") - errfile, err := os.CreateTemp("", "sqlcmderr") - assert.NoError(t, err, "os.CreateTemp") - s.SetOutput(outfile) - s.SetError(errfile) - err = s.ConnectDb(nil, true) - if err != nil { - os.Remove(outfile.Name()) - os.Remove(errfile.Name()) - } - assert.NoError(t, err, "s.ConnectDB") - return s, outfile, errfile -} - -// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set -func canTestAzureAuth() bool { - server := os.Getenv(SQLCMDSERVER) - userName := os.Getenv(SQLCMDUSER) - return strings.Contains(server, ".database.windows.net") && userName == "" -} - -func newConnect(t testing.TB) *ConnectSettings { - t.Helper() - connect := ConnectSettings{ - UserName: os.Getenv(SQLCMDUSER), - Database: os.Getenv(SQLCMDDBNAME), - ServerName: os.Getenv(SQLCMDSERVER), - Password: os.Getenv(SQLCMDPASSWORD), - } - if canTestAzureAuth() { - t.Log("Using ActiveDirectoryDefault") - connect.AuthenticationMethod = azuread.ActiveDirectoryDefault - } - return &connect -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "bytes" + "database/sql" + "fmt" + "io" + "os" + "os/user" + "strings" + "testing" + + "github.com/microsoft/go-mssqldb/azuread" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +const oneRowAffected = "(1 row affected)" + +func TestConnectionStringFromSqlCmd(t *testing.T) { + type connectionStringTest struct { + settings *ConnectSettings + connectionString string + } + + pwd := uuid.New().String() + + commands := []connectionStringTest{ + + {&ConnectSettings{}, "sqlserver://."}, + { + &ConnectSettings{TrustServerCertificate: true, WorkstationName: "mystation", Database: "somedatabase"}, + "sqlserver://.?database=somedatabase&trustservercertificate=true&workstation+id=mystation", + }, + { + &ConnectSettings{WorkstationName: "mystation", Encrypt: "false", Database: "somedatabase"}, + "sqlserver://.?database=somedatabase&encrypt=false&workstation+id=mystation", + }, + { + &ConnectSettings{TrustServerCertificate: true, Password: pwd, ServerName: `someserver\instance`, Database: "somedatabase", UserName: "someuser"}, + fmt.Sprintf("sqlserver://someuser:%s@someserver/instance?database=somedatabase&trustservercertificate=true", pwd), + }, + { + &ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true, Password: pwd, ServerName: `tcp:someserver,1045`, UserName: "someuser"}, + "sqlserver://someserver:1045?trustservercertificate=true", + }, + { + &ConnectSettings{ServerName: `tcp:someserver,1045`}, + "sqlserver://someserver:1045", + }, + { + &ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd}, + fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd), + }, + } + + for i, test := range commands { + + connectionString, err := test.settings.ConnectionString() + if assert.NoError(t, err, "Unexpected error from [%d] %+v", i, test.settings) { + assert.Equal(t, test.connectionString, connectionString, "Wrong connection string from [%d]: %+v", i, test.settings) + } + } +} + +/* The following tests require a working SQL instance and rely on SqlCmd environment variables +to manage the initial connection string. The default connection when no environment variables are +set will be to localhost using Windows auth. + +*/ +func TestSqlCmdConnectDb(t *testing.T) { + v := InitializeVariables(true) + s := &Sqlcmd{vars: v} + s.Connect = newConnect(t) + err := s.ConnectDb(nil, false) + if assert.NoError(t, err, "ConnectDb should succeed") { + sqlcmduser := os.Getenv(SQLCMDUSER) + if sqlcmduser == "" { + u, _ := user.Current() + sqlcmduser = u.Username + } + assert.Equal(t, sqlcmduser, s.vars.SQLCmdUser(), "SQLCMDUSER variable should match connected user") + } +} + +func ConnectDb(t testing.TB) (*sql.Conn, error) { + v := InitializeVariables(true) + s := &Sqlcmd{vars: v} + s.Connect = newConnect(t) + err := s.ConnectDb(nil, false) + return s.db, err +} + +func TestSqlCmdQueryAndExit(t *testing.T) { + s, file := setupSqlcmdWithFileOutput(t) + defer os.Remove(file.Name()) + s.Query = "select $(X" + err := s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + s.SetOutput(nil) + bytes, err := os.ReadFile(file.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1"+SqlcmdEol, string(bytes), "Incorrect output from Run") + } + } +} + +// Simulate :r command +func TestIncludeFileNoExecutions(t *testing.T) { + s, file := setupSqlcmdWithFileOutput(t) + defer os.Remove(file.Name()) + dataPath := "testdata" + string(os.PathSeparator) + err := s.IncludeFile(dataPath+"singlebatchnogo.sql", false) + s.SetOutput(nil) + if assert.NoError(t, err, "IncludeFile singlebatchnogo.sql false") { + assert.Equal(t, "-", s.batch.State(), "s.batch.State() after IncludeFile singlebatchnogo.sql false") + assert.Equal(t, "select 100 as num"+SqlcmdEol+"select 'string' as title", s.batch.String(), "s.batch.String() after IncludeFile singlebatchnogo.sql false") + bytes, err := os.ReadFile(file.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "", string(bytes), "Incorrect output from Run") + } + file, err = os.CreateTemp("", "sqlcmdout") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + s.SetOutput(file) + // The second file has a go so it will execute all statements before it + err = s.IncludeFile(dataPath+"twobatchnoendinggo.sql", false) + if assert.NoError(t, err, "IncludeFile twobatchnoendinggo.sql false") { + assert.Equal(t, "-", s.batch.State(), "s.batch.State() after IncludeFile twobatchnoendinggo.sql false") + assert.Equal(t, "select 'string' as title", s.batch.String(), "s.batch.String() after IncludeFile twobatchnoendinggo.sql false") + s.SetOutput(nil) + bytes, err := os.ReadFile(file.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run") + } + } + } +} + +// Simulate -i command line usage +func TestIncludeFileProcessAll(t *testing.T) { + s, file := setupSqlcmdWithFileOutput(t) + defer os.Remove(file.Name()) + dataPath := "testdata" + string(os.PathSeparator) + err := s.IncludeFile(dataPath+"twobatchwithgo.sql", true) + s.SetOutput(nil) + if assert.NoError(t, err, "IncludeFile twobatchwithgo.sql true") { + assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile twobatchwithgo.sql true") + assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile twobatchwithgo.sql true") + bytes, err := os.ReadFile(file.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run") + } + file, err = os.CreateTemp("", "sqlcmdout") + defer os.Remove(file.Name()) + assert.NoError(t, err, "os.CreateTemp") + s.SetOutput(file) + err = s.IncludeFile(dataPath+"twobatchnoendinggo.sql", true) + if assert.NoError(t, err, "IncludeFile twobatchnoendinggo.sql true") { + assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile twobatchnoendinggo.sql true") + assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile twobatchnoendinggo.sql true") + bytes, err := os.ReadFile(file.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run") + } + } + } +} + +func TestIncludeFileWithVariables(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + dataPath := "testdata" + string(os.PathSeparator) + err := s.IncludeFile(dataPath+"variablesnogo.sql", true) + if assert.NoError(t, err, "IncludeFile variablesnogo.sql true") { + assert.Equal(t, "=", s.batch.State(), "s.batch.State() after IncludeFile variablesnogo.sql true") + assert.Equal(t, "", s.batch.String(), "s.batch.String() after IncludeFile variablesnogo.sql true") + s.SetOutput(nil) + o := buf.buf.String() + assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol, o) + } +} + +func TestGetRunnableQuery(t *testing.T) { + v := InitializeVariables(false) + v.Set("var1", "v1") + v.Set("var2", "variable2") + + type test struct { + raw string + q string + } + tests := []test{ + {"$(var1)", "v1"}, + {"$ (var2)", "$ (var2)"}, + {"select '$(VAR1) $(VAR2)' as c", "select 'v1 variable2' as c"}, + {" $(VAR1) ' $(VAR2) ' as $(VAR1)", " v1 ' variable2 ' as v1"}, + } + s := New(nil, "", v) + for _, test := range tests { + s.batch.Reset([]rune(test.raw)) + _, _, _ = s.batch.Next() + s.Connect.DisableVariableSubstitution = false + t.Log(test.raw) + r := s.getRunnableQuery(test.raw) + assert.Equalf(t, test.q, r, `runnableQuery for "%s"`, test.raw) + s.Connect.DisableVariableSubstitution = true + r = s.getRunnableQuery(test.raw) + assert.Equalf(t, test.raw, r, `runnableQuery without variable subs for "%s"`, test.raw) + } +} + +func TestExitInitialQuery(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + _ = s.vars.Setvar("var1", "1200") + s.Query = "EXIT(SELECT '$(var1)', 2100)" + err := s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + s.SetOutput(nil) + o := buf.buf.String() + assert.Equal(t, "1200 2100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, o, "Output") + assert.Equal(t, 1200, s.Exitcode, "ExitCode") + } + +} + +func TestExitCodeSetOnError(t *testing.T) { + s, _ := setupSqlCmdWithMemoryOutput(t) + s.Connect.ErrorSeverityLevel = 12 + retcode, err := s.runQuery("RAISERROR (N'Testing!' , 11, 1)") + assert.NoError(t, err, "!ExitOnError 11") + assert.Equal(t, -101, retcode, "Raiserror below ErrorSeverityLevel") + retcode, err = s.runQuery("RAISERROR (N'Testing!' , 14, 1)") + assert.NoError(t, err, "!ExitOnError 14") + assert.Equal(t, 14, retcode, "Raiserror above ErrorSeverityLevel") + s.Connect.ExitOnError = true + retcode, err = s.runQuery("RAISERROR (N'Testing!' , 11, 1)") + assert.NoError(t, err, "ExitOnError and Raiserror below ErrorSeverityLevel") + assert.Equal(t, -101, retcode, "Raiserror below ErrorSeverityLevel") + retcode, err = s.runQuery("RAISERROR (N'Testing!' , 14, 1)") + assert.ErrorIs(t, err, ErrExitRequested, "ExitOnError and Raiserror above ErrorSeverityLevel") + assert.Equal(t, 14, retcode, "ExitOnError and Raiserror above ErrorSeverityLevel") + s.Connect.ErrorSeverityLevel = 0 + retcode, err = s.runQuery("RAISERROR (N'Testing!' , 11, 1)") + assert.ErrorIs(t, err, ErrExitRequested, "ExitOnError and ErrorSeverityLevel = 0, Raiserror above 10") + assert.Equal(t, 1, retcode, "ExitOnError and ErrorSeverityLevel = 0, Raiserror above 10") + retcode, err = s.runQuery("RAISERROR (N'Testing!' , 5, 1)") + assert.NoError(t, err, "ExitOnError and ErrorSeverityLevel = 0, Raiserror below 10") + assert.Equal(t, -101, retcode, "ExitOnError and ErrorSeverityLevel = 0, Raiserror below 10") + retcode, err = s.runQuery("RAISERROR (15001, 10, 127)") + assert.ErrorIs(t, err, ErrExitRequested, "RAISERROR with state 127") + assert.Equal(t, 15001, retcode, "RAISERROR (15001, 10, 127)") +} + +func TestSqlCmdExitOnError(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.Connect.ExitOnError = true + err := runSqlCmd(t, s, []string{"select 1", "GO", ":setvar", "select 2", "GO"}) + o := buf.buf.String() + assert.EqualError(t, err, "Sqlcmd: Error: Syntax error at line 3 near command ':SETVAR'.", "Run should return an error") + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"Sqlcmd: Error: Syntax error at line 3 near command ':SETVAR'."+SqlcmdEol, o, "Only first select should run") + assert.Equal(t, 1, s.Exitcode, "s.ExitCode for a syntax error") + + s, buf = setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.Connect.ExitOnError = true + s.Connect.ErrorSeverityLevel = 15 + s.vars.Set(SQLCMDERRORLEVEL, "14") + err = runSqlCmd(t, s, []string{"raiserror(N'13', 13, 1)", "GO", "raiserror(N'14', 14, 1)", "GO", "raiserror(N'15', 15, 1)", "GO", "SELECT 'nope'", "GO"}) + o = buf.buf.String() + assert.NotContains(t, o, "Level 13", "Level 13 should be filtered from the output") + assert.NotContains(t, o, "nope", "Last select should not be run") + assert.Contains(t, o, "Level 14", "Level 14 should be in the output") + assert.Contains(t, o, "Level 15", "Level 15 should be in the output") + assert.Equal(t, 15, s.Exitcode, "s.ExitCode for a syntax error") + assert.NoError(t, err, "Run should not return an error for a SQL error") +} + +func TestSqlCmdSetErrorLevel(t *testing.T) { + s, _ := setupSqlCmdWithMemoryOutput(t) + s.Connect.ErrorSeverityLevel = 15 + err := runSqlCmd(t, s, []string{"select bad as bad", "GO", "select 1", "GO"}) + assert.NoError(t, err, "runSqlCmd should have no error") + assert.Equal(t, 16, s.Exitcode, "Select error should be the exit code") +} + +type testConsole struct { + PromptText string + OnPasswordPrompt func(prompt string) ([]byte, error) + OnReadLine func() (string, error) +} + +func (tc *testConsole) Readline() (string, error) { + return tc.OnReadLine() +} + +func (tc *testConsole) ReadPassword(prompt string) ([]byte, error) { + return tc.OnPasswordPrompt(prompt) +} + +func (tc *testConsole) SetPrompt(s string) { + tc.PromptText = s +} + +func (tc *testConsole) Close() { + +} + +func TestPromptForPasswordNegative(t *testing.T) { + prompted := false + console := &testConsole{ + OnPasswordPrompt: func(prompt string) ([]byte, error) { + assert.Equal(t, "Password:", prompt, "Incorrect password prompt") + prompted = true + return []byte{}, nil + }, + OnReadLine: func() (string, error) { + assert.Fail(t, "ReadLine should not be called") + return "", nil + }, + } + v := InitializeVariables(true) + s := New(console, "", v) + s.Connect.UserName = "someuser" + err := s.ConnectDb(nil, false) + assert.True(t, prompted, "Password prompt not shown for SQL auth") + assert.Error(t, err, "ConnectDb") + prompted = false + s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword + err = s.ConnectDb(nil, false) + assert.True(t, prompted, "Password prompt not shown for AD Password auth") + assert.Error(t, err, "ConnectDb") + prompted = false +} + +func TestPromptForPasswordPositive(t *testing.T) { + prompted := false + c := newConnect(t) + if c.Password == "" { + // See if azure variables are set for activedirectoryserviceprincipal + c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID") + c.Password = os.Getenv("AZURE_CLIENT_SECRET") + c.AuthenticationMethod = azuread.ActiveDirectoryServicePrincipal + if c.Password == "" { + t.Skip("No password available") + } + } + password := c.Password + c.Password = "" + console := &testConsole{ + OnPasswordPrompt: func(prompt string) ([]byte, error) { + assert.Equal(t, "Password:", prompt, "Incorrect password prompt") + prompted = true + return []byte(password), nil + }, + OnReadLine: func() (string, error) { + assert.Fail(t, "ReadLine should not be called") + return "", nil + }, + } + v := InitializeVariables(true) + s := New(console, "", v) + // attempt without password prompt + err := s.ConnectDb(c, true) + assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password") + assert.Error(t, err, "ConnectDb with nopw==true and no password provided") + err = s.ConnectDb(c, false) + assert.True(t, prompted, "ConnectDb with !nopw should prompt for password") + assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt") + if s.Connect.Password != password { + t.Fatal(t, err, "Password not stored in the connection") + } +} + +func TestVerticalLayoutNoColumns(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.vars.Set(SQLCMDFORMAT, "vert") + _, err := s.runQuery("SELECT 100 as 'column1', 2000 as 'col2', 300") + assert.NoError(t, err, "runQuery failed") + assert.Equal(t, + "100"+SqlcmdEol+"2000"+SqlcmdEol+"300"+SqlcmdEol+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, + buf.buf.String(), "Query without column headers") +} + +func TestSelectGuidColumn(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + _, err := s.runQuery("select convert(uniqueidentifier, N'3ddba21e-ff0f-4d24-90b4-f355864d7865')") + assert.NoError(t, err, "runQuery failed") + assert.Equal(t, "3ddba21e-ff0f-4d24-90b4-f355864d7865"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "select a uniqueidentifier should work") +} + +func TestSelectNullGuidColumn(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + _, err := s.runQuery("select convert(uniqueidentifier,null)") + assert.NoError(t, err, "runQuery failed") + assert.Equal(t, "NULL"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "select a null uniqueidentifier should work") +} + +func TestVerticalLayoutWithColumns(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.vars.Set(SQLCMDFORMAT, "vert") + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") + _, err := s.runQuery("SELECT 100 as 'column1', 2000 as 'col2', 300") + assert.NoError(t, err, "runQuery failed") + assert.Equal(t, + "column1 100"+SqlcmdEol+"col2 2000"+SqlcmdEol+" 300"+SqlcmdEol+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, + buf.buf.String(), "Query without column headers") + +} + +func TestSqlCmdDefersToPrintError(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.PrintError = func(msg string, severity uint8) bool { + return severity > 10 + } + err := runSqlCmd(t, s, []string{"PRINT 'this has severity 10'", "RAISERROR (N'Testing!' , 11, 1)", "GO"}) + if assert.NoError(t, err, "runSqlCmd failed") { + assert.Equal(t, "this has severity 10"+SqlcmdEol, buf.buf.String(), "Errors should be filtered by s.PrintError") + } +} + +func TestSqlCmdMaintainsConnectionBetweenBatches(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + err := runSqlCmd(t, s, []string{"CREATE TABLE #tmp1 (col1 int)", "insert into #tmp1 values (1)", "GO", "select * from #tmp1", "drop table #tmp1", "GO"}) + if assert.NoError(t, err, "runSqlCmd failed") { + assert.Equal(t, oneRowAffected+SqlcmdEol+"1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "Sqlcmd uses the same connection for all queries") + } +} + +func TestDateTimeFormats(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + err := s.IncludeFile(`testdata/selectdates.sql`, true) + if assert.NoError(t, err, "selectdates.sql") { + assert.Equal(t, + `2022-03-05 14:01:02.000 2021-01-02 11:06:02.2000 2021-05-05 00:00:00.000000 +00:00 2019-01-11 13:00:00 14:01:02.0000000 2011-02-03`+SqlcmdEol+SqlcmdEol, + buf.buf.String(), + "Unexpected date format output") + + } +} + +func TestQueryServerPropertyReturnsColumnName(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "100") + defer buf.Close() + err := runSqlCmd(t, s, []string{"select SERVERPROPERTY('EngineEdition') AS DatabaseEngineEdition", "GO"}) + if assert.NoError(t, err, "select should succeed") { + assert.Contains(t, buf.buf.String(), "DatabaseEngineEdition", "Column name missing from output") + } +} + +func TestSqlCmdOutputAndError(t *testing.T) { + s, outfile, errfile := setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + s.Query = "select $(X" + err := s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1"+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution") + } + } + s.Query = "select '1'" + err = s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for query execution") + } + } + + s, outfile, errfile = setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + dataPath := "testdata" + string(os.PathSeparator) + err = s.IncludeFile(dataPath+"testerrorredirection.sql", false) + if assert.NoError(t, err, "IncludeFile testerrorredirection.sql false") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile outfile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for sql file execution in outfile") + } + bytes, err = os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile errfile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 3"+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile") + } + } +} + +// runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input +func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error { + t.Helper() + i := 0 + s.batch.read = func() (string, error) { + if i < len(lines) { + index := i + i++ + return lines[index], nil + } + return "", io.EOF + } + return s.Run(false, false) +} + +func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { + t.Helper() + v := InitializeVariables(true) + v.Set(SQLCMDMAXVARTYPEWIDTH, "0") + s := New(nil, "", v) + s.Connect = newConnect(t) + s.Format = NewSQLCmdDefaultFormatter(true) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + err := s.ConnectDb(nil, true) + assert.NoError(t, err, "s.ConnectDB") + return s, buf +} + +func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { + t.Helper() + v := InitializeVariables(true) + v.Set(SQLCMDMAXVARTYPEWIDTH, "0") + s := New(nil, "", v) + s.Connect = newConnect(t) + s.Format = NewSQLCmdDefaultFormatter(true) + file, err := os.CreateTemp("", "sqlcmdout") + assert.NoError(t, err, "os.CreateTemp") + s.SetOutput(file) + err = s.ConnectDb(nil, true) + if err != nil { + os.Remove(file.Name()) + } + assert.NoError(t, err, "s.ConnectDB") + return s, file +} + +func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) { + t.Helper() + v := InitializeVariables(true) + v.Set(SQLCMDMAXVARTYPEWIDTH, "0") + s := New(nil, "", v) + s.Connect = newConnect(t) + s.Format = NewSQLCmdDefaultFormatter(true) + outfile, err := os.CreateTemp("", "sqlcmdout") + assert.NoError(t, err, "os.CreateTemp") + errfile, err := os.CreateTemp("", "sqlcmderr") + assert.NoError(t, err, "os.CreateTemp") + s.SetOutput(outfile) + s.SetError(errfile) + err = s.ConnectDb(nil, true) + if err != nil { + os.Remove(outfile.Name()) + os.Remove(errfile.Name()) + } + assert.NoError(t, err, "s.ConnectDB") + return s, outfile, errfile +} + +// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set +func canTestAzureAuth() bool { + server := os.Getenv(SQLCMDSERVER) + userName := os.Getenv(SQLCMDUSER) + return strings.Contains(server, ".database.windows.net") && userName == "" +} + +func newConnect(t testing.TB) *ConnectSettings { + t.Helper() + connect := ConnectSettings{ + UserName: os.Getenv(SQLCMDUSER), + Database: os.Getenv(SQLCMDDBNAME), + ServerName: os.Getenv(SQLCMDSERVER), + Password: os.Getenv(SQLCMDPASSWORD), + } + if canTestAzureAuth() { + t.Log("Using ActiveDirectoryDefault") + connect.AuthenticationMethod = azuread.ActiveDirectoryDefault + } + return &connect +} diff --git a/pkg/sqlcmd/testdata/singlebatchnogo.sql b/pkg/sqlcmd/testdata/singlebatchnogo.sql index 8d4c9bf2..29b68f16 100644 --- a/pkg/sqlcmd/testdata/singlebatchnogo.sql +++ b/pkg/sqlcmd/testdata/singlebatchnogo.sql @@ -1,2 +1,2 @@ -select 100 as num -select 'string' as title +select 100 as num +select 'string' as title diff --git a/pkg/sqlcmd/testdata/twobatchnoendinggo.sql b/pkg/sqlcmd/testdata/twobatchnoendinggo.sql index 0be51209..90c4d289 100644 --- a/pkg/sqlcmd/testdata/twobatchnoendinggo.sql +++ b/pkg/sqlcmd/testdata/twobatchnoendinggo.sql @@ -1,3 +1,3 @@ -select 100 as num -go -select 'string' as title +select 100 as num +go +select 'string' as title diff --git a/pkg/sqlcmd/testdata/twobatchwithgo.sql b/pkg/sqlcmd/testdata/twobatchwithgo.sql index 26554bc9..58439168 100644 --- a/pkg/sqlcmd/testdata/twobatchwithgo.sql +++ b/pkg/sqlcmd/testdata/twobatchwithgo.sql @@ -1,4 +1,4 @@ -select 100 as num -GO -select 'string' as title -GO +select 100 as num +GO +select 'string' as title +GO diff --git a/pkg/sqlcmd/util.go b/pkg/sqlcmd/util.go index fc0fb08a..79f75b67 100644 --- a/pkg/sqlcmd/util.go +++ b/pkg/sqlcmd/util.go @@ -1,73 +1,73 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "strconv" - "strings" -) - -// splitServer extracts connection parameters from a server name input -func splitServer(serverName string) (string, string, uint64, error) { - instance := "" - port := uint64(0) - if strings.HasPrefix(serverName, "tcp:") { - if len(serverName) == 4 { - return "", "", 0, &InvalidServerName - } - serverName = serverName[4:] - } - serverNameParts := strings.Split(serverName, ",") - if len(serverNameParts) > 2 { - return "", "", 0, &InvalidServerName - } - if len(serverNameParts) == 2 { - var err error - port, err = strconv.ParseUint(serverNameParts[1], 10, 16) - if err != nil { - return "", "", 0, &InvalidServerName - } - serverName = serverNameParts[0] - } else { - serverNameParts = strings.Split(serverName, "\\") - if len(serverNameParts) > 2 { - return "", "", 0, &InvalidServerName - } - if len(serverNameParts) == 2 { - instance = serverNameParts[1] - serverName = serverNameParts[0] - } - } - return serverName, instance, port, nil -} - -// padRight appends c instances of s to builder -func padRight(builder *strings.Builder, c int64, s string) *strings.Builder { - var i int64 - for ; i < c; i++ { - builder.WriteString(s) - } - return builder -} - -// padLeft prepends c instances of s to builder -func padLeft(builder *strings.Builder, c int64, s string) *strings.Builder { - newBuilder := new(strings.Builder) - newBuilder.Grow(builder.Len()) - var i int64 - for ; i < c; i++ { - newBuilder.WriteString(s) - } - newBuilder.WriteString(builder.String()) - return newBuilder -} - -func contains(arr []string, s string) bool { - for _, a := range arr { - if a == s { - return true - } - } - return false -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "strconv" + "strings" +) + +// splitServer extracts connection parameters from a server name input +func splitServer(serverName string) (string, string, uint64, error) { + instance := "" + port := uint64(0) + if strings.HasPrefix(serverName, "tcp:") { + if len(serverName) == 4 { + return "", "", 0, &InvalidServerName + } + serverName = serverName[4:] + } + serverNameParts := strings.Split(serverName, ",") + if len(serverNameParts) > 2 { + return "", "", 0, &InvalidServerName + } + if len(serverNameParts) == 2 { + var err error + port, err = strconv.ParseUint(serverNameParts[1], 10, 16) + if err != nil { + return "", "", 0, &InvalidServerName + } + serverName = serverNameParts[0] + } else { + serverNameParts = strings.Split(serverName, "\\") + if len(serverNameParts) > 2 { + return "", "", 0, &InvalidServerName + } + if len(serverNameParts) == 2 { + instance = serverNameParts[1] + serverName = serverNameParts[0] + } + } + return serverName, instance, port, nil +} + +// padRight appends c instances of s to builder +func padRight(builder *strings.Builder, c int64, s string) *strings.Builder { + var i int64 + for ; i < c; i++ { + builder.WriteString(s) + } + return builder +} + +// padLeft prepends c instances of s to builder +func padLeft(builder *strings.Builder, c int64, s string) *strings.Builder { + newBuilder := new(strings.Builder) + newBuilder.Grow(builder.Len()) + var i int64 + for ; i < c; i++ { + newBuilder.WriteString(s) + } + newBuilder.WriteString(builder.String()) + return newBuilder +} + +func contains(arr []string, s string) bool { + for _, a := range arr { + if a == s { + return true + } + } + return false +} diff --git a/pkg/sqlcmd/variables.go b/pkg/sqlcmd/variables.go index ebf6b1b5..e86a5b5e 100644 --- a/pkg/sqlcmd/variables.go +++ b/pkg/sqlcmd/variables.go @@ -1,335 +1,335 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "fmt" - "os" - "strings" - "unicode" -) - -// Variables provides set and get of sqlcmd scripting variables -type Variables map[string]string - -// Built-in scripting variables -const ( - SQLCMDDBNAME = "SQLCMDDBNAME" - SQLCMDINI = "SQLCMDINI" - SQLCMDPACKETSIZE = "SQLCMDPACKETSIZE" - SQLCMDPASSWORD = "SQLCMDPASSWORD" - SQLCMDSERVER = "SQLCMDSERVER" - SQLCMDUSER = "SQLCMDUSER" - SQLCMDWORKSTATION = "SQLCMDWORKSTATION" - SQLCMDLOGINTIMEOUT = "SQLCMDLOGINTIMEOUT" - SQLCMDSTATTIMEOUT = "SQLCMDSTATTIMEOUT" - SQLCMDHEADERS = "SQLCMDHEADERS" - SQLCMDCOLSEP = "SQLCMDCOLSEP" - SQLCMDCOLWIDTH = "SQLCMDCOLWIDTH" - SQLCMDERRORLEVEL = "SQLCMDERRORLEVEL" - SQLCMDFORMAT = "SQLCMDFORMAT" - SQLCMDMAXVARTYPEWIDTH = "SQLCMDMAXVARTYPEWIDTH" - SQLCMDMAXFIXEDTYPEWIDTH = "SQLCMDMAXFIXEDTYPEWIDTH" - SQLCMDEDITOR = "SQLCMDEDITOR" - SQLCMDUSEAAD = "SQLCMDUSEAAD" -) - -// builtinVariables are the predefined SQLCMD variables. Their values are printed first by :listvar -var builtinVariables = []string{ - SQLCMDCOLSEP, - SQLCMDCOLWIDTH, - SQLCMDDBNAME, - SQLCMDEDITOR, - SQLCMDERRORLEVEL, - SQLCMDFORMAT, - SQLCMDHEADERS, - SQLCMDINI, - SQLCMDLOGINTIMEOUT, - SQLCMDMAXFIXEDTYPEWIDTH, - SQLCMDMAXVARTYPEWIDTH, - SQLCMDPACKETSIZE, - SQLCMDSERVER, - SQLCMDSTATTIMEOUT, - SQLCMDUSEAAD, - SQLCMDUSER, - SQLCMDWORKSTATION, -} - -// readonlyVariables are variables that can't be changed via :setvar -var readOnlyVariables = []string{ - SQLCMDDBNAME, - SQLCMDINI, - SQLCMDPACKETSIZE, - SQLCMDSERVER, - SQLCMDUSER, - SQLCMDWORKSTATION, -} - -func (v Variables) checkReadOnly(key string) error { - currentValue, hasValue := v[key] - if hasValue { - for _, variable := range readOnlyVariables { - if variable == key && currentValue != "" { - return ReadOnlyVariable(key) - } - } - } - return nil -} - -// Set sets or adds the value in the map. -func (v Variables) Set(name, value string) { - key := strings.ToUpper(name) - v[key] = value -} - -// Get returns the value of the named variable -// To distinguish an empty value from an unset value use the bool return value -func (v Variables) Get(name string) (string, bool) { - key := strings.ToUpper(name) - s, ok := v[key] - return s, ok -} - -// Unset removes the value from the map -func (v Variables) Unset(name string) { - key := strings.ToUpper(name) - delete(v, key) -} - -// All returns a copy of the current variables -func (v Variables) All() map[string]string { - return map[string]string(v) -} - -// SQLCmdUser returns the SQLCMDUSER variable value -func (v Variables) SQLCmdUser() string { - return v[SQLCMDUSER] -} - -// SQLCmdServer returns the server connection parameters derived from the SQLCMDSERVER variable value -func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, err error) { - serverName = v[SQLCMDSERVER] - return splitServer(serverName) -} - -// SQLCmdDatabase returns the SQLCMDDBNAME variable value -func (v Variables) SQLCmdDatabase() string { - return v[SQLCMDDBNAME] -} - -// UseAad returns whether the SQLCMDUSEAAD variable value is set to "true" -func (v Variables) UseAad() bool { - return strings.EqualFold(v[SQLCMDUSEAAD], "true") -} - -// ColumnSeparator is the value of SQLCMDCOLSEP variable. It can have 0 or 1 characters -func (v Variables) ColumnSeparator() string { - sep := v[SQLCMDCOLSEP] - if len(sep) > 1 { - return sep[:1] - } - return sep -} - -// MaxFixedColumnWidth is the value of SQLCMDMAXFIXEDTYPEWIDTH variable. -// When non-zero, it limits the width of columns for types CHAR, NCHAR, NVARCHAR, VARCHAR, VARBINARY, VARIANT -func (v Variables) MaxFixedColumnWidth() int64 { - w := v[SQLCMDMAXFIXEDTYPEWIDTH] - return mustValue(w) -} - -// MaxVarColumnWidth is the value of SQLCMDMAXVARTYPEWIDTH variable. -// When non-zero, it limits the width of columns for (max) versions of CHAR, NCHAR, VARBINARY. -// It also limits the width of xml, UDT, text, ntext, and image -func (v Variables) MaxVarColumnWidth() int64 { - w := v[SQLCMDMAXVARTYPEWIDTH] - return mustValue(w) -} - -// ScreenWidth is the value of SQLCMDCOLWIDTH variable. -// It tells the formatter how many characters wide to limit all screen output. -func (v Variables) ScreenWidth() int64 { - w := v[SQLCMDCOLWIDTH] - return mustValue(w) -} - -// RowsBetweenHeaders is the value of SQLCMDHEADERS variable. -// When MaxVarColumnWidth() is 0, it returns -1 -func (v Variables) RowsBetweenHeaders() int64 { - if v.MaxVarColumnWidth() == 0 { - return -1 - } - h := mustValue(v[SQLCMDHEADERS]) - return h -} - -// ErrorLevel controls the minimum level of errors that are printed -func (v Variables) ErrorLevel() int64 { - return mustValue(v[SQLCMDERRORLEVEL]) -} - -// Format is the name of the results format -func (v Variables) Format() string { - switch v[SQLCMDFORMAT] { - case "vert", "vertical": - return "vertical" - } - return "horizontal" -} - -// StartupScriptFile is the path to the file that contains the startup script -func (v Variables) StartupScriptFile() string { - return v[SQLCMDINI] -} - -// TextEditor is the query editor application launched by the :ED command -func (v Variables) TextEditor() string { - return v[SQLCMDEDITOR] -} - -func mustValue(val string) int64 { - var n int64 - _, err := fmt.Sscanf(val, "%d", &n) - if err == nil { - return n - } - panic(err) -} - -// defaultVariables defines variables that cannot be removed from the map, only reset -// to their default values. -var defaultVariables = Variables{ - SQLCMDCOLSEP: " ", - SQLCMDCOLWIDTH: "0", - SQLCMDEDITOR: defaultEditor, - SQLCMDERRORLEVEL: "0", - SQLCMDHEADERS: "0", - SQLCMDLOGINTIMEOUT: "30", - SQLCMDMAXFIXEDTYPEWIDTH: "0", - SQLCMDMAXVARTYPEWIDTH: "256", - SQLCMDSTATTIMEOUT: "0", -} - -// InitializeVariables initializes variables with default values. -// When fromEnvironment is true, then loads from the runtime environment -func InitializeVariables(fromEnvironment bool) *Variables { - variables := Variables{ - SQLCMDCOLSEP: defaultVariables[SQLCMDCOLSEP], - SQLCMDCOLWIDTH: defaultVariables[SQLCMDCOLWIDTH], - SQLCMDDBNAME: "", - SQLCMDEDITOR: defaultVariables[SQLCMDEDITOR], - SQLCMDERRORLEVEL: defaultVariables[SQLCMDERRORLEVEL], - SQLCMDHEADERS: defaultVariables[SQLCMDHEADERS], - SQLCMDINI: "", - SQLCMDLOGINTIMEOUT: defaultVariables[SQLCMDLOGINTIMEOUT], - SQLCMDMAXFIXEDTYPEWIDTH: defaultVariables[SQLCMDMAXFIXEDTYPEWIDTH], - SQLCMDMAXVARTYPEWIDTH: defaultVariables[SQLCMDMAXVARTYPEWIDTH], - SQLCMDPACKETSIZE: "4096", - SQLCMDSERVER: "", - SQLCMDSTATTIMEOUT: defaultVariables[SQLCMDSTATTIMEOUT], - SQLCMDUSER: "", - SQLCMDUSEAAD: "", - } - hostname, _ := os.Hostname() - variables.Set(SQLCMDWORKSTATION, hostname) - - if fromEnvironment { - for v := range variables.All() { - envVar, ok := os.LookupEnv(v) - if ok { - variables.Set(v, envVar) - } - } - } - return &variables -} - -// Setvar implements the :Setvar command -// TODO: Add validation functions for the variables. -func (variables *Variables) Setvar(name, value string) error { - err := ValidIdentifier(name) - if err == nil { - if err = variables.checkReadOnly(name); err != nil { - err = ReadOnlyVariable(name) - } - } - if err != nil { - return err - } - if value == "" { - if _, ok := variables.Get(name); !ok { - return UndefinedVariable(name) - } - if def, ok := defaultVariables.Get(name); ok { - value = def - } else { - variables.Unset(name) - return nil - } - } else { - value, err = ParseValue(value) - } - if err != nil { - return err - } - variables.Set(name, value) - return nil -} - -const validVariableRunes = "_-" - -// ValidIdentifier determines if a given string can be used as a variable name -func ValidIdentifier(name string) error { - - first := true - for _, c := range name { - if !unicode.IsLetter(c) && (first || (!unicode.IsDigit(c) && !strings.ContainsRune(validVariableRunes, c))) { - return fmt.Errorf("Invalid variable identifier %s", name) - } - first = false - } - return nil -} - -// ParseValue returns the string to use as the variable value -// If the string contains a space or a quote, it must be delimited by quotes and literal quotes -// within the value must be escaped by another quote -// "this has a quote "" in it" is valid -// "this has a quote" in it" is not valid -func ParseValue(val string) (string, error) { - quoted := val[0] == '"' - err := fmt.Errorf("Invalid variable value %s", val) - if !quoted { - if strings.ContainsAny(val, "\t\n\r ") { - return "", err - } - return val, nil - } - if len(val) == 1 || val[len(val)-1] != '"' { - return "", err - } - - b := new(strings.Builder) - quoted = false - r := []rune(val) -loop: - for i := 1; i < len(r)-1; i++ { - switch { - case quoted && r[i] == '"': - b.WriteRune('"') - quoted = false - case quoted && r[i] != '"': - break loop - case !quoted && r[i] == '"': - quoted = true - default: - b.WriteRune(r[i]) - } - } - if quoted { - return "", err - } - return b.String(), nil -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "fmt" + "os" + "strings" + "unicode" +) + +// Variables provides set and get of sqlcmd scripting variables +type Variables map[string]string + +// Built-in scripting variables +const ( + SQLCMDDBNAME = "SQLCMDDBNAME" + SQLCMDINI = "SQLCMDINI" + SQLCMDPACKETSIZE = "SQLCMDPACKETSIZE" + SQLCMDPASSWORD = "SQLCMDPASSWORD" + SQLCMDSERVER = "SQLCMDSERVER" + SQLCMDUSER = "SQLCMDUSER" + SQLCMDWORKSTATION = "SQLCMDWORKSTATION" + SQLCMDLOGINTIMEOUT = "SQLCMDLOGINTIMEOUT" + SQLCMDSTATTIMEOUT = "SQLCMDSTATTIMEOUT" + SQLCMDHEADERS = "SQLCMDHEADERS" + SQLCMDCOLSEP = "SQLCMDCOLSEP" + SQLCMDCOLWIDTH = "SQLCMDCOLWIDTH" + SQLCMDERRORLEVEL = "SQLCMDERRORLEVEL" + SQLCMDFORMAT = "SQLCMDFORMAT" + SQLCMDMAXVARTYPEWIDTH = "SQLCMDMAXVARTYPEWIDTH" + SQLCMDMAXFIXEDTYPEWIDTH = "SQLCMDMAXFIXEDTYPEWIDTH" + SQLCMDEDITOR = "SQLCMDEDITOR" + SQLCMDUSEAAD = "SQLCMDUSEAAD" +) + +// builtinVariables are the predefined SQLCMD variables. Their values are printed first by :listvar +var builtinVariables = []string{ + SQLCMDCOLSEP, + SQLCMDCOLWIDTH, + SQLCMDDBNAME, + SQLCMDEDITOR, + SQLCMDERRORLEVEL, + SQLCMDFORMAT, + SQLCMDHEADERS, + SQLCMDINI, + SQLCMDLOGINTIMEOUT, + SQLCMDMAXFIXEDTYPEWIDTH, + SQLCMDMAXVARTYPEWIDTH, + SQLCMDPACKETSIZE, + SQLCMDSERVER, + SQLCMDSTATTIMEOUT, + SQLCMDUSEAAD, + SQLCMDUSER, + SQLCMDWORKSTATION, +} + +// readonlyVariables are variables that can't be changed via :setvar +var readOnlyVariables = []string{ + SQLCMDDBNAME, + SQLCMDINI, + SQLCMDPACKETSIZE, + SQLCMDSERVER, + SQLCMDUSER, + SQLCMDWORKSTATION, +} + +func (v Variables) checkReadOnly(key string) error { + currentValue, hasValue := v[key] + if hasValue { + for _, variable := range readOnlyVariables { + if variable == key && currentValue != "" { + return ReadOnlyVariable(key) + } + } + } + return nil +} + +// Set sets or adds the value in the map. +func (v Variables) Set(name, value string) { + key := strings.ToUpper(name) + v[key] = value +} + +// Get returns the value of the named variable +// To distinguish an empty value from an unset value use the bool return value +func (v Variables) Get(name string) (string, bool) { + key := strings.ToUpper(name) + s, ok := v[key] + return s, ok +} + +// Unset removes the value from the map +func (v Variables) Unset(name string) { + key := strings.ToUpper(name) + delete(v, key) +} + +// All returns a copy of the current variables +func (v Variables) All() map[string]string { + return map[string]string(v) +} + +// SQLCmdUser returns the SQLCMDUSER variable value +func (v Variables) SQLCmdUser() string { + return v[SQLCMDUSER] +} + +// SQLCmdServer returns the server connection parameters derived from the SQLCMDSERVER variable value +func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, err error) { + serverName = v[SQLCMDSERVER] + return splitServer(serverName) +} + +// SQLCmdDatabase returns the SQLCMDDBNAME variable value +func (v Variables) SQLCmdDatabase() string { + return v[SQLCMDDBNAME] +} + +// UseAad returns whether the SQLCMDUSEAAD variable value is set to "true" +func (v Variables) UseAad() bool { + return strings.EqualFold(v[SQLCMDUSEAAD], "true") +} + +// ColumnSeparator is the value of SQLCMDCOLSEP variable. It can have 0 or 1 characters +func (v Variables) ColumnSeparator() string { + sep := v[SQLCMDCOLSEP] + if len(sep) > 1 { + return sep[:1] + } + return sep +} + +// MaxFixedColumnWidth is the value of SQLCMDMAXFIXEDTYPEWIDTH variable. +// When non-zero, it limits the width of columns for types CHAR, NCHAR, NVARCHAR, VARCHAR, VARBINARY, VARIANT +func (v Variables) MaxFixedColumnWidth() int64 { + w := v[SQLCMDMAXFIXEDTYPEWIDTH] + return mustValue(w) +} + +// MaxVarColumnWidth is the value of SQLCMDMAXVARTYPEWIDTH variable. +// When non-zero, it limits the width of columns for (max) versions of CHAR, NCHAR, VARBINARY. +// It also limits the width of xml, UDT, text, ntext, and image +func (v Variables) MaxVarColumnWidth() int64 { + w := v[SQLCMDMAXVARTYPEWIDTH] + return mustValue(w) +} + +// ScreenWidth is the value of SQLCMDCOLWIDTH variable. +// It tells the formatter how many characters wide to limit all screen output. +func (v Variables) ScreenWidth() int64 { + w := v[SQLCMDCOLWIDTH] + return mustValue(w) +} + +// RowsBetweenHeaders is the value of SQLCMDHEADERS variable. +// When MaxVarColumnWidth() is 0, it returns -1 +func (v Variables) RowsBetweenHeaders() int64 { + if v.MaxVarColumnWidth() == 0 { + return -1 + } + h := mustValue(v[SQLCMDHEADERS]) + return h +} + +// ErrorLevel controls the minimum level of errors that are printed +func (v Variables) ErrorLevel() int64 { + return mustValue(v[SQLCMDERRORLEVEL]) +} + +// Format is the name of the results format +func (v Variables) Format() string { + switch v[SQLCMDFORMAT] { + case "vert", "vertical": + return "vertical" + } + return "horizontal" +} + +// StartupScriptFile is the path to the file that contains the startup script +func (v Variables) StartupScriptFile() string { + return v[SQLCMDINI] +} + +// TextEditor is the query editor application launched by the :ED command +func (v Variables) TextEditor() string { + return v[SQLCMDEDITOR] +} + +func mustValue(val string) int64 { + var n int64 + _, err := fmt.Sscanf(val, "%d", &n) + if err == nil { + return n + } + panic(err) +} + +// defaultVariables defines variables that cannot be removed from the map, only reset +// to their default values. +var defaultVariables = Variables{ + SQLCMDCOLSEP: " ", + SQLCMDCOLWIDTH: "0", + SQLCMDEDITOR: defaultEditor, + SQLCMDERRORLEVEL: "0", + SQLCMDHEADERS: "0", + SQLCMDLOGINTIMEOUT: "30", + SQLCMDMAXFIXEDTYPEWIDTH: "0", + SQLCMDMAXVARTYPEWIDTH: "256", + SQLCMDSTATTIMEOUT: "0", +} + +// InitializeVariables initializes variables with default values. +// When fromEnvironment is true, then loads from the runtime environment +func InitializeVariables(fromEnvironment bool) *Variables { + variables := Variables{ + SQLCMDCOLSEP: defaultVariables[SQLCMDCOLSEP], + SQLCMDCOLWIDTH: defaultVariables[SQLCMDCOLWIDTH], + SQLCMDDBNAME: "", + SQLCMDEDITOR: defaultVariables[SQLCMDEDITOR], + SQLCMDERRORLEVEL: defaultVariables[SQLCMDERRORLEVEL], + SQLCMDHEADERS: defaultVariables[SQLCMDHEADERS], + SQLCMDINI: "", + SQLCMDLOGINTIMEOUT: defaultVariables[SQLCMDLOGINTIMEOUT], + SQLCMDMAXFIXEDTYPEWIDTH: defaultVariables[SQLCMDMAXFIXEDTYPEWIDTH], + SQLCMDMAXVARTYPEWIDTH: defaultVariables[SQLCMDMAXVARTYPEWIDTH], + SQLCMDPACKETSIZE: "4096", + SQLCMDSERVER: "", + SQLCMDSTATTIMEOUT: defaultVariables[SQLCMDSTATTIMEOUT], + SQLCMDUSER: "", + SQLCMDUSEAAD: "", + } + hostname, _ := os.Hostname() + variables.Set(SQLCMDWORKSTATION, hostname) + + if fromEnvironment { + for v := range variables.All() { + envVar, ok := os.LookupEnv(v) + if ok { + variables.Set(v, envVar) + } + } + } + return &variables +} + +// Setvar implements the :Setvar command +// TODO: Add validation functions for the variables. +func (variables *Variables) Setvar(name, value string) error { + err := ValidIdentifier(name) + if err == nil { + if err = variables.checkReadOnly(name); err != nil { + err = ReadOnlyVariable(name) + } + } + if err != nil { + return err + } + if value == "" { + if _, ok := variables.Get(name); !ok { + return UndefinedVariable(name) + } + if def, ok := defaultVariables.Get(name); ok { + value = def + } else { + variables.Unset(name) + return nil + } + } else { + value, err = ParseValue(value) + } + if err != nil { + return err + } + variables.Set(name, value) + return nil +} + +const validVariableRunes = "_-" + +// ValidIdentifier determines if a given string can be used as a variable name +func ValidIdentifier(name string) error { + + first := true + for _, c := range name { + if !unicode.IsLetter(c) && (first || (!unicode.IsDigit(c) && !strings.ContainsRune(validVariableRunes, c))) { + return fmt.Errorf("Invalid variable identifier %s", name) + } + first = false + } + return nil +} + +// ParseValue returns the string to use as the variable value +// If the string contains a space or a quote, it must be delimited by quotes and literal quotes +// within the value must be escaped by another quote +// "this has a quote "" in it" is valid +// "this has a quote" in it" is not valid +func ParseValue(val string) (string, error) { + quoted := val[0] == '"' + err := fmt.Errorf("Invalid variable value %s", val) + if !quoted { + if strings.ContainsAny(val, "\t\n\r ") { + return "", err + } + return val, nil + } + if len(val) == 1 || val[len(val)-1] != '"' { + return "", err + } + + b := new(strings.Builder) + quoted = false + r := []rune(val) +loop: + for i := 1; i < len(r)-1; i++ { + switch { + case quoted && r[i] == '"': + b.WriteRune('"') + quoted = false + case quoted && r[i] != '"': + break loop + case !quoted && r[i] == '"': + quoted = true + default: + b.WriteRune(r[i]) + } + } + if quoted { + return "", err + } + return b.String(), nil +} diff --git a/pkg/sqlcmd/variables_test.go b/pkg/sqlcmd/variables_test.go index 45b8ae5b..cf4de871 100644 --- a/pkg/sqlcmd/variables_test.go +++ b/pkg/sqlcmd/variables_test.go @@ -1,116 +1,116 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBasicVariableOperations(t *testing.T) { - variables := Variables{ - "var1": "val1", - } - variables.Set("var2", "val2") - assert.Contains(t, variables, "VAR2", "Set should add a capitalized key") - all := variables.All() - keys := make([]string, 0, len(all)) - for k := range all { - keys = append(keys, k) - } - assert.ElementsMatch(t, []string{"var1", "VAR2"}, keys, "All returns every key") - assert.Equal(t, "val2", all["VAR2"], "VAR2 set value") - -} - -func TestSetvarFailsForReadOnlyVariables(t *testing.T) { - variables := Variables{} - variables.Set("SQLCMDDBNAME", "somedatabase") - err := variables.Setvar("SQLCMDDBNAME", "newdatabase") - assert.Error(t, err, "setting a readonly variable fails") - assert.Equal(t, "somedatabase", variables.SQLCmdDatabase(), "readonly variable shouldn't be changed by Setvar") -} - -func TestEnvironmentVariablesAsInput(t *testing.T) { - os.Setenv("SQLCMDSERVER", "someserver") - defer os.Unsetenv("SQLCMDSERVER") - os.Setenv("x", "somevalue") - defer os.Unsetenv("x") - vars := InitializeVariables(true).All() - assert.Equal(t, "someserver", vars["SQLCMDSERVER"], "InitializeVariables should read a valid environment variable from the known list") - _, ok := vars["x"] - assert.False(t, ok, "InitializeVariables should skip variables not in the known list") -} - -func TestSqlServerSplitsName(t *testing.T) { - vars := Variables{ - SQLCMDSERVER: `tcp:someserver\someinstance`, - } - serverName, instance, port, err := vars.SQLCmdServer() - if assert.NoError(t, err, "tcp:server\\someinstance") { - assert.Equal(t, "someserver", serverName, "server name for instance") - assert.Equal(t, uint64(0), port, "port for instance") - assert.Equal(t, "someinstance", instance, "instance for instance") - } - vars = Variables{ - SQLCMDSERVER: `tcp:someserver,1111`, - } - serverName, instance, port, err = vars.SQLCmdServer() - if assert.NoError(t, err, "tcp:server,1111") { - assert.Equal(t, "someserver", serverName, "server name for port number") - assert.Equal(t, uint64(1111), port, "port for port number") - assert.Equal(t, "", instance, "instance for port number") - } -} - -func TestParseValue(t *testing.T) { - type test struct { - raw string - val string - valid bool - } - tests := []test{ - {`""`, "", true}, - {`"`, "", false}, - {`"""`, "", false}, - {`no quotes`, "", false}, - {`"is quoted"`, "is quoted", true}, - {`" " single quote "`, "", false}, - {`" "" escaped quotes "" "`, ` " escaped quotes " `, true}, - } - - for _, tst := range tests { - v, err := ParseValue(tst.raw) - if tst.valid { - if assert.NoErrorf(t, err, "Unexpected error for value %s", tst.raw) { - assert.Equalf(t, tst.val, v, "Incorrect parsed value for %s", tst.raw) - } - } else { - assert.Errorf(t, err, "Expected error for %s", tst.raw) - } - } -} - -func TestValidIdentifier(t *testing.T) { - type test struct { - raw string - valid bool - } - tests := []test{ - {"1A", false}, - {"A1", true}, - {"A+", false}, - {"A-_b", true}, - } - for _, tst := range tests { - err := ValidIdentifier(tst.raw) - if tst.valid { - assert.NoErrorf(t, err, "%s is valid", tst.raw) - } else { - assert.Errorf(t, err, "%s is invalid", tst.raw) - } - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBasicVariableOperations(t *testing.T) { + variables := Variables{ + "var1": "val1", + } + variables.Set("var2", "val2") + assert.Contains(t, variables, "VAR2", "Set should add a capitalized key") + all := variables.All() + keys := make([]string, 0, len(all)) + for k := range all { + keys = append(keys, k) + } + assert.ElementsMatch(t, []string{"var1", "VAR2"}, keys, "All returns every key") + assert.Equal(t, "val2", all["VAR2"], "VAR2 set value") + +} + +func TestSetvarFailsForReadOnlyVariables(t *testing.T) { + variables := Variables{} + variables.Set("SQLCMDDBNAME", "somedatabase") + err := variables.Setvar("SQLCMDDBNAME", "newdatabase") + assert.Error(t, err, "setting a readonly variable fails") + assert.Equal(t, "somedatabase", variables.SQLCmdDatabase(), "readonly variable shouldn't be changed by Setvar") +} + +func TestEnvironmentVariablesAsInput(t *testing.T) { + os.Setenv("SQLCMDSERVER", "someserver") + defer os.Unsetenv("SQLCMDSERVER") + os.Setenv("x", "somevalue") + defer os.Unsetenv("x") + vars := InitializeVariables(true).All() + assert.Equal(t, "someserver", vars["SQLCMDSERVER"], "InitializeVariables should read a valid environment variable from the known list") + _, ok := vars["x"] + assert.False(t, ok, "InitializeVariables should skip variables not in the known list") +} + +func TestSqlServerSplitsName(t *testing.T) { + vars := Variables{ + SQLCMDSERVER: `tcp:someserver\someinstance`, + } + serverName, instance, port, err := vars.SQLCmdServer() + if assert.NoError(t, err, "tcp:server\\someinstance") { + assert.Equal(t, "someserver", serverName, "server name for instance") + assert.Equal(t, uint64(0), port, "port for instance") + assert.Equal(t, "someinstance", instance, "instance for instance") + } + vars = Variables{ + SQLCMDSERVER: `tcp:someserver,1111`, + } + serverName, instance, port, err = vars.SQLCmdServer() + if assert.NoError(t, err, "tcp:server,1111") { + assert.Equal(t, "someserver", serverName, "server name for port number") + assert.Equal(t, uint64(1111), port, "port for port number") + assert.Equal(t, "", instance, "instance for port number") + } +} + +func TestParseValue(t *testing.T) { + type test struct { + raw string + val string + valid bool + } + tests := []test{ + {`""`, "", true}, + {`"`, "", false}, + {`"""`, "", false}, + {`no quotes`, "", false}, + {`"is quoted"`, "is quoted", true}, + {`" " single quote "`, "", false}, + {`" "" escaped quotes "" "`, ` " escaped quotes " `, true}, + } + + for _, tst := range tests { + v, err := ParseValue(tst.raw) + if tst.valid { + if assert.NoErrorf(t, err, "Unexpected error for value %s", tst.raw) { + assert.Equalf(t, tst.val, v, "Incorrect parsed value for %s", tst.raw) + } + } else { + assert.Errorf(t, err, "Expected error for %s", tst.raw) + } + } +} + +func TestValidIdentifier(t *testing.T) { + type test struct { + raw string + valid bool + } + tests := []test{ + {"1A", false}, + {"A1", true}, + {"A+", false}, + {"A-_b", true}, + } + for _, tst := range tests { + err := ValidIdentifier(tst.raw) + if tst.valid { + assert.NoErrorf(t, err, "%s is valid", tst.raw) + } else { + assert.Errorf(t, err, "%s is invalid", tst.raw) + } + } +} diff --git a/testdata/sql.txt b/testdata/sql.txt index b98b2681..7519b04f 100644 --- a/testdata/sql.txt +++ b/testdata/sql.txt @@ -1,3 +1,3 @@ -select 1 as col1 -go - +select 1 as col1 +go +