Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into test-certs
Browse files Browse the repository at this point in the history
  • Loading branch information
ninedraft committed Oct 2, 2022
2 parents 2f1bc88 + fb52392 commit e298209
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 66 deletions.
81 changes: 76 additions & 5 deletions gemax/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gemax
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand All @@ -20,19 +21,89 @@ import (
type Client struct {
MaxResponseSize int64
Dial func(ctx context.Context, host string, cfg *tls.Config) (net.Conn, error)
once sync.Once
// CheckRedirect specifies the policy for handling redirects.
// If CheckRedirect is not nil, the client calls it before
// following an Gemini redirect. The arguments req and via are
// the upcoming request and the requests made already, oldest
// first. If CheckRedirect returns an error, the Client's Fetch
// method returns both the previous Response (with its Body
// closed) and CheckRedirect's error.
// instead of issuing the Request req.
// As a special case, if CheckRedirect returns ErrUseLastResponse,
// then the most recent response is returned with its body
// unclosed, along with a nil error.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
CheckRedirect func(ctx context.Context, verification *urlpkg.URL, via []RedirectedRequest) error
once sync.Once
}

var (
// ErrTooManyRedirects means that server tried through too many adresses.
// Default limit is 10.
// User implementations of CheckRedirect should use this error then limiting number of redirects.
ErrTooManyRedirects = errors.New("too many redirects")
)

func (client *Client) checkRedirect(ctx context.Context, req *urlpkg.URL, via []RedirectedRequest) error {
if client.CheckRedirect != nil {
return client.CheckRedirect(ctx, req, via)
}
return defaultRedirect(ctx, req, via)
}

func defaultRedirect(_ context.Context, _ *urlpkg.URL, via []RedirectedRequest) error {
const max = 10
if len(via) < max {
return nil
}
return ErrTooManyRedirects
}

// RedirectedRequest contains executed gemini request data
// and corresponding response with closed body.
type RedirectedRequest struct {
Req *urlpkg.URL
Response *Response
}

const readerBufSize = 16 << 10

// Fetch gemini resource.
func (client *Client) Fetch(ctx context.Context, url string) (*Response, error) {
client.init()
var u, errParseURL = urlpkg.Parse(url)
if errParseURL != nil {
return nil, fmt.Errorf("parsing URL: %w", errParseURL)
//nolint:prealloc // unable to preallocate, we don't know number of redirects
var redirects []RedirectedRequest
for {
var u, errParseURL = urlpkg.Parse(url)
if errParseURL != nil {
return nil, fmt.Errorf("parsing URL: %w", errParseURL)
}
if err := client.checkRedirect(ctx, u, redirects); err != nil {
return nil, fmt.Errorf("redirect: %w", err)
}
resp, errFetch := client.fetch(ctx, url, u)
if errFetch != nil {
return resp, errFetch
}
if !isRedirect(resp.Status) {
return resp, nil
}
_ = resp.Close()
redirects = append(redirects, RedirectedRequest{
Req: u,
Response: resp,
})
url = resp.Meta
}
}

func isRedirect(code status.Code) bool {
return code == status.Redirect || code == status.RedirectPermanent
}

func (client *Client) fetch(ctx context.Context, origURL string, u *urlpkg.URL) (*Response, error) {
var host = u.Host
if strings.LastIndexByte(host, ':') < 0 {
host += ":1965"
Expand All @@ -51,7 +122,7 @@ func (client *Client) Fetch(ctx context.Context, url string) (*Response, error)
}
ctxConnDeadline(ctx, conn)

var _, errWrite = io.WriteString(conn, url+"\r\n")
var _, errWrite = io.WriteString(conn, origURL+"\r\n")
if errWrite != nil {
return nil, fmt.Errorf("sending request: %w", errWrite)
}
Expand Down
88 changes: 36 additions & 52 deletions gemax/client_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package gemax_test

import (
"bytes"
"context"
"crypto/tls"
"embed"
"errors"
"io"
"net"
"testing"

"github.com/ninedraft/gemax/gemax"
"github.com/ninedraft/gemax/gemax/internal/testaddr"
"github.com/ninedraft/gemax/gemax/internal/tester"
"github.com/ninedraft/gemax/gemax/status"
)

//go:embed testdata/client/pages/*
Expand Down Expand Up @@ -40,62 +38,48 @@ func TestClient(test *testing.T) {
test.Logf("%s", data)
}

func TestClientTLS(test *testing.T) {
var ctx, cancel = context.WithCancel(context.Background())
defer cancel()

var cert, errCert = tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem")
if errCert != nil {
test.Fatalf("loading test TLS certs: %v", errCert)
func TestClient_Redirect(test *testing.T) {
var dialer = tester.DialFS{
Prefix: "testdata/client/pages/",
FS: testClientPages,
}
var tlsCfg = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
var client = &gemax.Client{
Dial: dialer.Dial,
}

var addr = dumbServer(ctx, test, tlsCfg)

var client = &gemax.Client{}
var resp, errFetch = client.Fetch(ctx, "gemini://"+addr)
var ctx = context.Background()
var resp, errFetch = client.Fetch(ctx, "gemini://redirect1.com")
if errFetch != nil {
test.Fatal("fetching test data:", errFetch)
test.Errorf("unexpected fetch error: %v", errFetch)
return
}
if resp.Status != status.Success {
test.Fatalf("unexpected status code %v", resp.Status)
}
defer func() { _ = resp.Close() }()

var responseText, _ = io.ReadAll(resp)
if !bytes.Equal(responseText, []byte("\n# Hello world\n")) {
test.Fatalf("unexpected response: %q", responseText)
var data, errRead = io.ReadAll(resp)
if errRead != nil {
test.Errorf("unexpected error while reading response body: %v", errRead)
return
}
test.Logf("%s", data)
}

func dumbServer(ctx context.Context, test *testing.T, tlsCfg *tls.Config) (addr string) {
addr = testaddr.Addr()
var tcpListener, errListenTCP = net.Listen("tcp", addr)
if errListenTCP != nil {
test.Fatalf("starting a TCP listener: %v", errListenTCP)
func TestClient_InfiniteRedirect(test *testing.T) {
var dialer = tester.DialFS{
Prefix: "testdata/client/pages/",
FS: testClientPages,
}
test.Cleanup(func() { _ = tcpListener.Close() })

var listener = tls.NewListener(tcpListener, tlsCfg)
go func() {
<-ctx.Done()
_ = listener.Close()
}()

var testdata, errTestData = testClientPages.ReadFile("testdata/client/pages/success.com")
if errTestData != nil {
test.Fatal("reading test data:", errTestData)
var client = &gemax.Client{
Dial: dialer.Dial,
}
var ctx = context.Background()
var _, errFetch = client.Fetch(ctx, "gemini://redirect2.com")
switch {
case errors.Is(errFetch, gemax.ErrTooManyRedirects):
// ok
case errFetch != nil:
test.Fatalf("unexpected error %q", errFetch)
default:
test.Fatalf("an error is expected, got nil")
}
go func() {
defer func() { _ = listener.Close() }()

var conn, errAccept = listener.Accept()
if errAccept != nil {
test.Log("accepting test connection:", errAccept)
return
}
defer func() { _ = conn.Close() }()
_, _ = conn.Write(testdata)
}()
return addr
}
9 changes: 9 additions & 0 deletions gemax/request.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package gemax

import (
"errors"
"fmt"
"io"
"net/url"
"strings"
)

// MaxRequestSize is the maximum incoming request size in bytes.
Expand All @@ -15,6 +17,10 @@ type IncomingRequest interface {
RemoteAddr() string
}

var (
errDotPath = errors.New("dots in path are not permitted")
)

// ParseIncomingRequest constructs an IncomingRequest from bytestream
// and additional parameters (remote address for now).
func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, error) {
Expand All @@ -28,6 +34,9 @@ func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, err
if errParse != nil {
return nil, fmt.Errorf("bad request: %w", errParse)
}
if strings.Contains(parsed.Path, "/..") {
return nil, errDotPath
}
return &incomingRequest{
url: parsed,
remoteAddr: remoteAddr,
Expand Down
8 changes: 8 additions & 0 deletions gemax/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"strings"
"sync"

"github.com/ninedraft/gemax/gemax/internal/bufwriter"
Expand Down Expand Up @@ -36,6 +37,7 @@ func (rw *responseWriter) WriteStatus(code status.Code, meta string) {
if code == status.Success && meta == "" {
meta = MIMEGemtext
}
meta = metaSanitizer.Replace(meta)
_, _ = fmt.Fprintf(rw.writer, "%d %s\r\n", code, meta)
rw.status = code
rw.statusWritten = true
Expand All @@ -44,6 +46,12 @@ func (rw *responseWriter) WriteStatus(code status.Code, meta string) {
}
}

var metaSanitizer = strings.NewReplacer(
"\r\n", "\t",
"\n", "\t",
"\r", "\t",
)

func (rw *responseWriter) Write(data []byte) (int, error) {
if rw.isClosed {
return 0, io.ErrNoProgress
Expand Down
11 changes: 9 additions & 2 deletions gemax/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gemax
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -109,9 +110,15 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) {
}
}()
var req, errParseReq = ParseIncomingRequest(conn, conn.RemoteAddr().String())
var code = status.Success
switch {
case errors.Is(errParseReq, errDotPath):
code = status.PermanentFailure
case errParseReq != nil:
code = status.BadRequest
}
if errParseReq != nil {
server.logf("WARN: bad request: remote_ip=%s", conn.RemoteAddr())
const code = status.BadRequest
server.logf("WARN: bad request: remote_ip=%s, code=%s", conn.RemoteAddr(), code)
rw.WriteStatus(code, status.Text(code))
return
}
Expand Down

0 comments on commit e298209

Please sign in to comment.