Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import (
"log"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
)

type config struct {
socketPath string
powershellPath string
foreground bool
verbose bool
stop bool
Expand All @@ -35,10 +36,26 @@ func defaultSocketPath() string {
return filepath.Join(home, ".ssh", "wsl2-ssh-agent.sock")
}

func powershellPath() string {
path, err := exec.LookPath("powershell.exe")
if err != nil {
path := "/mnt/c/Windows/System32/WindowsPowerShell/v1.0/powershell.exe"
_, err := os.Stat(path)
if err == nil {
return path
} else {
return ""
}

}
return path
}

func newConfig() *config {
c := &config{}

flag.StringVar(&c.socketPath, "socket", defaultSocketPath(), "a path of UNIX domain socket to listen")
flag.StringVar(&c.powershellPath, "powershell-path", powershellPath(), "a path of Windows PowerShell (powershell.exe)")
flag.BoolVar(&c.foreground, "foreground", false, "run in foreground mode")
flag.BoolVar(&c.verbose, "verbose", false, "verbose mode")
flag.StringVar(&c.logFile, "log", "", "a file path to write the log")
Expand All @@ -52,10 +69,15 @@ func newConfig() *config {

flag.Parse()

if c.powershellPath == "" {
fmt.Printf("powershell.exe not found, use the -powershell-path to customize the path.\n")
os.Exit(1)
}

return c
}

func (c *config) start() (context.Context, bool) {
func (c *config) start() (context.Context) {
if c.version {
fmt.Printf("wsl2-ssh-agent %s\n", version)
os.Exit(0)
Expand Down Expand Up @@ -110,13 +132,7 @@ func (c *config) start() (context.Context, bool) {
signal.Ignore(syscall.SIGPIPE)
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)

// check if ssh-agent.exe is older than 8.9
ignoreOpenSSHExtensions := strings.Compare(getWinSshVersion(), "OpenSSH_for_Windows_8.9") == -1
if ignoreOpenSSHExtensions {
log.Printf("ssh-agent.exe seems to be old; ignore OpenSSH extension messages")
}

return ctx, ignoreOpenSSHExtensions
return ctx
}

func (c *config) setupLogFile() {
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package main
func main() {
c := newConfig()

ctx, ignoreOpenSSHExtensions := c.start()
ctx := c.start()

s := newServer(c.socketPath, ignoreOpenSSHExtensions)
s := newServer(c.socketPath, c.powershellPath)

s.run(ctx)
}
37 changes: 2 additions & 35 deletions repeater.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package main

import (
"bytes"
"context"
_ "embed"
"fmt"
"io"
"log"
"os/exec"
"strings"
"time"
)

Expand All @@ -28,11 +26,11 @@ var waitTimes []time.Duration = []time.Duration{
}

// invoke PowerShell.exe and run
func newRepeater(ctx context.Context) (*repeater, error) {
func newRepeater(ctx context.Context, powershell string) (*repeater, error) {
for i, limit := range waitTimes {
log.Printf("invoking [W] in PowerShell.exe%s", trial(i))

cmd := exec.Command("PowerShell.exe", "-Command", "-")
cmd := exec.Command(powershell, "-Command", "-")
in, err := cmd.StdinPipe()
if err != nil {
continue
Expand Down Expand Up @@ -96,37 +94,6 @@ func (rep *repeater) terminate() {
terminate(rep.cmd)
}

func getWinSshVersion() string {
for i, limit := range waitTimes {
ctx, cancel := context.WithTimeout(context.Background(), limit)
defer cancel()

log.Printf("check the version of ssh.exe%s", trial(i))

cmd := exec.CommandContext(ctx, "ssh.exe", "-V")

var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

err := cmd.Run()

if err != nil {
log.Printf("failed to invoke ssh.exe: %s", err)
continue
}

version := strings.TrimSuffix(stderr.String(), "\r\n")

log.Printf("the version of ssh.exe: %#v", version)
return version
}

log.Printf("failed to check the version of ssh.exe")

return ""
}

func trial(i int) string {
if i == 0 {
return ""
Expand Down
60 changes: 41 additions & 19 deletions repeater.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,68 @@ Function Log($msg) {
$host.ui.WriteErrorLine("[W] $date $msg")
}

Function RelayMessage($from, $to, $buf, $arrow) {
Function ReadMessage($stream) {
$buf = New-Object byte[] 4
$offset = 0
while ($offset -lt 4) {
$n = $from.Read($buf, $offset, 4 - $offset);
if ($n -eq 0) { exit }
$n = $stream.Read($buf, $offset, 4 - $offset);
if ($n -eq 0) {
break
}
$offset += $n;
}
$len = (($buf[0] * 256 + $buf[1]) * 256 + $buf[2]) * 256 + $buf[3] + 4
Log "[L] $arrow [W] $arrow ssh-agent.exe ($len B)"
$len
while ($offset -lt $len) {
$n = $from.Read($buf, $offset, [Math]::Min($len, $buf.Length) - $offset)
if ($n -eq 0) { exit }
$offset += $n
$to.Write($buf, 0, $offset)
$len -= $offset
$offset = 0
if ($offset -eq 4) {
$len = (($buf[0] * 256 + $buf[1]) * 256 + $buf[2]) * 256 + $buf[3] + 4
[Array]::Resize([ref]$buf, $len)
while ($offset -lt $buf.Length) {
$n = $stream.Read($buf, $offset, $buf.Length - $offset)
if ($n -eq 0) {
break
}
$offset += $n
}
}
[Array]::Resize([ref]$buf, $offset)
return $buf
}

Function MainLoop {
Try {
$buf = New-Object byte[] 8192
$ignoreOpenSSHExtensions = $false
Try {
$sshAgentVersion = (Get-Command -CommandType Application ssh-agent.exe -ErrorAction Stop)[0].Version
$ignoreOpenSSHExtensions = ($sshAgentVersion.Major -le 8 -and $sshAgentVersion.Minor -lt 9)
Log "ssh-agent.exe version: $($sshAgentVersion.ToString()) (ignoreOpenSSHExtensions: $ignoreOpenSSHExtensions)"
}
Catch {
$ignoreOpenSSHExtensions = $true
}

$ssh_client_in = [console]::OpenStandardInput()
$ssh_client_out = [console]::OpenStandardOutput()

$ver = $PSVersionTable["PSVersion"]
$ssh_client_out.WriteByte(0xff)
Log "ready: PSVersion $ver"

$buf[0] = 0xff
$ssh_client_out.Write($buf, 0, 1)

while ($true) {
Try {
$null = $ssh_client_in.Read((New-Object byte[] 1), 0, 0)
$buf = ReadMessage $ssh_client_in
if ($ignoreOpenSSHExtensions -and $buf.Length -gt 4 -and $buf[4] -eq 0x1b) {
$buf = [byte[]](0, 0, 0, 1, 6)
$ssh_client_out.Write($buf, 0, $buf.Length)
Log "[W] return dummy for OpenSSH ext."
Continue
}
$ssh_agent = New-Object System.IO.Pipes.NamedPipeClientStream ".", "openssh-ssh-agent", InOut
$ssh_agent.Connect()
Log "[W] named pipe: connected"
$len = RelayMessage $ssh_client_in $ssh_agent $buf "->"
$len = RelayMessage $ssh_agent $ssh_client_out $buf "<-"
$ssh_agent.Write($buf, 0, $buf.Length)
Log "[L] -> [W] -> ssh-agent.exe ($($buf.Length) B)"
$buf = ReadMessage $ssh_agent
$ssh_client_out.Write($buf, 0, $buf.Length)
Log "[L] <- [W] <- ssh-agent.exe ($($buf.Length) B)"
}
Finally {
if ($null -ne $ssh_agent) {
Expand Down
27 changes: 0 additions & 27 deletions repeater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ loop do
end
`

const dummySsh = `#!/usr/bin/ruby
$stderr << "Hello\r\n"
`

func setupDummyEnv(t *testing.T) string {
t.Helper()
log.SetOutput(io.Discard)
Expand Down Expand Up @@ -93,26 +89,3 @@ func TestRepeaterNormal(t *testing.T) {

rep.terminate()
}

func TestSshVersionNoSsh(t *testing.T) {
setupDummyEnv(t)

s := getWinSshVersion()
if s != "" {
t.Errorf("getWinSshVersion should fail")
}
}

func TestSshVersionNormal(t *testing.T) {
tmpDir := setupDummyEnv(t)

err := os.WriteFile(filepath.Join(tmpDir, "ssh.exe"), []byte(dummySsh), 0777)
if err != nil {
t.Fatal(err)
}

s := getWinSshVersion()
if s != "Hello" {
t.Errorf("getWinSshVersion does not work well: %#v", s)
}
}
34 changes: 6 additions & 28 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ import (

type server struct {
listener net.Listener
ignoreOpenSSHExtensions bool
powershellPath string
}

func newServer(path string, ignoreOpenSSHExtensions bool) *server {
listener, err := net.Listen("unix", path)
func newServer(socketPath string, powershellPath string) *server {
listener, err := net.Listen("unix", socketPath)
if err != nil {
log.Fatal(err)
}
log.Printf("start listening on %s", path)
log.Printf("start listening on %s", socketPath)

return &server{listener, ignoreOpenSSHExtensions}
return &server{listener, powershellPath}
}

type request struct {
Expand Down Expand Up @@ -91,7 +91,7 @@ func (s *server) server(ctx context.Context, cancel func(), requestQueue chan re

for {
// invoke PowerShell.exe
rep, err := newRepeater(ctx)
rep, err := newRepeater(ctx, s.powershellPath)
if err != nil {
return
}
Expand Down Expand Up @@ -180,22 +180,6 @@ func (s *server) client(wg *sync.WaitGroup, ctx context.Context, sshClient net.C
}
log.Printf("ssh -> [L] (%d B)", len(req))

if s.ignoreOpenSSHExtensions && req[4] == 0x1b /* SSH_AGENTC_EXTENSION */ {
// This is OpenSSH's extension message since OpenSSH 8.9.
//
// ref: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent
//
// If we pass this message to ssh-agent.exe in OpenSSH 8.6, the connection is closed.
// So we need to drop this message and send a dummy success response.
log.Printf("ssh <- [L] (5 B) <dummy for OpenSSH ext.>")
err := replyDummySuccess(sshClient, 0)
if err != nil {
log.Printf("failed to write to ssh: %s", err)
break
}
continue
}

requestQueue <- request{data: req, resultChannel: resChan}
resp, ok := <-resChan
if !ok {
Expand Down Expand Up @@ -241,9 +225,3 @@ func readMessage(from io.Reader) ([]byte, error) {

return append(header, body...), nil
}

func replyDummySuccess(client io.ReadWriter, len int64) error {
buf := []byte{0, 0, 0, 1, 6 /* SSH_AGENT_SUCCESS */}
_, err := client.Write(buf)
return err
}