diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..501443e --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +.git +.obj +tags +test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..557637d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.obj diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..3b7c5a7 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "deps/WireGuard"] + path = deps/WireGuard + url = git://git.zx2c4.com/WireGuard diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..57bc88a --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..312ac7f --- /dev/null +++ b/Makefile @@ -0,0 +1,88 @@ +.PHONY: all build clean run run-dbg run-vgd docker run-docker run-docker-dbg install-deps + +SO = .obj/whcore.so +SRC_C = $(wildcard src/core/*.c) +OBJ_C = $(patsubst src/core/%.c,.obj/%.o,$(SRC_C)) + +EMBED_WG_PATH = deps/WireGuard/contrib/examples/embeddable-wg-library + +CC=gcc +MINIMAL_CFLAGS=-Wall -fPIC +DEBUG?=n + +ifeq ($(DEBUG), y) + MINIMAL_CFLAGS+=-g +else + MINIMAL_CFLAGS+=-O2 +endif + +CFLAGS=$(MINIMAL_CFLAGS) -Wextra +WG_EMBED_CFLAGS=$(MINIMAL_CFLAGS) +LDFLAGS=-lsodium -lpthread -lpcap -lminiupnpc + + +all: build + +$(EMBED_WG_PATH)/wireguard.c: +$(EMBED_WG_PATH)/wireguard.h: + +src/core/wireguard.c: $(EMBED_WG_PATH)/wireguard.c + cp $< $@ + +src/core/wireguard.h: $(EMBED_WG_PATH)/wireguard.h + cp $< $@ + +install-deps: \ + src/core/wireguard.c \ + src/core/wireguard.h + +build: $(SO) + +$(SO): install-deps $(OBJ_C) + $(CC) -shared -o $@ $(OBJ_C) $(LDFLAGS) +ifeq ($(DEBUG), n) + strip $@ +endif + @ls -lh $@ + +.obj/wireguard.o: src/core/wireguard.c + $(CC) -c $< -o $@ $(WG_EMBED_CFLAGS) + +.obj/%.o: src/core/%.c + @mkdir -p .obj + $(CC) -c $< -o $@ $(CFLAGS) + +clean: + rm -f $(SO) $(OBJ_C) + +run: all + lua src/cli.lua + +run-dbg: all + gdb -ex run -args ./.obj/lua-dbg src/cli.lua + +run-vgd: + valgrind --track-origins=yes /usr/bin/env lua src/cli.lua + +docker: + docker build -t wirehub/wh -f docker/Dockerfile . + +docker-sandbox: + docker build --target builder -t wirehub/builder -f docker/Dockerfile . + docker build -t wirehub/sandbox -f docker/Dockerfile.sandbox . + +docker-root1: docker + docker build --no-cache=true -t wirehub/root1 -f docker/Dockerfile.root1 . + +run-docker: + docker run -it --rm --cap-add NET_ADMIN --cap-add SYS_ADMIN --cap-add SYS_PTRACE wirehub /bin/sh + +run-sandbox: + docker run -it --rm --cap-add NET_ADMIN --cap-add SYS_ADMIN --cap-add SYS_PTRACE -v "$(shell pwd):/root/wh" wirehub/sandbox /bin/bash + + +run-sandbox-nomount: + docker run -it --rm --cap-add NET_ADMIN --cap-add SYS_ADMIN --cap-add SYS_PTRACE wirehub/sandbox /bin/bash + +run-root1: + docker run -d --cap-add NET_ADMIN --network=host --name wh-root1 wirehub/root1 diff --git a/README.md b/README.md new file mode 100644 index 0000000..e2c2fb0 --- /dev/null +++ b/README.md @@ -0,0 +1,154 @@ +# WireHub + +WireHub (in a shell, *wh*) is a simple, small, peer-to-peer, decentralized, +extensible VPN. It goes through NATs. It uses [WireGuard tunnels][wireguard] and +provides distributed peer discovery & routing capabilities, NAT trasversal, +extendable name resolving, ... + +It is written in C and Lua and is <10KLOC. + +⚠️ **Not ready for production!** This is still a work-in-progress. It still +requires some work to be clean and secure. The current code is provided for +testing only. + +## Features + +- **Easy management of networks**: a network is defined by a single configuration + file which lists trusted peers. + +- **Decentralized discovery**: WireHub peers form a [Kademilia + DHT][kademilia] network which is the by-default + discovery mechanism to find new peers. [Sybil attack][sybil] is mitigated with + a configurable Proof-of-Work parameter; + +- **Peer-to-Peer communication**: WireHub go through NATs using ([UPnP + IGD][igd]) to map new ports on compatible routers, or using [UDP Hole + Punching][udp-hole-punching]. + +- **Relay communication**: if a P2P communication cannot be established, network + traffic is relayed through trusted relayed servers, or at the very least peers + from the community of WireHub nodes. + +## Dependencies + +- [Libpcap][libpcap] +- [Libsodium][libsodium] +- [Lua][lua] +- [WireGuard][wireguard] +- optionally, [Docker][docker] + +## Requirements + +- Linux or Docker +- WireGuard + +## Quickstart with Docker + +You can test WireHub with Docker with the image [`wirehub/wh`][wh-docker]. +There's a playground container [`wirehub/sandbox`][sandbox-docker] which is more +comfortable to use (auto-completion enabled, debug tooling, live troubleshooting +ready, ...). + +``` +$ docker run -it wirehub/sandbox --cap-add NET_ADMIN wirehub /bin/bash +``` + +Make sure WireHub is installed. + +``` +$ wh help +Usage: wh [] + +[...] +``` + +Set up the minimal configuration for the `public` network. + +``` +$ curl https://gawenr.keybase.pub/wirehub/bootstrap-unstable | wh setconf public +``` + +An example configuration for the network `public` looks like this: + +``` +# Example configuration for WireHub public network + +[Network] +Name = public +Namespace = public +Workbits = 8 + +[Peer] +# Trust = no +Bootstrap = yes +PublicKey = P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w +Endpoint = 51.15.227.165:62096 +``` + +Starts a peer for network `public`. + +``` +$ wh up public +``` + +You can make sure WireHub is running. + +``` +$ wh +interface gOVQwCSUxK, network public, node <> + public key: gOVQwCSUxKUhUrkUSF0aDvssDfWVrrnm47ZMp5GJtDg +``` + +Here to see all peers. + +``` +$ wh show gOVQwCSUxK all +interface gOVQwCSUxK, network public, node <> + public key: gOVQwCSUxKUhUrkUSF0aDvssDfWVrrnm47ZMp5GJtDg + + peers + ◒ BB_O_4Qxzw: 1.2.3.4:55329 (bucket:1) + ◒ C4mfi1ltU9: 1.2.3.4:46276 (bucket:1) + ◒ Dng_TaMHei: 1.2.3.4:6465 (bucket:1) + ◒ GjIX1RdmDj: 1.2.3.4:53850 (bucket:1) + ◒ G9qk6znNL5: 1.2.3.4:4523 (bucket:1) + ◒ J_RXehMJiw: 1.2.3.4:13962 (bucket:1) + ◒ PgjYqFfsyS: 1.2.3.4:39582 (bucket:1) + ● P17zMwXJFb: 51.15.227.165:62096 (bucket:1) + [...] +``` + +### Current limitations + +- **Untrusted cryptography**: even if WireHub basics cryptographic routines are + based on the trusted [Libsodium][libsodium], the WireHub cryptographic + architecture has not been audited yet. If you're interested to contribute on + this part, help is very welcome! + +- **Still panic**: still quite rough to use. Do not expect the daemon to be stable; + +- **For a relayed peer, only one relay is used**: the traffic is not distributed + yet between several relays, which makes a single point of failure of WireHub + relay mechanisms; + +- **Only IPv4**: implemeting IPv6 requires some additional work; + +- and related to WireGuard, which is still under active development. + +### Future + +- **Zero-configuration networking with IPv6 [ORCHID][orchid] addresses**: every + peer has an allocated IP address (see `wh orchid`); + +[kademilia]: https://en.wikipedia.org/wiki/Kademlia +[libsodium]: https://download.libsodium.org/doc/ +[lua]: https://www.lua.org/ +[orchid]: https://datatracker.ietf.org/doc/rfc4843/ +[wh-docker]: https://hub.docker.com/r/wirehub/wh/ +[sandbox-docker]: https://hub.docker.com/r/wirehub/sandbox/ +[wireguard]: https://www.wireguard.com/ +[sybil]: https://en.wikipedia.org/wiki/Sybil_attack +[pow]: https://en.wikipedia.org/wiki/Proof-of-work_system +[igd]: https://en.wikipedia.org/wiki/Internet_Gateway_Device_Protocol +[udp-hole-punching]: https://en.wikipedia.org/wiki/UDP_hole_punching + diff --git a/READMORE.md b/READMORE.md new file mode 100644 index 0000000..14f9591 --- /dev/null +++ b/READMORE.md @@ -0,0 +1,73 @@ +# Principles + +A WireHub **peer** is a network node running WireHub. Each peer has a Curve25519 +private key, used as a proof of its identity in the network. Its network address +is its Curve25519 public key. The human-readable version of a peer address if +encoded with Base64 (e.g. `P17zMwXJF..._KX3w`). + +Peers form a Kademilia DHT. The distance function `XOR` is used to make peers +close to each other aware of themselves. By requesting consecutively the closest +peers of one which is being looked up, peers are able to find the IPv6 or IPv6 +public address of another peer decentralizedly. Central servers may be provided, +but the network keeps working if they are unreachable. + +Peers can form a **private networks**. A private network sets a list of +**trusted peers** which has each a private IP address and optionally a hostname. +Application's network traffic of trusted peers are sent through WireGuard +tunnels. + +A private network is defined by a single configuration file, like so + +``` +[Network] +Name = jgl +Namespace = public +Workbits = 8 +SubNetwork = 10.0.42.1/24 + +[Peer] +# Trust = no +Bootstrap = yes +Name = bootstrap +PublicKey = P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w +Endpoint = 123.45.67.89:62096 + +[Peer] +Trust = yes +Name = 1.relay +IP = 10.0.42.1 +PublicKey = ZvuWjYZPQL7NGBZKXsB7zJgqVpY3zG_h-8ALBE3QHTM + +[Peer] +Trust = yes +Name = 2.relay +IP = 10.0.42.2 +PublicKey = vpeUTmuhSM44waVt0iquAd3E-GvjZ6kvKPHCuMymaks + +... +``` + +WireHub try to go through NATs to establish a peer-to-peer communication. If not +possible, application's network traffic is relayed through relay peers. Trusted +relay peers SHOULD BE preferred (TODO). + +## Public network + +XXX Every peer keep a list of other seen (non-trusted) peers + +## Advanced + +### Sybil attacks + +WireHub provides a Proof-of-Work mechanism to mitigate Sybil attack. Each private +network sets the field `WorkBit`. **Work bits** are the count of MSB bits set to +zero of a Blake2b hash of the Curve25519 public key. Any peer which does not +have the necessary amount of work bits will be rejected. + +The bigger the work bits, the more mitigated the sybil attack is. Each added +work bit multiply by 2 the complexity of generating a new identity. + +## Getting starteg + +The CLI tool to set up WireHub is `wh`. Make sure to enable auto-completion. + diff --git a/config/public b/config/public new file mode 100644 index 0000000..c0d498e --- /dev/null +++ b/config/public @@ -0,0 +1,11 @@ +[Network] +Name = public +Namespace = public +Workbits = 8 + +[Peer] +# Trust = no +Bootstrap = yes +PublicKey = P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w +Endpoint = bootstrap.wirehub.io:62096 + diff --git a/contrib/micronet/.gitignore b/contrib/micronet/.gitignore new file mode 100644 index 0000000..1746e32 --- /dev/null +++ b/contrib/micronet/.gitignore @@ -0,0 +1,2 @@ +bin +obj diff --git a/contrib/micronet/Makefile b/contrib/micronet/Makefile new file mode 100644 index 0000000..da9fa47 --- /dev/null +++ b/contrib/micronet/Makefile @@ -0,0 +1,35 @@ +.PHONY: all clean run-server vgd-server + +MICRONET=bin/micronet +SRC=$(wildcard src/*.c) obj/server.lua.c +HDR=$(wildcard src/*.h) + +CFLAGS=-Wall -Wextra -g -llua + +all: $(MICRONET) + +/dev/net/tun: + mkdir /dev/net + mknod /dev/net/tun c 10 200 + +obj/server.lua.c: src/server.lua scripts/file2buf.py + @mkdir -p obj + luac -o obj/server.luac src/server.lua + scripts/file2buf.py obj/server.luac _luacode_server > obj/server.lua.c + +$(MICRONET): $(SRC) $(HDR) + @mkdir -p bin + $(CC) -o $(MICRONET) $(SRC) $(CFLAGS) + +clean: + rm -rf obj bin + +run-server: $(MICRONET) /dev/net/tun + $(MICRONET) server examples/conf.lua + +gdb-server: $(MICRONET) /dev/net/tun + gdb -ex run -args $(MICRONET) server examples/conf.lua + +vgd-server: $(MICRONET) /dev/net/tun + valgrind $(MICRONET) server examples/conf.lua + diff --git a/contrib/micronet/README.md b/contrib/micronet/README.md new file mode 100644 index 0000000..0ac4fa5 --- /dev/null +++ b/contrib/micronet/README.md @@ -0,0 +1,51 @@ +# micronet + +`micronet` is a small software to simulate IP networks. A network topology is +defined in a configuration file. A server is run and relay the traffic between +the peers. Each peer runs a client which initiate a TUN IP tunnel on which all IP +traffic is routed. + +It is used to test WireHub in a simulated Internet on one single machine. +Containers are spawned with WireHub running, and micronet routes the network +traffic between containers. + +Configuration files language is a DSL over Lua. For example, + +``` +-- Initiate a WAN +W = wan() + +-- Initiate a public peer, with IP 51.15.227.165. It will act as the WireHub's +-- bootstrap node +M(W | peer{up_ip=subnet('51.15.227.165', 0)}) + +-- Initiate another public peer, with IP 1.1.1.1 +M(W | peer{up_ip=subnet("1.1.1.1", 0)}) + +-- Initiate a peer behind a full-cone NAT whose IP is 1.1.1.2 +M(W | nat{up_ip=subnet('1.1.1.2', 0), mode=NAT_FULL_CONE} | peer()) +``` + +## Features + +- **NATs**: symmetric, full-cone, restricted-cone and restricted-port NATs are + supported; + +- **ICMP echo and echo reply**: used for network pings; + +- **UDP**: used by WireHub; + +- **Extensible network componenets with Lua**: network components can be + customized in Lua (see the NAT component). + +## TODO + +- TCP through NAT: TCP traffic going through NAT is currently not supported + and no effort was done to make it work, as not required by WireHub. + +- Hop simulation: Currently, TTL of IP packets are not decremented. A + traceroute will report always one hop. + +- UPnP support + +- Simulated latency and packet drops diff --git a/contrib/micronet/examples/client.sh b/contrib/micronet/examples/client.sh new file mode 100755 index 0000000..e0d0409 --- /dev/null +++ b/contrib/micronet/examples/client.sh @@ -0,0 +1,4 @@ +#!/bin/sh +make /dev/net/tun +make +UNET_SERVERNAME=172.17.0.1 ./bin/micronet client $1 diff --git a/contrib/micronet/examples/conf.lua b/contrib/micronet/examples/conf.lua new file mode 100644 index 0000000..897187b --- /dev/null +++ b/contrib/micronet/examples/conf.lua @@ -0,0 +1,10 @@ +W = wan() + +M(W | peer{up_ip=subnet('51.15.227.165', 0)}) -- root + +M(W | peer{up_ip=subnet("1.1.1.1", 0)}) + +HomeLan = W | nat{up_ip=subnet('1.1.1.2', 0)} +M(HomeLan | peer()) +M(HomeLan | peer()) + diff --git a/contrib/micronet/scripts/file2buf.py b/contrib/micronet/scripts/file2buf.py new file mode 100755 index 0000000..bcaae89 --- /dev/null +++ b/contrib/micronet/scripts/file2buf.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import os +import sys + +MAX = 8 + +fpath = sys.argv[1] +name = sys.argv[2] + +with open(fpath, "rb") as fh: + sys.stdout.write("char %s[] = {" % (name,) ) + + i = 0 + while True: + if i > 0: + sys.stdout.write(", ") + + if i % MAX == 0: + sys.stdout.write("\n\t") + + c = fh.read(1) + + if not c: + sys.stdout.write("\n") + break + + sys.stdout.write("0x%.2x" % (ord(c), )) + + i = i + 1 + + print("};") + print("") + print("unsigned int %s_sz = %s;" % (name, i)) + print("") + diff --git a/contrib/micronet/src/client.c b/contrib/micronet/src/client.c new file mode 100644 index 0000000..26a996a --- /dev/null +++ b/contrib/micronet/src/client.c @@ -0,0 +1,375 @@ +#include +#include +#include +#include +#include +#include "common.h" + +#define UNET_DEFAULT_SERVERNAME "micronet.server" +#define UNET_ENV_SERVERNAME "UNET_SERVERNAME" + +#define UNET_DEFAULT_IFNAME "micronet" +#define UNET_ENV_IFNAME "UNET_IFNAME" + +static char subnet[64] = { 0 }; +static char gateway[INET_ADDRSTRLEN+1]; +static char tun_name[IFNAMSIZ]; +static const char* server_port = UNET_STR(DEFAULT_SERVER_PORT); +static int mtu; +static int node_id; +static int server_fd; +static int tun_fd; +static struct addrinfo* server_addr; + +static inline uint32_t _subnet_mask(int cidr) { + assert(0 <= cidr && cidr <= 32); + if (cidr == 32) { + return 0xffffffff; + } + + return ((1 << cidr)-1) << (32-cidr); +} + +static int get_server_addr(void) { + const char* server_name = getenv(UNET_ENV_SERVERNAME); + if (!server_name) server_name = UNET_DEFAULT_SERVERNAME; + assert(server_name); + + struct addrinfo hint; + memset(&hint, '\0', sizeof hint); + + hint.ai_family = PF_UNSPEC; + hint.ai_socktype = SOCK_DGRAM; + hint.ai_flags = 0; + + int ret; + if ((ret = getaddrinfo(server_name, server_port, &hint, &server_addr)) < 0) { + fprintf(stderr, "unknown host '%s':'%s': %s\n", server_name, server_port, gai_strerror(ret)); + return -errno; + } + + int sock = socket(server_addr->ai_family, server_addr->ai_socktype, server_addr->ai_protocol); + + if (sock < 0) { + ERROR("socket"); + freeaddrinfo(server_addr), server_addr = NULL; + return -errno; + } + + if (fcntl(sock, F_SETFL, fcntl(sock, F_GETFL, 0) | O_NONBLOCK) == -1) { + ERROR("fcntl(... | O_NONBLOCK)"); + close(sock); + return -errno; + } + + server_fd = sock; + return 0; +} + +static int create_tunnel(void) { + mtu = UNET_DEFAULT_MTU; + const char* mtu_s = getenv(UNET_ENV_MTU); + if (mtu_s) mtu = atoi(mtu_s); + + const char* ifname = getenv(UNET_ENV_IFNAME); + if (!ifname) ifname = UNET_DEFAULT_IFNAME; + + if (mtu < 576) { + fprintf(stderr, "MTU smaller than 576\n"); + return -EINVAL; + } + + int fd; + if ((fd = open("/dev/net/tun", O_RDWR)) < 0) { + ERROR("open(\"/dev/net/tun\")"); + return -errno; + } + + struct ifreq ifr; + memset(&ifr, 0, sizeof(ifr)); + ifr.ifr_flags = IFF_TUN | IFF_NO_PI; + strncpy(ifr.ifr_name, ifname, IFNAMSIZ); + + if (ioctl(fd, TUNSETIFF, (void*)&ifr) < 0) { + ERROR("ioctl(TUNSETIFF)"); + close(fd); + return -errno; + } + + memcpy(tun_name, ifr.ifr_name, IFNAMSIZ); + fprintf(stderr, "micronet tun ifname: %s\n", ifr.ifr_name); + + if (fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK) == -1) { + ERROR("fcntl(... | O_NONBLOCK)"); + close(fd); + return -errno; + } + + tun_fd = fd; + return 0; +} + +static int configure_net(void) { + char cmd[128]; + + snprintf(cmd, sizeof(cmd), "echo > /etc/resolv.conf"); + if (system(cmd) < 0) { + return -errno; + } + + snprintf(cmd, sizeof(cmd), "ip link set dev %s mtu %d", tun_name, mtu); + if (system(cmd) < 0) { + return -errno; + } + + snprintf(cmd, sizeof(cmd), "ip addr add %s dev %s", subnet, tun_name); + if (system(cmd) < 0) { + return -errno; + } + + snprintf(cmd, sizeof(cmd), "ip link set %s up", tun_name); + if (system(cmd) < 0) { + return -errno; + } + + if (strcmp(gateway, "0.0.0.0") == 0) { + snprintf(cmd, sizeof(cmd), "ip route replace default dev %s", tun_name); + } else { + snprintf(cmd, sizeof(cmd), "ip route replace default via %s", gateway); + } + if (system(cmd) < 0) { + return -errno; + } + + printf("interface %s up, local addr is %s, gateway is %s.\n", tun_name, subnet, gateway); + return 0; +} + +static int sendto_server(struct iovec* iov, int iovlen) { + struct msghdr m; + m.msg_name = server_addr->ai_addr; + m.msg_namelen = server_addr->ai_addrlen; + m.msg_iov = iov; + m.msg_iovlen = iovlen; + m.msg_control = 0; + m.msg_controllen = 0; + m.msg_flags = 0; + + return sendmsg(server_fd, &m, 0); +} + +static int loop() { + uint32_t nid = htonl(node_id); + + int epollfd = epoll_create1(0); + if (epollfd < 0) { + ERROR("epool_create1"); + return -1; + } + + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = tun_fd; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, tun_fd, &ev) < 0) { + ERROR("epoll_ctl"); + close(epollfd); + return -1; + } + + ev.events = EPOLLIN; + ev.data.fd = server_fd; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, server_fd, &ev) < 0) { + ERROR("epoll_ctl"); + close(epollfd); + return -1; + } + +#define SENDTO_SERVER(iov, iovlen) \ + do { \ + if (sendto_server(iov, iovlen) < 0) { \ + ret = -errno; \ + ERROR("sendto_server"); \ + break; \ + } \ + } while(0) + + const int buf_sz = 64 * 1024; + uint8_t* buf = malloc(buf_sz); + assert(buf); + + if (subnet[0] != 0 && configure_net() < 0) { + ERROR("configure_net"); + close(epollfd); + return -1; + } + + const int max_events = 10; + struct epoll_event events[max_events]; + int ret = 0; + int first_loop = 1; + for (;;) { + int timeout = -1; + + if (first_loop) { + timeout = 0; + } else if (subnet[0] == 0) { + timeout = 1 * 1000; + } + + first_loop = 0; + + int nfds = epoll_wait(epollfd, events, max_events, timeout); + if (nfds < 0) { + switch (errno) { + case EINTR: + continue; + }; + + ret = -errno; + ERROR("epoll_wait"); + break; + } + + int n; + for (n=0; n\n" + "\n" + "Ooptions:\n" + " -h Print this screen and quit\n", + arg0 + ); +} + +int main_client(int argc, char* argv[]) { + int opt; + while ((opt = getopt(argc, argv, "h")) != -1) { + switch (opt) { + case 'h': + default: + help(argv[0]); + return EXIT_FAILURE; + }; + } + + if (optind >= argc) { + fprintf(stderr, "ID required\n"); + help(argv[0]); + return EXIT_FAILURE; + } + const char* id_s = argv[optind]; + node_id = atoi(id_s); + + if (node_id <= 0) { + fprintf(stderr, "ID must be set strictly greater to 0\n"); + return EXIT_FAILURE; + } + + if (get_server_addr() < 0) { + fprintf(stderr, "could not resolve server. abort\n"); + return EXIT_FAILURE; + } + + if (create_tunnel() < 0) { + freeaddrinfo(server_addr), server_addr = NULL; + fprintf(stderr, "could not create tunnel. abort\n"); + return EXIT_FAILURE; + } + + //strcpy(subnet, "192.168.42.2/24"); strcpy(gateway, "192.168.42.1"); + loop(); + + close(tun_fd); + freeaddrinfo(server_addr), server_addr = NULL; + + return EXIT_SUCCESS; +} + diff --git a/contrib/micronet/src/common.h b/contrib/micronet/src/common.h new file mode 100644 index 0000000..234f2a8 --- /dev/null +++ b/contrib/micronet/src/common.h @@ -0,0 +1,41 @@ +#ifndef MICRONET_COMMON_H +#define MICRONET_COMMON_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "config.h" + +#define ERROR(func) \ + fprintf(stderr, func " error: %s (%s:%d)\n", strerror(errno), __FILE__, __LINE__) + +#define UNET_STR(x) UNET_STR_(x) +#define UNET_STR_(x) #x + +#define LOG(...) fprintf(stderr, __VA_ARGS__) +#define LOG_SOCKADDR(addr) \ + do { \ + char addr_s[INET_ADDRSTRLEN+1]; \ + inet_ntop(AF_INET, &(addr)->sin_addr,addr_s, sizeof(addr_s)-1); \ + LOG("%s:%d", addr_s, ntohs((addr)->sin_port)); \ + } while(0) + +#define LOG_ADDR(addr) \ + do { \ + char addr_s[INET_ADDRSTRLEN+1]; \ + inet_ntop(AF_INET, addr, addr_s, sizeof(addr_s)-1); \ + LOG("%s", addr_s); \ + } while(0) + + +#endif // MICRONET_COMMON_H + diff --git a/contrib/micronet/src/conf.c b/contrib/micronet/src/conf.c new file mode 100644 index 0000000..b8b0722 --- /dev/null +++ b/contrib/micronet/src/conf.c @@ -0,0 +1,495 @@ +#include "server.h" +#include "conf.h" + +extern char _luacode_server[]; +extern unsigned int _luacode_server_sz; + +lua_State* L = NULL; + +unsigned int nodes_max = 0; +struct node* nodes = NULL; + +node_id luaN_checkid(lua_State* L, int idx) { + lua_Integer i = luaL_checkinteger(L, idx); + if (i < 1 || nodes_max < i) { luaL_error(L, "invalid ID"); } + return i; +} + +node_id luaN_checkidornil(lua_State* L, int idx) { + lua_Integer i = luaL_checkinteger(L, idx); + if (i < 0 || nodes_max < i) { luaL_error(L, "invalid ID"); } + return i; +} + +static inline uint32_t _subnet_netmask(int cidr) { + assert(0 <= cidr && cidr <= 32); + if (cidr == 32) { + return 0xffffffff; + } + + return ((1 << cidr)-1) << (32-cidr); +} + +/* +** Message handler used to run all chunks +*/ +static int msghandler (lua_State *L) { + const char *msg = lua_tostring(L, 1); + if (msg == NULL) { /* is error object not a string? */ + if (luaL_callmeta(L, 1, "__tostring") && /* does it have a metamethod */ + lua_type(L, -1) == LUA_TSTRING) /* that produces a string? */ + return 1; /* that is the message */ + else + msg = lua_pushfstring(L, "(error object is a %s value)", + luaL_typename(L, 1)); + } + luaL_traceback(L, L, msg, 1); /* append a standard traceback */ + return 1; /* return the traceback */ +} + + +/* +** Interface to 'lua_pcall', which sets appropriate message function +** and C-signal handler. Used to run all chunks. +*/ +int docall (lua_State *L, int narg, int nres) { + int status; + int base = lua_gettop(L) - narg; /* function index */ + lua_pushcfunction(L, msghandler); /* push message handler */ + lua_insert(L, base); /* put it under function and args */ + status = lua_pcall(L, narg, nres, base); + lua_remove(L, base); /* remove message handler from the stack */ + return status; +} + +static int _addr_tostring(lua_State* L) { + struct in_addr* addr = luaL_checkudata(L, 1, "ip4"); + + char addr_s[INET_ADDRSTRLEN+1]; + inet_ntop(AF_INET, addr, addr_s, sizeof(addr_s)-1); + + lua_pushstring(L, addr_s); + return 1; +} + +static int _addr_index(lua_State* L) { + struct in_addr* addr = luaL_checkudata(L, 1, "ip4"); + const char* name = luaL_checkstring(L, 2); + int is_get = lua_gettop(L) == 2; + + if (strcmp(name, "s_addr") == 0) { + if (is_get) { + lua_pushnumber(L, ntohl(addr->s_addr)); + return 1; + } else { + addr->s_addr = htonl(luaL_checkinteger(L, 3)); + } + } + + return 0; + //return luaL_error(L, "unknown field: %s", name); +} + +static int _now(lua_State* L) { + lua_Number n = now.tv_sec + (double)now.tv_nsec / 1.0e9; + lua_pushnumber(L, n); + return 1; +} + +static void _pushsubnetmt(lua_State* L); + +static int _subnet_ip(lua_State* L) { + struct subnet* subnet = luaL_checkudata(L, 1, "subnet"); + + struct in_addr* addr = lua_newuserdata(L, sizeof(struct in_addr)); + *addr = subnet->addr; + if (luaL_newmetatable(L, "ip4")) { + lua_pushcfunction(L, _addr_index); + lua_setfield(L, -2, "__index"); + lua_pushcfunction(L, _addr_index); + lua_setfield(L, -2, "__newindex"); + lua_pushcfunction(L, _addr_tostring); + lua_setfield(L, -2, "__tostring"); + } + lua_setmetatable(L, -2); + return 1; +} + +static int _subnet_next(lua_State* L) { + struct subnet* subnet = luaL_checkudata(L, 1, "subnet"); + struct subnet* next = lua_newuserdata(L, sizeof(struct subnet)); + + _pushsubnetmt(L); + lua_setmetatable(L, -2); + + next->cidr = subnet->cidr; + + uint32_t netmask = _subnet_netmask(subnet->cidr); + uint32_t ip = ntohl(subnet->addr.s_addr); + uint32_t subnet_ip = ip & ~netmask; + ++subnet_ip; + if ((subnet_ip & netmask) != 0) { + luaL_error(L, "out of IPs"); + } + ip = (ip & netmask) | subnet_ip; + + next->addr.s_addr = htonl(ip); + return 1; +} + +static int _subnet_tostring(lua_State* L) { + struct subnet* subnet = luaL_checkudata(L, 1, "subnet"); + + char addr_s[INET_ADDRSTRLEN+1]; + inet_ntop(AF_INET, &subnet->addr, addr_s, sizeof(addr_s)-1); + + lua_pushfstring(L, "%s/%d", addr_s, subnet->cidr); + return 1; +} + +static void _pushsubnetmt(lua_State* L) { + if (luaL_newmetatable(L, "subnet")) { + lua_newtable(L); // "__index" + + lua_pushcfunction(L, _subnet_ip); + lua_setfield(L, -2, "ip"); + lua_pushcfunction(L, _subnet_next); + lua_setfield(L, -2, "next"); + + lua_pushcfunction(L, _subnet_tostring); + lua_setfield(L, -2, "__tostring"); + + lua_setfield(L, -2, "__index"); + } +} + +static int _subnet(lua_State* L) { + const char* addr_s = luaL_checkstring(L, 1); + lua_Integer cidr = luaL_checkinteger(L, 2); + + if (cidr < 0 || 32 < cidr) { + luaL_error(L, "bad CIDR"); + } + + struct subnet* subnet = lua_newuserdata(L, sizeof(struct subnet)); + + int ret = inet_pton(AF_INET, addr_s, &subnet->addr); + if (ret <= 0) { + luaL_error(L, "bad IPv4"); + } + subnet->cidr = cidr; + + _pushsubnetmt(L); + lua_setmetatable(L, -2); + + return 1; +} + +static int _randomwan(lua_State* L) { + struct subnet* subnet = lua_newuserdata(L, sizeof(struct subnet)); + + int fd = open("/dev/urandom", 0); + if (read(fd, &subnet->addr, 4) < 4) { + luaL_error(L, "urandom failed"); + } + subnet->cidr = 0; + + _pushsubnetmt(L); + lua_setmetatable(L, -2); + + return 1; +} + +static int _alloc_nodes(lua_State* L) { + if (nodes) { + luaL_error(L, "already allocated"); + } + + lua_Integer max = luaL_checkinteger(L, 1); + + if (max < 0) { + luaL_error(L, "max cannot be negative"); + } + + nodes_max = max; + nodes = calloc(max, sizeof(struct node)); + + return 0; +} + +struct node* _init_node(lua_State* L) { + node_id i = luaN_checkid(L, 1); + const char* type = lua_tostring(L, 2); + node_id up = luaN_checkidornil(L, 3); + + struct node* n = NODE(i); + n->id = i; + n->type = type ? strdup(type) : NULL; + n->up = up; + + return n; +} + +static int packet_tostring(lua_State* L) { + struct packet* p = luaL_checkudata(L, 1, PACKET_MT); + + char src[INET_ADDRSTRLEN+1], dst[INET_ADDRSTRLEN+1]; + assert(inet_ntop(AF_INET, &p->hdr.ip_src, src, sizeof(src)-1)); + assert(inet_ntop(AF_INET, &p->hdr.ip_dst, dst, sizeof(dst)-1)); + + const char* type; + switch (p->hdr.ip_p) { + case IPPROTO_TCP: type = "TCP"; break; + case IPPROTO_UDP: type = "UCP"; break; + case IPPROTO_ICMP: type = "ICMP"; break; + default: type = "?"; break; + }; + + lua_pushfstring(L, "packet p:%s dir:%s ", + type, + p->dir == UP ? "UP" : "DOWN", + src, dst, + p + ); + + const void* payload = packet_ip_payload(p, NULL); + + switch (p->hdr.ip_p) { + case IPPROTO_TCP: + case IPPROTO_UDP: + { + uint16_t sport = ntohs(((const uint16_t*)payload)[0]); + uint16_t dport = ntohs(((const uint16_t*)payload)[1]); + lua_pushfstring(L, "src:%s:%d dst:%s:%d", src, sport, dst, dport); + } + break; + default: + lua_pushfstring(L, "src:%s dst:%s", src, dst); + break; + }; + + lua_pushfstring(L, ": %p", p); + lua_concat(L, 3); + + return 1; +} + +static int packet_index(lua_State* L) { + struct packet* p = luaL_checkudata(L, 1, PACKET_MT); + const char* n = luaL_checkstring(L, 2); + int is_get = lua_gettop(L) == 2; + +#define FIELD(name, var, bitsize, ntoh, hton) \ + do { \ + if (strcmp(n, name) == 0) { \ + if (is_get) { \ + lua_pushnumber(L, ntoh(var)); \ + return 1; \ + } else { \ + uint32_t v = hton(luaL_checknumber(L, 3)); \ + if ((bitsize) > 0) { \ + uint64_t mask = ((((uint64_t)1) << ((bitsize) + 1)) - 1); \ + if ((v & ~mask) != 0) { \ + luaL_error(L, "bad value"); \ + } \ + } \ + var = v; \ + return 0; \ + } \ + } \ + } while(0) + + FIELD("dir", p->dir, 1, , ); + FIELD("from_id", p->from_id, 0, , ); + +#define FIELD_IP_INT(name, bitsize) FIELD(#name, p->hdr.ip_##name, bitsize, ,) +#define FIELD_IP_IP(name) FIELD(#name, p->hdr.ip_##name.s_addr, 32, ntohl, htonl) + + FIELD_IP_INT(hl, 4); + FIELD_IP_INT(v, 4); + FIELD_IP_INT(tos, 8); + FIELD_IP_INT(len, 16); + FIELD_IP_INT(id, 16); + FIELD_IP_INT(off, 16); + FIELD_IP_INT(ttl, 8); + FIELD_IP_INT(p, 8); + FIELD_IP_INT(sum, 16); + FIELD("saddr", p->hdr.ip_src.s_addr, 32, ntohl, htonl); + FIELD("daddr", p->hdr.ip_dst.s_addr, 32, ntohl, htonl); + + void* payload = packet_ip_payload(p, NULL); + + switch (p->hdr.ip_p) { + case IPPROTO_TCP: + case IPPROTO_UDP: + { + uint16_t* p_sport = (((uint16_t*)payload)+0); + uint16_t* p_dport = (((uint16_t*)payload)+1); + + FIELD("sport", *p_sport, 16, ntohs, htons); + FIELD("dport", *p_dport, 16, ntohs, htons); + } + break; + + case IPPROTO_ICMP: + { + struct icmphdr* icmp = payload; + + FIELD("icmp_type", icmp->type, 8, , ); + FIELD("icmp_echo_id", icmp->un.echo.id, 16, ntohs, htons); + } + break; + } + +#undef FIELD +#undef FIELD_IP_INT +#undef FIELD_IP_IP + + return 0; + //return luaL_error(L, "unknown field: %s", n); +} + +static void install_packet(lua_State* L) { + +#define GLOBAL(v) \ + do { \ + lua_pushnumber(L, v); \ + lua_setglobal(L, #v); \ + } while(0) + + GLOBAL(UP); + GLOBAL(DOWN); + GLOBAL(ICMP_ECHO); + GLOBAL(ICMP_ECHOREPLY); + GLOBAL(IPPROTO_ICMP); + GLOBAL(IPPROTO_TCP); + GLOBAL(IPPROTO_UDP); + +#undef GLOBAL + + luaL_newmetatable(L, PACKET_MT); + + lua_pushcfunction(L, packet_index); + lua_setfield(L, -2, "__index"); + lua_pushcfunction(L, packet_index); + lua_setfield(L, -2, "__newindex"); + lua_pushcfunction(L, packet_tostring); + lua_setfield(L, -2, "__tostring"); + + lua_pop(L, 1); +} + +int load_config(const char* confpath) { + int err; + assert(confpath); + + L = luaL_newstate(); + if (!L) { + ERROR("luaL_newstate"); + err = -1; + goto finally; + } + + luaL_openlibs(L); + + lua_pushglobaltable(L); + + lua_pushcfunction(L, _now); + lua_setfield(L, -2, "now"); + lua_pushcfunction(L, _subnet); + lua_setfield(L, -2, "subnet"); + lua_pushcfunction(L, _randomwan); + lua_setfield(L, -2, "randomwan"); + lua_pushcfunction(L, _alloc_nodes); + lua_setfield(L, -2, "_alloc_nodes"); + lua_pushcfunction(L, _peer); + lua_setfield(L, -2, "_peer"); + lua_pushcfunction(L, _link); + lua_setfield(L, -2, "_link"); + lua_pushcfunction(L, _nat); + lua_setfield(L, -2, "_nat"); + lua_pushcfunction(L, _wan); + lua_setfield(L, -2, "_wan"); + + install_packet(L); + + int status = luaL_loadbuffer(L, _luacode_server, _luacode_server_sz, "server.lua"); + if (status == LUA_OK) { + status = docall(L, 0, 0); + } + + if (status == LUA_OK) { + status = luaL_loadfile(L, confpath); + } + + if (status == LUA_OK) { + status = docall(L, 0, 0); + } + + if (status == LUA_OK) { + lua_pushglobaltable(L); + lua_getfield(L, -1, "_build"); + status = docall(L, 0, 0); + } + + if (status != LUA_OK) { + fprintf(stderr, "cannot load confpath: %s\n", + lua_tostring(L, -1) + ); + err = -1; + goto finally; + } + + err = 0; +finally: + return err; +} + +///// + +static void help(char* arg0) { + fprintf(stderr, + "Usage: %s [OPTS] \n" + "\n" + "Ooptions:\n" + " -h Print this screen and quit\n", + arg0 + ); +} + +int main_read(int argc, char* argv[]) { + int opt; + while ((opt = getopt(argc, argv, "h")) != -1) { + switch (opt) { + case 'h': + default: + help(argv[0]); + return EXIT_FAILURE; + }; + } + + if (optind >= argc) { + fprintf(stderr, "configuration required\n"); + help(argv[0]); + return EXIT_FAILURE; + } + + const char* confpath = argv[optind]; + + if (load_config(confpath) < 0) { + fprintf(stderr, "cannot load conf. abort\n"); + return EXIT_FAILURE; + } + + for (unsigned int i=1; i<=nodes_max; ++i) { + struct node* n = NODE(i); + + printf("%d\t%s\t%d\n", n->id, n->type, n->up); + } + + + if (L) lua_close(L), L = NULL; + + return EXIT_SUCCESS; +} + diff --git a/contrib/micronet/src/conf.h b/contrib/micronet/src/conf.h new file mode 100644 index 0000000..7745974 --- /dev/null +++ b/contrib/micronet/src/conf.h @@ -0,0 +1,77 @@ +#ifndef MICRONET_CONF_H +#define MICRONET_CONF_H + +#include "common.h" +#include +#include +#include + +#define PACKET_MT "packet" + +typedef uint32_t node_id; +#define NODEID_NULL ((node_id)0) + +extern lua_State* L; +extern unsigned int nodes_max; +extern struct node* nodes; + +struct node; +struct packet; + +typedef void(*node_kernel_cb)(struct node* n, struct packet* p); + +struct subnet { + struct in_addr addr; + uint8_t cidr; +}; + +struct route { + struct subnet subnet; + node_id id; +}; + +struct node { + node_id id; + struct sockaddr_in addr; + struct packet* pkts_heap,* pkts_tail; + node_kernel_cb kernel; + char* type; + node_id up; + + union { + struct { + struct subnet ip, gw; + } peer; + + struct { + node_id down; + } link; + + struct { + int count; + struct route* routes; + } wan; + + struct { + int kernel_ref; + struct subnet ip, gw; + } nat; + } as; +}; + +int load_config(const char* confpath); + +node_id luaN_checkid(lua_State* L, int idx); +node_id luaN_checkidornil(lua_State* L, int idx); + +struct node* _init_node(lua_State* L); + +static inline struct node* NODE(unsigned int i) { + assert(i > 0 && i <= nodes_max); + return &nodes[i-1]; +} + +int docall (lua_State *L, int narg, int nres); + +#endif // MICRONET_CONF_H + diff --git a/contrib/micronet/src/config.h b/contrib/micronet/src/config.h new file mode 100644 index 0000000..5cbf1b9 --- /dev/null +++ b/contrib/micronet/src/config.h @@ -0,0 +1,11 @@ +#ifndef MICRONET_CONFIG_H +#define MICRONET_CONFIG_H + +#define DEFAULT_SERVER_PORT 4321 + +#define UNET_DEFAULT_MTU 1500 +#define UNET_ENV_MTU "UNET_MTU" + + +#endif // MICRONET_CONFIG_H + diff --git a/contrib/micronet/src/link.c b/contrib/micronet/src/link.c new file mode 100644 index 0000000..7d5004a --- /dev/null +++ b/contrib/micronet/src/link.c @@ -0,0 +1,21 @@ +#include "server.h" + +static void _link_kernel(struct node* n, struct packet* p) { + // XXX implement + if (p->dir == UP) { + sendto_id(n, n->up, p); + + } else { // p->dir == DOWN + sendto_id(n, n->as.link.down, p); + } +} + +int _link(lua_State* L) { + struct node* n = _init_node(L); + node_id down = luaN_checkid(L, NODE_ARGS(1)); + + n->kernel = _link_kernel; + n->as.link.down = down; + + return 0; +} diff --git a/contrib/micronet/src/micronet.c b/contrib/micronet/src/micronet.c new file mode 100644 index 0000000..04cad58 --- /dev/null +++ b/contrib/micronet/src/micronet.c @@ -0,0 +1,54 @@ +#include +#include +#include + +int main_server(int argc, char* argv[]); +int main_client(int argc, char* argv[]); +int main_read(int argc, char* argv[]); + +static int help(char* arg0) { + fprintf(stderr, + "\n" + "Usage: %s COMMAND\n" + "\n" + "Commands:\n" + " client Run a client\n" + " read Read configuration\n" + " server Run a server\n" + "\n", + arg0 + ); + + return EXIT_FAILURE; +} + +int main(int argc, char* argv[]) { + if (argc == 0) { + return EXIT_FAILURE; + } + + char* arg0 = argv[0]; + + int cmd_idx = 0; + char* m = strstr(arg0, "micronet"); + if (m != NULL) { + ++cmd_idx; + } + + if (argc <= cmd_idx) { + return help(arg0); + } + + argc -= cmd_idx; + argv += cmd_idx; + + if (strcmp(argv[0], "client") == 0) { + return main_client(argc, argv); + } else if (strcmp(argv[0], "server") == 0) { + return main_server(argc, argv); + } else if (strcmp(argv[0], "read") == 0) { + return main_read(argc, argv); + } else { + return help(arg0); + } +} diff --git a/contrib/micronet/src/nat.c b/contrib/micronet/src/nat.c new file mode 100644 index 0000000..2e47b78 --- /dev/null +++ b/contrib/micronet/src/nat.c @@ -0,0 +1,48 @@ +#include "server.h" + +#define PORT_MAX 65535 +#define ICMP_ID_MAX 65535 +#define NAT_TIMEOUT 60 + +static void _nat_kernel(struct node* n, struct packet* p) { + lua_rawgeti(L, LUA_REGISTRYINDEX, n->as.nat.kernel_ref); + + lua_pushlightuserdata(L, p); + luaL_setmetatable(L, PACKET_MT); + + if (docall(L, 1, 1) != LUA_OK) { + fprintf(stderr, "kernel error: %s\n", + lua_tostring(L, -1) + ); + } + + int isnum; + int id = lua_tointegerx(L, -1, &isnum); + const char* name = lua_tostring(L, -1); + lua_pop(L, 1); + + if (isnum) { + packet_refresh_sum(p); + sendto_id(n, id, p); + } else { + if (!name) name = "unknown reason"; + DROP(n, p, "%s", name); + } +} + +int _nat(lua_State* L) { + struct node* n = _init_node(L); + struct subnet* up_ip = luaL_checkudata(L, NODE_ARGS(1), "subnet"); + struct subnet* up_gw = luaL_checkudata(L, NODE_ARGS(2), "subnet"); + luaL_checktype(L, NODE_ARGS(3), LUA_TFUNCTION); + + n->kernel = _nat_kernel; + n->as.nat.ip = *up_ip; + n->as.nat.gw = *up_gw; + + lua_pushvalue(L, NODE_ARGS(3)); + n->as.nat.kernel_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + return 0; +} + diff --git a/contrib/micronet/src/nat.c.bkp b/contrib/micronet/src/nat.c.bkp new file mode 100644 index 0000000..c90ec0b --- /dev/null +++ b/contrib/micronet/src/nat.c.bkp @@ -0,0 +1,217 @@ +#include "server.h" + +#define PORT_MAX 65535 +#define ICMP_ID_MAX 65535 +#define NAT_TIMEOUT 60 + +static struct nat_tcpudp* _get_nat_tcpudp(struct node* n, int port) { + struct nat_tcpudp* e = &n->as.nat.tcpudp_map[port]; + + // timeout + if (e->id != NODEID_NULL && now.tv_sec > e->opened_ts + NAT_TIMEOUT) { + e->id = NODEID_NULL; + } + + return e; +} + +static struct nat_icmp* _get_nat_icmp(struct node* n, int iid) { + struct nat_icmp* e = &n->as.nat.icmp_map[iid]; + + // timeout + if (e->id != NODEID_NULL && now.tv_sec > e->opened_ts + NAT_TIMEOUT) { + e->id = NODEID_NULL; + } + + return e; +} + + +static void _nat_kernel_tcpudp(struct node* n, struct packet* p) { + void* tcpudp = packet_ip_payload(p, NULL); + + uint16_t* p_sport = (((uint16_t*)tcpudp)+0); + uint16_t* p_dport = (((uint16_t*)tcpudp)+1); + + if (p->dir == UP) { + // search for a available source port + uint32_t nat_sport = ntohs(*p_sport); + int first = 1; + while (nat_sport <= PORT_MAX) { + struct nat_tcpudp* e = _get_nat_tcpudp(n, nat_sport); + + if (e->id == NODEID_NULL || ( + e->saddr.s_addr == p->hdr.ip_src.s_addr && + e->daddr.s_addr == p->hdr.ip_dst.s_addr && + e->sport == *p_sport && + e->dport == *p_dport + )) { + // create entry if not + if (e->id == NODEID_NULL) { + e->id = p->from_id; + e->saddr = p->hdr.ip_src; + e->daddr = p->hdr.ip_dst; + e->sport = *p_sport; + e->dport = *p_dport; + } + + // touch entry + e->opened_ts = now.tv_sec; + + // modify IP header + p->hdr.ip_src.s_addr = n->as.nat.ip.addr.s_addr; + *p_sport = htons(nat_sport); + + // redirect to up + packet_refresh_sum(p); + sendto_id(n, n->up, p); + return; + } + + if (first) { + nat_sport = 1024; + } else { + ++nat_sport; + } + + first = 0; + } + + DROP(n, p, "NAT: out of ports"); + } + + else { // p->dir == DOWN + // look at entry + struct nat_tcpudp* e = _get_nat_tcpudp(n, ntohs(*p_dport)); + + if (e->id == NODEID_NULL) { + DROP(n, p, "port not opened"); + } + + if ( + e->daddr.s_addr != p->hdr.ip_src.s_addr || + e->dport != *p_sport + ) { + DROP(n, p, "bad NAT mapping (NAT not full cone)"); + } + + // modify IP header + p->hdr.ip_dst.s_addr = e->saddr.s_addr; + *p_dport = e->sport; + + // redirect to down + packet_refresh_sum(p); + sendto_id(n, e->id, p); + return; + } +} + +static void _nat_kernel_icmp(struct node* n, struct packet* p) { + struct icmphdr* icmp = packet_ip_payload(p, NULL); + + if (p->dir == UP) { + if (icmp->type == ICMP_ECHO) { + uint32_t iid = ntohs(icmp->un.echo.id); + int first = 1; + while (iid <= ICMP_ID_MAX) { + struct nat_icmp* e = _get_nat_icmp(n, iid); + + if (e->id == NODEID_NULL || ( + e->siid == icmp->un.echo.id && + e->saddr.s_addr == p->hdr.ip_src.s_addr + )) { + + // create entry if not + if (e->id == NODEID_NULL) { + e->id = p->from_id; + e->saddr.s_addr = p->hdr.ip_src.s_addr; + e->siid = icmp->un.echo.id; + } + + // touch entry + e->opened_ts = now.tv_sec; + + // modify IP header + p->hdr.ip_src.s_addr = n->as.nat.ip.addr.s_addr; + icmp->un.echo.id = htons(iid); + + // redirect to up + packet_refresh_sum(p); + sendto_id(n, n->up, p); + return; + } + + if (first) { + iid = 1; + } else { + ++iid; + } + + first = 0; + } + } + + DROP(n, p, "NAT ICMP: drop"); + } + + else { // p->dir == DOWN + // XXX drop if fragmented + if (icmp->type == ICMP_ECHO) { + p->dir = p->dir == UP ? DOWN : UP; + p->hdr.ip_dst.s_addr = p->hdr.ip_src.s_addr; + p->hdr.ip_src.s_addr = n->as.nat.ip.addr.s_addr; + icmp->type = ICMP_ECHOREPLY; + + sendto_id(n, p->from_id, p); + } + + else if (icmp->type == ICMP_ECHOREPLY) { + struct nat_icmp* e = _get_nat_icmp(n, ntohs(icmp->un.echo.id)); + + if (e->id == NODEID_NULL) { + DROP(n, p, "ICMP ID not opened"); + } + + // modify IP header + p->hdr.ip_dst.s_addr = e->saddr.s_addr; + icmp->un.echo.id = e->siid; + + // redirect to down + packet_refresh_sum(p); + sendto_id(n, e->id, p); + return; + } + } +} + +static void _nat_kernel(struct node* n, struct packet* p) { + switch (p->hdr.ip_p) { + case IPPROTO_UDP: + case IPPROTO_TCP: + return _nat_kernel_tcpudp(n, p); + + case IPPROTO_ICMP: + return _nat_kernel_icmp(n, p); + + default: + DROP(n, p, "unknown IP protocol"); + }; +} + +int _nat(lua_State* L) { + struct node* n = _init_node(L); + struct subnet* up_ip = luaL_checkudata(L, NODE_ARGS(1), "subnet"); + struct subnet* up_gw = luaL_checkudata(L, NODE_ARGS(2), "subnet"); + luaL_checktype(L, NODE_ARGS(3), LUA_TFUNCTION); + + n->kernel = _nat_kernel; + n->as.nat.kernel_ref = luaL_ref(L, NODE_ARGS(3)); + n->as.nat.ip = *up_ip; + n->as.nat.gw = *up_gw; + n->as.nat.tcpudp_map = calloc(PORT_MAX+1, sizeof(struct nat_tcpudp)); + n->as.nat.icmp_map = calloc(PORT_MAX+1, sizeof(struct nat_icmp)); + + + return 0; +} + diff --git a/contrib/micronet/src/peer.c b/contrib/micronet/src/peer.c new file mode 100644 index 0000000..40cac99 --- /dev/null +++ b/contrib/micronet/src/peer.c @@ -0,0 +1,44 @@ +#include "server.h" + +static void _peer_kernel(struct node* n, struct packet* p) { + if (p->dir == UP) { + if (n->as.peer.ip.addr.s_addr != p->hdr.ip_src.s_addr) { + DROP(n, p, "bad source address"); + } + + sendto_id(n, n->up, p); + } + + else { // p->dir == DOWN + if (n->as.peer.ip.addr.s_addr != p->hdr.ip_dst.s_addr) { + DROP(n, p, "bad destination address"); + } + + if (n->addr.sin_addr.s_addr == 0) { + DROP(n, p, "unknown micronet client address"); + } + + struct iovec iov[1]; + iov[0].iov_base = p->body; + iov[0].iov_len = p->sz; + _udp_sendto(&n->addr, iov, 1); + + LOG("\n"); + free(p), p=NULL; + } +} + +int _peer(lua_State* L) { + struct node* n = _init_node(L); + struct subnet* up_ip = luaL_checkudata(L, NODE_ARGS(1), "subnet"); + struct subnet* up_gw = luaL_checkudata(L, NODE_ARGS(2), "subnet"); + + n->kernel = _peer_kernel; + n->as.peer.ip = *up_ip; + n->as.peer.gw = *up_gw; + + + return 0; +} + + diff --git a/contrib/micronet/src/server.c b/contrib/micronet/src/server.c new file mode 100644 index 0000000..7bde194 --- /dev/null +++ b/contrib/micronet/src/server.c @@ -0,0 +1,411 @@ +#include "server.h" + +struct timespec now; +static int server_fd = -1; + +static uint32_t ip_chksum_update(uint32_t sum, void* buf_, int count) { + const uint16_t* buf = (uint16_t*)buf_; + + // Sum up 2-byte values until none or only one byte left. + while (count > 1) { + sum += *(buf++); + count -= 2; + } + + // Add left-over byte, if any. + if (count > 0) { + sum += *(uint8_t*)buf; + } + + return sum; +} + +static uint16_t ip_chksum_final(uint32_t sum) { + uint16_t answer = 0; + // Fold 32-bit sum into 16 bits; we lose information by doing this, + // increasing the chances of a collision. + // sum = (lower 16 bits) + (upper 16 bits shifted right 16 bits) + while (sum >> 16) { + sum = (sum & 0xffff) + (sum >> 16); + } + + // Checksum is one's compliment of sum. + answer = ~sum; + + return answer; + +} + +static inline uint16_t ip_chksum(void* buf, int count) { + uint32_t sum = 0; + sum = ip_chksum_update(sum, buf, count); + return ip_chksum_final(sum); +} + +struct ip_pseudohdr { + uint32_t src; + uint32_t dst; + uint8_t zero; + uint8_t p; + uint16_t len; +}; + +void packet_refresh_sum(struct packet* p) { + size_t payload_sz; + void* payload = packet_ip_payload(p, &payload_sz); + + p->hdr.ip_sum = 0; + p->hdr.ip_sum = ip_chksum(&p->hdr, payload - (void*)&p->hdr); + + switch (p->hdr.ip_p) { + case IPPROTO_UDP: + ((struct udphdr*)payload)->uh_sum = 0; + // XXX calculate checksum + break; + + case IPPROTO_ICMP: + ((struct icmphdr*)payload)->checksum = 0; + ((struct icmphdr*)payload)->checksum = ip_chksum(payload, payload_sz); + break; + + case IPPROTO_TCP: + // XXX the following block seems to be buggy + { + + struct ip_pseudohdr h = { + .src = p->hdr.ip_src.s_addr, + .dst = p->hdr.ip_dst.s_addr, + .zero = 0, + .p = IPPROTO_TCP, + .len = payload_sz, + }; + assert(sizeof(h) == 12); + + ((struct tcphdr*)payload)->th_sum = 0; + + uint32_t sum = 0; + sum = ip_chksum_update(sum, &h, sizeof(h)); + sum = ip_chksum_update(sum, payload, payload_sz); + + ((struct tcphdr*)payload)->th_sum = ip_chksum_final(sum); + } + break; + } +} + +static void process_packet(struct node* n, struct packet* p) { + if (p->hdr.ip_v != 4) { + DROP(n, p, "not IPv4 packet"); + return; + } + + if (n->kernel) { + n->kernel(n, p); + } else { + DROP(n, p, "undefined kernel"); + } +} + +static void run_node(struct node* n) { + // XXX add latency simulation + while (n->pkts_heap) { + struct packet* p = n->pkts_heap; + + n->pkts_heap = p->next; + if (!n->pkts_heap) { + n->pkts_tail = NULL; + } + + process_packet(n, p); + } +} + +void sendto_id(struct node* from_n, node_id to_id, struct packet* p) { + if (to_id == NODEID_NULL) { + DROP(from_n, p, "cannot send to node id 0"); + } + + struct node* to_n = NODE(to_id); + p->from_id = from_n->id; + + if (to_n->pkts_heap) { + to_n->pkts_tail->next = p; + } else { + to_n->pkts_heap = to_n->pkts_tail = p; + } + to_n->pkts_tail = p; + + LOG("-> #%d ", to_id); + + run_node(to_n); +} + +void hex(const uint8_t* buf, size_t sz) { + size_t i; + for (i=0; i 0 && i % 16 == 0) { + printf("\n"); + } + + printf("%.2x ", (int)buf[i]); + } + printf("\n"); +} + +void print_packet(FILE* fh, struct packet* p) { + char src[INET_ADDRSTRLEN+1], dst[INET_ADDRSTRLEN+1]; + uint16_t sport = 0, dport = 0; + void* payload = packet_ip_payload(p, NULL); + + assert(inet_ntop(AF_INET, &p->hdr.ip_src, src, sizeof(src)-1)); + assert(inet_ntop(AF_INET, &p->hdr.ip_dst, dst, sizeof(dst)-1)); + + char proto[64]; + switch (p->hdr.ip_p) { + case IPPROTO_UDP: + strncpy(proto, "udp", sizeof(proto)); + sport = ntohs(((struct udphdr*)payload)->uh_sport); + dport = ntohs(((struct udphdr*)payload)->uh_dport); + break; + + case IPPROTO_TCP: + strncpy(proto, "tcp", sizeof(proto)); + sport = ntohs(((struct tcphdr*)payload)->th_sport); + dport = ntohs(((struct tcphdr*)payload)->th_dport); + break; + + case IPPROTO_ICMP: + strncpy(proto, "icmp", sizeof(proto)); + break; + + default: + snprintf(proto, sizeof(proto), "0x%.2x", p->hdr.ip_p); + break; + } + + + fprintf(stderr, "\nXXX %.8x\n", p->hdr.ip_dst.s_addr); + + fprintf(fh, + "src:%s:%d dst:%s:%d\n" + "proto: %s\n" , + src, sport, dst, dport, + proto + ); +} + +int _udp_sendto(struct sockaddr_in* peer_addr, struct iovec* iov, int iovlen) { + struct msghdr m; + m.msg_name = peer_addr; + m.msg_namelen = sizeof(struct sockaddr_in); + m.msg_iov = iov; + m.msg_iovlen = iovlen; + m.msg_control = 0; + m.msg_controllen = 0; + m.msg_flags = 0; + + return sendmsg(server_fd, &m, 0); +} + +static void on_packet(struct packet* p, struct sockaddr_in* peer_addr) { + (void)peer_addr; + + p->id = ntohl(p->id); + LOG(" ID:%d ", p->id); + if (nodes_max < p->id) { + LOG("over limit\n"); + free(p); + return; + } + + struct node* n = NODE(p->id); + if (strcmp(n->type, "peer") != 0) { + LOG("ERROR: not a peer but a '%s'!\n", n->type); + free(p); + return; + } + + memcpy(&n->addr, peer_addr, sizeof(n->addr)); + + LOG("%s ", n->type); + + if (p->sz == 4) { + LOG("assign? "); + LOG("returns IP:"); + LOG_ADDR(&n->as.peer.ip.addr); + LOG("/%d gw:", n->as.peer.ip.cidr); + LOG_ADDR(&n->as.peer.gw.addr); + LOG("\n"); + + struct iovec iov[3]; + iov[0].iov_base = &n->as.peer.ip.addr; + iov[0].iov_len = sizeof(n->as.peer.ip.addr); + iov[1].iov_base = &n->as.peer.gw.addr; + iov[1].iov_len = sizeof(n->as.peer.gw.addr); + iov[2].iov_base = &n->as.peer.ip.cidr; + iov[2].iov_len = sizeof(n->as.peer.ip.cidr); + _udp_sendto(peer_addr, iov, 3); + + return; + } + + p->sz -= 4; + + //print_packet(stdout, p); +#if 0 + LOG("hex { \n"); + hex(p->body, p->sz); + LOG("} "); +#endif + + p->dir = UP; + sendto_id(n, n->id, p); +} + +/***/ + +static void help(char* arg0) { + fprintf(stderr, + "Usage: %s [OPTS] \n" + "\n" + "Ooptions:\n" + " -h Print this screen and quit\n", + arg0 + ); +} + +int create_server_socket(void) { + server_fd = socket(AF_INET, SOCK_DGRAM, 0); + + if (server_fd < 0) { + ERROR("socket"); + return -errno; + } + + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(DEFAULT_SERVER_PORT); + if (bind(server_fd, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + ERROR("bind"); + return -errno; + } + + if (fcntl(server_fd, F_SETFL, fcntl(server_fd, F_GETFL, 0) | O_NONBLOCK) == -1) { + ERROR("fcntl(... | O_NONBLOCK)"); + close(server_fd); + return -errno; + } + + return 0; +} + +static int loop() { + int epollfd = epoll_create1(0); + if (epollfd < 0) { + ERROR("epool_create1"); + return -1; + } + + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = server_fd; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, server_fd, &ev) < 0) { + ERROR("epoll_ctl"); + close(epollfd); + return -1; + } + + int ret; + const int max_events = 10; + struct epoll_event events[max_events]; + for (;;) { + int timeout = -1; + int nfds = epoll_wait(epollfd, events, max_events, timeout); + + clock_gettime(CLOCK_REALTIME, &now); + + if (nfds < 0) { + switch (errno) { + case EINTR: + continue; + }; + + ret = -errno; + ERROR("epoll_wait"); + break; + } + + int n; + for (n=0; nbuf, sizeof(p->buf), 0, (struct sockaddr*)&peer_addr, &peer_addr_sz); + + if (sz < 0 && errno == EAGAIN) { + break; + } + + if (sz < 0) { + ret = -errno; + ERROR("recvfrom"); + break; + } + + p->sz = sz; + on_packet(p, &peer_addr); + } + } + } + + return ret; +} + +int main_server(int argc, char* argv[]) { + int opt; + while ((opt = getopt(argc, argv, "h")) != -1) { + switch (opt) { + case 'h': + default: + help(argv[0]); + return EXIT_FAILURE; + }; + } + + if (optind >= argc) { + fprintf(stderr, "configuration required\n"); + help(argv[0]); + return EXIT_FAILURE; + } + + const char* confpath = argv[optind]; + + if (create_server_socket() < 0) { + fprintf(stderr, "cannot create server sock. abort\n"); + return EXIT_FAILURE; + } + + if (load_config(confpath) < 0) { + fprintf(stderr, "cannot load conf. abort\n"); + return EXIT_FAILURE; + } + + loop(); + + close(server_fd); + if (nodes) { + for (unsigned int i=1; i<=nodes_max; ++i) { + struct node* n = NODE(i); + if (n->type) free(n->type); + } + + free(nodes); + } + + if (L) lua_close(L), L = NULL; + + return EXIT_SUCCESS; +} + diff --git a/contrib/micronet/src/server.h b/contrib/micronet/src/server.h new file mode 100644 index 0000000..a9ced1c --- /dev/null +++ b/contrib/micronet/src/server.h @@ -0,0 +1,81 @@ +#ifndef MICRONET_SERVER_H +#define MICRONET_SERVER_H + +#include "common.h" +#include +#include +#include +#include +#include +#include "conf.h" + +extern struct timespec now; // updated after each epoll_wait() + +#define DROP(n,p,reason, ...) \ + do { \ + LOG("drop! " reason "\n",##__VA_ARGS__); \ + free(p); \ + return; \ + } while(0) + + +enum direction { + UP = 0, + DOWN = 1 +}; + +struct packet { + struct packet* next; + size_t sz; + enum direction dir; + node_id from_id; + union { + struct { + uint32_t id; + union { + uint8_t body[UNET_DEFAULT_MTU]; + struct ip hdr; + }; + }; + uint8_t buf[sizeof(uint32_t)+UNET_DEFAULT_MTU]; + }; +}; + +static inline void* packet_ip_payload(struct packet* p, size_t *psize) { + size_t ip_hdr_sz = p->hdr.ip_hl*sizeof(uint32_t); + if (psize) { + *psize = p->sz - ip_hdr_sz; + } + return p->body+ip_hdr_sz; +} + +struct nat_tcpudp { + node_id id; + time_t opened_ts; + struct in_addr saddr; + struct in_addr daddr; + uint16_t sport; + uint16_t dport; +}; + +struct nat_icmp { + node_id id; + time_t opened_ts; + uint16_t siid; + struct in_addr saddr; +}; + +#define NODE_ARGS(i) (3+(i)) +void sendto_id(struct node* from_n, node_id to_id, struct packet* p); +int _udp_sendto(struct sockaddr_in* peer_addr, struct iovec* iov, int iovlen); + +int _peer(lua_State* L); +int _link(lua_State* L); +int _nat(lua_State* L); +int _wan(lua_State* L); + +void packet_refresh_sum(struct packet* p); +void print_packet(FILE* fh, struct packet* p); + +#endif // MICRONET_SERVER_H + diff --git a/contrib/micronet/src/server.lua b/contrib/micronet/src/server.lua new file mode 100644 index 0000000..9880e23 --- /dev/null +++ b/contrib/micronet/src/server.lua @@ -0,0 +1,400 @@ +local NODEID_NULL = 0 +local DEFAULT_LAN_SUBNET = subnet("192.168.0.1", 24) +local WAN_SUBNET = subnet("0.0.0.0", 0) + +local node_mt = {} +local nodes = {} + +function node_mt.__bor(n, down) + assert(down.up == nil) + down.up = n + + n.down = n.down or {} + n.down[#n.down+1] = down + + return down +end + +function node(n) + assert(n and n.type) + nodes[#nodes+1] = n + return setmetatable(n, node_mt) +end + +function peer(n) + n = n or {} + + n.type = 'peer' + n._init = function(n) + return _peer( + n.id, + 'peer', + n.up and n.up.id, + n.up_ip, + n.up_gw + ) + end + + return node(n) +end + +function link(n) + n.type = 'link' + n._init = function(n) + return _link( + n.id, + 'link', + n.up and n.up.id, + n.down and n.down.id + ) + end + return node(n) +end + +-- NAT ------------------------------------------------------------------------ + +local NAT_TIMEOUT = 60 +NAT_SYMMETRIC = "symmetric" +NAT_FULL_CONE = "full cone" +NAT_RESTRICTED_CONE = "restricted cone" +NAT_RESTRICTED_PORT = "restricted port" + +local function mapping_rawget(t, up_id) + local e = t[up_id] + + if e then + assert(e.id ~= NODEID_NULL) + + if e.id ~= NODEID_NULL and now() > e.opened_ts + NAT_TIMEOUT then + t[up_id] = nil + e = nil + end + end + + return e +end + +local function mapping_up_key(mode, addr, port) + if mode == NAT_FULL_CONE then + return '' + elseif mode == NAT_RESTRICTED_CONE then + return string.pack('I', addr) + elseif mode == NAT_RESTRICTED_PORT or mode == NAT_SYMMETRIC then + return string.pack('IH', addr, port or 0) + end +end + +local function mapping_get(n, p) + local up_id + local t + local mode + + if p.p == IPPROTO_ICMP and (p.icmp_type == ICMP_ECHO or p.icmp_type == ICMP_ECHOREPLY) then + t = n.icmp_mapping + up_id = p.icmp_echo_id + mode = NAT_SYMMETRIC + + elseif p.p == IPPROTO_TCP or p.p == IPPROTO_UDP then + t = n.tcpudp_mapping + up_id = p.dport + mode = n.mode + + else + error("unknown protocol") + end + + local e = mapping_rawget(t, up_id) + + if not e then + return nil, "not opened" + end + + if e and not e.up_addr[mapping_up_key(mode, p.saddr, p.sport)] then + return nil, string.format("port opened but bad mapping (NAT is %s)", n.mode) + end + + return e +end + +local function mapping_translate(n, p) + local down_id + local min + local t + local mode + + if p.p == IPPROTO_ICMP and p.icmp_type == ICMP_ECHO then + down_id = p.icmp_echo_id + min = 1 + mode = NAT_SYMMETRIC + t = n.icmp_mapping + + elseif p.p == IPPROTO_TCP or p.p == IPPROTO_UDP then + down_id = p.sport + min = 1024 + mode = n.mode + t = n.tcpudp_mapping + + else + error("unknown protocol") + end + + local function check(up_id) + local e = mapping_rawget(t, up_id) + if e and (e.down_addr ~= p.saddr or e.down_id ~= down_id) then + return nil + end + + if not e then + e = { + id = p.from_id, + down_addr = p.saddr, + down_id = down_id, + up_addr = {}, + } + + t[up_id] = e + end + + e.opened_ts = now() + e.up_addr[mapping_up_key(mode, p.daddr, p.dport)] = true + + return e + end + + local up_id = down_id + if mode == NAT_SYMMETRIC then + up_id = (up_id * p.saddr) % 65534 + 1 + up_id = (up_id * p.daddr) % 65534 + 1 + + if p.p == IPPROTO_TCP or p.p == IPPROTO_UDP then + up_id = (up_id * p.dport) % 65534 + 1 + end + end + + assert(up_id ~= nil) + + local v + v = check(up_id) + if v ~= nil then return up_id, v end + + for i = min, 65535 do + v = check(i) + if v ~= nil then return i, v end + end + + return nil +end + + +local function _nat_kernel(n, p) + -- XXX IP fragmentation + + if p.p == IPPROTO_ICMP then + if p.dir == UP then + if p.icmp_type == ICMP_ECHO then + -- allocate opened entry + local up_echo_id, e = mapping_translate(n, p) + + if not up_echo_id then + return "NAT ICMP full" + end + + assert(e.id ~= NODEID_NULL) + + -- modify IP packet + p.saddr = n.up_ip:ip().s_addr + p.icmp_echo_id = up_echo_id + + -- redirect to up + return n.up.id + else + return "unknown ICMP packet" + end + else + if p.icmp_type == ICMP_ECHO then + p.dir = p.dir == UP and DOWN or UP + p.daddr = p.saddr + p.saddr = n.up_ip:ip().s_addr + p.icmp_type = ICMP_ECHOREPLY + + return p.from_id + elseif p.icmp_type == ICMP_ECHOREPLY then + local e, err = mapping_get(n, p) + + if not e then + return err + end + + -- modify IP + p.daddr = e.down_addr + p.icmp_echo_id = e.down_id + + -- redirect to down + return e.id + else + return "unknown ICMP packet" + end + end + elseif p.p == IPPROTO_TCP or p.p == IPPROTO_UDP then + if p.dir == UP then + local nat_sport, e = mapping_translate(n, p) + + if not nat_sport then + return "NAT full" + end + + assert(e.id ~= NODEID_NULL) + + -- modify IP header + p.saddr = n.up_ip:ip().s_addr + p.sport = nat_sport + + -- redirect to up + return n.up.id + else + local e, err = mapping_get(n, p) + + if not e then + return err + end + + -- XXX + + -- modify IP + p.daddr = e.down_addr + p.dport = e.down_id + + -- redirect to down + return e.id + end + else + return "unknown protocol" + end +end + +function nat(n) + n.type = 'nat' + n.mode = n.mode or NAT_RESTRICTED_PORT + + n._init = function(n) + return _nat( + n.id, + 'nat', + n.up and n.up.id, + n.up_ip, + n.up_gw, + function(...) return _nat_kernel(n, ...) end + ) + end + n.lan_subnet = n.lan_subnet or DEFAULT_LAN_SUBNET + n.down_ip = n.lan_subnet + n.lan_ip = n.lan_subnet:next() + n.tcpudp_mapping = {} + n.icmp_mapping = {} + + return node(n) +end + +function _init_wan(n) + -- create WAN simulation + + local routes = {} + + local browse_down + function browse_down(n) + if n.up_ip then + routes[#routes+1] = {n.up_ip, n.id} + else + for _, n_down in ipairs(n.down) do + browse_down(n_down) + end + end + end + browse_down(n) + + + return _wan( + n.id, + 'wan', + NODEID_NULL, + routes + ) +end + +function wan() + return node{ + type='wan', + _init = _init_wan + } +end + +function _build() + local cur_peer_id = 0 + + -- get maximum peer id + for _, node in ipairs(nodes) do + if node.type == 'peer' then + if not node.id then + cur_peer_id = cur_peer_id + 1 + node.id = cur_peer_id + end + + if node.id > cur_peer_id then + cur_peer_id = node.id + end + end + end + + -- allocate IDs + for _, node in ipairs(nodes) do + if not node.id then + cur_peer_id = cur_peer_id + 1 + node.id = cur_peer_id + end + end + + local function browse_and_alloc(n, parent, gw) + assert(node and parent and gw) + + n.up_gw = gw + + if n.up and n.type ~= 'link' then + if n.up_ip then + -- do nothing + elseif parent.lan_ip then + n.up_ip = parent.lan_ip + n.up_gw = gw + + parent.lan_ip = parent.lan_ip:next() + else + n.up_ip = randomwan() + end + end + + if n.type == 'nat' then + parent = n + gw = n.down_ip + end + + for _, n_down in ipairs(n.down or {}) do + browse_and_alloc(n_down, parent, gw) + end + end + + -- allocate IPs + for _, node in ipairs(nodes) do + if not node.up then + node.ip = node.ip or randomwan() + browse_and_alloc(node, node, WAN_SUBNET) + end + end + + -- init every node + _alloc_nodes(cur_peer_id) + for _, node in ipairs(nodes) do + node:_init() + end +end + +function M() end + diff --git a/contrib/micronet/src/wan.c b/contrib/micronet/src/wan.c new file mode 100644 index 0000000..3e7a309 --- /dev/null +++ b/contrib/micronet/src/wan.c @@ -0,0 +1,58 @@ +#include "server.h" + +#define PORT_MAX 65535 + +static void _wan_kernel(struct node* n, struct packet* p) { + assert(p->dir == UP); + + int i; + for (i=0; ias.wan.count; ++i) { + struct route* r = &n->as.wan.routes[i]; + + if (p->hdr.ip_dst.s_addr == r->subnet.addr.s_addr) { + p->dir = DOWN; + sendto_id(n, r->id, p); + return; + } + } + + char dst[INET_ADDRSTRLEN+1]; + assert(inet_ntop(AF_INET, &p->hdr.ip_dst, dst, sizeof(dst)-1)); + DROP(n, p, "unknown route %s", dst); +} + +static struct route* luaN_checkroutes(lua_State* L, int idx, int* pcount) { + luaL_checktype(L, idx, LUA_TTABLE); + + int i; + *pcount = luaL_len(L, idx); + struct route* r = calloc(*pcount, sizeof(struct route)); + + int l = luaL_len(L, idx); + for (i=0; ikernel = _wan_kernel; + n->as.wan.count = count; + n->as.wan.routes = r; + + return 0; +} + diff --git a/docker/0nc.lua b/docker/0nc.lua new file mode 100755 index 0000000..0f9dc14 --- /dev/null +++ b/docker/0nc.lua @@ -0,0 +1,71 @@ +#!/usr/bin/env lua + +require('wh') + +function execf(...) + local cmd = string.format(...) + --print(cmd) + return os.execute(cmd) +end + +function readb64(fp, mode) + local fh = io.open(fp) + if not fh then error(string.format('file not found: %s', fp)) end + local buf = wh.fromb64(fh:read(), mode) + fh:close() + return buf +end + +function writeb64(fp, buf, mode) + local fh = io.open(fp, 'w') + fh:write(wh.tob64(buf, mode) .. '\n') + fh:close() +end + +execf("make > /dev/null 2> /dev/null") +execf("wh clearconf znc") +execf("wh set znc workbit 8 subnet 10.0.42.1/24") +execf("wh set znc endpoint bootstrap.wirehub.io bootstrap yes untrusted peer P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w") + +execf("wh genkey znc | tee /tmp/znc.sk | wh pubkey > /tmp/znc.k") + +local k = readb64('/tmp/znc.k') + +execf("ip link add dev wg1 type wireguard") +execf("wg set wg1 private-key /tmp/znc.sk listen-port 0") +execf("ip link set wg1 up") + +local is_server = arg[1] == nil + +if is_server then + execf("wh genkey znc | tee /tmp/alias.znc.sk | wh pubkey > /tmp/alias.znc.k") + local alias_sk = readb64('/tmp/alias.znc.sk', 'wg') + local alias_k = readb64('/tmp/alias.znc.k') + + execf("wh set znc ip 10.0.42.1 name server.znc router yes peer %s", wh.tob64(k)) + execf("wh set znc ip 10.0.42.2 name client.znc alias %s", wh.tob64(alias_k)) + execf("wh up znc interface wg1 mode nat") + + local invit = wh.tob64(k .. alias_sk) + + print("znc invitation: " .. invit) + + execf("nc -l -p 1234") +else + local keys = wh.fromb64(arg[1]) + local server_k = string.sub(keys, 1, 32) + local alias_sk = string.sub(keys, 33, 64) + local alias_k = wh.publickey(alias_sk) + + writeb64('/tmp/alias.znc.sk', alias_sk, 'wg') + + execf("wh set znc ip 10.0.42.1 name server.znc router yes peer %s", wh.tob64(server_k)) + execf("wh set znc ip 10.0.42.2 name client.znc alias %s", wh.tob64(alias_k)) + execf("wh up znc interface wg1 mode nat") + + execf("sleep 1") + execf("wh auth wg1 %s /tmp/alias.znc.sk", wh.tob64(server_k)) + + execf("sleep 1") + execf("nc 10.0.42.1 1234") +end diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..5a9224a --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,98 @@ +FROM alpine:latest as builder + +RUN apk add --no-cache \ + autoconf \ + automake \ + bison \ + build-base \ + curl \ + flex \ + git \ + libmnl-dev \ + libtool \ + linux-headers \ + readline-dev \ + gdb pv strace valgrind vim # for debug only + +RUN mkdir -p \ + /baseroot/opt/wh/tools \ + /baseroot/usr/bin \ + /baseroot/usr/lib \ + /baseroot/usr/local/lib/lua/5.3 \ + /baseroot/usr/share/bash-completion/completions + +WORKDIR /root +RUN git clone https://github.com/jedisct1/libsodium && \ + git clone https://github.com/miniupnp/miniupnp && \ + curl -R -O http://www.tcpdump.org/release/libpcap-1.9.0.tar.gz && \ + curl -R -O https://www.lua.org/ftp/lua-5.3.5.tar.gz && \ + tar xfz libpcap-1.9.0.tar.gz && tar xfz lua-5.3.5.tar.gz + +# Build libpcap +WORKDIR /root/libpcap-1.9.0 +RUN ./configure && \ + make -j && \ + make install + +# Build sodium +WORKDIR /root/libsodium +RUN git checkout stable && \ + ./autogen.sh && \ + ./configure && \ + make -j && \ + make install + +# Build Lua +WORKDIR /root/lua-5.3.5 +#RUN sed -i 's/MYCFLAGS=/MYCFLAGS=-g/g' src/Makefile && sed -i 's/-O2//g' src/Makefile +RUN make -j linux && \ + make install + +# Build MiniUPNPc +WORKDIR /root/miniupnp/miniupnpc +RUN git checkout miniupnpc_2_1 && \ + make -j && \ + make install && \ + make install DESTDIR=/baseroot + +# Build WireGuard tools +WORKDIR /root/wh/ +COPY deps deps +WORKDIR /root/wh/deps/WireGuard/src/tools +RUN make -j && \ + make install DESTDIR=/baseroot + +# Prepare wh +RUN printf "#!/bin/sh\nexport LUA_PATH=/opt/wh/?.lua\nlua /opt/wh/tools/cli.lua \$@\n" >> /baseroot/usr/bin/wh && \ + chmod +x /baseroot/usr/bin/wh + +# Build WireHub +WORKDIR /root/wh +COPY Makefile . +COPY src src +RUN make -j && \ + cp src/*.lua /baseroot/opt/wh && \ + cp src/tools/*.lua /baseroot/opt/wh/tools && \ + cp .obj/*.so /baseroot/usr/local/lib/lua/5.3/ + +COPY config/* /baseroot/etc/wirehub/ + +WORKDIR /baseroot +RUN cp /usr/local/lib/*.so* usr/lib/ && \ + cp /usr/local/bin/lua* usr/bin && \ + tar cf /baseroot.tar . + +## + +FROM alpine:latest as wh + +RUN apk add --no-cache \ + iptables \ + libmnl \ + readline + +COPY --from=builder /baseroot.tar / +RUN tar xf /baseroot.tar && \ + rm /baseroot.tar && \ + rm -rf /usr/include/* /usr/share/man/* /usr/lib/*.a + diff --git a/docker/Dockerfile.root1 b/docker/Dockerfile.root1 new file mode 100644 index 0000000..ff8aba9 --- /dev/null +++ b/docker/Dockerfile.root1 @@ -0,0 +1,18 @@ +FROM wirehub/wh:latest + +RUN apk add --no-cache \ + bash \ + bash-completion + +WORKDIR /opt/wh + +RUN nc 172.17.0.1 1324 > ./sk + +RUN wh set public workbit 8 peer P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w endpoint 51.15.227.165 && \ + printf "lua src/sink-udp.lua &\nFG=y LOG=2 wh up public private-key ./sk mode direct\n" > ./run-root1.sh && \ + chmod +x ./run-root1.sh + +RUN wh completion get-bash > /usr/share/bash-completion/completions/wh + +ENTRYPOINT ./run-root1.sh + diff --git a/docker/Dockerfile.sandbox b/docker/Dockerfile.sandbox new file mode 100644 index 0000000..de5c544 --- /dev/null +++ b/docker/Dockerfile.sandbox @@ -0,0 +1,35 @@ +FROM wirehub/builder:latest + +RUN (cd /baseroot && tar cf - .) | (cd / && tar xf -) && \ + rm -r /baseroot /opt /usr/local/lib/lua/5.3/whcore.so /usr/bin/wh && \ + printf "#!/bin/sh\nexport LUA_PATH=/root/wh/src/?.lua\nlua /root/wh/src/tools/cli.lua \$@\n" >> /usr/bin/wh && \ + chmod +x /usr/bin/wh + +RUN apk add --no-cache \ + bash \ + bash-completion \ + bmon \ + build-base \ + curl \ + gdb \ + git \ + iptables \ + linux-headers \ + mtr \ + pv \ + strace \ + tcpdump \ + valgrind \ + vim + +ENV DEBUG y +ENV LUA_PATH "/root/wh/src/?.lua" +ENV LUA_CPATH "/root/wh/.obj/?.so" +ENV PATH "/root/wh/docker:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + +RUN ln -s /root/wh/docker/sandbox.bashrc /root/.bashrc +RUN wh completion get-bash > /usr/share/bash-completion/completions/wh + +WORKDIR /root/wh +COPY docker docker + diff --git a/docker/client.lua b/docker/client.lua new file mode 100755 index 0000000..b8eb9fb --- /dev/null +++ b/docker/client.lua @@ -0,0 +1,76 @@ +#!/usr/bin/env lua + +function execf(...) + local cmd = string.format(...) + print(cmd) + return os.execute(cmd) +end + +seed = io.popen("dd if=/dev/urandom bs=1 count=4"):read() +seed = string.unpack("I", seed) +math.randomseed(seed) + +port = math.floor(math.random()*(65535-1024)+1024) + +execf("make") +os.execute("wh clear jgl") +os.execute("wh set jgl workbit 8 subnet 10.0.42.1/24") + +os.execute("wh set jgl peer P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w endpoint bootstrap.wirehub.io") +--os.execute("wh set jgl peer P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w endpoint 172.17.0.1") +os.execute("wh set jgl name root.jgl peer ZvuWjYZPQL7NGBZKXsB7zJgqVpY3zG_h-8ALBE3QHTM ip 10.0.42.1 router yes") +os.execute("wh set jgl name test1.jgl alias ahfGTIiek0znHEnNTk-G1yjNEoDlhQ_g-OLliAMii3g") + +local id = tonumber(arg[1]) +assert(id) + +local key = arg[2] or 'rand' +local mode = arg[3] or 'nat' + +if key == 'rand' then + execf("wh genkey jgl > sk") +else + execf("echo " .. os.getenv('EXAMPLE' .. key .. '_KEY') .. " > sk") +end + +execf("cat sk | wg pubkey > k") +execf("cat k | wh orchid jgl - | tee orchid") + +local orchid = io.popen("cat orchid"):read() + + +execf("ip link del wg%d", id) +execf("ip link add dev wg%d type wireguard", id) +execf("wg set wg%d private-key ./sk listen-port %d", id, port) +execf("ip addr add 10.0.42.%d/24 dev wg%d", id, id) +execf("ip addr add %s/128 dev wg%d", orchid, id) +execf("ip link set wg%d up", id) + +do + local wh=require'wh' + + local i = 1 + while true do + local i_sk = os.getenv('EXAMPLE' .. tostring(i) .. '_KEY') + + if i_sk == nil then + break + end + + local i_k = io.popen(string.format("echo %s | wg pubkey", i_sk)):read() + + if i ~= tonumber(key) then + execf("wg set wg%d peer %s allowed-ips 10.0.42.%d/32", id, i_k, i) + end + + i = i + 1 + end +end + +if os.getenv('FOREGROUND') then + cmd = 'attach' +else + cmd = 'up' +end + +execf("wh %s jgl interface wg%d mode %s", cmd, id, mode) diff --git a/docker/sandbox.bashrc b/docker/sandbox.bashrc new file mode 100644 index 0000000..53905e9 --- /dev/null +++ b/docker/sandbox.bashrc @@ -0,0 +1,50 @@ +export PS1='\u@\h:\W \$ ' +source /etc/profile.d/bash_completion.sh + +alias t="wh-sandbox-test" + +if [ ! -f /dev/net/tun ]; then + mkdir /dev/net + mknod /dev/net/tun c 10 200 +fi + +function compile_micronet() { + cp -r contrib/micronet /tmp + (cd /tmp/micronet && make clean && make) + cp /tmp/micronet/bin/micronet /usr/local/bin +} + +echo "" +echo "#####################" +echo "# wirehub's sandbox #" +echo "#####################" +echo "" + + +if [ -z "$MICRONET" ]; then + echo "µnet is disabled. Will use default network." +else + if [ -z "$MICRONET_SERVER" ]; then + export MICRONET_SERVER=172.17.0.1 + fi + + echo "µnet is enabled, node is $MICRONET, server is $MICRONET_SERVER" + compile_micronet + UNET_SERVERNAME=$MICRONET_SERVER micronet client $MICRONET & +fi + +if [ ! -z "$ROOT" ]; then + echo "start root" + echo $ROOT > /tmp/sk + lua src/sink-udp.lua & + wh up public private-key /tmp/sk mode direct & + +elif [ ! -z "$T" ]; then + if [ ! -f "tests/keys/config" ]; then + wh-sandbox-test -1 + exit -1 + fi + echo "start test node $T" + wh-sandbox-test $T & +fi + diff --git a/docker/spawn.lua b/docker/spawn.lua new file mode 100755 index 0000000..3fd9aa9 --- /dev/null +++ b/docker/spawn.lua @@ -0,0 +1,17 @@ +#!/usr/bin/env lua + +if arg[1] == nil then + print(string.format("usage: %s ", arg[0])) + print() + print( "Spawn N WireHub peers (without WireGuard interface). This is useful to populate\n" .. + "a network with ephemeron peers." + ) + return +end + +local n = tonumber(arg[1]) + +for i = 1, n do + print("spawn", i) + os.execute("wh up public listen-port 0") +end diff --git a/docker/wh-sandbox-test b/docker/wh-sandbox-test new file mode 100755 index 0000000..4795ad3 --- /dev/null +++ b/docker/wh-sandbox-test @@ -0,0 +1,53 @@ +#!/usr/bin/env lua + +if io.open("tests/keys/config") == nil then + print("You first need to generate test keys.") + print("To do so, run on your host machine ...") + print("") + print(" ./tests/generate-keys.sh") + print("") + print("... and retry") + + return +end + +if arg[1] == nil then + print(string.format("usage: %s ", arg[0])) + + return +end + +local id = tonumber(arg[1]) +assert(id) + +function execf(...) + local cmd = string.format(...) + print("\x1b[1;30m$ " .. cmd .. "\x1b[0m") + return os.execute(cmd) +end + +execf("rm -f /tmp/log") +execf("ip link del wg") +execf("make") + +print(string.rep('-', 80)) + +execf("wh clearconf test") +execf("cp tests/keys/config /etc/wirehub/test") + +local mode = arg[2] or 'unknown' + +skpath = string.format('tests/keys/%d.sk', id) + +--execf('cat %s | wg pubkey > k', skpath) + +print("") +execf("ip link add dev wg type wireguard") +execf("wg set wg private-key %s listen-port 0", skpath) +execf("ip link set wg up") + +if os.getenv("VALGRIND") then + execf("WH_LOGPATH=/tmp/log valgrind /usr/local/bin/lua src/tools/cli.lua up test interface wg mode %s", mode) +else + execf("WH_LOGPATH=/tmp/log wh up test interface wg mode %s", mode) +end diff --git a/src/auth.lua b/src/auth.lua new file mode 100644 index 0000000..190a571 --- /dev/null +++ b/src/auth.lua @@ -0,0 +1,63 @@ +local packet = require('packet') + +local M = {} + +function M.update(n, a, deadlines) + -- still searching? + if not a.p then + return + end + + local deadline = a.req_ts+a.retry+1 + + if now >= deadline then + if a.retry > wh.AUTH_RETRY then + return a:cb(false, "could not auth") + end + + n:_sendto{ + dst=a.p, + m=packet.auth(n, a.p), + sk=a.alias_sk, + } + + a.retry = a.retry + 1 + a.req_ts = now + a.last_seen = now + + deadline = a.req_ts+a.retry+1 + end + + deadlines[#deadlines+1] = deadline +end + +function M.resolve_alias(n, alias, src) + -- alias may be nil + + if alias then + -- copy all attributes of alias to p + src.relay = nil + for k, v in pairs(alias) do + if k ~= 'k' and k ~= 'alias' then + src[k] = alias[k] + end + end + + alias.alias = src.k + end +end + +function M.on_authed(n, alias_k, src) + for a in pairs(n.auths) do + if a.alias_k == alias_k then + local alias = n.kad:get(alias_k) + + M.resolve_alias(n, alias, n.p) + + return a:cb(true) + end + end +end + +return M + diff --git a/src/bwlog.lua b/src/bwlog.lua new file mode 100644 index 0000000..3960cf5 --- /dev/null +++ b/src/bwlog.lua @@ -0,0 +1,60 @@ +local queue = require('queue') + +local MT = { + __index = {}, +} + +function MT.__index.collect(bw) + for i, v in queue.iter(bw) do + local ts = v[1] + + if now - ts <= bw.scale then + break + end + + assert(i == bw.heap) + queue.remove(bw, 1) + end + + bw.last_collect_ts = now +end + +function MT.__index.add(bw, class, dir, sz) + return queue.push(bw, { now, class, dir, sz }) +end + +function MT.__index.add_rx(bw, class, sz) return bw:add(class, 'rx', sz) end +function MT.__index.add_tx(bw, class, sz) return bw:add(class, 'tx', sz) end + +MT.__index.length = queue.length + +function MT.__index.avg(bw) + local acc = {} + for i, v in queue.iter(bw) do + local ts, class, dir, sz = table.unpack(v) + + if now - ts <= bw.scale then + acc[class] = acc[class] or { rx=0, tx=0 } + acc[class][dir] = acc[class][dir] + sz + else + assert(i == bw.heap) + queue.remove(bw, 1) + end + end + + bw.last_collect_ts = now + + for k, v in ipairs(acc) do + v.rx = v.rx / bw.scale + v.tx = v.tx / bw.scale + end + + return acc +end + +return function(bw) + assert(bw and bw.scale) + bw.last_collect_ts = 0 + return setmetatable(bw, MT) +end + diff --git a/src/conf.lua b/src/conf.lua new file mode 100644 index 0000000..e5757e3 --- /dev/null +++ b/src/conf.lua @@ -0,0 +1,193 @@ +local function parseconf(conf) + if not conf then return end + + local entry = {} + local cur_section + + for l in string.gmatch(conf, "[^\n]*") do + if string.sub(l, 1, 1) ~= '#' then + local section = string.match(l, '%[[%a%d]+%]') + + if section then + cur_section = string.sub(section, 2, -2) + cur_section = string.lower(cur_section) + entry[#entry+1] = {_name=cur_section} + + else + local k, v = string.match(l, "(%a+)%s+=%s+(%g+)") + + if k and v then + if not cur_section then + return + end + + k = string.lower(k) + entry[#entry][k] = v + end + end + end + end + + return entry +end + +function wh.fromconf(conf) + if not conf then return end + + local entry = parseconf(conf) + + local r = { + name=nil, -- explicit + workbit=nil, -- explicit + peers={}, + } + local has_section_network = false + for _, section in ipairs(entry) do + if section._name == 'interface' then + r['private-key'] = section.privatekey + + elseif section._name == 'network' then + if has_section_network then return end + has_section_network = true + + r.name = section.name + r.namespace = section.namespace + r.subnet = section.subnetwork + + if section.workbits then + r.workbit = tonumber(section.workbits) + if not r.workbit then return end + end + + elseif section._name == 'peer' then + local p = {} + + if not section.publickey and not section.name then + return + end + + if section.publickey then + local ok, k = pcall(wh.fromb64, section.publickey) + if not ok then return end + p.k = k + end + + if section.endpoint then + p.addr = wh.address(section.endpoint, wh.DEFAULT_PORT) + end + + if section.alias then + local ok, k = pcall(wh.fromb64, section.alias) + if not ok then return end + p.alias = k + end + + p.hostname = section.name + p.is_router = section.router == "yes" + p.is_gateway = section.gateway == "yes" + p.trust = section.trust == "yes" + p.ip = section.ip and wh.address(section.ip) + p.bootstrap = section.bootstrap == "yes" + + if section['allowedips'] then + local r = {} + for subnet in string.gmatch(section['allowedips'], "([^,]+)") do + if not subnet then + return nil + end + r[#r+1] = subnet + end + + p['allowed-ips'] = r + end + + r.peers[#r.peers+1] = p + end + end + + return r +end + +function wh.toconf(conf) + local r = {} + + if conf.namespace or conf.workbit then + if conf['private-key'] then + r[#r+1] = "[Interface]\n" + r[#r+1] = string.format("PrivateKey = %s\n", conf['private-key']) + + r[#r+1] = '\n' + end + + r[#r+1] = "[Network]\n" + + if conf.name then + r[#r+1] = string.format("Name = %s\n", conf.name) + end + + if conf.namespace then + r[#r+1] = string.format("Namespace = %s\n", conf.namespace) + end + + if conf.workbit then + r[#r+1] = string.format("Workbits = %d\n", conf.workbit) + end + + if conf.subnet then + r[#r+1] = string.format("SubNetwork = %s\n", conf.subnet) + end + + for _, p in ipairs(conf.peers) do + r[#r+1] = "\n[Peer]\n" + if p.trust then + r[#r+1] = "Trust = yes\n" + else + r[#r+1] = "# Trust = no\n" + end + + if p.bootstrap then + r[#r+1] = "Bootstrap = yes\n" + end + + if p.hostname then + r[#r+1] = string.format("Name = %s\n", p.hostname) + end + + if p.alias then + r[#r+1] = string.format("Alias = %s\n", wh.tob64(p.alias)) + end + + if p.is_router then + r[#r+1] = "Router = yes\n" + end + + if p.is_gateway then + r[#r+1] = "Gateway = yes\n" + end + + if p.ip then + r[#r+1] = string.format("IP = %s\n", p.ip:addr()) + end + + if p.k then + r[#r+1] = string.format("PublicKey = %s\n", wh.tob64(p.k)) + end + + if p.addr then + r[#r+1] = string.format("Endpoint = %s\n", p.addr) + end + + if p['allowed-ips'] then + r[#r+1] = string.format("AllowedIPs = ") + for i, v in ipairs(p['allowed-ips']) do + if i > 1 then r[#r+1] = ',' end + r[#r+1] = v + end + r[#r+1] = '\n' + end + end + end + + return table.concat(r) +end + diff --git a/src/connectivity.lua b/src/connectivity.lua new file mode 100644 index 0000000..8ec4a22 --- /dev/null +++ b/src/connectivity.lua @@ -0,0 +1,169 @@ +local M = {} + +local function _work_upnp(n) + local u = { + enabled=false, + peers={}, + } + + local d, url + d, u.iaddr, url = wh.upnp.discover_igd(1) + + if d then + printf('UPnP device found: %s', url) + local ok, val = wh.upnp.external_ip(d) + + if ok then + u.external_ip = val + else + u.external_ip = nil + end + + local function _add(port) + local ok, err = wh.upnp.add_redirect(d, { + desc = string.format('WireHub %s', wh.tob64(n.k)), + eport=port, + iaddr=u.iaddr, + iport=port, + lease=wh.UPNP_REFRESH_EVERY, + protocol='udp', + }) + + if ok then + printf('upsert UPnP redirection %s:%d -> %s:%d', + u.external_ip or '???', n.port, u.iaddr, n.port + ) + else + printf('UPnP error: %s. ignore', err) + end + + return ok + end + + if _add(n.port) and _add(n.port_echo) then + u.enabled = true + end + + for _, r in ipairs(wh.upnp.list_redirects(d)) do + local k = string.match(r.desc, "WireHub ([^ ]+)") + + if k then + local ok + ok, k = pcall(wh.fromb64, k) + if not ok then k = nil end + end + + if k and #k == 32 and n.k ~= k then + -- check version if necessary + u.peers[k] = {r.iaddr, r.iport} + end + end + end + + return u +end + +local function update_upnp(n, deadlines) + assert(n.upnp) + + if n.mode ~= 'unknown' then + return + end + + local u = n.upnp + + local deadline = u.last_check + wh.UPNP_REFRESH_EVERY - 30 + + if deadline <= now and not u.checking then + u.checking = true + + n:explain("checking UPnP...") + u.worker:pcall( + function(ok, ...) + u.checking = false + + if not ok then + error(...) + end + local nu = ... + + local peers = nu.peers + nu.peers = nil + + for k, v in pairs(nu) do + u[k] = v + end + + for k, addr in pairs(peers) do + local addr = wh.address(addr[1], addr[2]) + printf("found UPnP device $(yellow)%s$(reset) (%s)", wh.tob64(k), addr) + + local upnp_p = n.kad:touch(k) + upnp_p.addr = addr + end + + u.last_check = now + n.last_connectivity_check = nil + end, + _work_upnp, + { + k=n.k, + port=n.port, + port_echo=n.port_echo, + } + ) + + elseif u.checking then + deadline = nil + else + deadline = u.last_check + wh.UPNP_REFRESH_EVERY - 30 + end + + deadlines[#deadlines+1] = deadline +end + +function M.update(n, deadlines) + if n.upnp then + update_upnp(n, deadlines) + + if n.upnp.checking then + return + end + end + + if n.checking_connectivity then + return + end + + local deadline = (n.last_connectivity_check or 0) + wh.CONNECTIVITY_CHECK_EVERY + if now > deadline then + if n.mode == 'unknown' then + n:explain("checking connectivity...") + n.checking_connectivity = true + n:detect_nat(nil, function(mode) + n.checking_connectivity = false + + n:explain("NAT is $(magenta)%s", mode) + + n.is_nated = mode ~= 'direct' + + n:explain("find self") + n:search(n.k, 'lookup') -- center + n.last_connectivity_check = now + end) + + deadline = nil + else + n:explain("find self") + n:search(n.k, 'lookup') -- center + + n.last_connectivity_check = now + deadline = n.last_connectivity_check + wh.CONNECTIVITY_CHECK_EVERY + end + end + + deadlines[#deadlines+1] = deadline +end + +return M + diff --git a/src/core/.gitignore b/src/core/.gitignore new file mode 100644 index 0000000..549ba3d --- /dev/null +++ b/src/core/.gitignore @@ -0,0 +1,2 @@ +wireguard.c +wireguard.h diff --git a/src/core/common.h b/src/core/common.h new file mode 100644 index 0000000..5919b2b --- /dev/null +++ b/src/core/common.h @@ -0,0 +1,32 @@ +#ifndef WIREHUB_COMMON_H +#define WIREHUB_COMMON_H + +#include "config.h" + +#include +#include +#include +#include +#include +#include +#include + +/* +#include +#include +#include +#include +#include + +#include +#include +*/ +/*** CONSTANTS ***************************************************************/ + +#define crypto_scalarmult_curve25519_KEYBASE64BYTES 44 + +static const uint8_t wh_pkt_hdr[] = {0xff, 0x00, 0x00, 0x00}; +static const int wh_version[3] = {0, 1, 0}; + +#endif // WIREHUB_COMMON_H + diff --git a/src/core/config.h b/src/core/config.h new file mode 100644 index 0000000..7058b99 --- /dev/null +++ b/src/core/config.h @@ -0,0 +1,15 @@ +#ifndef WIREHUB_CONFIG_H +#define WIREHUB_CONFIG_H + +// source: https://www.wireguard.com/install/#kernel-requirements +#define WH_LINUX_MINVERSION { 3, 10 } + +#define WH_ENV_CONFPATH "WH_CONFPATH" +#define WH_DEFAULT_CONFPATH "/etc/wirehub/" +#define WH_DEFAULT_SOCKPATH "/var/run/wirehub/" +#define WH_ENABLE_MINIUPNPC 1 + +#define WH_TUN_ICMP 1 + +#endif // WIREHUB_CONFIG_H + diff --git a/src/core/ipc.c b/src/core/ipc.c new file mode 100644 index 0000000..12b3a35 --- /dev/null +++ b/src/core/ipc.c @@ -0,0 +1,168 @@ +#include "common.h" +#include +#include +#include +#include +#include +#include +#include + +int ipc_prepare(void) { + return system("mkdir -p " WH_DEFAULT_SOCKPATH " 2> /dev/null"); +} + +static int ipc_sockaddr_un(struct sockaddr_un* addr_un, const char* interface) { + addr_un->sun_family = AF_UNIX; + + ssize_t sz = snprintf( + addr_un->sun_path, sizeof(addr_un->sun_path), + WH_DEFAULT_SOCKPATH "%s.sock", interface + ); + + if (sz < 0 || (size_t)sz >= sizeof(addr_un->sun_path)) { + return -1; + } + + return 0; +} + +int ipc_unlink(const char* interface) { + struct sockaddr_un addr_un; + if (ipc_sockaddr_un(&addr_un, interface) < 0) { + return -1; + } + + if (unlink(addr_un.sun_path) < 0) { + return -1; + } + + return 0; +} + +int ipc_bind(const char* interface, int force) { + int sock = -1; + + struct sockaddr_un addr_un; + if (ipc_sockaddr_un(&addr_un, interface) < 0) { + goto err; + } + + if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) { + goto err; + } + + int bind_ret; + if ((bind_ret = bind(sock, (struct sockaddr*)&addr_un, sizeof(addr_un))) < 0) { + if (errno == EADDRINUSE && force) { + if (unlink(addr_un.sun_path) < 0) { + goto err; + } + + bind_ret = bind(sock, (struct sockaddr*)&addr_un, sizeof(addr_un) < 0); + } + + if (bind_ret < 0) { + goto err; + } + } + + if (listen(sock, 1) < 0) { + goto err; + } + + return sock; +err: + if (sock != -1) { + close(sock); + } + return -1; +} + +int ipc_connect(const char* interface) { + int sock = -1; + + struct sockaddr_un addr_un; + if (ipc_sockaddr_un(&addr_un, interface) < 0) { + goto err; + } + + if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) { + goto err; + } + + if (connect(sock, (struct sockaddr*)&addr_un, sizeof(addr_un)) < 0) { + goto err; + } + + if (fcntl(sock, F_SETFL, fcntl(sock, F_GETFL, 0) | O_NONBLOCK) == -1) { + goto err; + } + + return sock; +err: + if (sock != -1) { + close(sock); + } + + return -1; +} + +int ipc_list(int(*cb)(const char*, void*), void* ud) { + DIR* dirp; + int ret = 0; + + if ((dirp = opendir(WH_DEFAULT_SOCKPATH)) == NULL) { + return -1; + } + + struct dirent* dp; + while ((dp = readdir(dirp))) { + char fullpath[PATH_MAX]; + snprintf(fullpath, sizeof(fullpath), "%s%s", WH_DEFAULT_SOCKPATH, dp->d_name); + + struct stat st; + if (stat(fullpath, &st) < 0) { + continue; + } + + if (!S_ISSOCK(st.st_mode)) { + continue; + } + + // remove folder's path and suffix '.sock' + char* ext = strstr(dp->d_name, ".sock"); + if (!ext) { + continue; + } + *ext = 0; + + int ret = cb(dp->d_name, ud); + if (ret < 0) { + break; + } + } + + closedir(dirp), dirp = NULL; + return ret; +} + +int ipc_accept(int sock) { + int new_sock = -1; + if ((new_sock = accept(sock, NULL, NULL)) < 0) { + return -1; + } + + if (fcntl(new_sock, F_SETFL, fcntl(new_sock, F_GETFL, 0) | O_NONBLOCK) == -1) { + goto err; + } + + return new_sock; + +err: + if (new_sock == -1) { + close(new_sock); + } + + return -1; +} + diff --git a/src/core/ipc.h b/src/core/ipc.h new file mode 100644 index 0000000..aa2626d --- /dev/null +++ b/src/core/ipc.h @@ -0,0 +1,12 @@ +#ifndef WIREHUB_IPC_H +#define WIREHUB_IPC_H + +int ipc_prepare(void); +int ipc_bind(const char* interface, int force); +int ipc_connect(const char* interface); +int ipc_accept(int sock); +int ipc_unlink(const char* interface); +int ipc_list(int(*cb)(const char*, void*), void* ud); + +#endif // WIREHUB_IPC_H + diff --git a/src/core/key.c b/src/core/key.c new file mode 100644 index 0000000..3906758 --- /dev/null +++ b/src/core/key.c @@ -0,0 +1,212 @@ +// XXX protect secret data (sodium_malloc, sodium_mlock) + +#include "common.h" +#include +#include + +int trailing_0s_buf(const void* buf, size_t l) { + assert(l % sizeof(uint32_t) == 0); + unsigned int i; + unsigned int r = 0; + for(i=0; iklen; + unsigned char* tohash = alloca(tohash_l); + unsigned char hash[crypto_generichash_BYTES]; + + memcpy(tohash+crypto_scalarmult_curve25519_BYTES, s->k, s->klen); + + for (int i=0; !s->found; ++i) { + if (i == 256) { + __sync_add_and_fetch(&s->stat->count, i); + i = 0; + } + + crypto_sign_ed25519_keypair(ed25519_pk, ed25519_sk); + + // pk should be positive + if ((ed25519_pk[crypto_sign_ed25519_PUBLICKEYBYTES-1] & 0x80) == 0x80) { + continue; + } + + // convert to curve25519 point + if (crypto_sign_ed25519_pk_to_curve25519(tohash, ed25519_pk) < 0) { + continue; + } + + crypto_generichash(hash, sizeof(hash), tohash, tohash_l, NULL, 0); + + unsigned int wb = trailing_0s_buf(hash, sizeof(hash)); + + if (wb >= s->wb) { + if (crypto_sign_ed25519_sk_to_curve25519(s->x25519_sk, ed25519_sk) < 0) { + continue; + } + + pthread_mutex_lock(&s->mutex); + + if (!s->found) { + // found! + s->found = 1; + + memcpy(s->ed25519_pk, ed25519_pk, sizeof(ed25519_pk)); + memcpy(s->ed25519_sk, ed25519_sk, sizeof(ed25519_sk)); + memcpy(s->x25519_pk, tohash, crypto_scalarmult_curve25519_BYTES); + memcpy(s->hash, hash, sizeof(hash)); + s->wb = wb; + + pthread_cond_broadcast(&s->cond); + } + + pthread_mutex_unlock(&s->mutex); + break; + } + } + + return 0; +} + +static void* worker(void* ud) { + struct search_st* s = (struct search_st*)ud; + + search(s); + + pthread_exit(NULL); + return NULL; +} + +int genkey( + uint8_t* ed25519_sk, + const char* key, + int wb, + int num_threads +) { + int err = 0; + struct stat_st stat; + memset(&stat, 0, sizeof(stat)); + + struct search_st s = { + .stat = &stat, + .k = (const uint8_t*)key, + .klen = strlen(key), + .wb = wb, + .mutex = PTHREAD_MUTEX_INITIALIZER, + .cond = PTHREAD_COND_INITIALIZER, + .found = 0, + }; + + pthread_t* threads = calloc(num_threads, sizeof(pthread_t)); + + pthread_mutex_lock(&s.mutex); + for (int i=0; i 0) { + anim_i = (anim_i + 1) % (sizeof(anim)); + } + + fprintf(stderr, " \r%c generating %dworkbit key for '%s' (%.1fkK/s, %ds)", + anim[anim_i], + s.wb, + s.k, + (double)h_per_s/1000.0, + duration + ); + + fflush(stdout); + + struct timespec to; + + clock_gettime(CLOCK_REALTIME, &to); + to.tv_sec += 1; + + int retval = pthread_cond_timedwait(&s.cond, &s.mutex, &to); + + if (retval == 0) { + break; + } else if (retval != ETIMEDOUT) { + err = -1; + goto finally; + } + } +#endif + + memcpy(ed25519_sk, s.ed25519_sk, crypto_sign_ed25519_SECRETKEYBYTES); + pthread_mutex_unlock(&s.mutex); + + assert(err == 0); +finally: + for (int i=0; isa_family) { + case AF_INET: + if (!inet_ntop(a->sa_family, &a->in4.sin_addr, s, sl-1)) { + return 0; + } + break; + + case AF_INET6: + if (!inet_ntop(a->sa_family, &a->in6.sin6_addr, s, sl-1)) { + return 0; + } + break; + default: + return 0; + }; + lua_pushstring(L, s); + return 1; +} + +static int _address_port(lua_State* L) { + struct address* a = luaL_checkudata(L, 1, "address"); + lua_pushinteger(L, address_port(a)); + return 1; +} + +static int _address_same_subnet(lua_State* L) { + const struct address* a = luaL_checkudata(L, 1, "address"); + const struct address* b = luaL_checkudata(L, 2, "address"); + int cidr = luaL_checkinteger(L, 3); + + if (a->sa_family != b->sa_family) { + luaL_error(L, "address does not have the same type"); + return 0; + } + + assert(address_len(a) == address_len(b)); + int max_cidr = address_len(a) * 8; + + if (cidr > max_cidr) { + luaL_error(L, "bad CIDR"); + return 0; + } + + const uint8_t* ap = NULL,* bp = NULL; + +#define _GET_POINTER(ap, a) \ + switch ((a)->sa_family) { \ + case AF_INET: ap = (const uint8_t*)&(a)->in4.sin_addr; break; \ + case AF_INET6: ap = (const uint8_t*)&(a)->in6.sin6_addr; break; \ + } + + _GET_POINTER(ap, a); + _GET_POINTER(bp, b); + +#undef _GET_POINTER + + assert(ap && bp); + + int is_same = 1; + while (cidr > 0) { + int c = cidr>8?8:cidr; + uint8_t mask = ((1 << c)-1)<<(8-c); + + if ((*ap & mask) != (*bp & mask)) { + is_same = 0; + break; + } + + ++ap, ++bp, cidr -= 8; + } + + lua_pushboolean(L, is_same); + return 1; +} + +static inline uint32_t _subnet_mask(int cidr) { + assert(0 <= cidr && cidr <= 32); + if (cidr == 32) { + return 0xffffffff; + } + + return ((1 << cidr)-1) << (32-cidr); +} + +static int _address_subnet_id(lua_State* L) { + struct address* a = luaL_checkudata(L, 1, "address"); + int cidr = luaL_checkinteger(L, 2); + int idx = luaL_checkinteger(L, 3); + + if (a->sa_family != AF_INET) { + luaL_error(L, "address must be IP4"); + } + + if (cidr < 0 || 32 < cidr) { + luaL_error(L, "CIDR is between 0 and 32"); + } + + int64_t max_idx = (1L << (32-cidr)) - 2; + if (idx < 1 || max_idx < idx) { + luaL_error(L, "1 <= idx <= %d is false", max_idx); + } + + uint32_t mask = _subnet_mask(cidr); + uint32_t addr = ntohl(a->in4.sin_addr.s_addr); + addr &= mask; + addr |= idx; + + struct address* n = luaW_newaddress(L); + n->sa_family = n->in4.sin_family = AF_INET; + n->in4.sin_addr.s_addr = htonl(addr); + n->in4.sin_port = a->in4.sin_port; + + return 1; +} + +static int _address_pack(lua_State* L) { + struct address* a = luaL_checkudata(L, 1, "address"); + + luaL_Buffer b; + luaL_buffinit(L, &b); + + switch (a->sa_family) { + case AF_INET: + luaL_addstring(&b, "\x04"); + luaL_addlstring(&b, (const void*)&a->in4.sin_addr, sizeof(a->in4.sin_addr)); + luaL_addlstring(&b, (const void*)&a->in4.sin_port, sizeof(a->in4.sin_port)); + break; + + case AF_INET6: + luaL_addstring(&b, "\x06"); + luaL_addlstring(&b, (const void*)&a->in6.sin6_addr, sizeof(a->in6.sin6_addr)); + luaL_addlstring(&b, (const void*)&a->in6.sin6_port, sizeof(a->in6.sin6_port)); + break; + + default: + luaL_error(L, "bad address"); + }; + + luaL_pushresult(&b); + return 1; +} + +struct address* luaW_newaddress(lua_State* L) { + struct address* a = lua_newuserdata(L, sizeof(struct address)); + + if (luaL_newmetatable(L, "address")) { + lua_pushcfunction(L, _address_tostring); + lua_setfield(L, -2, "__tostring"); + + lua_newtable(L); + + lua_pushcfunction(L, _address_addr); + lua_setfield(L, -2, "addr"); + + lua_pushcfunction(L, _address_port); + lua_setfield(L, -2, "port"); + + lua_pushcfunction(L, _address_same_subnet); + lua_setfield(L, -2, "same_subnet"); + + lua_pushcfunction(L, _address_subnet_id); + lua_setfield(L, -2, "subnet_id"); + + lua_pushcfunction(L, _address_pack); + lua_setfield(L, -2, "pack"); + + lua_setfield(L, -2, "__index"); + + // XXX extend address with methods (:addr(), :port(), :subnet()) + } + lua_setmetatable(L, -2); + + return a; +} + +static int _gc_fd(lua_State* L) { + lua_getfield(L, LUA_REGISTRYINDEX, "fds"); + for (lua_pushnil(L); lua_next(L, -2) != 0; lua_pop(L, 1)) { + if (lua_type(L, -1) == LUA_TBOOLEAN && lua_toboolean(L, -1)) { + int fd = lua_tointeger(L, -2); + fprintf(stderr, "warning: fd %d was not closed.\n", fd); + close(fd); + } + } + + return 0; +} + +void luaW_pushfd(lua_State* L, int fd) { + lua_getfield(L, LUA_REGISTRYINDEX, "fds"); + + if (lua_type(L, -1) == LUA_TNIL) { + lua_pop(L, 1); + lua_newtable(L); + lua_newtable(L); + lua_pushcfunction(L, _gc_fd); + lua_setfield(L, -2, "__gc"); + lua_setmetatable(L, -2); + + lua_pushvalue(L, -1); + lua_setfield(L, LUA_REGISTRYINDEX, "fds"); + } + + lua_pushinteger(L, fd); + lua_pushboolean(L, 1); + lua_settable(L, -3); + lua_pop(L, 1); + + lua_pushinteger(L, fd); +} + +int luaW_getfd(lua_State* L, int idx) { + int isint; + int fd = lua_tointegerx(L, idx, &isint); + + if (!isint) { + luaL_error(L, "bad element #%d (integer expected, got %s)", + idx, lua_typename(L, lua_type(L, idx)) + ); + } + + lua_getfield(L, LUA_REGISTRYINDEX, "fds"); + int found = 1; + if (lua_type(L, -1) == LUA_TNIL) { + found = 0; + } else { + lua_pushinteger(L, fd); + lua_gettable(L, -2); + + found = lua_type(L, -1) == LUA_TBOOLEAN && lua_toboolean(L, -1); + lua_pop(L, 2); + } + + if (!found) { + luaL_error(L, "fd %d is not owned", fd); + } + + return fd; +} + + diff --git a/src/core/luawh.h b/src/core/luawh.h new file mode 100644 index 0000000..22708ae --- /dev/null +++ b/src/core/luawh.h @@ -0,0 +1,61 @@ +#ifndef LUAWH_H +#define LUAWH_H + +#include "common.h" +#include "net.h" + +#include +#include +#include + +int luaW_version(lua_State *L); + +void* luaW_newsecret(lua_State* L, size_t len); +void* luaW_tosecret(lua_State* L, int idx, size_t len); +void* luaW_checksecret(lua_State* L, int idx, size_t len); +void* luaW_ownsecret(lua_State* L, int idx, size_t len); +void luaW_freesecret(void* p); + +// declare a pointer +// example: luaW_declptr(L, "buffer", free); +void luaW_declptr(lua_State* L, const char* mt, void(*del)(void*)); +// push a pointer. pointer is not owned after the call +// example: luaW_pushptr(L, "buffer", malloc(1024)); +void luaW_pushptr(lua_State* L, const char* mt, void* ptr); +// returns pointer after checking it. raises an error if bad type or pointer is +// dangling +// example: luaW_checkptr(L, -1, "buffer"); +void* luaW_checkptr(lua_State* L, int idx, const char* mt); +// returns pointer after checking it. returns null if pointer is dangling +// example: luaW_toptr(L, -1, "buffer") +void* luaW_toptr(lua_State* L, int idx, const char* mt); +// as luaW_checkptr, but owns the pointer +// example: luaW_ownptr(L, -1, "buffer"); +void* luaW_ownptr(lua_State* L, int idx, const char* mt); + +struct address* luaW_newaddress(lua_State* L); + +static inline uint16_t luaW_checkport(lua_State* L, int idx) { + lua_Number port_n = luaL_checkinteger(L, idx); + + if (port_n < 0 || UINT16_MAX < port_n) { + luaL_error(L, "bad port: %d", port_n); + } + + return (uint16_t)port_n; +} + +void luaW_pushfd(lua_State* L, int fd); +int luaW_getfd(lua_State* L, int idx); + +LUAMOD_API int luaopen_tun(lua_State* L); +LUAMOD_API int luaopen_wg(lua_State* L); +LUAMOD_API int luaopen_whcore(lua_State* L); +LUAMOD_API int luaopen_worker(lua_State* L); + +#if WH_ENABLE_MINIUPNPC +LUAMOD_API int luaopen_whupnp(lua_State* L); +#endif + +#endif // LUAWH_H + diff --git a/src/core/mem.h b/src/core/mem.h new file mode 100644 index 0000000..e122f1c --- /dev/null +++ b/src/core/mem.h @@ -0,0 +1,13 @@ +#ifndef WIREHUB_MEM_H +#define WIREHUB_MEM_H + +static inline void* memdup(const void* p, size_t len) { + void* n = malloc(len); + assert(n); + memcpy(n, p, len); + return n; +} + +#endif // WIREHUB_MEM_H + + diff --git a/src/core/net.c b/src/core/net.c new file mode 100644 index 0000000..1f83046 --- /dev/null +++ b/src/core/net.c @@ -0,0 +1,314 @@ +#include "net.h" +#include +#include +#include +#include +#include + +int parse_address(struct address* a, const char* endpoint, uint16_t port) { + struct addrinfo hint, *res = NULL; + int ret; + + if (!endpoint) { + return -1; + } + + const char* addr_s = endpoint,* addr_end = NULL; + // ip6? + if (addr_s[0] == '[') { + ++addr_s; + const char* e = strchr(addr_s, ']'); + if (!e) { + return -1; + } + + addr_end = e; + } + + // ip4? + else if ('0' <= addr_s[0] && addr_s[0] <= '9') { + addr_end = strchr(addr_s, ':'); + if (!addr_end) addr_end = addr_s + strlen(addr_s); + } + + else { + addr_end = strchr(addr_s, ':'); + if (!addr_end) addr_end = addr_s + strlen(addr_s); + } + + const char* port_s = strrchr(endpoint, ':'); + // ignore : from the ip6 string + if (port_s && port_s < addr_end) { + port_s = NULL; + } + + if (port_s) { + int port_i = atoi(port_s+1); + + if (port_i < 0 || UINT16_MAX < port_i) { + return -1; + } + + port = (uint16_t)port_i; + } + + memset(&hint, '\0', sizeof hint); + + hint.ai_family = PF_UNSPEC; + hint.ai_socktype = SOCK_DGRAM; + hint.ai_flags = AI_PASSIVE; + + char* addr = alloca(addr_end-addr_s+1); + memcpy(addr, addr_s, addr_end-addr_s); + addr[addr_end-addr_s] = 0; + + if ((ret = getaddrinfo(addr, NULL, &hint, &res))) { + // more info gai_strerror(ret) + return -1; + } + + a->sa_family = res->ai_family; + + if(res->ai_family == AF_INET) { + a->in4 = *(struct sockaddr_in*)res->ai_addr; + a->in4.sin_port = htons(port); + } + + else if (res->ai_family == AF_INET6) { + a->in6 = *(struct sockaddr_in6*)res->ai_addr; + a->in6.sin6_port = htons(port); + } + + else { + // "unknown address format %d\n",argv[1],res->ai_family); + return -1; + } + + freeaddrinfo(res); + return 0; +} + +socklen_t address_len(const struct address* a) { + switch (a->sa_family) { + case AF_INET: return sizeof(struct sockaddr_in); + case AF_INET6: return sizeof(struct sockaddr_in6); + default: return 0; + }; +} + +const char* format_address(const struct address* a, char* s, size_t sl) { + assert(s); + assert(a); + + // s needs to be at least 47 + // example: [e0be:b85d:88ed:6c3b:a1aa:3f57:ab3:c850]:65535 + + socklen_t inl = address_len(a); + if (inl == 0) { + return NULL; + } + + switch (a->sa_family) { + case AF_INET: + if (!inet_ntop(a->sa_family, &a->in4.sin_addr, s, sl-1)) { + return NULL; + } + break; + + case AF_INET6: + s[0] = '['; + if (!inet_ntop(a->sa_family, &a->in6.sin6_addr, s+1, sl-2)) { + return NULL; + } + strcat(s, "]"); + break; + default: + return NULL; + }; + + char buf[8]; + snprintf(buf, sizeof(buf), ":%d", address_port(a)); + strncat(s, buf, sl); + + return s; +} + +int address_from_sockaddr(struct address* out, const struct sockaddr* in) { + switch (*(sa_family_t*)in) { + case AF_INET: + out->sa_family = AF_INET; + out->in4 = *(struct sockaddr_in*)in; + break; + + case AF_INET6: + out->sa_family = AF_INET6; + out->in6 = *(struct sockaddr_in6*)in; + break; + + default: + return -1; + }; + + return 0; +} + + + + +int socket_udp(const struct address* a) { + int s = socket(a->sa_family, SOCK_DGRAM, 0); + if (s == -1) { + return -1; + } + + if (fcntl(s, F_SETFL, fcntl(s, F_GETFL, 0) | O_NONBLOCK) == -1) { + close(s); + return -1; + } + + if (bind(s, &a->in, address_len(a)) == -1) { + close(s); + return -1; + } + + return s; +} + +int socket_raw_udp(sa_family_t sa_family, int hdrincl) { + int s = socket(sa_family, SOCK_RAW, IPPROTO_UDP); + if (s == -1) { + return -1; + } + + /*if (fcntl(s, F_SETFL, fcntl(s, F_GETFL, 0) | O_NONBLOCK) == -1) { + close(s); + return -1; + }*/ + + if (hdrincl) { + int on = 1; + if (setsockopt(s, IPPROTO_IP, IP_HDRINCL, &on, sizeof(on)) < 0) { + close(s); + return -1; + } + } + + return s; +} + +int ip4_to_udp(const void* d, const void** pdata, size_t* psize, struct address* src, struct address* dst) { + assert(d && psize); + // src and dst may be NULL + + const void* p = d; + + if (*psize < 1) { + return -1; + } + +#define IPHDR ((struct ip*)(p)) + size_t ip_hdr_sz = IPHDR->ip_hl*sizeof(uint32_t); + if (*psize < ip_hdr_sz) { + return -1; + } + + if (IPHDR->ip_p != IPPROTO_UDP) { + return -1; + } + + if (src) { + src->sa_family = src->in4.sin_family = AF_INET; + memcpy(&src->in4.sin_addr, &IPHDR->ip_src, 4); + } + + if (dst) { + dst->sa_family = dst->in4.sin_family = AF_INET; + memcpy(&dst->in4.sin_addr, &IPHDR->ip_dst, 4); + } + +#undef IPHDR + + p += ip_hdr_sz; + + // XXX do IP6 + const int udp_hdr_sz = 8; + + if (*psize < ip_hdr_sz+udp_hdr_sz) { + return -1; + } + +#define UDPHDR ((struct udphdr*)(p)) + + // XXX should check checksum? + + if (src) { + switch (src->sa_family) { + case AF_INET: src->in4.sin_port = UDPHDR->uh_sport; break; + case AF_INET6: src->in6.sin6_port = UDPHDR->uh_sport; break; + }; + } + + if (dst) { + switch(dst->sa_family) { + case AF_INET: dst->in4.sin_port = UDPHDR->uh_dport; break; + case AF_INET6: dst->in6.sin6_port = UDPHDR->uh_dport; break; + }; + } + + + uint16_t udp_sz = ntohs(UDPHDR->uh_ulen); + if (udp_sz < udp_hdr_sz) { + return -1; + } + + if (*psize < ip_hdr_sz+udp_sz) { + fprintf(stderr, "WARNING: *psize:%d, ip_hdr_sz:%d udp_sz:%d\n", + (int)*psize, (int)ip_hdr_sz, (int)udp_sz + ); + + FILE* fh = fopen("/tmp/packet.buf", "wb"); + fwrite(d, *psize, 1, fh); + fclose(fh); + + return -1; + } + +#undef UDPHDR + + p += udp_hdr_sz; + + *pdata = p; + *psize = udp_sz-udp_hdr_sz; + + return 0; +} + +uint16_t checksum_ip(const void* buf_, int count) { + register uint32_t sum = 0; + uint16_t answer = 0; + const uint16_t* buf = buf_; + + // Sum up 2-byte values until none or only one byte left. + while (count > 1) { + sum += *(buf++); + count -= 2; + } + + // Add left-over byte, if any. + if (count > 0) { + sum += *(uint8_t *) buf; + } + + // Fold 32-bit sum into 16 bits; we lose information by doing this, + // increasing the chances of a collision. + // sum = (lower 16 bits) + (upper 16 bits shifted right 16 bits) + while (sum >> 16) { + sum = (sum & 0xffff) + (sum >> 16); + } + + // Checksum is one's compliment of sum. + answer = ~sum; + + return answer; +} + diff --git a/src/core/net.h b/src/core/net.h new file mode 100644 index 0000000..0ca27b9 --- /dev/null +++ b/src/core/net.h @@ -0,0 +1,52 @@ +#ifndef WIREHUB_NET_H +#define WIREHUB_NET_H + +#include "common.h" +#include +#include +#include +#include +#include +#include + +#define IP4_HDRLEN 20 +#define UDP_HDRLEN 8 + +struct address { + int sa_family; + union { + struct sockaddr in; + struct sockaddr_in in4; + struct sockaddr_in6 in6; + }; +}; + +static inline uint16_t address_port(const struct address* a) { + switch (a->sa_family) { + case AF_INET: return ntohs(a->in4.sin_port); + case AF_INET6: return ntohs(a->in6.sin6_port); + default: return 0; + }; +} + +int parse_address(struct address* a, const char* endpoint, uint16_t port); +const char* format_address(const struct address* a, char* s, size_t sl); +int address_from_sockaddr(struct address* out, const struct sockaddr* in); +socklen_t address_len(const struct address* a); +void orchid(struct address* a, const void* cid, size_t cid_sz, const void* m, size_t l, uint16_t port); + +int socket_udp(const struct address* a); +int socket_raw_udp(sa_family_t sa_family, int hdrincl); +int ip4_to_udp(const void* d, const void** pdata, size_t* psize, struct address* src, struct address* dst); + +enum sniff_proto { + SNIFF_PROTO_WG, + SNIFF_PROTO_WH, +}; + +pcap_t* sniff(const char* interface, pcap_direction_t direction, enum sniff_proto proto, const char* expr); + +uint16_t checksum_ip(const void* addr, int len); + +#endif // WIREHUB_NET_H + diff --git a/src/core/orchid.c b/src/core/orchid.c new file mode 100644 index 0000000..c3f5df1 --- /dev/null +++ b/src/core/orchid.c @@ -0,0 +1,26 @@ +#include "net.h" +#include + +void orchid(struct address* a, const void* cid, size_t cid_sz, const void* m, size_t l, uint16_t port) { + unsigned char hash[crypto_generichash_BYTES]; + crypto_generichash_state s; + crypto_generichash_init(&s, NULL, 0, sizeof(hash)); + crypto_generichash_update(&s, (const void*)cid, cid_sz); + crypto_generichash_update(&s, (const void*)m, l); + crypto_generichash_final(&s, hash, sizeof(hash)); + + a->sa_family = a->in6.sin6_family = AF_INET6; + a->in6.sin6_port = htons(port); + + // XXX RFC 4843 states to get the middle 100-bit-long bitstring from the + // hash + + assert(sizeof(a->in6.sin6_addr) <= crypto_generichash_BYTES); + memcpy((uint8_t*)&a->in6.sin6_addr, hash, sizeof(a->in6.sin6_addr)); + ((uint8_t*)&a->in6.sin6_addr)[0] = 0x20; + ((uint8_t*)&a->in6.sin6_addr)[1] = 0x01; + ((uint8_t*)&a->in6.sin6_addr)[2] &= 0x0f; + ((uint8_t*)&a->in6.sin6_addr)[2] |= 0x10; +} + + diff --git a/src/core/os.c b/src/core/os.c new file mode 100644 index 0000000..db30c32 --- /dev/null +++ b/src/core/os.c @@ -0,0 +1,7 @@ +#include "common.h" +#include + +uint64_t now_seconds(void) { + return time(NULL); +} + diff --git a/src/core/os.h b/src/core/os.h new file mode 100644 index 0000000..fa3fd79 --- /dev/null +++ b/src/core/os.h @@ -0,0 +1,8 @@ +#ifndef WIREHUB_OS_H +#define WIREHUB_OS_H + +uint64_t now_seconds(void); + +#endif // WIREHUB_OS_H + + diff --git a/src/core/packet.c b/src/core/packet.c new file mode 100644 index 0000000..898ef0c --- /dev/null +++ b/src/core/packet.c @@ -0,0 +1,40 @@ +#include "packet.h" + +int auth_packet(uint8_t* p, size_t l, const uint8_t* sk, const uint8_t* pk) { + uint8_t k[crypto_scalarmult_curve25519_SCALARBYTES]; + sodium_mlock(k, sizeof(k)); + + if (crypto_scalarmult_curve25519(k, sk, pk) != 0) { + return -1; + } + + crypto_auth_hmacsha512256(packet_mac(p, l), p, packet_mac(p, l)-p, k); + + sodium_munlock(k, sizeof(k)); + + return 0; +} + +int verify_packet(const uint8_t* p, size_t pl, const uint8_t* sk) { + if (pl + +#define packet_flags_TIMEMASK (((uint64_t)-1) >> 1) +#define packet_flags_TIMESHIFT 0 +#define packet_flags_DIRECTMASK 0x1 +#define packet_flags_DIRECTSHIFT 63 + +#define packet_hdr(p) (p+0) +#define packet_src(p) (packet_hdr(p)+4) +#define packet_flags_time(p) (packet_src(p)+crypto_scalarmult_curve25519_BYTES) +#define packet_body(p) (packet_flags_time(p)+8) +#define packet_mac(p,l) (packet_body(p)+l) + +static inline size_t packet_size(size_t l) { + return ( + 4 + + crypto_scalarmult_curve25519_BYTES + + //crypto_scalarmult_curve25519_BYTES + + 8 + + l + + crypto_auth_hmacsha512256_BYTES + ); +} + +int auth_packet(uint8_t* p, size_t l, const uint8_t* sk, const uint8_t* pk); +int verify_packet(const uint8_t* p, size_t pl, const uint8_t* sk); + +#endif // PACKET_H + diff --git a/src/core/pcap.c b/src/core/pcap.c new file mode 100644 index 0000000..6e76a27 --- /dev/null +++ b/src/core/pcap.c @@ -0,0 +1,92 @@ +#include "net.h" + +pcap_t* sniff(const char* interface, pcap_direction_t direction, enum sniff_proto proto, const char* expr) { + assert(interface); + // expr may be NULL + + pcap_t* h; + + if (!(h=pcap_create(interface, NULL))) { + return NULL; + } + + if (pcap_set_timeout(h, 1000)) { + pcap_close(h); + return NULL; + } + + /*const int wh_sniff_buffer_size = 64 * 1024; + if (pcap_set_buffer_size(h, wh_sniff_buffer_size)) { + pcap_close(h); + return NULL; + }*/ + + if (pcap_set_immediate_mode(h, 1)) { + pcap_close(h); + return NULL; + } + + int err = pcap_activate(h); + if (err != 0) { + fprintf(stderr, "error: %s\n", pcap_geterr(h)); + pcap_close(h); + return NULL; + } + + if (pcap_setnonblock(h, 1, NULL) == PCAP_ERROR) { + pcap_close(h); + return NULL; + } + + if (pcap_setdirection(h, direction) == PCAP_ERROR) { + pcap_close(h); + return NULL; + } + + // XXX COMPILER ASSERT + assert(sizeof(wh_pkt_hdr)==4); + + char filter_exp[256]; + switch (proto) { + case SNIFF_PROTO_WG: + snprintf(filter_exp, sizeof(filter_exp), + "udp and udp[8] & 0xf8 == 0 and udp[9]==%d and udp[10]==%d and udp[11]==%d%s", + (int)wh_pkt_hdr[1], + (int)wh_pkt_hdr[2], + (int)wh_pkt_hdr[3], + expr ? expr : "" + ); + break; + + case SNIFF_PROTO_WH: + snprintf(filter_exp, sizeof(filter_exp), + "udp and udp[8]==%d and udp[9]==%d and udp[10]==%d and udp[11]==%d%s", + (int)wh_pkt_hdr[0], + (int)wh_pkt_hdr[1], + (int)wh_pkt_hdr[2], + (int)wh_pkt_hdr[3], + expr ? expr : "" + ); + break; + }; + + struct bpf_program filter; + const int optimize = 0; + if (pcap_compile(h, &filter, filter_exp, optimize, 0) == PCAP_ERROR) { + fprintf(stderr, "error: %s\n", pcap_geterr(h)); + pcap_close(h); + return NULL; + } + + int r = pcap_setfilter(h, &filter); + pcap_freecode(&filter); + + if (r == PCAP_ERROR) { + pcap_close(h); + return NULL; + } + + return h; +} + + diff --git a/src/core/secretdata.c b/src/core/secretdata.c new file mode 100644 index 0000000..8201056 --- /dev/null +++ b/src/core/secretdata.c @@ -0,0 +1,63 @@ +#include "luawh.h" +#include + +static const char* mt = "secret"; + +void* luaW_newsecret(lua_State* L, size_t len) { + void* p = sodium_malloc(sizeof(size_t) + len); + size_t *plen = p; + void* buf = p + sizeof(size_t); + + luaW_declptr(L, mt, sodium_free); + luaW_pushptr(L, mt, p); + + *plen = len; + return buf; +} + +void* luaW_checksecret(lua_State* L, int idx, size_t len) { + void* p = luaW_checkptr(L, idx, mt); + size_t *plen = (size_t*)p; + + if (*plen != len) { + luaL_error(L, "bad secret size (%d expected, got %d)", + (int)len, + (int)*plen + ); + } + + return p+sizeof(size_t); +} + +void* luaW_tosecret(lua_State* L, int idx, size_t len) { + void* p = luaW_toptr(L, idx, mt); + + if (!p) { + return NULL; + } + + size_t *plen = (size_t*)p; + if (*plen != len) { + return NULL; + } + + return p+sizeof(size_t); +} + +void* luaW_ownsecret(lua_State* L, int idx, size_t len) { + void* p = luaW_ownptr(L, idx, mt); + size_t *plen = (size_t*)p; + + if (*plen != len) { + luaL_error(L, "bad secret size (%zu expected, got %zu)", + len, + *plen + ); + } + + return p+sizeof(size_t); +} + +void luaW_freesecret(void* p) { + sodium_free(p-sizeof(size_t)); +} diff --git a/src/core/serdes.c b/src/core/serdes.c new file mode 100644 index 0000000..85e8c27 --- /dev/null +++ b/src/core/serdes.c @@ -0,0 +1,208 @@ +#include +#include +#include +#include "serdes.h" +#include "luawh.h" + +struct load_ud { + int fd; + size_t sz; + char buf[256]; +}; + +static const char* _deser_load(lua_State* L, void* data, size_t* psize) { + (void)L; + + struct load_ud* ld = (struct load_ud*)data; + + *psize = ld->sz; + + if (*psize > sizeof(ld->buf)) { + *psize = sizeof(ld->buf); + } + + if (*psize > 0) { + ssize_t ret = read(ld->fd, ld->buf, *psize); + assert(ret>=0); + assert((size_t)ret<=ld->sz); + *psize = ret; + ld->sz -= ret; + } + + return ld->buf; +} + +#define READ(var) assert(read(fd, &var, sizeof(var)) == sizeof(var)) +int luaW_read(lua_State* L, int fd) { + char type_b; + uint8_t u8; + size_t sz; + lua_Number number; + char* str; + int load_ret; + struct load_ud ld; + + READ(type_b); + + switch ((int)type_b) { + case LUA_TNONE: + return 0; + + case LUA_TNIL: + lua_pushnil(L); + break; + + case LUA_TBOOLEAN: + READ(u8); + lua_pushboolean(L, u8); + break; + + case LUA_TNUMBER: + READ(number); + lua_pushnumber(L, number); + break; + + case LUA_TSTRING: + READ(sz); + // XXX stream? + str = malloc(sz); + assert(read(fd, str, sz) == (ssize_t)sz); + lua_pushlstring(L, str, sz); + free(str); + break; + + case LUA_TTABLE: + lua_newtable(L); + for (;;) { + int r = luaW_read(L, fd); + if (r == -1) { + return -1; + } else if (r == 0) { + break; + } + + if (luaW_read(L, fd) < 0) { + return -1; + } + + lua_rawset(L, -3); + } + break; + + case LUA_TFUNCTION: + READ(ld.sz); + ld.fd = fd; + load_ret = lua_load(L, _deser_load, &ld, "work", NULL); + if (load_ret != LUA_OK) { + lua_pushnumber(L, load_ret); + lua_insert(L, -2); + + while(lua_gettop(L) > 2) { + lua_remove(L, 1); + } + + return -1; + } + break; + + case LUA_TLIGHTUSERDATA: + case LUA_TUSERDATA: + case LUA_TTHREAD: + assert(0 && "unhandled type"); + }; + + return 1; +} + +int luaW_readstack(lua_State* L, int fd) { + int ret; + while((ret = luaW_read(L, fd)) != 0); + return ret; +} +#undef READ + +#define WRITE(var) assert(write(fd, &var, sizeof(var)) == sizeof(var)) +static int _ser_dump(lua_State *L, const void* b, size_t size, void* ud) { + (void)L; + luaL_Buffer* buf = (luaL_Buffer*)ud; + luaL_addlstring(buf, (const char*)b, size); + return 0; +} + +int luaW_write(lua_State* L, int idx, int fd) { + char type_b = lua_type(L, idx); + uint8_t u8; + lua_Number number; + size_t sz; + const char* str; + luaL_Buffer buf; + + WRITE(type_b); + + switch ((int)type_b) { + case LUA_TNONE: + return 0; + + case LUA_TNIL: + break; + + case LUA_TBOOLEAN: + u8 = lua_toboolean(L, idx); + WRITE(u8); + break; + + case LUA_TNUMBER: + number = lua_tonumber(L, idx); + WRITE(number); + break; + + case LUA_TSTRING: + str = lua_tolstring(L, idx, &sz); + WRITE(sz); + assert(write(fd, str, sz) == (ssize_t)sz); + break; + + case LUA_TTABLE: + lua_pushnil(L); /* first key */ + while (lua_next(L, idx) != 0) { + assert(luaW_write(L, lua_gettop(L)-1, fd) == 1); + assert(luaW_write(L, lua_gettop(L), fd) == 1); + lua_pop(L, 1); + } + type_b = LUA_TNONE; + WRITE(type_b); + break; + + case LUA_TFUNCTION: + luaL_buffinit(L, &buf); + lua_pushvalue(L, idx); + if (lua_dump(L, _ser_dump, &buf, 0) != 0) { + return luaL_error(L, "unable to dump given function"); + } + lua_pop(L, 1); + luaL_pushresult(&buf); + str = lua_tolstring(L, -1, &sz); + WRITE(sz); + assert(write(fd, str, sz) == (ssize_t)sz); + lua_pop(L, 1); + break; + + case LUA_TLIGHTUSERDATA: + case LUA_TUSERDATA: + case LUA_TTHREAD: + luaL_error(L, "unhandled type: %s", lua_typename(L, type_b)); + } + + return 1; +} + +void luaW_writestack(lua_State* L, int idx, int fd) { + if (idx < 0) { + idx = lua_gettop(L)+idx+1; + } + + for (; luaW_write(L, idx, fd); ++idx); +} + +#undef WRITE + diff --git a/src/core/serdes.h b/src/core/serdes.h new file mode 100644 index 0000000..7e1d458 --- /dev/null +++ b/src/core/serdes.h @@ -0,0 +1,33 @@ +#ifndef WH_SERDES_H +#define WH_SERDES_H + +#include + +/** Reads `fd`, deserializes one element and pushes it in the stack. + * + * Returns 1 if an element was read; 0 if none element was read; -1 if + * deserialization failed. + */ +int luaW_read(lua_State* L, int fd); + +/** Reads `fd`, deserializes all elements and pushes them in the stack. + * + * Returns 0 if succeed, else -1. + */ +int luaW_readstack(lua_State* L, int fd); + +/** Serializes element at index `idx` and writes it in `fd`. + * + * Returns 1 if an element was written; 0 if not element was written; raises a + * Lua error if something went wrong. + */ +int luaW_write(lua_State* L, int idx, int fd); + +/** Serializes the stack from index `idx` and writes them in `fd`. + * + * Raises a Lua error if something went wrong. + */ +void luaW_writestack(lua_State* L, int idx, int fd); + +#endif // WH_SERDES_H + diff --git a/src/core/smartptr.c b/src/core/smartptr.c new file mode 100644 index 0000000..3069953 --- /dev/null +++ b/src/core/smartptr.c @@ -0,0 +1,122 @@ +#include "luawh.h" + +static const char* _registry_field = "ptrmt"; + +static int _free(lua_State* L) { + void** pp = lua_touserdata(L, 1); + void(*del)(void*) = lua_touserdata(L, lua_upvalueindex(1)); + + if (*pp) { + del(*pp), *pp = NULL; + } + + return 0; +} + +static int _tostring(lua_State* L) { + void** pp = lua_touserdata(L, 1); + + lua_getmetatable(L, 1); + lua_getfield(L, -1, "mt"); + lua_remove(L, -2); + + if (*pp) { + lua_pushfstring(L, "*: %p", *pp); + } else { + lua_pushstring(L, "*: "); + } + lua_concat(L, 2); + + return 1; +} + +static void _newmt(lua_State* L, const char* mt, void(*del)(void*)) { + if (luaL_newmetatable(L, mt)) { + // build metatable + lua_pushstring(L, mt); + lua_setfield(L, -2, "mt"); + + lua_newtable(L); // __index + + lua_pushlightuserdata(L, del); + lua_pushcclosure(L, _free, 1); + lua_pushvalue(L, -1); + lua_setfield(L, -3, "free"); + lua_setfield(L, -3, "__gc"); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, _tostring); + lua_setfield(L, -2, "__tostring"); + + // lazy build of the registry's field + if (lua_getfield(L, LUA_REGISTRYINDEX, _registry_field) == LUA_TNIL) { + lua_pop(L, 1); + lua_newtable(L); + lua_pushvalue(L, -1); + lua_setfield(L, LUA_REGISTRYINDEX, _registry_field); + } + + // register metatable as a pointer's + lua_pushvalue(L, -2); + lua_pushboolean(L, 1); + lua_settable(L, -3); + lua_pop(L, 1); + } +} + +void luaW_declptr(lua_State* L, const char* mt, void(*del)(void*)) { + _newmt(L, mt, del); + lua_pop(L, 1); +} + +void luaW_pushptr(lua_State* L, const char* mt, void* ptr) { + void** pp = lua_newuserdata(L, sizeof(void*)); + *pp = ptr; + + // free is given as a default deleter. if type was previously declared, free + // will be ignored + _newmt(L, mt, free); + lua_setmetatable(L, -2); +} + +static void* getptr(lua_State* L, int idx, const char* mt, int error, int own) { + void **pp = luaL_testudata(L, idx, mt); + if (!pp) { + if (error) { + luaL_error(L, "bad type (pointer %s expected, got %s)", mt, + lua_typename(L, idx)); + } else { + return NULL; + } + } + + void *p = *pp; + if (!p) { + if (error) { + luaL_error(L, "dangling pointer #%d", idx); + } else { + return NULL; + } + } + + if (own) { + *pp = NULL; + } + + return p; +} + +void* luaW_checkptr(lua_State* L, int idx, const char* mt) { + luaL_checkudata(L, idx, mt); + return getptr(L, idx, mt, 1, 0); +} + +void* luaW_ownptr(lua_State* L, int idx, const char* mt) { + return getptr(L, idx, mt, 1, 1); +} + +void* luaW_toptr(lua_State* L, int idx, const char* mt) { + return getptr(L, idx, mt, 0, 0); +} + + diff --git a/src/core/tun.c b/src/core/tun.c new file mode 100644 index 0000000..ea56ee1 --- /dev/null +++ b/src/core/tun.c @@ -0,0 +1,276 @@ +#include "luawh.h" +#include "net.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define MT "tun" + +struct tun { + char* interface; + int fd; +}; + +static void delete_tun(struct tun* t) { + if (t->fd >= 0) { close(t->fd), t->fd = -1; } + if (t->interface) { free(t->interface), t->interface = NULL; } + + free(t); +} + +static void delete_tun_pvoid(void* t) { + return delete_tun((struct tun*)t); +} + +int luaW_newtun(lua_State* L) { + const char* interface = luaL_checkstring(L, 1); + const char* subnet = luaL_checkstring(L, 2); + int mtu = luaL_checkinteger(L, 3); + + if (mtu < 576) { + luaL_error(L, "MTU too small"); + } + + struct tun* t = calloc(1, sizeof(struct tun)); + + if ((t->fd = open("/dev/net/tun", O_RDWR)) < 0) { + delete_tun(t); + luaL_error(L, "could not open \"/dev/net/tun\": %s", strerror(errno)); + } + + struct ifreq ifr; + memset(&ifr, 0, sizeof(ifr)); + ifr.ifr_flags = IFF_TUN | IFF_NO_PI; + if (interface) { + strncpy(ifr.ifr_name, interface, IFNAMSIZ); + } + + if (ioctl(t->fd, TUNSETIFF, (void *) &ifr) < 0) { + delete_tun(t); + luaL_error(L, "ioctl() failed: %s", strerror(errno)); + } + + t->interface = strdup(ifr.ifr_name); + + if (fcntl(t->fd, F_SETFL, fcntl(t->fd, F_GETFL, 0) | O_NONBLOCK) == -1) { + delete_tun(t); + luaL_error(L, "fcntl() failed: %s", strerror(errno)); + } + + // XXX + { + char cmd[128]; + snprintf(cmd, sizeof(cmd), "ip link set dev %s mtu %d", t->interface, mtu); + if (system(cmd) < 0) { + delete_tun(t); + luaL_error(L, "set mtu failed: %s", strerror(errno)); + } + + snprintf(cmd, sizeof(cmd), "ip addr add %s dev %s", subnet, t->interface); + if (system(cmd) < 0) { + delete_tun(t); + luaL_error(L, "set address failed: %s", strerror(errno)); + } + + snprintf(cmd, sizeof(cmd), "ip link set %s up", t->interface); + if (system(cmd) < 0) { + delete_tun(t); + luaL_error(L, "tun up failed: %s", strerror(errno)); + } + } + + luaW_pushptr(L, MT, t); + return 1; +} + +static int _read(lua_State* L) { + struct tun* t = luaW_checkptr(L, 1, MT); + + char pkt[2048]; + ssize_t r = read(t->fd, pkt, sizeof(pkt)); + + if (r < 0 && errno == EAGAIN) { + return 0; + } + + if (r < 0) { + luaL_error(L, "read error(): %s", strerror(errno)); + } + + if (r < 5) { + lua_pushboolean(L, 0); + lua_pushstring(L, "malformed packet"); + return 2; + } + + /* + const size_t tun_hdr_sz = 4; + struct tun_pi* tun_hdr = (struct tun_pi*)pkt; + tun_hdr->proto = ntohs(tun_hdr->proto); + + if (tun_hdr->proto != ETHERTYPE_IP) { + lua_pushboolean(L, 0); + lua_pushstring(L, "unknown protocol"); + return 2; + }*/ + + const struct ip* ip_hdr = (struct ip*)pkt; + size_t ip_hdr_sz = ip_hdr->ip_hl*sizeof(uint32_t); + if ((size_t)r < ip_hdr_sz) { + lua_pushboolean(L, 0); + lua_pushstring(L, "malformed IP header"); + return 2; + } + + if (ip_hdr->ip_hl != 5) { + lua_pushboolean(L, 0); + lua_pushstring(L, "unhandled IP4 options"); + return 2; + } + + size_t l = ntohs(ip_hdr->ip_len); + if (l < ip_hdr_sz) { + lua_pushboolean(L, 0); + lua_pushstring(L, "malformed IP header"); + return 2; + } + l -= ip_hdr_sz; + + if ( + ip_hdr->ip_p != IPPROTO_UDP && + (!WH_TUN_ICMP || ip_hdr->ip_p != IPPROTO_ICMP) + ) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "unhandled protocol: %d", ip_hdr->ip_p); + return 2; + } + + // IPv4 UDP fragmentation + uint16_t ip_off = ntohs(ip_hdr->ip_off); + if (ip_hdr->ip_p == IPPROTO_UDP && + (ip_off&IP_OFFMASK) == 0x0000 + ) { + // packet contains UDP header + if (l < UDP_HDRLEN) { + lua_pushboolean(L, 0); + lua_pushstring(L, "malformed UDP header"); + return 2; + } + + struct udphdr* udp_hdr = (struct udphdr*)(pkt+ip_hdr_sz); + udp_hdr->uh_sport = 0; + udp_hdr->uh_dport = 0; + udp_hdr->uh_sum = 0; + } + + const void* m = pkt+ip_hdr_sz; + + lua_pushboolean(L, 1); + + struct address* src = luaW_newaddress(L); + src->sa_family = src->in4.sin_family = AF_INET; + memcpy(&src->in4.sin_addr, &ip_hdr->ip_src, 4); + src->in4.sin_port = 0; + + struct address* dst = luaW_newaddress(L); + dst->sa_family = dst->in4.sin_family = AF_INET; + memcpy(&dst->in4.sin_addr, &ip_hdr->ip_dst, 4); + dst->in4.sin_port = 0; + + lua_pushlstring(L, (const void*)&ip_hdr->ip_id, sizeof(ip_hdr->ip_id)); + lua_pushlstring(L, (const void*)&ip_hdr->ip_off, sizeof(ip_hdr->ip_off)); + lua_pushlstring(L, m, l); + lua_concat(L, 3); + + return 4; +} + +static int _write(lua_State* L) { + struct tun* t = luaW_checkptr(L, 1, MT); + struct address* src = luaL_checkudata(L, 2, "address"); + struct address* dst = luaL_checkudata(L, 3, "address"); + size_t l; + const void* m = luaL_checklstring(L, 4, &l); + + if (src->sa_family != AF_INET || dst->sa_family != AF_INET) { + luaL_error(L, "address must be IP4"); + } + + if (l < 4) { + luaL_error(L, "malformed raw packet"); + } + + size_t buf_sz = ( + IP4_HDRLEN + // ip + UDP_HDRLEN + // udp + l-4 + ); + + void* buf = calloc(1, buf_sz); + + struct ip* ip_hdr = buf; + ip_hdr->ip_v = 4; + assert(IP4_HDRLEN%sizeof(uint32_t)==0); + ip_hdr->ip_hl = IP4_HDRLEN/sizeof(uint32_t); + ip_hdr->ip_len = htons(IP4_HDRLEN+UDP_HDRLEN+l-4); + ip_hdr->ip_id = *(uint16_t*)(m+0); + ip_hdr->ip_off = *(uint16_t*)(m+2); + ip_hdr->ip_ttl = 255; + ip_hdr->ip_p = IPPROTO_UDP; + memcpy(&ip_hdr->ip_src, &src->in4.sin_addr, 4); + memcpy(&ip_hdr->ip_dst, &dst->in4.sin_addr, 4); + ip_hdr->ip_sum = checksum_ip(ip_hdr, IP4_HDRLEN); + + struct udphdr* udp_hdr = buf+IP4_HDRLEN; + udp_hdr->uh_sport = src->in4.sin_port; + udp_hdr->uh_dport = dst->in4.sin_port; + udp_hdr->uh_ulen = htons(UDP_HDRLEN+l-4); + + memcpy(buf+IP4_HDRLEN+UDP_HDRLEN, m+4, l-4); + + int w_ret = write(t->fd, buf, buf_sz); + free(buf); + + if (w_ret < 0) { + luaL_error(L, "write() error: %s", strerror(errno)); + } + + if ((size_t)w_ret != buf_sz) { + luaL_error(L, "truncated write()"); + } + + return 0; +} + +static int _info(lua_State* L) { + struct tun* t = luaW_checkptr(L, 1, MT); + lua_pushinteger(L, t->fd); + lua_pushstring(L, t->interface); + return 2; +} + +LUAMOD_API int luaopen_tun(lua_State* L) { + luaW_declptr(L, MT, delete_tun_pvoid); + + luaL_getmetatable(L, MT); + lua_getfield(L, -1, "__index"); + lua_pushcfunction(L, _info); + lua_setfield(L, -2, "info"); + lua_pushcfunction(L, _read); + lua_setfield(L, -2, "read"); + lua_pushcfunction(L, _write); + lua_setfield(L, -2, "write"); + lua_pop(L, 2); + + lua_pushcfunction(L, luaW_newtun); + return 1; +} + diff --git a/src/core/wglib.c b/src/core/wglib.c new file mode 100644 index 0000000..9b2cf5d --- /dev/null +++ b/src/core/wglib.c @@ -0,0 +1,497 @@ +#include "wireguard.h" +#include "luawh.h" +#include + +int check_linux_version(void) { + struct utsname name; + uname(&name); + + char minver[] = WH_LINUX_MINVERSION; + char* t = name.release; + unsigned int i = 0; + while (i minver[i]) { + return 1; + } else if (v < minver[i]) { + return 0; + } + + t = NULL; + ++i; + } + + return 1; +} + +int check_wireguard_module(void) { + FILE* fh = fopen("/proc/modules", "r"); + + int found = 0; + char buf[256]; + while (fgets(buf, sizeof(buf), fh)) { + char modname[64]; + sscanf(buf, "%s", modname); + + if (strcmp(modname, "wireguard") == 0) { + found = 1; + break; + } + } + + fclose(fh); + + return found; +} + +static int _check(lua_State* L) { + if (!check_linux_version()) { + lua_pushstring(L, "oldkernel"); + } + + else if (!check_wireguard_module()) { + lua_pushstring(L, "notloaded"); + } + + else { + lua_pushstring(L, "ok"); + } + + return 1; +} + +static void _push_allowedip(lua_State* L, struct wg_allowedip* allowedip) { + lua_newtable(L); + + struct address* ip = luaW_newaddress(L); + ip->sa_family = allowedip->family; + switch (ip->sa_family) { + case AF_INET: memcpy(&ip->in4.sin_addr, &allowedip->ip4, sizeof(allowedip->ip4)); break; + case AF_INET6: memcpy(&ip->in6.sin6_addr, &allowedip->ip6, sizeof(allowedip->ip6)); break; + default: perror("unknown sa_family"); + }; + lua_rawseti(L, -2, 1); + + lua_pushinteger(L, allowedip->cidr); + lua_rawseti(L, -2, 2); +} + +static void _push_peer(lua_State* L, struct wg_peer* p) { + lua_newtable(L); + + if (p->flags & WGPEER_HAS_PUBLIC_KEY) { + lua_pushlstring(L, (const void*)p->public_key, 32); + lua_setfield(L, -2, "public_key"); + } + + if (p->flags & WGPEER_HAS_PRESHARED_KEY) { + memcpy(luaW_newsecret(L, 32), p->preshared_key, 32); + lua_setfield(L, -2, "preshared_key"); + } + + struct address* endpoint = luaW_newaddress(L); + endpoint->sa_family = p->endpoint.addr.sa_family; + switch (endpoint->sa_family) { + case AF_INET: memcpy(&endpoint->in4, &p->endpoint.addr4, sizeof(endpoint->in4)); break; + case AF_INET6: memcpy(&endpoint->in6, &p->endpoint.addr6, sizeof(endpoint->in6)); break; + default: perror("unknown sa_family"); + }; + lua_setfield(L, -2, "endpoint"); + + lua_pushnumber(L, p->last_handshake_time.tv_sec + (double)p->last_handshake_time.tv_nsec / 1.0e9); + lua_setfield(L, -2, "last_handshake_time"); + + lua_pushnumber(L, p->rx_bytes); + lua_setfield(L, -2, "rx_bytes"); + + lua_pushnumber(L, p->tx_bytes); + lua_setfield(L, -2, "tx_bytes"); + + lua_pushnumber(L, p->persistent_keepalive_interval); + lua_setfield(L, -2, "persistent_keepalive_interval"); + + lua_newtable(L); + struct wg_allowedip* allowedip; + wg_for_each_allowedip(p, allowedip) { + _push_allowedip(L, allowedip); + lua_rawseti(L, -2, luaL_len(L, -2)+1); + } + lua_setfield(L, -2, "allowedips"); +} + +static void _push_device(lua_State* L, struct wg_device* d) { + lua_newtable(L); + + lua_pushstring(L, d->name); + lua_setfield(L, -2, "name"); + + lua_pushinteger(L, d->ifindex); + lua_setfield(L, -2, "ifindex"); + + if (d->flags & WGDEVICE_HAS_PUBLIC_KEY) { + lua_pushlstring(L, (const void*)d->public_key, 32); + lua_setfield(L, -2, "public_key"); + } + + if (d->flags & WGDEVICE_HAS_PRIVATE_KEY) { + memcpy(luaW_newsecret(L, 32), d->private_key, 32); + lua_setfield(L, -2, "private_key"); + } + + lua_pushinteger(L, d->fwmark); + lua_setfield(L, -2, "fwmark"); + + lua_pushinteger(L, d->listen_port); + lua_setfield(L, -2, "listen_port"); + + lua_newtable(L); + struct wg_peer* p; + wg_for_each_peer(d, p) { + _push_peer(L, p); + lua_rawseti(L, -2, luaL_len(L, -2)+1); + } + lua_setfield(L, -2, "peers"); +} + +#define IFK(x) if (strcmp(k, x) == 0) + +static void _check_allowedip(lua_State* L, int idx, struct wg_device* d, struct wg_peer* p) { + assert(idx > 0); + + struct wg_allowedip* allowedip = calloc(1, sizeof(struct wg_allowedip)); + + if (p->last_allowedip) { + p->last_allowedip->next_allowedip = allowedip; + } else { + p->first_allowedip = allowedip; + } + p->last_allowedip = allowedip; + + lua_rawgeti(L, idx, 1); + struct address* ip = luaL_testudata(L, -1, "address"); + if (!ip) { + wg_free_device(d); + luaL_error(L, "invalid allowedip's ip"); + } + + allowedip->family = ip->sa_family; + int max_cidr; + switch (ip->sa_family) { + case AF_INET: + memcpy(&allowedip->ip4, &ip->in4.sin_addr, sizeof(allowedip->ip4)); + max_cidr = 32; + break; + case AF_INET6: + memcpy(&allowedip->ip6, &ip->in6.sin6_addr, sizeof(allowedip->ip6)); + max_cidr = 128; + break; + + default: + return perror("unknown sa_family"); + }; + lua_pop(L, 1); + + lua_rawgeti(L, idx, 2); + int isnum; + lua_Integer cidr = lua_tointegerx(L, -1, &isnum); + if (!isnum || cidr < 0 || max_cidr < cidr) { + wg_free_device(d); + luaL_error(L, "invalid CIDR"); + } + allowedip->cidr = cidr; + lua_pop(L, 1); +} + +static void _check_peer(lua_State* L, int idx, struct wg_device* d) { + assert(idx > 0); + + struct wg_peer* p = calloc(1, sizeof(struct wg_peer)); + + if (d->last_peer) { + d->last_peer->next_peer = p; + } else { + d->first_peer = p; + } + d->last_peer = p; + + for (lua_pushnil(L); lua_next(L, idx) != 0; lua_pop(L, 1)) { + const char* k = lua_tostring(L, -2); + + IFK("remove_me") { + if (lua_toboolean(L, -1)) { + p->flags |= WGPEER_REMOVE_ME; + } + } + + else IFK("replace_allowedips") { + if (lua_toboolean(L, -1)) { + p->flags |= WGPEER_REPLACE_ALLOWEDIPS; + } + } + + else IFK("public_key") { + size_t sz; + const char* k = lua_tolstring(L, -1, &sz); + + if (sz != 32) { + wg_free_device(d); + luaL_error(L, "invalid public key"); + } + + memcpy(p->public_key, k, 32); + p->flags |= WGPEER_HAS_PUBLIC_KEY; + } + + else IFK("preshared_key") { + const void* preshared_key = luaW_tosecret(L, -1, 32); + if (!preshared_key) { + wg_free_device(d); + luaL_error(L, "invalid preshared key"); + } + + memcpy(p->preshared_key, preshared_key, 32); + p->flags |= WGPEER_HAS_PRESHARED_KEY; + } + + else IFK("endpoint") { + struct address* endpoint = luaL_testudata(L, -1, "address"); + + if (!endpoint) { + wg_free_device(d); + luaL_error(L, "invalid endpoint"); + } + + p->endpoint.addr.sa_family = endpoint->sa_family; + switch (endpoint->sa_family) { + case AF_INET: memcpy(&p->endpoint.addr4, &endpoint->in4, sizeof(endpoint->in4)); break; + case AF_INET6: memcpy(&p->endpoint.addr6, &endpoint->in6, sizeof(endpoint->in6)); break; + default: perror("unknown sa_family"); + }; + } + + else IFK("persistent_keepalive_interval") { + int isnum; + lua_Integer v = lua_tointegerx(L, -1, &isnum); + + if (!isnum || v < 0) { + luaL_error(L, "invalid persistent_keepalive_interval"); + } + + p->persistent_keepalive_interval = v; + p->flags |= WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL; + } + + else IFK("allowedips") { + for (lua_Integer i = 1; i <= luaL_len(L, -1); ++i) { + lua_rawgeti(L, -1, i); + _check_allowedip(L, lua_gettop(L), d, p); + lua_pop(L, 1); + } + } + } +} + +static struct wg_device* _check_device(lua_State* L, int idx) { + assert(idx > 0); + + luaL_checktype(L, idx, LUA_TTABLE); + + struct wg_device* d = calloc(1, sizeof(struct wg_device)); + for (lua_pushnil(L); lua_next(L, idx) != 0; lua_pop(L, 1)) { + const char* k = lua_tostring(L, -2); + + IFK("name") { + size_t sz; + const char* name = lua_tolstring(L, -1, &sz); + + if (sz > IFNAMSIZ) { + wg_free_device(d); + luaL_error(L, "device's name too long"); + } + + memcpy(d->name, name, sz); + } + + else IFK("ifindex") { + int isnum; + int ifindex = lua_tointegerx(L, -1, &isnum); + + if (!isnum || ifindex < 0 || UINT32_MAX < (unsigned int)ifindex) { + wg_free_device(d); + luaL_error(L, "invalid ifindex"); + } + + d->ifindex = ifindex; + } + + else IFK("replace_peers") { + if (lua_toboolean(L, -1)) { + d->flags |= WGDEVICE_REPLACE_PEERS; + } + } + + else IFK("public_key") { + size_t sz; + const char* k = lua_tolstring(L, -1, &sz); + + if (sz != 32) { + wg_free_device(d); + luaL_error(L, "invalid public key"); + } + + memcpy(d->public_key, k, 32); + d->flags |= WGDEVICE_HAS_PUBLIC_KEY; + } + + else IFK("private_key") { + const void* sk = luaW_tosecret(L, -1, 32); + if (!sk) { + wg_free_device(d); + luaL_error(L, "invalid private key"); + } + memcpy(d->private_key, sk, 32); + d->flags |= WGDEVICE_HAS_PRIVATE_KEY; + } + + else IFK("fwmark") { + int isnum; + int fwmark = lua_tointegerx(L, -1, &isnum); + + if (!isnum || fwmark < 0 || UINT16_MAX < (unsigned int)fwmark) { + wg_free_device(d); + luaL_error(L, "invalid fwmark"); + } + + d->fwmark = fwmark; + d->flags |= WGDEVICE_HAS_FWMARK; + } + + else IFK("listen_port") { + int isnum; + int listen_port = lua_tointegerx(L, -1, &isnum); + + if (!isnum || listen_port < 0 || UINT16_MAX < (unsigned int)listen_port) { + wg_free_device(d); + luaL_error(L, "invalid listen_port"); + } + + d->listen_port = listen_port; + d->flags |= WGDEVICE_HAS_LISTEN_PORT; + } + + else IFK("peers") { + for (lua_Integer i = 1; i <= luaL_len(L, -1); ++i) { + lua_rawgeti(L, -1, i); + _check_peer(L, lua_gettop(L), d); + lua_pop(L, 1); + } + } + + } + + return d; +} + +#undef IFK + +static int _list_device_names(lua_State* L) { + char* device_names = wg_list_device_names(); + if (!device_names) { + return luaL_error(L, "wg_list_device_names() failed"); + } + + lua_newtable(L); + char* device_name; + size_t len; + wg_for_each_device_name(device_names, device_name, len) { + lua_pushlstring(L, device_name, len); + lua_rawseti(L, -2, luaL_len(L, -2) + 1); + } + + free(device_names); + + return 1; +} + +static int _add_device(lua_State* L) { + const char* device_name = luaL_checkstring(L, 1); + int ret = wg_add_device(device_name); + if (ret < 0) { + return luaL_error(L, "wg_add_device() failed: %s", strerror(-ret)); + } + + return 0; +} + +static int _del_device(lua_State* L) { + const char* device_name = luaL_checkstring(L, 1); + int ret = wg_del_device(device_name); + if (ret < 0) { + return luaL_error(L, "wg_del_device() failed: %s", strerror(-ret)); + } + + return 0; +} + +static int _get_device(lua_State* L) { + const char* device_name = luaL_checkstring(L, 1); + struct wg_device* device; + int ret = wg_get_device(&device, device_name); + if (ret < 0) { + if (ret == -ENODEV) { + return 0; + } + + return luaL_error(L, "wg_get_device() failed: %s", strerror(-ret)); + } + + _push_device(L, device); + wg_free_device(device), device = NULL; + + return 1; +} + +static int _set_device(lua_State* L) { + struct wg_device* device = _check_device(L, 1); + int ret = wg_set_device(device); + wg_free_device(device), device = NULL; + + if (ret < 0) { + return luaL_error(L, "wg_set_device() failed: %s", strerror(-ret)); + } + + return 0; +} + +static const luaL_Reg funcs[] = { + {"add", _add_device}, + {"check", _check}, + {"delete", _del_device}, + {"get", _get_device}, + {"list_names", _list_device_names}, + {"set", _set_device}, + + { NULL, NULL} +}; + +LUAMOD_API int luaopen_wg(lua_State* L) { + luaL_checkversion(L); + luaL_newlib(L, funcs); + + { + char minversion[] = WH_LINUX_MINVERSION; + unsigned int i; + + lua_newtable(L); + for (i=0; i +#include +#include +#include +#include +#include +#include +#include + +static void _expanduser(lua_State* L) { + luaL_loadstring(L, + "return string.gsub(..., '~', function() return os.getenv('HOME') end)" + ); + + lua_insert(L, -2); + lua_call(L, 1, 1); +} + +struct pipe_event { + int fds[2]; +}; + +// genkey(str key, int workbit[, int num_threads]) +static int _genkey(lua_State* L) { + const char* key = luaL_checkstring(L, 1); + int workbit = luaL_checkinteger(L, 2); + + int num_threads = 1; + if (lua_gettop(L) == 3) { + num_threads = luaL_checkinteger(L, 3); + } + + if (num_threads == 0) { + num_threads = sysconf(_SC_NPROCESSORS_ONLN)*2-1; + } + + void* sign_sk = luaW_newsecret(L, crypto_sign_ed25519_SECRETKEYBYTES); + if (genkey(sign_sk, key, workbit, num_threads) < 0) { + luaL_error(L, "key generation failed"); + } + + uint8_t sign_pk[crypto_sign_ed25519_PUBLICKEYBYTES]; + crypto_sign_ed25519_sk_to_pk(sign_pk, sign_sk); + lua_pushlstring(L, (const void*)sign_pk, sizeof(sign_pk)); + + void* sk = luaW_newsecret(L, crypto_scalarmult_curve25519_BYTES); + crypto_sign_ed25519_sk_to_curve25519(sk, sign_sk); + + uint8_t pk[crypto_scalarmult_curve25519_BYTES]; + crypto_scalarmult_curve25519_base(pk, sk); + lua_pushlstring(L, (const void*)pk, sizeof(pk)); + + return 4; +} + +static int _publickey(lua_State* L) { + const void* sk; + + if (lua_type(L, 1) == LUA_TSTRING) { + size_t sz; + sk = lua_tolstring(L, 1, &sz); + if (sz != crypto_scalarmult_curve25519_BYTES) { + luaL_error(L, "bad length"); + } + } else { + sk = luaW_checksecret(L, 1, crypto_scalarmult_curve25519_BYTES); + } + + uint8_t pk[crypto_scalarmult_curve25519_BYTES]; + crypto_scalarmult_curve25519_base(pk, sk); + lua_pushlstring(L, (void*)pk, sizeof(pk)); + return 1; +} + +static int _readsk(lua_State* L) { + (void)luaL_checkstring(L, 1); + + if (lua_gettop(L) != 1) { + luaL_error(L, "function only takes one argument"); + } + + _expanduser(L); + + const char* filepath = lua_tostring(L, 1); + assert(filepath); + + FILE* fh = fopen(filepath, "rb"); + + if (!fh && errno == ENOENT) { + lua_pushnil(L); + return 1; + } + + else if (!fh) { + luaL_error(L, "cannot open file '%s': %s", filepath, strerror(errno)); + } + + int success = 0; + long l; + char* secret_b64 = NULL; + if ( + fseek(fh, 0, SEEK_END) >= 0 && + (l = ftell(fh)) >= crypto_scalarmult_curve25519_KEYBASE64BYTES && + fseek(fh, 0, SEEK_SET) >= 0 + ) { + secret_b64 = sodium_malloc(l); + success = fread(secret_b64, 1, l, fh) == (size_t)l; + } + + fclose(fh); + + if (success) { + void* secret = luaW_newsecret(L, crypto_scalarmult_curve25519_BYTES); + + size_t bin_l = crypto_scalarmult_curve25519_BYTES; + const int variant = sodium_base64_VARIANT_ORIGINAL; + const char* b64_end; + success = sodium_base642bin( + secret, + crypto_scalarmult_curve25519_BYTES, + secret_b64, + l, + NULL, + &bin_l, + &b64_end, + variant + ) == 0; + + success &= bin_l == crypto_scalarmult_curve25519_BYTES; + + if (!success) { + luaW_freesecret(luaW_ownsecret(L, -1, crypto_scalarmult_curve25519_BYTES)); + lua_pop(L, 1); + } + } + + if (secret_b64) { + sodium_free(secret_b64), secret_b64 = NULL; + } + + if (!success) { + luaL_error(L, "cannot read file '%s': %s", filepath, strerror(errno)); + } + + return 1; +} + +static int _wgkey(lua_State* L) { + if (lua_type(L, 1) == LUA_TSTRING) { + size_t l; + const void* pk = luaL_checklstring(L, 1, &l); + + if (l != crypto_sign_ed25519_PUBLICKEYBYTES) { + luaL_error(L, "bad public key"); + } + + uint8_t wg_pk[crypto_scalarmult_curve25519_BYTES]; + if (crypto_sign_ed25519_pk_to_curve25519(wg_pk, pk)) { + luaL_error(L, "bad public key"); + } + + lua_pushlstring(L, (void*)wg_pk, sizeof(wg_pk)); + } + + else { + void* sk = luaW_checksecret(L, 1, crypto_sign_ed25519_SECRETKEYBYTES); + void* wg_sk = luaW_newsecret(L, crypto_scalarmult_curve25519_BYTES); + if (crypto_sign_ed25519_sk_to_curve25519(wg_sk, sk)) { + luaL_error(L, "bad private key"); + } + } + + return 1; +} + +static int _workbit(lua_State* L) { + size_t l; + const void* pk = luaL_checklstring(L, 1, &l); + if (l != crypto_sign_ed25519_PUBLICKEYBYTES) { + luaL_error(L, "bad public key"); + } + + const void* k = luaL_checklstring(L, 2, &l); + lua_pushinteger(L, workbit(pk, k, l)); + return 1; +} + +static int _sandbox(lua_State* L) { + (void)L; + return 0; +} + +static int _netdevs(lua_State* L) { + pcap_if_t* devs; + char err[PCAP_ERRBUF_SIZE]; + if (pcap_findalldevs(&devs, err) == PCAP_ERROR) { + lua_pushboolean(L, 0); + lua_pushstring(L, err); + return 2; + } + + lua_newtable(L); + + int n; + pcap_if_t* ifi; + for (n=1, ifi=devs; ifi; ifi=ifi->next) { + lua_newtable(L); + + lua_pushstring(L, ifi->name); + lua_setfield(L, -2, "name"); + + lua_pushstring(L, ifi->description); + lua_setfield(L, -2, "description"); + + lua_newtable(L); + pcap_addr_t* ai; + int o; + for (o=1, ai=ifi->addresses; ai; ai=ai->next) { + lua_newtable(L); + +#define PUSH_SOCKADDR(v) \ + if (v) { \ + struct address* a = luaW_newaddress(L); \ + if (address_from_sockaddr(a, v) == -1) { \ + lua_pop(L, 2); \ + continue; \ + } \ + } else { \ + lua_pushnil(L); \ + } + + PUSH_SOCKADDR(ai->addr); + lua_setfield(L, -2, "addr"); + + PUSH_SOCKADDR(ai->netmask); + lua_setfield(L, -2, "netmask"); + + PUSH_SOCKADDR(ai->broadaddr); + lua_setfield(L, -2, "broadcast"); + + PUSH_SOCKADDR(ai->dstaddr); + lua_setfield(L, -2, "dest"); + +#undef PUSH_SOCKADDR + + lua_seti(L, -2, o++); + } + lua_setfield(L, -2, "addresses"); + +#define PUSH_FLAG(v, f) \ + lua_pushboolean(L, ((ifi->flags & (f)) == (f))); \ + lua_setfield(L, -2, v); + + PUSH_FLAG("loopback", PCAP_IF_LOOPBACK); + PUSH_FLAG("up", PCAP_IF_UP); + PUSH_FLAG("running", PCAP_IF_RUNNING); + PUSH_FLAG("wireless", PCAP_IF_WIRELESS); + //PUSH_FLAG("conn_status", PCAP_IF_CONNECTION_STATUS); + //PUSH_FLAG("conn_status_unknown", PCAP_IF_CONNECTION_STATUS_UNKNOWN); + //PUSH_FLAG("conn_status_connected", PCAP_IF_CONNECTION_STATUS_CONNECTED); + //PUSH_FLAG("conn_status_disconnected", PCAP_IF_CONNECTION_STATUS_DISCONNECTED); + //PUSH_FLAG("conn_status_not_applicable ", PCAP_IF_CONNECTION_STATUS_NOT_APPLICABLE); +#undef PUSH_FLAG + + lua_seti(L, -2, n++); + } + + pcap_freealldevs(devs), devs = NULL; + return 1; +} + +static int _address(lua_State* L) { + if (lua_type(L, 1) == LUA_TSTRING) { + uint16_t port = 0; + if (lua_gettop(L) >= 2) { + port = luaW_checkport(L, 2); + } + + struct address* a = luaW_newaddress(L); + if (parse_address(a, lua_tostring(L, 1), port) == -1) { + luaL_error(L, "bad address: %s", lua_tostring(L, 1)); + } + + return 1; + } + + else { + luaL_checkudata(L, 1, "sockaddr"); + lua_pushvalue(L, 1); + return 1; + } +} + +static int _unpack_address(lua_State* L) { + size_t l; + const char* b = luaL_checklstring(L, 1, &l); + + struct address* a = luaW_newaddress(L); + + switch (b[0]) { + case 0x04: + if (l < 1+4+2) { + luaL_error(L, "bad address"); + } + l = 1+4+2; + + a->sa_family = a->in4.sin_family = AF_INET; + memcpy(&a->in4.sin_addr, b+1, 4); + memcpy(&a->in4.sin_port, b+1+4, 2); + break; + + case 0x06: + if (l < 1+16+2) { + luaL_error(L, "bad address"); + } + l = 1+16+2; + + a->sa_family = a->in6.sin6_family = AF_INET6; + memcpy(&a->in6.sin6_addr, b+1, 16); + memcpy(&a->in6.sin6_port, b+1+16, 2); + break; + + default: + luaL_error(L, "bad packed address: %d", (int)b[0]); + }; + + lua_pushinteger(L, l); + return 2; +} + +static int _orchid(lua_State* L) { + size_t cid_sz; + const char* cid = luaL_checklstring(L, 1, &cid_sz); + + size_t l; + const char* m = luaL_checklstring(L, 2, &l); + + uint16_t port = luaW_checkport(L, 3); + + struct address* a = luaW_newaddress(L); + orchid(a, cid, cid_sz, m, l, port); + + return 1; +} + +static int _set_address_port(lua_State* L) { + struct address* a = luaL_checkudata(L, 1, "address"); + uint16_t port = luaW_checkport(L, 2); + + struct address* an = luaW_newaddress(L); + + memcpy(an, a, sizeof(struct address)); + + switch (a->sa_family) { + case AF_INET: an->in4.sin_port = htons(port); break; + case AF_INET6: an->in6.sin6_port = htons(port); break; + }; + + return 1; +} + +static int _socket_udp(lua_State* L) { + struct address* a = luaL_checkudata(L, 1, "address"); + int s = socket_udp(a); + if (s == -1) { + luaL_error(L, "socket error: %s", strerror(errno)); + } + luaW_pushfd(L, s); + return 1; +} + +static int _socket_raw_udp(lua_State* L) { + const char* proto = luaL_checkstring(L, 1); + + int hdrincl = 0; + sa_family_t sa_family; + if (strcmp(proto, "ip4") == 0) { + sa_family = AF_INET; + } else if (strcmp(proto, "ip6") == 0) { + sa_family = AF_INET6; + } else if (strcmp(proto, "ip4_hdrincl") == 0) { + sa_family = AF_INET; + hdrincl = 1; + } else { + return luaL_error(L, "unknown protocol: %s", proto); + } + + int s = socket_raw_udp(sa_family, hdrincl); + if (s == -1) { + luaL_error(L, "socket error: %s", strerror(errno)); + } + + luaW_pushfd(L, s); + return 1; +} + + +static int _select(lua_State* L) { + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, 2, LUA_TTABLE); + luaL_checktype(L, 3, LUA_TTABLE); + + struct timeval* pval = NULL; + + if (lua_gettop(L) == 4 && lua_type(L, 4) != LUA_TNIL) { + lua_Number timeout = luaL_checknumber(L, 4); + pval = alloca(sizeof(struct timeval)); + pval->tv_sec = (time_t)timeout; + pval->tv_usec = (timeout-pval->tv_sec)*1000000; + } + + fd_set fds[3]; + + lua_Integer nfds = 0; + + for (int i=0; i<3; ++i) { + FD_ZERO(&fds[i]); + int l = luaL_len(L, i+1); + for (lua_Integer j=1; j<=l; ++j) { + lua_geti(L, i+1, j); + int ok; + lua_Integer s = lua_tointegerx(L, -1, &ok); + if (!ok) { + luaL_error(L, "bad file descriptor type (integer expected, got %s)", + lua_typename(L, lua_type(L, -1)) + ); + } + lua_pop(L, 1); + + FD_SET(s, &fds[i]); + + if (s > nfds) { + nfds = s; + } + } + } + + if (select(nfds+1, &fds[0], &fds[1], &fds[2], pval) == -1) { + luaL_error(L, "select(): %s", strerror(errno)); + } + + for (int i=0; i<3; ++i) { + lua_newtable(L); + + int l = luaL_len(L, i+1); + for (lua_Integer j=1; j<=l; ++j) { + lua_geti(L, i+1, j); + lua_Integer s = lua_tointegerx(L, -1, NULL); + lua_pop(L, 1); + + if (FD_ISSET(s, &fds[i])) { + lua_pushboolean(L, 1); + lua_seti(L, -2, s); + } + } + } + + return 3; +} + +static int _send(lua_State* L) { + int flags = 0; + size_t l; + int fd = luaW_getfd(L, 1); + const char* m = luaL_checklstring(L, 2, &l); + if (lua_type(L, 3) == LUA_TNUMBER) { + flags = luaL_checkinteger(L, 3); + } + + int r = send(fd, m, l, flags); + + lua_pushinteger(L, r); + return 1; +} + +static int _sendto(lua_State* L) { + int flags = 0; + struct address* a = NULL; + size_t l; + int fd = luaW_getfd(L, 1); + const char* m = luaL_checklstring(L, 2, &l); + if (lua_type(L, 3) == LUA_TNUMBER) { + flags = luaL_checkinteger(L, 3); + a = luaL_checkudata(L, 4, "address"); + } else { + a = luaL_checkudata(L, 3, "address"); + } + + int r = sendto(fd, m, l, flags, &a->in, address_len(a)); + + lua_pushinteger(L, r); + return 1; +} + +static int _sendto_raw_udp(lua_State* L) { + int fd4 = luaW_getfd(L, 1); + int fd6 = luaW_getfd(L, 2); + size_t l; + const char* m = luaL_checklstring(L, 3, &l); + uint16_t src_port = luaW_checkport(L, 4); + struct address* dst_addr = luaL_checkudata(L, 5, "address"); + + if (l >= 0x10000 - UDP_HDRLEN) { + luaL_error(L, "packet too long"); + } + + // prepare packet + void* pkt = malloc(UDP_HDRLEN+l); + +#define UDPHDR ((struct udphdr*)(pkt+0)) + UDPHDR->uh_sport = htons(src_port); + UDPHDR->uh_dport = htons(address_port(dst_addr)); + UDPHDR->uh_ulen = htons(UDP_HDRLEN+l); + UDPHDR->uh_sum = 0x0000; +#undef UDPHDR + + memcpy(pkt+UDP_HDRLEN, m, l); + + int fd; + switch (dst_addr->sa_family) { + case AF_INET: fd = fd4; break; + case AF_INET6: fd = fd6; break; + default: return luaL_error(L, "bad address family"); + }; + +#if 0 + printf("sendto(fd=%d, #m=%d, #addr=%d)\n", fd, (int)(8+l), address_len(dst_addr)); + + printf(" \t"); + for (int i=0; i<16; ++i) { + printf("%.2x ", i); + } + for (int i=0; i<(int)(8+l); ++i) { + if (i%16 == 0) { + printf("\n%.4x\t", i); + } + + printf("%.2x ", (int)((uint8_t*)pkt)[i]); + } + printf("\n"); +#endif + + const int flags = 0; + ssize_t r = sendto(fd, pkt, 8+l, flags, &dst_addr->in, address_len(dst_addr)); + + free(pkt); + + if (r < 0) { + lua_pushboolean(L, 0); + lua_pushstring(L, strerror(errno)); + return 2; + } else if ((size_t)r != 8+l) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "truncated send: %dB != %dB", r, 8+l); + return 2; + } else { + lua_pushboolean(L, 1); + return 1; + } +} + +static int _recv(lua_State* L) { + int fd = luaW_getfd(L, 1); + int l = luaL_checkinteger(L, 2); + int flags = 0; + if (lua_gettop(L) == 3) { + flags = luaL_checkinteger(L, 3); + } + + luaL_Buffer b; + char* m = luaL_buffinitsize(L, &b, l); + ssize_t r = recv(fd, m, l, flags); + + if (r < 0 && ( + errno == EAGAIN || + errno == ECONNRESET)) { + return 0; + } + + if (r < 0) { + luaL_error(L, "recv() failed: %s (%d)", strerror(errno), errno); + } + + luaL_pushresultsize(&b, r); + return 1; +} + +static int _recvfrom(lua_State* L) { + int fd = luaW_getfd(L, 1); + int l = luaL_checkinteger(L, 2); + int flags = 0; + if (lua_gettop(L) == 3) { + flags = luaL_checkinteger(L, 3); + } + + luaL_Buffer b; + char* m = luaL_buffinitsize(L, &b, l); + struct address* a = luaW_newaddress(L); + socklen_t al = sizeof(a->in6); + ssize_t r = recvfrom(fd, m, l, flags, &a->in, &al); + + if (r < 0 && errno == EAGAIN) { + return 0; + } + + if (r < 0) { + luaL_error(L, "recvfrom() failed: %s", strerror(errno)); + } + + if (r >= 0) { + switch (al) { + case sizeof(struct sockaddr_in): a->sa_family = AF_INET; break; + case sizeof(struct sockaddr_in6): a->sa_family = AF_INET6; break; + default: luaL_error(L, "bad address size: %d", al); + }; + } + + luaL_pushresultsize(&b, r); + lua_insert(L, -2); + return 2; +} + +static int _sendto_raw_wg(lua_State* L) { + int fd4 = luaW_getfd(L, 1); + size_t l; + const uint8_t* m = (const uint8_t*)luaL_checklstring(L, 2, &l); + struct address* src_addr = luaL_checkudata(L, 3, "address"); + uint16_t wg_port = luaW_checkport(L, 4); + + // fd4 should be opened with wh.socket_raw_udp("ip4_hdrincl") + + // src_addr must be IP4 + if (src_addr->sa_family != AF_INET) { + luaL_error(L, "bad address"); + return 0; + } + +#if 0 // XXX + // src_addr must be from 127.0.0.0/8 + if ((ntohl(src_addr->in4.sin_addr.s_addr) & 0xff000000) != 0x7f000000) { + luaL_error(L, "address is not loopback"); + return 0; + } +#endif + + // packet must be wireguard + if (m[0] > 4 || m[1] != 0 || m[2] != 0 || m[3] != 0) { + luaL_error(L, "not a wireguard packet."); + return 0; + } + + // 0x10000 - 0x08 (UDP) - 0x14 (IP) + if (l >= 0x10000 - IP4_HDRLEN - UDP_HDRLEN) { + luaL_error(L, "packet too long"); + } + + // prepare packet + void* pkt = malloc(IP4_HDRLEN+UDP_HDRLEN+l); + + struct sockaddr_in dst_addr; + dst_addr.sin_family = AF_INET; + dst_addr.sin_addr.s_addr = htonl(0x7f000001); + dst_addr.sin_port = htons(wg_port); + +#define IPHDR ((struct ip*)(pkt+0)) + memset(IPHDR, 0, sizeof(struct ip)); + IPHDR->ip_hl = IP4_HDRLEN/sizeof(uint32_t); + IPHDR->ip_v = 4; + IPHDR->ip_tos = 0; + IPHDR->ip_len = htons(IP4_HDRLEN+UDP_HDRLEN); + IPHDR->ip_id = 0; + IPHDR->ip_off = 0; + IPHDR->ip_ttl = 255; + IPHDR->ip_p = IPPROTO_UDP; + memcpy(&IPHDR->ip_src, &src_addr->in4.sin_addr, 4); + memcpy(&IPHDR->ip_dst, &dst_addr.sin_addr, 4); + IPHDR->ip_sum = 0; + //IPHDR->ip_sum = checksum_ip(IPHDR, IP4_HDRLEN); +#undef IPHDR + +#define UDPHDR ((struct udphdr*)(pkt+IP4_HDRLEN)) + UDPHDR->uh_sport = htons(address_port(src_addr)); + UDPHDR->uh_dport = dst_addr.sin_port; + UDPHDR->uh_ulen = htons(UDP_HDRLEN+l); + UDPHDR->uh_sum = 0x0000; +#undef UDPHDR + + memcpy(pkt+IP4_HDRLEN+UDP_HDRLEN, m, l); + +#if 0 + printf("sendto_raw_wg(fd=%d, #m=%d, #addr=%d)\n", fd, (int)(8+l), address_len(dst_addr)); + + printf(" \t"); + for (int i=0; i<16; ++i) { + printf("%.2x ", i); + } + for (int i=0; i<(int)(8+l); ++i) { + if (i%16 == 0) { + printf("\n%.4x\t", i); + } + + printf("%.2x ", (int)((uint8_t*)pkt)[i]); + } + printf("\n"); +#endif + + const int flags = 0; + ssize_t r = sendto(fd4, pkt, IP4_HDRLEN+UDP_HDRLEN+l, flags, (struct sockaddr*)&dst_addr, sizeof(dst_addr)); + + free(pkt); + + if (r < 0) { + lua_pushboolean(L, 0); + lua_pushstring(L, strerror(errno)); + return 2; + } else if ((size_t)r != 0x1c+l) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "truncated send: %dB != %dB", r, 8+l); + return 2; + } else { + lua_pushboolean(L, 1); + return 1; + } +} + +static int luaW_checkb64variant(lua_State* L, int idx) { + int variant = sodium_base64_VARIANT_URLSAFE_NO_PADDING; + + int t = lua_type(L, idx); + if (t != -1 && t != LUA_TNIL) { + const char* s = luaL_checkstring(L, idx); + if (strcmp(s, "wh") == 0) { + variant = sodium_base64_VARIANT_URLSAFE_NO_PADDING; + } else if (strcmp(s, "wg") == 0) { + variant = sodium_base64_VARIANT_ORIGINAL; + } + } + + return variant; +} + +static int _tob64(lua_State* L) { + size_t l; + const char* m = luaL_checklstring(L, 1, &l); + int variant = luaW_checkb64variant(L, 2); + + size_t b64l = sodium_base64_ENCODED_LEN(l, variant); + luaL_Buffer b; + char* b64 = luaL_buffinitsize(L, &b, b64l); + sodium_bin2base64(b64, b64l, (const void*)m, l, variant); + + luaL_pushresultsize(&b, strlen(b64)); + return 1; +} + +static int _fromb64(lua_State* L) { + size_t b64l; + const char* b64 = luaL_checklstring(L, 1, &b64l); + int variant = luaW_checkb64variant(L, 2); + + luaL_Buffer b; + void* bin = luaL_buffinitsize(L, &b, b64l); + + size_t l = b64l; + if (sodium_base642bin(bin, l, b64, b64l, NULL, &l, NULL, variant) != 0) { + luaL_error(L, "invalid base64: len:%d", b64l); + } + + luaL_pushresultsize(&b, l); + return 1; +} + +static int _randombytes(lua_State* L) { + int sz = luaL_checkinteger(L, 1); + if (sz < 0) { + luaL_error(L, "arg #1 is not positive"); + } + + luaL_Buffer b; + void* buf = luaL_buffinitsize(L, &b, sz); + randombytes_buf(buf, sz); + luaL_pushresultsize(&b, sz); + + return 1; +} + +static int _packet(lua_State* L) { + size_t l; + void* src_wg_sk = luaW_checksecret(L, 1, crypto_scalarmult_curve25519_BYTES); + + uint8_t src_wg_pk[crypto_scalarmult_curve25519_BYTES]; + if (crypto_scalarmult_base(src_wg_pk, src_wg_sk)) { + luaL_error(L, "bad private key"); + } + + const void* dst_wg_pk = luaL_checklstring(L, 2, &l); + if (l != crypto_sign_ed25519_PUBLICKEYBYTES) { + luaL_error(L, "bad public key"); + } + + luaL_checktype(L, 3, LUA_TBOOLEAN); + uint64_t is_nated = lua_toboolean(L, 3) ? 1 : 0; + + const void* m = luaL_checklstring(L, 4, &l); + uint64_t flags_time_b = 0; + flags_time_b |= (htobe64(now_seconds()) & packet_flags_TIMEMASK) << packet_flags_TIMESHIFT; + flags_time_b |= (is_nated & packet_flags_DIRECTMASK) << packet_flags_DIRECTSHIFT; + + size_t sz = packet_size(l); + luaL_Buffer b; + void* pkt = luaL_buffinitsize(L, &b, sz); + + memcpy(packet_hdr(pkt), wh_pkt_hdr, sizeof(wh_pkt_hdr)); + memcpy(packet_src(pkt), src_wg_pk, crypto_scalarmult_curve25519_BYTES); + memcpy(packet_flags_time(pkt), &flags_time_b, sizeof(flags_time_b)); + memcpy(packet_body(pkt), m, l); + + if (auth_packet(pkt, l, src_wg_sk, dst_wg_pk)) { + luaL_error(L, "auth failed"); + } + + luaL_pushresultsize(&b, sz); + + return 1; +} + +static int _open_packet(lua_State* L) { + void* dst_wg_sk = luaW_checksecret(L, 1, crypto_scalarmult_curve25519_BYTES); + size_t sz; + const void* pkt = luaL_checklstring(L, 2, &sz); + + if (verify_packet(pkt, sz, dst_wg_sk)) { + return 0; + } + + uint64_t flags_time_s; + memcpy(&flags_time_s, packet_flags_time(pkt), sizeof(flags_time_s)); + uint64_t time_s = be64toh((flags_time_s >> packet_flags_TIMESHIFT) & packet_flags_TIMEMASK); + uint64_t is_nated = (flags_time_s >> packet_flags_DIRECTSHIFT) & packet_flags_DIRECTMASK; + + lua_pushlstring(L, packet_src(pkt), crypto_scalarmult_curve25519_BYTES); + lua_pushboolean(L, is_nated); + lua_pushinteger(L, time_s); + lua_pushlstring(L, packet_body(pkt), sz-packet_size(0)); + + return 4; +} + +static int _now(lua_State* L) { + struct timespec now; + clock_gettime(CLOCK_REALTIME, &now); + lua_Number n = now.tv_sec + (double)now.tv_nsec / 1.0e9; + lua_pushnumber(L, n); + return 1; +} + +static int _bid(lua_State* L) { + size_t sz1, sz2; + const char* s1 = luaL_checklstring(L, 1, &sz1); + const char* s2 = NULL; + + if (lua_gettop(L) == 2) { + s2 = luaL_checklstring(L, 2, &sz2); + } + + if (s2 && sz1 != sz2) { + luaL_error(L, "not same length."); + } + +#define sz sz1 + assert(sz % sizeof(uint32_t) == 0); + unsigned int i; + unsigned int r = 0; + for(i=0; ifds[1], "\x2a", 1) < 0) { + luaL_error(L, "write() failed: %s", strerror(errno)); + } + + return 0; +} + +static int _clear_pipe_event(lua_State* L) { + struct pipe_event* pe = luaW_checkptr(L, 1, "pipe_event"); + + char buf[128]; + if (read(pe->fds[0], buf, sizeof(buf)) < 0) { + luaL_error(L, "read() failed: %s", strerror(errno)); + } + + return 0; +} + +static void _pipe_event_delete(void* ud) { + struct pipe_event* pe = ud; + + close(pe->fds[0]); + close(pe->fds[1]); + free(pe); +} + +static int _close_pipe_event(lua_State* L) { + struct pipe_event* pe = luaW_ownptr(L, 1, "pipe_event"); + + _pipe_event_delete(pe); + return 0; +} + +static int _pipe_event_fd(lua_State* L) { + struct pipe_event* pe = luaW_checkptr(L, 1, "pipe_event"); + lua_pushinteger(L, pe->fds[0]); + return 0; +} + +static int _pipe_event(lua_State* L) { + int fds[2]; + if (pipe(fds)) { + luaL_error(L, "pipe() failed: %s", strerror(errno)); + } + + struct pipe_event* pe = malloc(sizeof(struct pipe_event)); + pe->fds[0] = fds[0]; + pe->fds[1] = fds[1]; + luaW_pushptr(L, "pipe_event", pe); + + return 1; +} + +static int _todate(lua_State* L) { + lua_Number nf = luaL_checknumber(L, 1); + time_t n = nf; // cast + + luaL_Buffer b; + size_t sz = sizeof"1991-08-25T20:57:08Z"; + char* buf = luaL_buffinitsize(L, &b, sz); + sz = strftime(buf, sz, "%FT%TZ", gmtime(&n)); + luaL_pushresultsize(&b, sz); + return 1; +} + +static int _sniff(lua_State* L) { + const char* interface = luaL_checkstring(L, 1); + const char* direction_s = luaL_checkstring(L, 2); + const char* proto_s = luaL_checkstring(L, 3); + const char* expr = lua_tostring(L, 4); + + pcap_direction_t direction; + if (strcmp(direction_s, "in") == 0) { + direction = PCAP_D_IN; + } else if (strcmp(direction_s, "out") == 0) { + direction = PCAP_D_OUT; + } else if (strcmp(direction_s, "inout") == 0) { + direction = PCAP_D_INOUT; + } else { + luaL_error(L, "bad direction"); + return 0; + } + + enum sniff_proto proto; + if (strcmp(proto_s, "wg") == 0) { + proto = SNIFF_PROTO_WG; + } else if (strcmp(proto_s, "wh") == 0) { + proto = SNIFF_PROTO_WH; + } else { + luaL_error(L, "unknown proto"); + return 0; + } + + pcap_t* h = sniff(interface, direction, proto, expr); + + if (!h) { + luaL_error(L, "pcap init error"); + } + + luaW_pushptr(L, "pcap", h); + return 1; +} + +static int _get_pcap(lua_State* L) { + pcap_t* h = luaW_checkptr(L, 1, "pcap"); + + int fd = pcap_get_selectable_fd(h); + if (fd == PCAP_ERROR) { + luaL_error(L, "pcap_get_selectable() failed: %s", pcap_geterr(h)); + } + + luaW_pushfd(L, fd); + + struct timeval* tv = pcap_get_required_select_timeout(h); + + if (tv) { + lua_Number timeout = tv->tv_sec + (lua_Number)tv->tv_usec / 1e6; + lua_pushnumber(L, timeout); + } else { + lua_pushnil(L); + } + + return 2; +} + +static int _pcap_next_udp(lua_State* L) { + pcap_t* h = luaW_checkptr(L, 1, "pcap"); + + struct pcap_pkthdr* hdr = NULL; + const u_char* data = NULL; + int r = pcap_next_ex(h, &hdr, &data); + + if (r == PCAP_ERROR) { + return luaL_error(L, "pcap_next_ex() failed: %s", pcap_geterr(h)); + } + + assert (hdr); + + if (!data) { + return 0; + } + + if (hdr->caplen != hdr->len) { + return 0; + } + + const size_t pcap_hdr_sz = 16; + if (hdr->len < pcap_hdr_sz) { + return 0; + } + + uint16_t proto; + memcpy(&proto, data+14, sizeof(proto)); + proto = ntohs(proto); + + if (proto != ETHERTYPE_IP) { + return 0; + } + + const void* m; + const void* d = data + pcap_hdr_sz; + size_t l = hdr->len - pcap_hdr_sz; + struct address* src = luaW_newaddress(L); + struct address* dst = luaW_newaddress(L); + if (ip4_to_udp(d, &m, &l, src, dst) == -1) { + printf("FAILED! ip4_to_udp()\n"); + return 0; + } + + lua_pushlstring(L, m, l); + return 3; +} + +static int _close_pcap(lua_State* L) { + pcap_close(luaW_ownptr(L, 1, "pcap")); + return 0; +} + +static inline int _parse_wgkey(void* bin, size_t* pbinl, const char* b64) { + if (strcmp(b64, "(none)") == 0) { + *pbinl = 0; + return 0; + } else { + return sodium_base642bin( + (void*)bin, *pbinl, + b64, strlen(b64), + NULL, pbinl, + NULL, sodium_base64_VARIANT_ORIGINAL + ); + } +} + +static int _burnsk(lua_State* L) { + void* sk = luaW_ownsecret(L, 1, crypto_scalarmult_curve25519_BYTES); + luaW_freesecret(sk); + + return 0; +} + +static int _revealsk(lua_State* L) { + void* sk = luaW_ownsecret(L, 1, crypto_scalarmult_curve25519_BYTES); + lua_pushlstring(L, sk, crypto_scalarmult_curve25519_BYTES); + luaW_freesecret(sk); + + return 1; +} + +static const char* _confpath(void) { + const char* confpath = getenv(WH_ENV_CONFPATH); + if (!confpath) { + confpath = WH_DEFAULT_CONFPATH; + } + + return confpath; +} + +static void _pushconfpath(lua_State* L, const char* name) { + const char* confpath = _confpath(); + + luaL_Buffer b; + luaL_buffinit(L, &b); + luaL_addstring(&b, confpath); + size_t confpath_len = strlen(confpath); + if (confpath_len == 0 || confpath[confpath_len-1] != '/') { + luaL_addchar(&b, '/'); + } + if (name) { + luaL_addstring(&b, name); + } + luaL_pushresult(&b); + _expanduser(L); +} + +static int _removeconf(lua_State* L) { + const char* name = luaL_checkstring(L, 1); + + _pushconfpath(L, name); + const char* conf_filepath = lua_tostring(L, -1); + + if (unlink(conf_filepath) < 0 && errno != ENOENT) { + luaL_error(L, "cannot remove conf '%s': %s", conf_filepath, strerror(errno)); + } + + return 0; +} + +static int _writeconf(lua_State* L) { + const char* name = luaL_checkstring(L, 1); + size_t conf_sz; + + if (lua_gettop(L) == 2 && lua_type(L, 2) == LUA_TNIL) { + return _removeconf(L); + } + + const char* conf = luaL_checklstring(L, 2, &conf_sz); + const char* confpath = _confpath(); + + { + luaL_Buffer b; + luaL_buffinit(L, &b); + luaL_addstring(&b, "mkdir -p \""); + luaL_addstring(&b, confpath); + luaL_addstring(&b, "\""); + luaL_pushresult(&b); + const char* cmd = lua_tostring(L, -1); + if (system(cmd) < 0) { + luaL_error(L, "command '%s' failed: %s", cmd, strerror(errno)); + } + lua_pop(L, 1); + } + + { + _pushconfpath(L, name); + const char* conf_filepath = lua_tostring(L, -1); + + FILE* fh = fopen(conf_filepath, "wb"); + if (!fh) { + luaL_error(L, "cannot open conf '%s': %s", conf_filepath, strerror(errno)); + } + if (fwrite(conf, conf_sz, 1, fh) != 1) { + luaL_error(L, "cannot write conf '%s': %s", conf_filepath, strerror(errno)); + } + fclose(fh), fh=NULL; + } + + return 0; +} + +static int _listconf(lua_State* L) { + DIR* dirp; + + _pushconfpath(L, NULL); + const char* confpath = lua_tostring(L, -1); + if ((dirp = opendir(confpath)) == NULL) { + luaL_error(L, "opendir() failed: %s", strerror(errno)); + } + + lua_newtable(L); + struct dirent* dp; + while ((dp = readdir(dirp))) { + if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) { + continue; + } + + lua_pushstring(L, dp->d_name); + lua_seti(L, -2, luaL_len(L, -2)+1); + } + + closedir(dirp), dirp = NULL; + return 1; + +} + +static void _pushfile(lua_State* L, const char* filepath) { + FILE* fh = fopen(filepath, "rb"); + + if (!fh && errno == ENOENT) { + lua_pushnil(L); + return; + } + + else if (!fh) { + luaL_error(L, "cannot open file '%s': %s", filepath, strerror(errno)); + } + + luaL_Buffer b; + long l = -1; + int did_read = 0; + if ( + fseek(fh, 0, SEEK_END) >= 0 && + (l = ftell(fh)) >= 0 && + fseek(fh, 0, SEEK_SET) >= 0 + ) { + char* p = luaL_buffinitsize(L, &b, l); + did_read = fread(p, 1, l, fh) == (size_t)l; + } + + fclose(fh); + + if (!did_read) { + luaL_error(L, "cannot read file '%s': %s", filepath, strerror(errno)); + } + + assert(l >= 0); + luaL_pushresultsize(&b, l); +} + + +static int _readconf(lua_State* L) { + const char* name = luaL_checkstring(L, 1); + + _pushconfpath(L, name); + const char* conf_filepath = lua_tostring(L, -1); + + _pushfile(L, conf_filepath); + return 1; +} + +static int _ipc_prepare(lua_State* L) { + if (ipc_prepare() < 0) { + luaL_error(L, "prepare ipc failed: %s", strerror(errno)); + } + + return 0; +} + +static int _ipc_connect(lua_State* L) { + const char* interface = luaL_checkstring(L, 1); + + int sock; + if ((sock = ipc_connect(interface)) < 0) { + switch (errno) { + case ENOENT: return 0; + default: luaL_error(L, "connect ipc failed(): %s", strerror(errno)); + }; + } + + luaW_pushfd(L, sock); + return 1; +} + +static int _ipc_unlink(lua_State* L) { + const char* interface = lua_tostring(L, lua_upvalueindex(1)); + assert(interface); + + lua_pushboolean(L, ipc_unlink(interface) >= 0); + return 1; +} + +static int _ipc_bind(lua_State* L) { + const char* interface = luaL_checkstring(L, 1); + luaL_checktype(L, 2, LUA_TBOOLEAN); + int force = lua_toboolean(L, 2); + + int sock; + if ((sock = ipc_bind(interface, force)) < 0) { + luaL_error(L, "connect ipc failed(): %s", strerror(errno)); + } + + luaW_pushfd(L, sock); + lua_pushstring(L, interface); + lua_pushcclosure(L, _ipc_unlink, 1); + return 2; +} + +static int _ipc_accept(lua_State* L) { + int sock = luaW_getfd(L, 1); + + int new_sock; + if ((new_sock = ipc_accept(sock)) < 0) { + luaL_error(L, "connect ipc failed(): %s", strerror(errno)); + } + + luaW_pushfd(L, new_sock); + return 1; +} + +static int _ipc_list_cb(const char* name, void* ud) { + lua_State* L = ud; + + lua_pushstring(L, name); + lua_seti(L, -2, luaL_len(L, -2)+1); + + return 0; +} + +static int _ipc_list(lua_State* L) { + lua_newtable(L); + if (ipc_list(_ipc_list_cb, L)) { + luaL_error(L, "IPC list failed: %s", strerror(errno)); + } + return 1; +} + +static int _syslog_print(lua_State* L) { + int n = lua_gettop(L); /* number of arguments */ + int i; + lua_getglobal(L, "tostring"); + + luaL_Buffer b; + luaL_buffinit(L, &b); + + for (i=1; i<=n; i++) { + const char *s; + size_t l; + lua_pushvalue(L, -1); /* function to be called */ + lua_pushvalue(L, i); /* value to print */ + lua_call(L, 1, 1); + s = lua_tolstring(L, -1, &l); /* get result */ + if (s == NULL) + return luaL_error(L, "'tostring' must return a string to 'print'"); + if (i>1) luaL_addstring(&b, "\t"); + luaL_addlstring(&b, s, l); + lua_pop(L, 1); /* pop result */ + } + + luaL_pushresult(&b); + const char* l = lua_tostring(L, -1); + syslog(LOG_NOTICE, l); + + return 0; +} + +static int _daemon(lua_State* L) { + pid_t pid = fork(); + + if (pid < 0) { + luaL_error(L, "fork() failed: %s", strerror(errno)); + } + + if (pid > 0) { + exit(0); + } + + if (setsid() < 0) { + fprintf(stderr, "setsid() failed: %s\n", strerror(errno)); + exit(EXIT_FAILURE); + } + + //TODO: Implement a working signal handler */ + signal(SIGCHLD, SIG_IGN); + signal(SIGHUP, SIG_IGN); + + pid = fork(); + + if (pid < 0) { + fprintf(stderr, "second fork() failed: %s\n", strerror(errno)); + exit(EXIT_FAILURE); + } + + if (pid > 0) { + exit(0); + } + + /* Set new file permissions */ + umask(0); + + /* Change the working directory to the root directory */ + /* or another appropriated directory */ + //chdir("/"); + + /* Close all open file descriptors */ + for (int x = sysconf(_SC_OPEN_MAX); x>=0; x--) + { + close (x); + } + + lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS); + lua_pushstring(L, "DAEMON"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + + /* Open the log file */ + openlog ("wirehub", LOG_PID, LOG_DAEMON); + + lua_pushstring(L, "_print"); + lua_pushcfunction(L, _syslog_print); + lua_rawset(L, -3); + + return 0; +} + +static int _color_mode(lua_State* L) { + (void)L; + + static int mode = -1; + const char *var; + + lua_getglobal(L, "DAEMON"); + int is_daemon = lua_type(L, -1) == LUA_TBOOLEAN && lua_toboolean(L, -1); + lua_pop(L, 1); + + if (is_daemon) { + mode = 0; + } + + else { + var = getenv("WH_COLOR_MODE"); + + if (var && !strcmp(var, "always")) { + mode = 1; + } else if (var && !strcmp(var, "never")) { + mode = 0; + } else { + mode = isatty(1) ? 1 : 0; + } + } + + assert(mode != -1); + lua_pushboolean(L, mode); + return 1; +} + +static const luaL_Reg funcs[] = { + {"address", _address}, + {"bid", _bid}, + {"burnsk", _burnsk}, + {"clear_pipe_event", _clear_pipe_event}, + {"close", _close}, + {"close_pcap", _close_pcap}, + {"close_pipe_event", _close_pipe_event}, + {"color_mode", _color_mode}, + {"daemon", _daemon}, + {"fromb64", _fromb64}, + {"genkey", _genkey}, + {"get_pcap", _get_pcap}, + {"ipc_accept", _ipc_accept}, + {"ipc_bind", _ipc_bind}, + {"ipc_connect", _ipc_connect}, + {"ipc_prepare", _ipc_prepare}, + {"ipc_list", _ipc_list}, + {"listconf", _listconf}, + {"netdevs", _netdevs}, + {"now", _now}, + {"open_packet", _open_packet}, + {"orchid", _orchid}, + {"packet", _packet}, + {"pcap_next_udp", _pcap_next_udp}, + {"pipe_event", _pipe_event}, + {"pipe_event_fd", _pipe_event_fd}, + {"publickey", _publickey}, + {"randombytes", _randombytes}, + {"readconf", _readconf}, + {"readsk", _readsk}, + {"recv", _recv}, + {"recvfrom", _recvfrom}, + {"revealsk", _revealsk}, + {"sandbox", _sandbox}, + {"select", _select}, + {"send", _send}, + {"sendto", _sendto}, + {"sendto_raw_udp", _sendto_raw_udp}, + {"sendto_raw_wg", _sendto_raw_wg}, + {"set_address_port", _set_address_port}, + {"set_pipe_event", _set_pipe_event}, + {"sniff", _sniff}, + {"socket_raw_udp", _socket_raw_udp}, + {"socket_udp", _socket_udp}, + {"tob64", _tob64}, + {"todate", _todate}, + {"unpack_address", _unpack_address}, + {"version", luaW_version}, + {"wgkey", _wgkey}, + {"workbit", _workbit}, + {"writeconf", _writeconf}, + {"xor", _xor}, + {NULL, NULL}, +}; + +static void _pcap_close(void* ud) { + fprintf(stderr, "warning: pcap handler %p not closed.\n", ud); + pcap_close((pcap_t*)ud); +} + +LUAMOD_API int luaopen_whcore(lua_State* L) { + luaL_checkversion(L); + + // initialize sodium + if (sodium_init() == -1) { + luaL_error(L, "sodium init failed."); + } + + luaL_newlib(L, funcs); + +#if WH_ENABLE_MINIUPNPC + assert(luaopen_whupnp(L) == 1); + lua_setfield(L, -2, "upnp"); +#endif + + assert(luaopen_worker(L) == 1); + lua_setfield(L, -2, "worker"); + + assert(luaopen_tun(L) == 1); + lua_setfield(L, -2, "tun"); + + assert(luaopen_wg(L) == 1); + lua_setfield(L, -2, "wg"); + + lua_newtable(L); + lua_setfield(L, LUA_REGISTRYINDEX, "wh_fds"); + + luaW_declptr(L, "secret", sodium_free); + luaW_declptr(L, "pcap", _pcap_close); + luaW_declptr(L, "pipe_event", _pipe_event_delete); + + return 1; +} + diff --git a/src/core/whupnplib.c b/src/core/whupnplib.c new file mode 100644 index 0000000..4e2e87f --- /dev/null +++ b/src/core/whupnplib.c @@ -0,0 +1,253 @@ +#include "luawh.h" + +#if WH_ENABLE_MINIUPNPC + +#include +#include +#include + +static const char* mt = "device_igd"; + +struct device_igd { + struct UPNPUrls urls; + struct IGDdatas data; + int type; +}; + +static int _discover_igd(lua_State* L) { + int delay_ms = luaL_checknumber(L, 1) * 1000; + int upnp_err = 0; + struct UPNPDev* devices = upnpDiscover(delay_ms, NULL, NULL, UPNP_LOCAL_PORT_ANY, 0, 2, &upnp_err); + + if (upnp_err != 0) { + luaL_error(L, "upnpDiscover failed(): %d", upnp_err); + } + + if (!devices) { + return 0; + } + + struct device_igd* d = lua_newuserdata(L, sizeof(struct device_igd)); + luaL_newmetatable(L, mt); + lua_setmetatable(L, -2); + + char lanaddr[64] = ""; + int ret = UPNP_GetValidIGD(devices, &d->urls, &d->data, lanaddr, sizeof(lanaddr)); + + if (ret < 0) { + luaL_error(L, "UPNP_GetValidIGD failed(): %d", ret); + } + + if (ret == 0) { + return 0; + } + + d->type = ret; + + if (lanaddr[0]) { + lua_pushstring(L, lanaddr); + } else { + lua_pushnil(L); + } + + lua_pushstring(L, d->urls.controlURL); + + return 3; +} + +static int _external_ip(lua_State* L) { + struct device_igd* d = luaL_checkudata(L, 1, mt); + char ip[40]; + + int r = UPNP_GetExternalIPAddress(d->urls.controlURL, d->data.first.servicetype, ip); + + if (r != UPNPCOMMAND_SUCCESS) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "GetExternalIPAddress() failed: %d (%s)", + r, strupnperror(r) + ); + return 2; + } + + lua_pushboolean(L, 1); + lua_pushstring(L, ip); + return 2; +} + +static int _list_redirects(lua_State* L) { + struct device_igd* d = luaL_checkudata(L, 1, mt); + + lua_newtable(L); + + for (int i = 0; ; ++i) { + char index[16]; + snprintf(index, sizeof(index), "%d", i); + + char int_client[40] = ""; + char int_port[16] = ""; + char ext_port[16] = ""; + char protocol[4] = ""; + char desc[80] = ""; + char enabled[16] = ""; + char host[64] = ""; + char duration[16] = ""; + + + int ret = UPNP_GetGenericPortMappingEntry( + d->urls.controlURL, d->data.first.servicetype, index, ext_port, + int_client, int_port, protocol, desc, enabled, host, duration + ); + + if (ret != 0) { + break; + } + + lua_newtable(L); + +#define lua_pushstringnumber(L, s) \ + do { \ + if (lua_stringtonumber(L, s) == 0) { \ + lua_pushnil(L); \ + } \ + } while(0) + + for (char* c=protocol; *c; ++c) { + *c = tolower(*c); + } + + lua_pushstring(L, protocol); + lua_setfield(L, -2, "protocol"); + + lua_pushstringnumber(L, ext_port); + lua_setfield(L, -2, "eport"); + + lua_pushstring(L, int_client); + lua_setfield(L, -2, "iaddr"); + + lua_pushstringnumber(L, int_port); + lua_setfield(L, -2, "iport"); + + lua_pushstring(L, desc); + lua_setfield(L, -2, "desc"); + + lua_pushstring(L, host); + lua_setfield(L, -2, "host"); + + lua_pushstringnumber(L, duration); + lua_setfield(L, -2, "lease"); + +#undef lua_pushstringnumber + + lua_seti(L, -2, i+1); + } + + return 1; +} + +static int _add_redirect(lua_State* L) { + struct device_igd* d = luaL_checkudata(L, 1, mt); + luaL_checktype(L, 2, LUA_TTABLE); + + lua_getfield(L, -1, "protocol"); + const char* protocol_const = lua_tostring(L, -1); + if (!protocol_const) { luaL_error(L, "field 'protocol' is nil"); } + char protocol[8]; + strncpy(protocol, protocol_const, sizeof(protocol)-1); + for (char* c=protocol; *c; ++c) { *c = toupper(*c); } + + lua_pop(L, 1); + + lua_getfield(L, -1, "eport"); + char ext_port[6]; + snprintf(ext_port, sizeof(ext_port), "%d", luaW_checkport(L, -1)); + lua_pop(L, 1); + + lua_getfield(L, -1, "iport"); + char int_port[6]; + snprintf(int_port, sizeof(int_port), "%d", luaW_checkport(L, -1)); + lua_pop(L, 1); + + lua_getfield(L, -1, "iaddr"); + const char* int_client = lua_tostring(L, -1); + lua_pop(L, 1); + + lua_getfield(L, -1, "lease"); + char lease[16]; + snprintf(lease, sizeof(lease), "%lld", lua_tointeger(L, -1)); + lua_pop(L, 1); + + lua_getfield(L, -1, "desc"); + const char* desc = lua_tostring(L, -1); + lua_pop(L, 1); + + if (!int_client) { luaL_error(L, "field 'iaddr', is nil"); } + + int ret = UPNP_AddPortMapping( + d->urls.controlURL, d->data.first.servicetype, ext_port, int_port, + int_client, desc, protocol, NULL, lease + ); + + if (ret != UPNPCOMMAND_SUCCESS) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "AddPortMapping(%s, %s, %s) failed: %d (%s)", + ext_port, int_port, int_client, ret, strupnperror(ret) + ); + return 2; + } + + lua_pushboolean(L, 1); + return 1; +} + +static int _remove_redirect(lua_State* L) { + struct device_igd* d = luaL_checkudata(L, 1, mt); + luaL_checktype(L, 2, LUA_TTABLE); + + lua_getfield(L, -1, "protocol"); + const char* protocol_const = lua_tostring(L, -1); + if (!protocol_const) { luaL_error(L, "field 'protocol' is nil"); } + char protocol[8]; + strncpy(protocol, protocol_const, sizeof(protocol)-1); + for (char* c=protocol; *c; ++c) { *c = toupper(*c); } + lua_pop(L, 1); + + lua_getfield(L, -1, "eport"); + char ext_port[6]; + snprintf(ext_port, sizeof(ext_port), "%d", luaW_checkport(L, -1)); + lua_pop(L, 1); + + int ret = UPNP_DeletePortMapping( + d->urls.controlURL, d->data.first.servicetype, ext_port, protocol, NULL + ); + + if (ret != UPNPCOMMAND_SUCCESS) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "DeletePortMapping(%s, %s) failed: %d (%s)", + ext_port, protocol, ret, strupnperror(ret) + ); + return 2; + } + + lua_pushboolean(L, 1); + return 1; +} + +static const luaL_Reg funcs[] = { + { "discover_igd", _discover_igd }, + { "external_ip", _external_ip }, + { "list_redirects", _list_redirects }, + { "add_redirect", _add_redirect }, + { "remove_redirect", _remove_redirect }, + { NULL, NULL} +}; + +LUAMOD_API int luaopen_whupnp(lua_State* L) { + luaL_checkversion(L); + + luaL_newlib(L, funcs); + + return 1; +} + +#endif // WH_ENABLE_MINIUPNPC + diff --git a/src/core/workerlib.c b/src/core/workerlib.c new file mode 100644 index 0000000..4d1cc32 --- /dev/null +++ b/src/core/workerlib.c @@ -0,0 +1,188 @@ +#include "luawh.h" +#include +#include +#include "serdes.h" + +#define MT "worker" + +struct pipefd { + int r_fd; + int w_fd; +}; + +struct worker { + struct pipefd req, resp; + volatile int running; + char* name; + lua_State* L; + pthread_t thread; +}; + +static int _tostring(lua_State* L) { + struct worker* w = luaW_toptr(L, 1, MT); + + if (w) { + lua_pushfstring(L, "worker* '%s': %p", w->name, w); + } else { + lua_pushstring(L, "worker*: "); + } + + return 1; +} + + +static void delete_worker(struct worker* w) { + if (w->running) { + char type_b = LUA_TNONE; + assert(write(w->req.w_fd, &type_b, 1) == 1); + pthread_join(w->thread, NULL); + } + + if (w->req.r_fd != -1) { close(w->req.r_fd), w->req.r_fd = -1; } + if (w->req.w_fd != -1) { close(w->req.w_fd), w->req.w_fd = -1; } + if (w->resp.r_fd != -1) { close(w->resp.r_fd), w->resp.r_fd = -1; } + if (w->resp.w_fd != -1) { close(w->resp.w_fd), w->resp.w_fd = -1; } + if (w->name) { free(w->name), w->name = NULL; } + if (w->L) { lua_close(w->L), w->L = NULL; } + if (w->name) { free(w->name), w->name = NULL; } + + free(w); +} + +static void delete_worker_pvoid(void* w) { + return delete_worker((struct worker*)w); +} + +static void* worker(void* ud) { + struct worker* w = ud; + lua_State* L = w->L; + + for (;;) { + lua_settop(L, 0); + + int success = luaW_readstack(L, w->req.r_fd) == 0; + + if (success && lua_gettop(L) < 2) { + break; + } + + success &= lua_pcall(L, lua_gettop(L)-2, LUA_MULTRET, 0) == LUA_OK; + lua_pushboolean(L, success); + lua_insert(L, 2); + luaW_writestack(L, 1, w->resp.w_fd); + } + + luaW_writestack(L, 1, w->resp.w_fd); + w->running = 0; + + return NULL; +} + +int luawh_pushworker(lua_State* L) { + const char* name = lua_tostring(L, 1); + + struct worker* w = calloc(1, sizeof(struct worker)); + assert(w); + + w->name = name ? strdup(name) : NULL; + w->req.r_fd = w->req.w_fd = -1; + w->resp.r_fd = w->resp.w_fd = -1; + + if (pipe((int*)&w->req)) { + delete_worker(w); + luaL_error(L, "pipe() failed: %s", strerror(errno)); + } + + if (pipe((int*)&w->resp)) { + delete_worker(w); + luaL_error(L, "pipe() failed: %s", strerror(errno)); + } + + w->L = luaL_newstate(); + luaL_openlibs(w->L); + + if (pthread_create(&w->thread, NULL, worker, w)) { + delete_worker(w); + luaL_error(L, "pthread_create() failed: %s", strerror(errno)); + } + + w->running = 1; + + luaW_pushptr(L, MT, w); + return 1; +} + +static int _pushwork(lua_State* L) { + struct worker* w = luaW_checkptr(L, 1, MT); + luaL_checktype(L, 2, LUA_TFUNCTION); + luaL_checktype(L, 3, LUA_TFUNCTION); + lua_pushvalue(L, 2); + int ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushinteger(L, ref); + lua_replace(L, 2); + luaW_writestack(L, 2, w->req.w_fd); + + return 0; +} + +static int _update(lua_State* L) { + struct worker* w = luaW_checkptr(L, 1, MT); + luaL_checktype(L, 2, LUA_TTABLE); + lua_pushinteger(L, w->resp.r_fd); + lua_seti(L, 2, luaL_len(L, 2)+1); + + return 0; +} + +static int _on_readable(lua_State* L) { + struct worker* w = luaW_checkptr(L, 1, MT); + luaL_checktype(L, 2, LUA_TTABLE); + + lua_pushinteger(L, w->resp.r_fd); + lua_gettable(L, 2); + + if (lua_toboolean(L, -1)) { + lua_settop(L, 1); + + if (luaW_readstack(L, w->resp.r_fd) != 0) { + luaL_error(L, "deserialization failed"); + } + + lua_pushvalue(L, 2); + lua_gettable(L, LUA_REGISTRYINDEX); + assert (lua_type(L, -1) == LUA_TFUNCTION); + lua_replace(L, 2); + lua_call(L, lua_gettop(L)-2, 0); + } + + return 0; +} + +LUAMOD_API int luaopen_worker(lua_State* L) { + luaW_declptr(L, MT, delete_worker_pvoid); + + luaL_getmetatable(L, MT); + lua_getfield(L, -1, "__index"); + + lua_pushcfunction(L, _pushwork); + lua_setfield(L, -2, "pcall"); + + lua_pushcfunction(L, _update); + lua_setfield(L, -2, "update"); + + lua_pushcfunction(L, _on_readable); + lua_setfield(L, -2, "on_readable"); + + lua_pop(L, 1); + + lua_pushcfunction(L, _tostring); + lua_setfield(L, -2, "__tostring"); + + lua_pop(L, 1); + + lua_pushcfunction(L, luawh_pushworker); + + return 1; +} + diff --git a/src/handlers.lua b/src/handlers.lua new file mode 100644 index 0000000..76b0b1d --- /dev/null +++ b/src/handlers.lua @@ -0,0 +1,321 @@ +local packet = require('packet') +local peer = require('peer') + +local auth = require('auth') +local kad = require('kad') +local nat = require('nat') +local search = require('search') + +local H = {} + +local function log_cmd(n, m, src, fmt, ...) + --printf("%s - %s - %d - %s", src, wh.todate(now), #m, string.format(fmt, ...)) + if n.log >= 2 then + printf("%s -> :%d: %s (%dB)", src, n.port, string.format(fmt, ...), #m) + end +end + +H[packet.cmds.ping] = function(n, m, src, via) + local arg = string.sub(m, 2, 2) + if arg == '\x00' then + arg = 'normal' + elseif arg == '\x01' then + arg = 'swapsrc' + elseif arg == '\x02' then + arg = 'direct' + end + + local body = string.sub(m, 3) + + --if src.lazy and (arg ~= "normal" or #body ~= 0) then return end + + -- ignore ping with body bigger than 8 + if #body > 8 then + printf("$(red)drop too big ping") + return + end + + -- by default, respond via same port, except if argument is 'swapsrc' + local echo + + if via == 'relay' then + echo = false + else + echo = via == 'echo' + if arg == 'swapsrc' then echo = not echo end + end + + if arg == 'direct' then + -- remove relay + src = {addr=src.addr, k=src.k} + end + + log_cmd(n, m, src, "$(yellow)ping$(reset)(%s, %s)", arg, wh.tob64(body)) + n:_sendto{ + dst=src, + m=packet.pong(n.port_echo, src.addr, body), + from_echo=echo + } +end + +H[packet.cmds.pong] = function(n, m, src) + --if src.lazy then return end + + local i = 2 + local l + + local src_port_echo, src_addr_echo, public_addr + + public_addr, l = wh.unpack_address(string.sub(m, i)) + i = i + l + + src_port_echo = string.unpack("!H", string.sub(m, i, i+1)) + src_addr_echo = wh.set_address_port(src.addr, src_port_echo) + i = i + 2 + + src.addr_echo = src_addr_echo + + local body = string.sub(m, i) + + log_cmd(n, m, src, "$(yellow)pong$(reset)(%s, port_echo=%s, self=%s)", wh.tob64(body), src_addr_echo, public_addr) + + kad.on_pong(n, src) + nat.on_pong(n, body, src) + search.on_pong(n, body, src) +end + +H[packet.cmds.search] = function(n, m, src) + local k = string.sub(m, 2) + log_cmd(n, m, src, "$(yellow)search$(reset)(%s)", n:key(k)) + + local closest = n.kad:kclosest(k, wh.KADEMILIA_K) + + n:_sendto{dst=src, m=packet.result(k, closest)} +end + +H[packet.cmds.result] = function(n, m, src) + if src.lazy then return end + + local pks = string.sub(m, 2, 33) + local closest = {} + local i = 34 + local l + + while i < #m do + local p = {} + + local flag = string.sub(m, i, i) + i = i + 1 + + do + p.k = string.sub(m, i, i+31); + i = i + 32 + + p.addr, l = wh.unpack_address(string.sub(m, i)) + i = i + l + end + + if flag == '\x01' then + local relay = {} + relay.k = string.sub(m, i, i+31) + i = i + 32 + relay.addr = wh.unpack_address(string.sub(m, i)) + i = i + l + + -- prefer own source + p.relay = n.kad:get(relay.k) or peer(relay) + + elseif flag == '\x02' then + p.relay = src + end + + closest[#closest+1] = peer(p) + end + + log_cmd(n, m, src, "$(yellow)result$(reset)(#%d)", #closest) + + search.on_result(n, pks, closest, src) +end + +H[packet.cmds.relay] = function(n, m, src) + local i = 2 + local l + + local dst_k = string.sub(m, i, i+32-1) + i = i + 32 + if #dst_k ~= 32 then + return + end + + local relayed_m = string.sub(m, i) + + local dst = n.kad:get(dst_k) + if not dst then + return + end + + -- XXX bandwidth limit + -- XXX whitelist management + -- XXX keep source in kademilia for some time as it is currently relaying + -- with dst + -- XXX SECURITY ISSUE! + -- Do not let anyone send a 'relayed' packet with any type of body. Just + -- sign the digest, not the entire body! (AEAD?) + -- OK for the POC + -- XXX make sure to keep dst in the kademilia table + + log_cmd(n, m, src, "$(blue)relay$(reset)(%s, %d)", n:key(dst_k), #relayed_m) + + if dst.relay then + printf("$(red)cannot forward relayed packet to relayed peer %s$(reset)", n:key(dst)) + return + end + + if not dst.addr then + printf("$(red)unknown route for relayed packet to %s$(reset)", n:key(dst)) + return + end + + n:_sendto{dst=dst, m=packet.relayed(src, relayed_m)} +end + +H[packet.cmds.relayed] = function(n, m, relay) + local i = 2 + local l + + local src_addr + src_addr, l = wh.unpack_address(string.sub(m, i)) + i = i + l + assert(src_addr) -- XXX + + local me = string.sub(m, i) + + log_cmd(n, m, relay, "$(blue)relayed$(reset)(%s, %d)", src_addr, #me) + + local src_k, time, src_is_nated + src_k, src_is_nated, time, m = wh.open_packet(n.sk, me) + + if m == nil then + printf("$(red)relayed packet dropped!$(reset)") + return + end + + -- XXX SECURITY ISSUE! + -- a node received a relayed must check that the source indeed sent the + -- relayed packet through the relay! + + local src_relay + if src_is_nated then + src_relay = relay + end + + return n:read( + m, + src_addr, + src_k, + src_is_nated, + time, + 'relay', + src_relay + ) +end + +H[packet.cmds.auth] = function(n, m, alias) + if not alias.alias then + -- not set as alias + return + end + + local me = string.sub(m, 2) + local src_k, src_is_nated, src_time, src_m = wh.open_packet(n.sk, me) + + if src_m == nil then + log_cmd(n, m, alias, "$(yellow)auth$(reset)($(red)bad$(reset))") + return + end + + if src_k ~= src_m then + log_cmd(n, m, alias, "$(yellow)auth$(reset)($(red)invalid$(reset))") + return + end + + log_cmd(n, m, alias, "$(yellow)auth$(reset)(%s)", n:key(src_k)) + + local src = n.kad:touch(src_k) + + auth.resolve_alias(n, alias, src) + + n:_sendto{ + dst=src, + m=packet.authed(alias.k), + } +end + +H[packet.cmds.authed] = function(n, m, src) + local alias_k = string.sub(m, 2) + + log_cmd(n, m, alias, "$(yellow)authed$(reset)(%s)", n:key(alias_k)) + + auth.on_authed(n, alias_k, src) +end + +H[packet.cmds.fragment] = function(n, m, src) + if not n.lo then + return + end + + local id, b = string.unpack(">HB", string.sub(m, 2, 4)) + local num = b&0x7f + local mf = b&0x80==0x80 + m = string.sub(m, 5) + + log_cmd(n, m, alias, "$(yellow)fragment$(reset)(id:%.4x, num:%d, mf:%s, m:%d)", id, num, mf, #m) + + if not src.fragments then + src.fragments = {} + end + local id = src.k .. string.pack("H", id) + local sess = src.fragments[id] + if not sess then + sess = { + deadline=now+wh.FRAGMENT_TIMEOUT, + id=id, + } + src.fragments[#src.fragments+1] = sess + table.sort(src.fragments, function(a, b) + return a.deadline < b.deadline + end) + src.fragments[sess.id] = sess + end + + sess[num+1] = m + + -- if last fragment was received + if not mf then + sess.last = num+1 + end + + if not sess.last or #sess ~= sess.last then + return + end + + local m = table.concat(sess) + + n.lo:recv_datagram(src, m) + + -- clean + src.fragments[id] = nil + for i, v in ipairs(src.fragments) do + if v == sess then + table.remove(src.fragments, i) + break + end + end + + --if #src.fragments == 0 then + -- src.fragments = nil + --end +end + +return H + diff --git a/src/handlers_ipc.lua b/src/handlers_ipc.lua new file mode 100644 index 0000000..b18457c --- /dev/null +++ b/src/handlers_ipc.lua @@ -0,0 +1,283 @@ +return function(n) + local H = {} + + function H.down(send, close) + n:stop() + send('OK\n') + return close() + end + + H['getent ([^%s]+)'] = function(send, close, k) + return n:getent(k, function(k) + if k then + local r = {} + r[#r+1] = wh.tob64(k) + + local p = n.kad:get(k) + if p and p.ip then + r[#r+1] = ' ' .. tostring(p.ip) + else + r[#r+1] = ' nil' + end + r[#r+1] = '\n' + send(table.concat(r)) + end + + return close() + end) + end + + H['dump ([^%s]+)'] = function(send, close, k) + return n:getent(k, function(k) + if k then + local p = n.kad:get(k) + + if p then + send(dump(p)) + end + end + + return close() + end) + end + + local function _resolve(send, close, cmd, k) + local args + if cmd == 'gethostbyname' then + args = {name = k} + else + assert(cmd == 'gethostbyaddr') + args = {ip = k} + end + + return n:resolve(args, function(k, hostname, ip) + if k then + send(string.format("%s\t%s\t%s\n", + wh.tob64(k), + hostname or '', + ip or '' + )) + end + + return close() + end) + end + + H['(gethostbyname) ([^%s]+)'] = _resolve + H['(gethostbyaddr) ([^%s]+)'] = _resolve + + function H.info(send, close) + send('WireHub %s\n', wh.version) + send('Uptime: %.1f\n', now-start_time) + if opts.interface then + send('Interface: %s\n', opts.interface) + end + + send('Key: %s\n', wh.tob64(wh.publickey(n.sk))) + send('ListenPort: %d\n', n.port) + return close() + end + + function H.key(send, close) + send('%s\n', wh.tob64(wh.publickey(n.sk))) + return close() + end + + H['describe ([^%s]+)'] = function(send, close, mode) + send(n:describe(mode) .. '\n') + return close() + end + + function H.list(send, close) + local append = function(p, s) + send(tostring(s) .. '\t' .. (p.trust and 'trusted' or 'untrusted') .. '\n') + end + + for bid, bucket in pairs(n.kad.buckets) do + for _, p in ipairs(bucket) do + append(p, wh.tob64(p.k)) + + if p.hostname then + append(p, p.hostname) + end + end + end + + return close() + end + + function H.dumpkad(send, close) + for bid, bucket in pairs(n.kad.buckets) do + for i, p in ipairs(bucket) do + + local d = {} + for k, v in pairs(p) do d[k] = v end + d.bid = bid + + send("%s", dump(d) .. '\n') + end + end + return close() + end + + local function _nat(send, close, k) + if k == 4 then k = nil end + + local function detect(k) + return n:detect_nat(k, function(mode) + send('%s\n', mode) + close() + end) + end + + if k then + return n:getent(k, function(k) + if not k then + send('invalid key\n') + return close() + end + + return detect(k) + end) + else + return detect() + end + end + + H['nat()'] = _nat + H['nat ([^%d]+)'] = _nat + + local function _search(send, close, cmd, k) + local s + + n:getent(k, function(k) + if not k then + send('invalid key\n') + return close() + end + + s = n:search(k, cmd, nil, nil, function(s, p, via) + if p then + local mode + if p.relay then + mode = wh.tob64(p.relay.k) + elseif p.is_nated then + mode = '(nat)' + else + mode = '(direct)' + end + + send('%s %s %s %s\n', + wh.tob64(k), + mode, + p.addr, + wh.tob64(via.k) + ) + else + close() + end + end) + end) + + return function() + if s then + n:stop_search(s) + end + + -- XXX getent stop + end + end + + H['(p2p) ([^%s]+)'] = _search + H['(lookup) ([^%s]+)'] = _search + H['(ping) ([^%s]+)'] = _search + + H['connect ([^%s]+)'] = function(send, close, k) + local s + + n:getent(k, function(k) + if not k then + send('invalid hostname or key\n') + return close() + end + + s = n:connect(k, nil, function(s, p, p2p, endpoint) + if p then + if p2p then + send("p2p ") + end + send("%s\n", endpoint) + end + return close() + end) + end) + + return function() + if s then + n:stop_search(s) + end + + -- XXX getent stop + end + end + + H['forget ([^%s]+)'] = function(send, close, k) + return n:getent(k, function(k) + if not k then + send('invalid hostname or key\n') + return close() + end + + n:forget(k) + return close() + end) + end + + H['authenticate ([^%s]+) (.+)'] = function(send, close, k, path) + local alias_sk = cpcall(wh.readsk, path) + if not alias_sk then + send('!\n') + return close() + end + + local a + n:getent(k, function(k) + if not k then + send('invalid key or unknown hostname\n') + return close() + end + + a = n:authenticate(k, alias_sk, function(a, success, errmsg) + if success then + send('authenticated!\n') + else + send(string.format('failed: %s\n', errmsg)) + end + + return close() + end) + end) + + return function() + if a then + n:stop_authenticate(a) + end + end + end + + H.bw = function(send, close) + if n.bw then + for k, avg in pairs(n.bw:avg()) do + send(string.format("%s\t%s\t%s\n", + wh.tob64(k), + avg.rx, + avg.tx + )) + end + end + return close() + end + + return H +end + diff --git a/src/helpers.lua b/src/helpers.lua new file mode 100644 index 0000000..28a7113 --- /dev/null +++ b/src/helpers.lua @@ -0,0 +1,318 @@ +-- random seed +math.randomseed(string.unpack("I", wh.randombytes(4))) + +function execf(...) + return os.execute(string.format(...)) +end + +function tointeger(s) + local n = tonumber(s) + if n == nil then return nil end + if math.floor(n) ~= n then return nil end + return n +end + + +function string.join(val, tbl) + local r = {} + for i, v in pairs(tbl) do + if i > 1 then r[#r+1] = val end + r[#r+1] = tostring(v) + end + return table.concat(r) +end + +function dump(x, level) + level = level or 0 + + if type(x) == 'table' then + local function format_k(k) + if type(k) == 'number' then + return string.format('[%d]', k) + elseif type(k) == 'string' then + local r = dump(k) + if string.sub(r, 1, 1) == '"' and string.sub(r, -1, -1) == '"' then + return string.sub(r, 2, -2) + else + return string.format('[%s]', r) + end + else + -- XXX check for invalid Lua characters for key + return tostring(k) + end + end + + local r = {'{'} + + level = level + 1 + + local keys = {} + for k in pairs(x) do keys[#keys+1] = k end + table.sort(keys, function(a,b) + local a_type = type(a) + local b_type = type(b) + + if a_type ~= b_type then + a = a_type + b = b_type + end + + return a < b + end) + for i, k in ipairs(keys) do + local v = x[k] + + if i > 1 then r[#r+1] = ',' end + r[#r+1] = '\n' + + r[#r+1] = string.rep(' ', level) .. string.format('%s = %s', format_k(k), dump(v, level)) + end + level = level - 1 + + r[#r+1] = '\n' .. string.rep(' ', level) .. '}' + + return table.concat(r) + elseif type(x) == 'string' then + if #x == 32 then + return string.format('base64(%s)', wh.tob64(x)) + end + + return string.format('%q', x) + else + return tostring(x) + end +end + +function parsearg(idx, fields) + local state = {} + while true do + if arg[idx] == nil then + break + end + + local field = arg[idx] + local field_func = fields[field] + + local value + if field_func == true then -- XXX replace true by 'boolean' + value = true + + elseif not field_func or not arg[idx+1] then + printf('Invalid argument: %s', field) + return + else + idx = idx + 1 + local errmsg + value, errmsg = field_func(arg[idx]) + + if value == nil then + printf("Invalid argument: %s. %s", field, errmsg or '') + return + end + end + + state[field] = value + idx = idx + 1 + end + + return state +end + +function parsebool(s) + if not s then + return + end + + s = s:lower() + if s == 'yes' or s == 'true' or s == '1' then + return true + elseif s == 'no' or s == 'false' or s == '0' then + return false + end +end + +local notif = "" +do + local C = { + reset=0, + bold=1, + black='0;30', + red='0;31', + green='0;32', + orange='0;33', + blue='0;34', + magenta='0;35', + cyan='0;36', + gray='0;37', + darkgray='1;30', + lightred='1;31', + lightgreen='1;32', + yellow='1;33', + lightblue='1;34', + lightpurple='1;35', + lightcyan='1;36', + white='1;37', + } + + function format_color(s, color_mode) + local any_col = false + s = string.gsub(s, "$%(%a+%)", function(c) + c = string.sub(c, 3, -2) + local col = C[c] + assert(col, "unknown color") + + if color_mode == nil then + color_mode = wh.color_mode() + end + if not color_mode then return '' end + + any_col = true + return string.format("\x1b[%sm", col) + end) + if any_col then + s = s .. "\x1b[0m" + end + return s + end + + _G['log'] = false + + function printf(...) + local fmt = string.format(...) + + if _G['log'] then + _G['log'](format_color(fmt, false)) + end + + io.stdout:write(string.format("\r%s\r", string.rep(' ', #notif))) + print(format_color(fmt)) + io.stdout:write(notif) + io.stdout:flush() + end +end + +function hexdump(x) + if #x == 0 then + return "\n" + end + local r = {} + for i = 1, #x do + r[#r+1] = string.format("%.2x ", string.byte(string.sub(x, i, i))) + if i > 1 and i % 16 == 1 then + r[#r+1] = "\n" + elseif i > 1 and i % 8 == 1 then + r[#r+1] = " " + end + end + return table.concat(r) +end + +function status(fmt, ...) + if true then return end + if not wh.color_mode() then + return + end + + local prev_notif = notif + if fmt == nil then + notif = nil + else + notif = string.format(fmt, ...) + notif = string.format('$(lightblue)(%s)$(reset)', notif) + notif = format_color(notif) + end + + io.stdout:write("\r" .. string.rep(' ', #(prev_notif or "")) .. "\r") + + if notif then + io.stdout:write("\r" .. notif) + end + + io.stdout:flush() +end + +function cpcall(cb, ...) + return (function(ok, ...) + if ok then + return ... + else + return + end + end)(xpcall(cb, function(msg) return print(debug.traceback(msg, 2)) end, ...)) +end + +local exits_cb = {} +function atexit(cb, ...) + assert(cb) + exits_cb[#exits_cb+1] = table.pack(cb, ...) +end + +function _do_atexits() + for _, x in ipairs(exits_cb) do + cpcall(table.unpack(x)) + end +end + +function min(a) + local m = nil + for i, v in ipairs(a) do + if v ~= nil and (m == nil or m > v) then + m = v + end + end + return m +end + +function max(a) + local m = nil + for i, v in ipairs(a) do + if v ~= nil and (m == nil or m < v) then + m = v + end + end + return m +end + +function randomrange(s, e) + return math.floor(math.random()*(e-s)+s) +end + +function dbg(fmt, ...) + return printf("$(red)" .. fmt .. "$(reset)", ...) +end + +function errorf(...) + return error(string.format(...)) +end + +function disable_globals() + setmetatable(_G, { + __newindex = function(t, n, v) + if n == '_' then + return + end + + if t[n] == nil then + error(string.format("cannot set any global: %s", n)) + end + + return rawset(t, n, v) + end, + }) +end + +local MEMUNITS = { "B", "KiB", "MiB", "GiB" } +function memunit(v) + local unit + for _, u in ipairs(MEMUNITS) do + if v < 800 then + unit = u + break + end + + v = v / 1024 + end + + return string.format("%.1f%s", v, unit) +end + diff --git a/src/ipc.lua b/src/ipc.lua new file mode 100644 index 0000000..63f9eef --- /dev/null +++ b/src/ipc.lua @@ -0,0 +1,152 @@ +local MT = {__index = {}} +local I = MT.__index + +local function close(ipc, sock) + local state = ipc.states[sock] + ipc.states[sock] = nil + + if state == nil then + return + end + + if state.close_cb then + cpcall(state.close_cb) + end + + wh.close(sock) +end + +function I.close(ipc) + for sock in pairs(ipc.states) do + close(ipc, sock) + end + ipc.states = {} + + if ipc.listen_sock then + wh.close(ipc.listen_sock) + ipc.listen_sock = nil + end + + if ipc.close_cb then + ipc.close_cb() + ipc.close_cb = nil + end +end + +local function on_sock_readable(ipc, sock, cmd) + local state = ipc.states[sock] + + if not state then + return + end + + if not cmd or #cmd == 0 then + return close(ipc, sock) + end + + if state.wait_cmd and cmd and #cmd > 0 then + state.wait_cmd = false + + -- remove trailing \n + while string.sub(cmd, -1) == '\n' do + cmd = string.sub(cmd, 1, -2) + end + + local function send(...) + if not ipc.states[sock] then + return + end + + local s = string.format(...) + if wh.send(sock, s) ~= #s then + error("send truncated") + end + end + + local function _close() + close(ipc, sock) + end + + local done + for pattern, cb in pairs(ipc.h) do + done = (function(...) + if ... == nil then + return false + end + + state.close_cb = cpcall(cb, send, _close, ...) + return true + end)(string.match(cmd, pattern)) + + if done then + break + end + end + + if not done then + send('?\n') + _close() + end + end + +end + +function I.on_readable(ipc, r) + if ipc.listen_sock and r[ipc.listen_sock] then + r[ipc.listen_sock] = nil + + local new_sock = wh.ipc_accept(ipc.listen_sock) + ipc.states[new_sock] = {wait_cmd=true} + end + + for sock in pairs(ipc.states) do + if r[sock] then + local cmd = wh.recv(sock, 65535) + on_sock_readable(ipc, sock, cmd) + end + end +end + +function I.update(ipc, socks) + if ipc.listen_sock then + socks[#socks+1] = ipc.listen_sock + end + + for sock, state in pairs(ipc.states) do + socks[#socks+1] = sock + end + + return nil +end + +local M = {} + +function M.bind(interface_name, h) + assert(interface_name and h) + local listen_sock, close_cb = wh.ipc_bind(interface_name, false) + + return setmetatable({ + close_cb=close_cb, + states={}, + listen_sock=listen_sock, + h=h, + }, MT) +end + +function M.call(interface_name, cmd) + assert(interface_name) + local sock = wh.ipc_connect(interface_name) + + if not sock then + return + end + + if wh.send(sock, cmd .. '\n') ~= (#cmd+1) then + error("send truncated") + end + + return sock +end + +return M + diff --git a/src/kad.lua b/src/kad.lua new file mode 100644 index 0000000..636ed0f --- /dev/null +++ b/src/kad.lua @@ -0,0 +1,171 @@ +local packet = require('packet') +local time = require('time') + +local M = {} + +local function explain(n, p, fmt, ...) + return n:explain("(peer %s) " .. fmt, n:key(p), ...) +end + +local function update_peer(n, p, excedent) + -- take advantage of iterating over all nodes to remove fragments, + -- without management of any deadlines + if p.fragments then + local to_remove = {} + local count = #p.fragments + for i, sess in ipairs(p.fragments) do + if (count > wh.FRAGMENT_MAX or + sess.deadline <= now) then + + to_remove[#to_remove+1] = i + count = count - 1 + end + end + + for i = #to_remove, 1, -1 do + local sess_i = to_remove[i] + local sess = p.fragments[sess_i] + printf("$(red)drop fragment session %s$(reset)", wh.tob64(sess.id)) + p.fragments[sess.id] = nil + table.remove(p.fragments, sess_i) + end + end + + -- keeps aliases for ever + if p.alias then + return 'inf' + end + + -- if relay was forgotten + if p.relay and p.relay.addr == nil then + p.relay = nil + end + + -- XXX if excedent??? keep trusted peers infinitely? + + + local last_seen = p.last_seen or 0 + local ping_retry = p.ping_retry or 0 + + -- XXX NOTE XXX + -- if + -- current peer is NAT-ed and remote peer is not, OR + -- current peer has a tunnel opened to remote peer, OR + -- peer is excedent + + local should_ping + local reason + + if not p.addr and not p.relay then + should_ping = false + + elseif p.wg_connected then + reason = 'wireguard is enabled' + should_ping = true + + -- XXX should only ping the closest peers, not all! + elseif n.is_nated then + reason = 'current peer is NAT-ed' + should_ping = true + + elseif excedent then + reason = 'peer is excedent' + should_ping = true + end + + if should_ping then + local ping_retry + if not p.bootstrap then + ping_retry = wh.PING_RETRY + end + + local do_ping, deadline = time.retry_ping_backoff( + p, + wh.NAT_TIMEOUT - n.jitter_rand, + ping_retry, + wh.PING_BACKOFF + ) + + if do_ping then + explain(n, p, "alive? (%s)", reason) + n:_sendto{dst=p, m=packet.ping()} + end + + return deadline + end + + -- if address was forgotten + if p.addr == nil then + if p.trust then + return 'inf' + else + return nil + end + end + + if not p.is_nated then + return 'inf' + end + + -- p is NAT-ed. Forget if it does not contact current peer after a certain + -- amount of time + local deadline = last_seen + wh.NAT_TIMEOUT * 2 + if deadline <= now then + return nil + end + + return deadline +end + +function M.update(n, deadlines) + -- maintain the Kademilia tree + + for bid, bucket in pairs(n.kad.buckets) do + table.sort(bucket, function(a, b) return (a.first_seen or 0) < (b.first_seen or 0) end) + + local to_remove = {} + + local c = 0 + for i, p in ipairs(bucket) do + local excedent = c >= n.kad.K + + local deadline = update_peer(n, p, excedent) + + if deadline == nil then + p.addr = nil + p.addr_echo = nil + p.is_nated = nil + p.relay = nil + + if p.trust then + explain(n, p, "forget!") + n.kad.touched[p.k] = p + else + n.kad.touched[p.k] = nil + to_remove[#to_remove+1] = i + end + elseif deadline ~= 'inf' then + deadlines[#deadlines+1] = deadline + c = c + 1 + end + end + + for i = #to_remove, 1, -1 do + local p = bucket[to_remove[i]] + explain(n, p, "remove!") + table.remove(bucket, to_remove[i]) + bucket[p.k] = nil + end + end +end + +function M.on_pong(n, src) + if (src.ping_retry or 0) > 0 then + explain(n, src, "alive!") + end + + src.ping_retry = 0 +end + +return M + diff --git a/src/kadstore.lua b/src/kadstore.lua new file mode 100644 index 0000000..2f5edf8 --- /dev/null +++ b/src/kadstore.lua @@ -0,0 +1,102 @@ +-- kademilia + +local peer = require('peer') + +local MT = { + __index = {}, +} + +function MT.__index.touch(t, k) + if t.root.k == k then + t.touched[k] = t.root + return t.root + end + + local bid = wh.bid(t.root.k, k) + local b = t.buckets[bid] + if not b then + b = {} + t.buckets[bid] = b + end + + local p = b[k] + if p == nil then + p = { + k=k, + } + + b[#b+1] = p + b[k] = p + end + + t.touched[p.k] = p + + return peer(p) +end + +function MT.__index.get(t, k) + if t.root.k == k then + return t.root + end + + local bid = wh.bid(t.root.k, k) + local b = t.buckets[bid] + if not b then return end + + return b[k] +end + +function MT.__index.clear_touched(t) + t.touched = {} +end + +function MT.__index.kclosest(t, k, count, filter_cb) + local empty = {} + if count == nil then count = t.K end + local bid = wh.bid(t.root.k, k) + + local r = {} + + local function extend(i) + for _, p in ipairs(t.buckets[i] or empty) do + if (p.k and + p.addr and + not p.alias and ( + not filter_cb or + filter_cb(p) + ) + ) then + r[#r+1] = {wh.xor(p.k, k), p} + end + end + end + + extend(bid) + + if #r < count then + for i = bid+1, #t.root.k*8 do extend(i) end + end + + for i = bid-1, 1, -1 do + if #r >= count then + break + end + + extend(i) + end + + table.sort(r, function(a, b) return a[1] < b[1] end) + return r +end + +return function(root_k, kad_k) + assert(root_k and kad_k) + + return setmetatable({ + buckets={}, + K=kad_k, + touched={}, + root={k=root_k}, + }, MT) +end + diff --git a/src/key.lua b/src/key.lua new file mode 100644 index 0000000..75c465d --- /dev/null +++ b/src/key.lua @@ -0,0 +1,43 @@ +local dbg_keys = {idx=1} +function wh.key(k_or_p, n) + local k, p + + if type(k_or_p) == 'table' then + p = k_or_p + k = p.k + else + p = nil + k = k_or_p + + if k == nil then + return nil + end + end + + if not p and n then + p = n.kad:get(k) + end + + if p then + if p.hostname then + return p.hostname + elseif p.ip then + return string.format("", p.ip) + end + end + + if false then + local dk = dbg_keys[k] + if dk == nil then + dk = string.format("KEY_%d", dbg_keys.idx) + dbg_keys.idx = dbg_keys.idx + 1 + dbg_keys[k] = dk + end + return dk + end + + local b64 = wh.tob64(k) + b64 = string.sub(b64, 1, 10) + return b64 +end + diff --git a/src/lo.lua b/src/lo.lua new file mode 100644 index 0000000..0f6e4d7 --- /dev/null +++ b/src/lo.lua @@ -0,0 +1,230 @@ +local TRY_AUTOCONNECT_EVERY_S = 60 + +local MT = { + __index = {} +} + +local function _alloc(lo) + -- XXX manages lo.cidr + -- XXX slow + for i = 1024, 65535 do + local a = wh.set_address_port(lo.addr, i) + if not lo.addr_ks[a:pack()] then + return a + end + end + + error("no more free space for ports") + return nil +end + +function MT.__index.touch(lo, k) + local a = lo.k_addrs[k] + if not a then + a = _alloc(lo, k) + lo.k_addrs[k] = a + lo.addr_ks[a:pack()] = k + end + assert(a) + + return a +end + +function MT.__index.free(lo, k) + local a = lo.k_addrs[k] + if a then + lo.addr_ks[a:pack()] = nil + lo.k_addrs[k] = nil + end +end + +function MT.__index.update(lo, socks) + local deadlines = {} + local timeout + lo.sniff_fd, timeout = wh.get_pcap(lo.sniff) + + socks[#socks+1] = lo.sniff_fd + if timeout then + deadlines[#deadlines+1] = now+timeout + end + + for _, c in pairs(lo.connects) do + if not c.ac_deadline then + -- do nothing + elseif c.ac_deadline > now then + deadlines[#deadlines+1] = c.ac_deadline + else + lo.connects[c.k] = nil + end + end + + return min(deadlines) +end + +function MT.__index.touch_tunnel(lo, p) + if not p.tunnel then + p.tunnel = { + lo_addr = lo:touch(p.k) + } + end + return p.tunnel +end + +function MT.__index.free_tunnel(lo, p) + if p.tunnel then + lo:free(p.k) + p.tunnel = nil + end +end + +local function on_connect(lo, c, dst, p2p) + --dbg(dump(dst)) + if dst then + if p2p then + printf("$(green)auto-connect to %s succeed$(reset)", lo.n:key(dst)) + else + printf("$(orange)cannot establish P2P connection with %s$(reset)", lo.n:key(dst)) + + assert(dst.tunnel) + dst.tunnel.last_tx = now + end + + for _, m in ipairs(c.pkt_buf or {}) do + -- XXX + lo.n:send_datagram(dst, m) + end + + else + printf("$(red)could not find %s$(reset)", lo.n:key(c.k)) + end + + c.ac_deadline = now + TRY_AUTOCONNECT_EVERY_S +end + +function MT.__index.on_readable(lo, r) + while r[lo.sniff_fd] do + local src_addr, dst_lo_addr, m = wh.pcap_next_udp(lo.sniff) + + if not m then + break + end + + local dst_k = lo.addr_ks[dst_lo_addr:pack()] + + if not dst_k then + printf("$(red)error: unknown lo addr: %s$(reset)", dst_lo_addr) + return + end + + local dst = lo.n.kad:get(dst_k) + if not dst then + return + end + + -- is peer set with a tunnel? if so, redirect the wireguard packet. + -- else, try to connect while buffering the packets + + if lo.auto_connect then + local c = lo.connects[dst.k] + if not c then + printf("$(green)auto-connecting to %s$(reset)", lo.n:key(dst)) + c = lo.n:connect(dst.k, nil, function(...) + return on_connect(lo, ...) + end) + lo.connects[dst.k] = c + end + + if not c.pkt_buf then + c.pkt_buf = {} + end + + c.pkt_buf[#c.pkt_buf+1] = m + if #c.pkt_buf > lo.buffer_max then + table.remove(c.pkt_buf, 1) + end + end + + if dst.tunnel then + dst.tunnel.last_tx = now + local through_tunnel = lo.n:send_datagram(dst, m) + + if not through_tunnel then + lo:free_tunnel(dst) + end + end + end +end + +function MT.__index.forget(lo, k) + lo.connects[k] = nil +end + +function MT.__index.recv_datagram(lo, src, m) + lo:touch_tunnel(src) + src.tunnel.last_rx = now + + printf("$(orange)receive datagram %dB$(reset)", #m) + + local ret, errmsg = wh.sendto_raw_wg(lo.sock, m, src.tunnel.lo_addr, lo.n.port) + + if not ret then + printf('$(red)error: could not send packet: %s', errmsg) + end +end + +function MT.__index.close(lo) + if lo.sniff then + wh.close_pcap(lo.sniff) + lo.sniff = nil + end + + if lo.sniff_fd then + wh.close(lo.sniff_fd) + lo.sniff_fd = nil + end + + if lo.sock then + wh.close(lo.sock) + lo.sock = nil + end +end + +return function(lo) + assert(lo.n) + + if lo.auto_connect == nil then + lo.auto_connect = true + end + + if not lo.cidr then + lo.cidr = 32 + end + + if not lo.addr then + -- XXX + lo.addr = wh.address(string.format("127.%d.%d.%d", + randomrange(1, 254), + randomrange(1, 254), + randomrange(1, 254) + ), 0) + end + lo.subnet = lo.addr:addr() .. '/' .. tostring(lo.cidr) + + lo.k_addrs = {} + lo.addr_ks = {} + lo.tunnels = {} + lo.connects = {} + + if lo.auto_connect then + if not lo.buffer_max then + lo.buffer_max = 1 + end + end + + -- XXX lazy? + lo.sniff = wh.sniff('any', 'in', 'wg', " and dst net " .. lo.subnet) + lo.sock = wh.socket_raw_udp('ip4_hdrincl') + + return setmetatable(lo, MT) +end + diff --git a/src/nat.lua b/src/nat.lua new file mode 100644 index 0000000..b543758 --- /dev/null +++ b/src/nat.lua @@ -0,0 +1,101 @@ +local time = require('time') +local packet = require('packet') +local peer = require('peer') + +local M = {} + +local function explain(n, d, fmt, ...) + return n:explain("(nat %s) " .. fmt, n:key(d.k), ...) +end + +function M.update(n, d, deadlines) + local to_remove = {} + + local do_ping, deadline + + -- check if offline. if node answers, peek its echo node + if d.may_offline then + do_ping, deadline = time.retry_backoff(d, 'retry', 'req_ts', wh.PING_RETRY, wh.PING_BACKOFF) + + if do_ping then + explain(n, d, "offline? (retry: %d)", d.retry) + n:_sendto{dst=d.p, m=packet.ping(nil, d.uid)} + end + + if deadline == nil then + return d.cb('offline') + end + end + + -- maybe public? ping node and ask to answer as an echo. if it is not + -- received after several retry, consider the node as not public. + if not d.may_offline and d.may_direct then + assert(d.p_echo, 'a pong with the echo address should have been received') + + do_ping, deadline = time.retry_backoff(d, 'retry', 'req_ts', wh.PING_RETRY, wh.PING_BACKOFF) + + if do_ping then + explain(n, d, "direct? (retry: %d)", d.retry) + n:_sendto{dst=d.p, m=packet.ping('swapsrc', d.uid)} + end + + if deadline == nil then + explain(n, d, "not direct!") + d.may_direct = false + d.req_ts = 0 + d.retry = 0 + d.uid = wh.randombytes(8) + end + end + + if not d.may_offline and not d.may_direct and d.may_cone then + do_ping, deadline = time.retry_backoff(d, 'retry', 'req_ts', wh.PING_RETRY, wh.PING_BACKOFF) + + if do_ping then + -- UDP hole punch + explain(n, d, "cone? (retry: %d)", d.retry) + n:_sendto{dst=d.p_echo, m=packet.ping('normal', d.uid_echo)} + n:_sendto{dst=d.p, m=packet.ping('swapsrc', d.uid)} + end + + if deadline == nil then + explain(n, d, "not cone!") + d.may_cone = false + d.req_ts = 0 + d.retry = 0 + d.uid=wh.randombytes(8) + end + end + + if not d.may_offline and not d.may_direct and not d.may_cone then + return d.cb('blocked') + end + + assert(deadline ~= nil) + deadlines[#deadlines+1] = deadline +end + + +function M.on_pong(n, body, src) + for d in pairs(n.nat_detectors) do + if d.uid == body then + if d.may_offline then + assert(src.addr_echo) + d.p_echo = peer{k=src.k, addr=src.addr_echo} + d.may_offline = false + d.req_ts = 0 + d.retry = 0 + d.uid=wh.randombytes(8) + explain(n, d, "online!") + elseif d.may_direct then + d.cb("direct") + elseif d.may_cone then + d.cb("cone") + end + end + end + +end + +return M + diff --git a/src/node.lua b/src/node.lua new file mode 100644 index 0000000..81f272d --- /dev/null +++ b/src/node.lua @@ -0,0 +1,1054 @@ +local handlers = require('handlers') +local packet = require('packet') +local peer = require('peer') +local time = require('time') + +local auth = require('auth') +local kad = require('kad') +local nat = require('nat') +local search = require('search') +local connectivity = require('connectivity') + +local M = {} +local MT = { + __index = {}, +} + +function MT.__index._extend(n, s, closest, src) + s.states[src.k] = {retry=0, rep=true} + + local set = {} + for _, c in ipairs(s.closest) do + local p = c[2] + set[p:pack()] = true + end + + for _, c in ipairs(closest) do + local dist = c[1] + local p = c[2] + + local st = s.states[p.k] + + if ( + -- if peer has address + p.addr and + + -- if peer already answered, skip, + (st == nil or not st.rep) and + + -- ignore doublons + not set[p:pack()] + ) then + --printf('extend $(cyan)%s', p) + s.closest[#s.closest+1] = {dist, n:add(p)} + set[p:pack()] = true + + if s.cb and p.k == s.k then + cpcall(s.cb, s, p, src) + end + end + end + + -- XXX maybe give a preference to trusted peers, or peers which are public + -- or geographically close + table.sort(s.closest, function(a, b) return a[1] < b[1] end) +end + +function MT.__index._sendto(n, opts) + assert(opts.dst) + + opts.sk = opts.sk or n.sk + + if opts.dst.k == n.k then + error("cannot send to self") + end + + --if math.random() < .3 then return end + + local me = opts.me + if not me then + me = wh.packet(opts.sk, opts.dst.k, n.is_nated, opts.m) + + if opts.dst.relay then + me = wh.packet(opts.sk, opts.dst.relay.k, n.is_nated, packet.relay( + opts.dst.k, me + )) + end + end + + local udp_dst = opts.dst.relay and opts.dst.relay or opts.dst + + local udp_dst_addr + if opts.to_echo then + if not udp_dst.addr_echo then + errorf("unknown echo address for %s", udp_dst) + end + + udp_dst_addr = udp_dst.addr_echo + else + if not udp_dst.addr then + errorf("unknown address for %s", udp_dst) + end + udp_dst_addr = udp_dst.addr + end + + local port = opts.from_echo and n.port_echo or n.port + + -- DEBUG + if n.log >= 2 then + local msg = {} + msg[#msg+1] = string.format('%s <- :%d: $(yellow)', udp_dst, port) + + if opts.dst.relay then + msg[#msg+1] = string.format("relay$(reset)(%s, $(yellow)", opts.dst.addr) + end + + if opts.m then + msg[#msg+1] = packet.cmds[string.byte(string.sub(opts.m, 1, 1))+1] or "???" + else + msg[#msg+1] = string.format("", #me) + end + msg[#msg+1] = "$(reset)" + + if opts.dst.relay then + msg[#msg+1] = ')' + end + + msg[#msg+1] = string.format(" (%dB)", #me) + printf(table.concat(msg)) + end + + if n.bw then + n.bw:add_tx(udp_dst.k, #me) + end + + local ret, errmsg = wh.sendto_raw_udp(n.sock4_raw, n.sock6_raw, me, port, udp_dst_addr) + + if not ret then + printf('$(red)error: could not send packet: %s', errmsg) + end +end + +function MT.__index.update(n, socks) + local timeout + local deadlines = {} + + socks[#socks+1] = n.sock_echo + socks[#socks+1] = wh.pipe_event_fd(n.pe) + + n.in_udp_fd, timeout = wh.get_pcap(n.in_udp) + socks[#socks+1] = n.in_udp_fd + if timeout then + deadlines[#deadlines+1] = now+timeout + end + + if n.upnp then + n.upnp.worker:update(socks) + end + + connectivity.update(n, deadlines) + + if n.ns then + n.ns.worker:update(socks) + end + + for d in pairs(n.nat_detectors) do + nat.update(n, d, deadlines) + end + + for s in pairs(n.searches) do + search.update(n, s, deadlines) + end + + for a in pairs(n.auths) do + auth.update(n, a, deadlines) + end + + kad.update(n, deadlines) + + if n.lo then + deadlines[#deadlines+1] = n.lo:update(socks) + end + + if (n.bw and + n.bw:length() ~= 0 and + time.every(deadlines, n.bw, 'last_collect_ts', n.bw.scale)) then + + n.bw:collect() + end + + --if DEBUG then + -- for k, p in pairs(n.kad.touched) do + -- dbg('touched %s', p) + -- end + --end + + return min(deadlines) +end + +function MT.__index.add(n, other) + assert(other.k) + local self = n.kad:touch(other.k) + + assert(self.k == other.k) + + local changed = ( + -- do not change bootstrap + not self.bootstrap and + + -- ignore older others + (not other.last_seen or other.last_seen >= (self.last_seen or 0)) and + + ( + -- replace if we don't know the address (worst situation) + not self.addr or + + -- replace if other peer is not relayed + not other.relay or + + -- replace if our route is through a relay + self.relay + + -- XXX if the relay is of good quality, avoid to change it by a bad + -- one + ) + ) + + if changed then + self.addr = other.addr + self.is_nated = other.is_nated + self.last_ping = nil + self.last_seen = other.last_seen or self.last_seen + self.ping_retry = 0 + self.relay = other.relay + end + + return self, changed +end + +function MT.__index.getent(n, hostname, result_cb) + if hostname == nil then + return nil + end + + local cbs = {} + + if n.ns then + for _, ns in ipairs(n.ns) do + cbs[#cbs+1] = ns + end + end + + -- might be a shorther version of Base64 WireHub Base64 + cbs[#cbs+1] = function(n, k, cb) + local test = function(p) + local e = string.find(wh.tob64(p.k), k) + return e + end + + local match + + if test(n.kad.root) then + match = n.kad.root + end + + for _, bucket in ipairs(n.kad.buckets) do + for _, p in ipairs(bucket) do + if test(p) then + if match then + -- there's an possible ambiguity. fails + return cb(nil) + else + match = p + end + end + end + end + + if match then + return cb(match.k) + else + return cb(nil) + end + end + + -- might be Base64 from WireHub + cbs[#cbs+1] = function(n, k, cb) + local ok, k = pcall(wh.fromb64, k, 'wh') + + if ok then + if #k ~= 32 then k = nil end + else + k = nil + end + + return cb(k) + end + + -- might be Base64 from WireGuard + cbs[#cbs+1] = function(n, k, cb) + local k = pcall(wh.fromb64, k, 'wg') + if k then + if #k ~= 32 then k = nil end + else + k = nil + end + return cb(k) + end + + -- might be a hostname + cbs[#cbs+1] = function(n, h, cb) + -- XXX manages index for hostnames + + if n.kad.root.hostname == h then + return cb(n.kad.root.k) + end + + for _, bucket in ipairs(n.kad.buckets) do + for _, p in ipairs(bucket) do + if p.hostname == h and p.k then + return cb(p.k) + end + end + end + + return cb() + end + + local key = nil + local cont_cb + + function cont_cb() + local cb + key, cb = next(cbs, key) + + if key and cb then + return cb(n, hostname, function(k) + if k then + return result_cb(k) + else + return cont_cb() + end + end) + else + return result_cb(nil) + end + end + + return cont_cb() +end + +function MT.__index.search(n, k, mode, count, timeout, cb) + assert(k) + + if mode == nil then mode = 'ping' end + if mode ~= 'lookup' and + mode ~= 'p2p' and + mode ~= 'ping' then + error("arg #3 must be 'p2p', 'lookup' or 'ping'") + end + if count == nil then count = wh.KADEMILIA_K end + if timeout == nil then timeout = wh.SEARCH_TIMEOUT end + + local s = setmetatable({ + cb=cb, + closest={}, + count=count, + deadline=now+timeout, + k=k, + mode=mode, + running=true, + uid1=wh.randombytes(8), + uid2=wh.randombytes(8), + may_offline=true, + states={}, + }, { + __index = S + }) + + n.searches[s] = true + + -- bootstrap + n:_extend(s, n.kad:kclosest(s.k, wh.KADEMILIA_K), n.kad.root) + + return s +end + +function MT.__index.stop_search(n, s) + if s.running then + s.running = false + + --printf('stop search $(cyan)%s', n:key(s)) + + n.searches[s] = nil + + if s.cb then + cpcall(s.cb, s, nil) + end + end +end + +function MT.__index.detect_nat(n, k, cb) + local p + if k == nil then + -- XXX get the closest node which is public! + local closest = n.kad:kclosest(n.k, 1, function(p) + return p.bootstrap + end) + if #closest == 0 then + return cb("offline") + end + + p = closest[1][2] + k = p.k + else + p = n.kad:get(k) + + if not p then + error(string.format("no route to %s", n:key(k))) + end + end + + local d = { + may_cone=true, + may_offline=true, + may_direct=true, + k=k, + req_ts=0, + retry=0, + uid=wh.randombytes(8), + uid_echo=wh.randombytes(8), + p=p, + p_echo=nil, -- explicit + } + + d.cb = function(...) + n.nat_detectors[d] = nil + return cb(...) + end + + + n.nat_detectors[d] = true +end + +function MT.__index.authenticate(n, k, alias_sk, cb) + local a = { + alias_sk = alias_sk, + alias_k = wh.publickey(alias_sk), + k = k, + retry=0, + req_ts=0, + } + + a.cb = function(ok, ...) + if not n.auths[a] then + return + end + + n.auths[a] = nil + + if a.alias_sk then + wh.burnsk(a.alias_sk) + a.alias_sk = nil + end + + if a.s then + n:stop_search(a.s) + a.s = nil + end + + if cb then + cpcall(cb, ok, ...) + end + end + + a.s = n:search(a.k, 'lookup', nil, nil, function(s, p, via) + if not a.s then + return + end + a.s = nil + + n:stop_search(s) + + if not p then + return a:cb(false, "not found") + end + + a.p = p + end) + + n.auths[a] = true + + return a +end + +function MT.__index.stop_authenticate(n, a) + a:cb(false, 'interrupted') +end + +function MT.__index.connect(n, dst_k, timeout, cb) + local count = 1 + local p_relay + local cbed = false + + return n:search(dst_k, 'p2p', count, timeout, function(s, p, via) + if cbed then return end + + if p and not p.relay and p.addr then + cbed = true + return cb(s, p, true, p.addr) + elseif p and p.relay then + p_relay = p + elseif n.lo and p_relay then + local tunnel = n.lo:touch_tunnel(p_relay) + cbed = true + return cb(s, p_relay, false, tunnel.lo_addr) + else + cbed = true + return cb(s) + end + end) +end + +function MT.__index.forget(n, dst_k) + local p = n.kad:get(dst_k) + if not p then + return + end + + -- do not forget bootstrap nodes + if p.bootstrap then + return + end + + n.kad:touch(dst_k) + p.addr = nil + p.addr_echo = nil + p.first_seen = nil + p.is_nated = nil + p.last_ping = nil + p.last_seen = nil + p.ping_retry = nil + p.relay = nil + p.tunnel = nil + + if n.lo then + n.lo:forget(dst_k) + end +end + + +function MT.__index.send_datagram(n, dst, m) + if type(dst) == 'string' then + dst = n.kad:get(dst) + + -- unknown destination. close tunnel + if not dst then + return false + end + end + + if dst.relay then + local num = 0 + while true do + local fragment = string.sub(m, num*wh.FRAGMENT_MTU+1, (num+1)*wh.FRAGMENT_MTU) + local mf = #m > ((num+1)*wh.FRAGMENT_MTU) + + if #fragment == 0 then + break + end + + n:_sendto{ + dst=dst, + m=packet.fragment( + n.frag_counter, + num, + mf, + fragment + ) + } + + num = num + 1 + assert(num < 64) + end + + n.frag_counter = (n.frag_counter + 1) % 0x10000 + + return true + elseif dst.addr then + n:_sendto{dst=dst, me=m} + + return false + else + printf("$(red)no route to %s. drop datagram$(reset)", n:key(p)) + + return true + end +end + +function MT.__index.read(n, m, src_addr, src_k, src_is_nated, time, via, relay) + -- peer's key needs enough workbit + if n.workbit == 0 or wh.workbit(src_k, n.namespace) < n.workbit then + return + end + + local cmd = string.sub(m, 1, 1) + if #cmd == 0 then return end + + -- a relayed message must not be of type 'relayed' + if via == 'relay' and cmd == packet.cmds.relayed then + printf("$(red)drop a double relayed packet$(reset)") + return + end + + local src, better_route = n:add{ + addr = src_addr, + is_nated = src_is_nated, + k = src_k, + last_seen = now, + relay = relay, + } + + src.first_seen = src.first_seen or now + src.last_seen = now + + if src.bootstrap and src_is_nated then + printf("$(yellow)INVALID: bootstrap cannot be behind a NAT$(reset)") + return + end + + local real_relay + -- better connection. However the new route will be used only to respond to this + -- it is ok to answer relayed requests through relay, even we already know a + -- request. + if not better_route and not src.relay and relay then + real_relay = src.relay + src.relay = relay + else + real_relay = src.relay + end + + local h = handlers[cmd] + if h then + h(n, m, src, via) + else + printf("$(red)unknown cmd: {%d} (%dB)\t%s", string.byte(cmd), #m, src) + end + + src.relay = real_relay +end + +function MT.__index.on_readable(n, r) + if r[wh.pipe_event_fd(n.pe)] then + wh.clear_pipe_event(n.pe) + end + + if n.upnp then + n.upnp.worker:on_readable(r) + end + + if n.ns then + n.ns.worker:on_readable(r) + end + + while r[n.in_udp_fd] or r[n.sock_echo] do + local me, src_addr + local via + + if r[n.in_udp_fd] then + src_addr, _, me = wh.pcap_next_udp(n.in_udp) + via = "normal" + + if not me then + r[n.in_udp_fd] = nil + end + end + + if r[n.sock_echo] then + me, src_addr = wh.recvfrom(n.sock_echo, 1500) -- XXX MTU? + via = "echo" + end + + -- if no more packet, break + if me == nil then + break + else + local src_k, src_is_nated, time, m = wh.open_packet(n.sk, me) + + -- XXX do something with time + + -- if message is valid, + if m ~= nil then + if n.bw then + n.bw:add_rx(src_k, #me) + end + + n:read(m, src_addr, src_k, src_is_nated, time, via) + end + end + end + + if n.lo then + n.lo:on_readable(r) + end +end + +function MT.__index.close(n) + if n.upnp then + n.upnp.worker:free() + end + + if n.ns then + n.ns.worker:free() + end + + if n.lo then + n.lo:close() + end + + wh.close(n.sock4_raw) + n.sock4_raw = nil + + wh.close(n.sock6_raw) + n.sock6_raw = nil + + wh.close(n.sock_echo) + n.sock_echo = nil + + wh.close_pcap(n.in_udp) + n.in_udp = nil + + wh.close_pipe_event(n.pe) + n.pe = nil + + if n.in_udp_fd then + wh.close(n.in_udp_fd) + n.in_udp_fd = nil + end +end + +function MT.__index.describe(n, mode) + if mode == nil then mode = 'all' end + + assert(mode == 'all' or mode == 'light') + + local r = {} + + if n.name then + r[#r+1] = string.format("network $(bold)%s$(reset), ", n.name) + end + + r[#r+1] = "node " + + if n.p.hostname then + r[#r+1] = string.format("$(bold)%s$(reset) ", n.p.hostname) + end + + do + local mode = {} + if n.is_nated then mode[#mode+1] = "NAT" end + if n.p.is_router then mode[#mode+1] = "ROUTER" end + if n.p.is_gateway then mode[#mode+1] = "GATEWAY" end + + r[#r+1] = string.format("<%s>\n", string.join(',', mode)) + end + + r[#r+1] = string.format(" public key: %s\n", wh.tob64(n.p.k)) + + --r[#r+1] = string.format(" port: %d, port echo: %d\n", n.port, n.port_echo) + --if n.workbit then + -- r[#r+1] = string.format(" namespace: %s, workbit: %d\n", n.namespace, n.workbit) + --end + + if mode == 'all' then + local any_nat = false + for d in pairs(n.nat_detectors) do + if not any_nat then + any_nat = true + r[#r+1] = " $(bold)nat detecting$(reset)\n" + end + + local mode + if d.may_offline then + mode = "OFFLINE?" + elseif d.may_direct then + mode = "DIRECT?" + elseif d.may_cone then + mode = "CONE?" + else + mode = "BLOCKED" + end + + r[#r+1] = string.format(" %s (%s)\n", n:key(d), wh.tob64(d.uid)) + r[#r+1] = string.format(" mode: %s (retry: %d)\n", mode, d.retry) + end + end + + if mode == 'all' then + local any_search = false + for s in pairs(n.searches) do + if not any_search then + any_search = true + r[#r+1] = " $(bold)searches$(reset)\n" + end + + r[#r+1] = string.format(" %s (%d queued, closest %d)\n", + n:key(s), #s.closest, + s.closest[1] and wh.bid(s.k, s.closest[1][1]) or 0 + ) + + if s.deadline-now<5 then + r[#r+1] = string.format(" timeout in %.1fs\n", s.deadline-now) + end + end + end + + local peers = {} + for bid, bucket in pairs(n.kad.buckets) do + for _, p in ipairs(bucket) do + peers[#peers+1] = { + bid=bid, + p=p, + } + end + end + + local filter_cb, comp_cb + if mode == 'light' then + filter_cb = function(p) + return p.p.trust + end + end + comp_cb = function(a, b) + local a_w_hostname = a.p.hostname and not a.p.alias + local b_w_hostname = b.p.hostname and not b.p.alias + + if a_w_hostname and b_w_hostname then + return a.p.hostname < b.p.hostname + elseif a_w_hostname and not b_w_hostname then + return true + elseif not a_w_hostname and b_w_hostname then + return false + else + return a.p.k < b.p.k + end + end + + do + if filter_cb then + for i = #peers, 1, -1 do + local p = peers[i] + if not filter_cb(p) then + table.remove(peers, i) + end + end + end + + if comp_cb then + table.sort(peers, comp_cb) + end + end + + local bw = n.bw and n.bw:avg() + + if #peers > 0 then + r[#r+1] = "\n $(bold)peers$(reset)\n" + for _, x in ipairs(peers) do + local bid = x.bid + local p = x.p + + r[#r+1] = " " + + local active = p.last_seen and now-p.last_seen <= wh.KEEPALIVE_TIMEOUT and not p.alias + + if active then + r[#r+1] = "$(green)" + end + + if p.alias then + r[#r+1] = '◌ ' + elseif p.relay then + r[#r+1] = '○ ' + elseif p.is_nated and p.addr then + r[#r+1] = '◒ ' + elseif p.addr then + r[#r+1] = '● ' + else + r[#r+1] = ' ' + end + + r[#r+1] = string.format("$(reset) %s", n:key(p)) + + if p.alias then + if type(p.alias) == 'string' then + r[#r+1] = string.format(" is %s", p.hostname or n:key(p.alias)) + end + elseif p.relay then + elseif p.addr then + r[#r+1] = string.format(': %s', p.addr) + end + + if p.is_router then r[#r+1] = ' (master)' end + if p.is_gateway then r[#r+1] = ' (gw)' end + + if mode == 'all' then + r[#r+1] = string.format(" (bucket:%d)", bid) + end + + if bw and bw[p.k] then + local b = bw[p.k] + + r[#r+1] = " (" + + if b.tx > 0 then + r[#r+1] = string.format("↑ %s/s", memunit(bw[p.k].tx)) + end + + if b.tx > 0 and b.rx > 0 then + r[#r+1] = ", " + end + + if b.rx > 0 then + r[#r+1] = string.format("↓ %s/s", memunit(bw[p.k].rx)) + end + + r[#r+1] = ")" + end + + r[#r+1] = "\n" + end + end + + + return table.concat(r) +end + +function MT.__index.stop(n) + n.running = false + wh.set_pipe_event(n.pe) +end + +function MT.__index.key(n, p_or_k) + return wh.key(p_or_k, n) +end + +function MT.__index.explain(n, ...) + if n.log >= 1 then + local fmt = select(1, ...) + printf("$(green)" .. fmt .. "$(reset)", select(2, ...)) + end +end + +function MT.__index.resolve(n, opts, cb) + local peers = {} + + local function cont() + if opts.k then + local p = n.kad:get(opts.k) + + if p then + peers[#peers+1] = p + end + end + + if opts.ip then + if n.kad.root.ip == opts.ip then + peers[#peers+1] = n.kad.root + end + + for bid, bucket in pairs(n.kad.buckets) do + for _, p in ipairs(bucket) do + if p.ip == opts.ip then + peers[#peers+1] = p + end + end + end + end + + -- check every peer are identical, else drop + for i = 2, #peers do + if peers[1] ~= peers[i] then + return cb() + end + end + + local peer = peers[1] + if peer then + return cb(peer.k, peer.hostname, peer.ip) + end + + return cb() + end + + if opts.name then + return n:getent(opts.name, function(k) + opts.k = k + cont() + end) + else + return cont() + end +end + +function M.new(n) + assert(n.sk and n.port and n.port_echo) + + if n.mode == nil then n.mode = 'unknown' end + if n.bw == nil then n.bw = true end + + if n.workbit == nil then + n.workbit = 0 + else + assert(n.namespace) + end + + n.log = n.log or 0 + n.running = true + n.k = wh.publickey(n.sk) + n.in_udp = wh.sniff('any', 'in', 'wh', " and dst port " .. tostring(n.port)) + n.sock_echo = wh.socket_udp(wh.address('0.0.0.0', n.port_echo)) + n.sock4_raw = wh.socket_raw_udp("ip4") + n.sock6_raw = wh.socket_raw_udp("ip6") + n.kad = require('kadstore')(n.k, wh.KADEMILIA_K) + n.p = n.kad.root + n.searches = {} + n.connects = {} + n.auths = {} + n.nat_detectors = {} + n.jitter_rand = math.random() * 1 + n.pe = wh.pipe_event() + n.frag_counter = math.floor(math.random() * 0xffff) + + if n.bw then + n.bw = require('bwlog'){scale=1.0} + end + + n.is_nated = n.mode ~= 'direct' + + if wh.upnp then + n.upnp = { + worker = wh.worker('upnp'), + enabled = false, + last_check = 0, + checking = false, + } + + n.upnp.worker:pcall(function() end, function() + require('wh') + require('helpers') + end) + end + + if n.ns then + n.ns.worker = wh.worker('ns') + + n.ns.worker:pcall(function() end, function() + require('wh') + require('helpers') + end) + end + + return setmetatable(n, MT) +end + +return M diff --git a/src/ns_keybase.lua b/src/ns_keybase.lua new file mode 100644 index 0000000..6b2c10e --- /dev/null +++ b/src/ns_keybase.lua @@ -0,0 +1,42 @@ +local TIMEOUT = 2 +local CMD = string.format("curl -m %s -s ", TIMEOUT) + +local function generate_url(path) + local hostname, user = string.match(path, "(.+)%.(.+)") + + if not hostname and not user then + user = path + hostname = "default" + end + + return string.format("https://%s.keybase.pub/wirehub/%s", user, hostname) +end + +return function(n, k, cb) + local path = string.match(k, "(.+)%.kb.wh") + + if not path then + return cb(nil) + end + + local url = generate_url(path) + local cmd = CMD .. url + + n.ns.worker:pcall( + function(ok, resp) + if resp then + local ok, k = pcall(wh.fromb64, resp) + + if ok then + return cb(k) + end + end + + return cb(nil) + end, + function(cmd) + return io.popen(cmd):read() + end, + cmd + ) +end diff --git a/src/packet.lua b/src/packet.lua new file mode 100644 index 0000000..41b35b7 --- /dev/null +++ b/src/packet.lua @@ -0,0 +1,131 @@ +local M = {} + +local cmds = { + 'ping', + 'pong', + 'search', + 'result', + 'relay', + 'relayed', + 'auth', + 'authed', + 'fragment', +} +for i, str in ipairs(cmds) do cmds[str] = string.pack("B", i-1) end + +M.cmds = cmds + +function M.ping(arg, body) + body = body or '' + assert(#body <= 8) + + if arg == nil or arg == 'normal' then + arg = "\x00" + elseif arg == 'swapsrc' then + arg = "\x01" + elseif arg == 'direct' then + arg = "\x02" + end + + return table.concat{cmds.ping, arg, body or ''} +end +function M.pong(port_echo, src, body) + return table.concat{ + cmds.pong, + src:pack(), + string.pack("!H", port_echo), + body, + } +end +function M.search(k) + return table.concat{cmds.search, k} +end + +function M.result(k, closest) + local m = {cmds.result, k} + + for i, c in ipairs(closest) do + local p = c[2] + + if i > wh.KADEMILIA_K then + break + end + + if p.relay then + m[#m+1] = "\x01" + elseif p.is_nated then + m[#m+1] = "\x02" + else + m[#m+1] = "\x00" + end + + do + m[#m+1] = p.k + m[#m+1] = p.addr:pack() + end + + if p.relay then + m[#m+1] = p.relay.k + m[#m+1] = p.relay.addr:pack() + end + end + + return table.concat(m) +end + +function M.relay(dst, body) + assert(#dst == 32) + assert(type(body) == "string") + return table.concat{ + cmds.relay, + dst, + body, + } +end + +function M.relayed(src, body) + return table.concat{ + cmds.relayed, + src.addr:pack(), + body, + } +end + +function M.auth(n, dst) + local m = n.k + + return table.concat{ + cmds.auth, + wh.packet(n.sk, dst.k, false, m), + } +end + +function M.authed(alias_k) + return table.concat{ + cmds.authed, + alias_k, + } +end + +function M.fragment(id, num, mf, m) + -- mf: More Fragment + + assert(id&0xffff==id) + assert(num&0x7f==num) + + assert(#m <= wh.FRAGMENT_MTU) + + local b = num + if mf then + b = 0x80 | b + end + + return table.concat{ + cmds.fragment, + string.pack(">HB", id, b), + m + } +end + +return M + diff --git a/src/peer.lua b/src/peer.lua new file mode 100644 index 0000000..57dd6d1 --- /dev/null +++ b/src/peer.lua @@ -0,0 +1,55 @@ +local M = {} +local MT = { + __index = {}, +} + +function MT.__index.pack(p) + local r = { p.k, p.addr:pack() } + + --if p.relay ~= nil then + -- r[#r+1] = p.relay.k + -- r[#r+1] = p.relay.addr:pack() + --end + + return table.concat(r) +end + +if DEBUG then + function MT.__newindex(r, attr, val) + local t = "" + if false then + t = debug.traceback() + t = string.match(t, "[^\n]*\n[^\n]*\n\t.*src/([^\n]*) *\n.*") + t = "$(reset)\t(" .. t .. ")" + end + + if rawget(r, attr) ~= val then + printf("$(darkgray)--$(reset) $(bold)%s$(reset).$(blue)%s$(reset) = $(blue)%s"..t, + wh.key(r), attr, dump(val) + ) + end + + rawset(r, attr, val) + end +end + +function MT.__tostring(p) + local r = {} + + local s = string.format("%s@%s", wh.key(p), p.addr) + + if p.is_nated then + s = s .. ' (NAT)' + end + + if p.relay then + s = s .. string.format(" (relayed)") + end + + return s +end + +return function(p) + return setmetatable(p, MT) +end + diff --git a/src/queue.lua b/src/queue.lua new file mode 100644 index 0000000..7b77cc1 --- /dev/null +++ b/src/queue.lua @@ -0,0 +1,83 @@ +local M = {} + +function M.push(q, e) + q = q or {} + q.tail = (q.tail or 0) + 1 + q[q.tail] = e + q.heap = q.heap or q.tail + return q +end + +function M.remove(q, c) + if not q then + return + end + + if not q.tail then + assert(not q.heap) + return + end + + local i = q.heap + while i <= q.tail and i-q.heap < c do + q[i] = nil + i = i + 1 + end + q.heap = i + + if q.tail < q.heap then + q.heap = nil + q.tail = nil + return nil + end + + return q +end + +function M.pop(q) + if not q.heap then + return + end + + local v = q[q.heap] + + M.remove(q, 1) + + return v +end + +function M.length(q) + if not q or not q.heap then + return 0 + end + + assert(q.tail) + + return q.tail - q.heap +end + +local function queue_next(q, k) + if not q or not q.heap then + return + end + assert(q.tail) + + if k == nil then + k = q.heap + else + k = k + 1 + end + + if k > q.tail then + return + end + + return k, q[k] +end + +function M.iter(q) + return queue_next, q, nil +end + +return M + diff --git a/src/search.lua b/src/search.lua new file mode 100644 index 0000000..4732bd0 --- /dev/null +++ b/src/search.lua @@ -0,0 +1,183 @@ +local peer = require('peer') +local packet = require('packet') + +local M = {} + +local function explain(n, s, fmt, ...) + return n:explain("(search %s) " .. fmt, n:key(s.k), ...) +end + +function M.update(n, s, deadlines) + local to_remove = {} + + if s.deadline ~= nil then + -- if search timeout, remove search + if now >= s.deadline then + explain(n, s, "stop") + n:stop_search(s) + return + end + + deadlines[#deadlines+1] = s.deadline + end + + for i, c in ipairs(s.closest) do + local p = c[2] + + local deadline + local st = s.states[p.k] or {retry=0, rep=false} + + if p.k == s.k then + if s.mode == 'ping' then + deadline = (st.req_ts or 0)+st.retry+1 + + if now >= deadline then + explain(n, s, "check if peer %s is alive", n:key(p)) + + n:_sendto{dst=p, m=packet.ping('normal', s.uid1)} + st.retry = st.retry + 1 + st.req_ts = now + st.last_seen = now + deadline = st.req_ts+st.retry+1 + end + elseif s.mode == 'p2p' then + if st.is_punched then + deadline = st.req_ts + st.retry + 1 + elseif st.is_online then + deadline = st.req_ts + wh.MAX_PUNCH_TIMEOUT + else + deadline = (st.req_ts or 0) + st.retry + 1 + end + + if now >= deadline and st.is_punched then + if st.retry < wh.PING_RETRY then + explain(n, s, "is %s alive?", n:key(p)) + + assert(not p.relay, dump{p=p, st=st}) + n:_sendto{dst=p, m=packet.ping('normal', s.uid2)} + st.req_ts = now + st.retry = st.retry + 1 + + deadline = st.req_ts + st.retry + 1 + end + + elseif now >= deadline then + if st.retry > wh.MAX_PUNCH_RETRY then + explain(n, s, "maximum tentative of punch with %s. abort!", n:key(p)) + n:stop_search(s) + return + end + + explain(n, s, "try to punch to %s", n:key(p)) + local p_direct = peer{k=p.k, addr=p.addr} + n:_sendto{dst=p_direct, m=packet.ping('normal', s.uid2)} + n:_sendto{dst=p, m=packet.ping('normal', s.uid1)} + n:_sendto{dst=p, m=packet.ping('direct', s.uid2)} + + st.retry = st.retry + 1 + st.req_ts = now + + if st.is_online then + deadline = st.req_ts + wh.MAX_PUNCH_TIMEOUT + else + deadline = (st.req_ts or 0) + st.retry + 1 + end + end + end + + -- do not send packet to self + elseif p.k == n.k then + -- keep deadline to nil + + -- remove if search has too many nodes + elseif i > s.count then + -- keep deadline to nil + + -- ignore if node does not have any address + elseif not p.addr then + -- keep deadline to nil + + -- no response and not enough retry, send a find and wait retry sec + elseif not st.rep and st.retry <= wh.PING_RETRY then + deadline = (st.req_ts or 0)+st.retry+1 + + if now >= deadline then + n:_sendto{dst=p, m=packet.search(s.k)} + st.retry = st.retry + 1 + st.req_ts = now + st.rep = false + deadline = st.req_ts+st.retry+1 + end + end + + -- save state + s.states[p.k] = st + if deadline ~= nil then + deadlines[#deadlines+1] = deadline + + else + to_remove[#to_remove+1] = i + end + end + + for i = #to_remove, 1, -1 do + table.remove(s.closest, to_remove[i]) + end + + if #s.closest == 0 then + n:stop_search(s) + end +end + +function M.on_pong(n, body, src) + for s in pairs(n.searches) do + if s.k == src.k then + local st = s.states[src.k] + + if s.mode == 'ping' then + if s.uid1 == body then + explain(n, s, "%s is alive!", n:key(src)) + + cpcall(s.cb, s, src, src) + n:stop_search(s) + end + + elseif s.mode == 'p2p' then + if s.uid1 == body then + explain(n, s, "%s is alive!", n:key(src)) + + st.is_online = true + elseif s.uid2 == body and st.is_punched then + explain(n, s, "UDP hole punching is stable with %s!", n:key(src)) + + cpcall(s.cb, s, src, src) + n:stop_search(s) + + elseif not src.relay and s.uid2 == body then + explain(n, s, "punched to %s!", n:key(src)) + st.is_punched = true + st.retry = 0 + s.uid2 = wh.randombytes(8) + end + end + + break + end + end +end + +function M.on_result(n, pks, closest, src) + for s in pairs(n.searches) do + if pks == s.k then + local s_closest = {} + for i, p in ipairs(closest) do + s_closest[i] = {wh.xor(s.k, p.k), p} + end + + n:_extend(s, s_closest, src) + end + end +end + +return M + diff --git a/src/sink-udp.lua b/src/sink-udp.lua new file mode 100644 index 0000000..dbdee6a --- /dev/null +++ b/src/sink-udp.lua @@ -0,0 +1,9 @@ +require'wh' + +local s = wh.socket_udp(wh.address('0.0.0.0', wh.DEFAULT_PORT)) + +while true do + wh.select({s}, {}, {}) + wh.recvfrom(s, 1500) +end + diff --git a/src/time.lua b/src/time.lua new file mode 100644 index 0000000..754ed6a --- /dev/null +++ b/src/time.lua @@ -0,0 +1,78 @@ +local M = {} + +function M.every(deadlines, obj, field_ts, value) + local last_ts = obj[field_ts] + + if now - last_ts >= value then + obj[field_ts] = now + deadlines[#deadlines+1] = now + value + return true + else + deadlines[#deadlines+1] = last_ts + value + return false + end +end + +function M.retry_backoff(obj, retry_field, last_field, retry_max, backoff) + local deadline + + -- wait for deadline + deadline = ( + (obj[last_field] or 0) + + (obj[retry_field] or 0) * backoff + ) + + if now <= deadline then + return false, deadline + end + + -- deadline is reached. If retry_max is reached too, timeout + if retry_max ~= nil and (obj[retry_field] or 0) >= retry_max then + return false, nil + end + + -- action has to be performed; calculate next deadline + obj[last_field] = now + obj[retry_field] = (obj[retry_field] or 0) + 1 + deadline = ( + obj[last_field] + + obj[retry_field] * backoff + ) + + return true, deadline + +end + +local function retry_ping_backoff_deadline(p, retry_every, backoff) + local v = max{ + (p.last_seen or 0) + retry_every, + (p.last_ping or 0) + (p.ping_retry or 0) * backoff, + } + + assert(v ~= nil) + return v +end + +function M.retry_ping_backoff(p, retry_every, retry_max, backoff) + local deadline + deadline = retry_ping_backoff_deadline(p, retry_every, backoff) + + if now <= deadline then + return false, deadline + end + + -- deadline is reached. If retry_max is reached too, timeout + if retry_max ~= nil and (p.ping_retry or 0) >= retry_max then + return false, nil + end + + -- action has to be performed; calculate next deadline + p.last_ping = now + p.ping_retry = (p.ping_retry or 0) + 1 + deadline = retry_ping_backoff_deadline(p, retry_every, backoff) + + return true, deadline +end + +return M + diff --git a/src/tools/authenticate.lua b/src/tools/authenticate.lua new file mode 100644 index 0000000..62c3e4f --- /dev/null +++ b/src/tools/authenticate.lua @@ -0,0 +1,70 @@ +function help() + print('Usage: wh authenticate {|} ') +end + +if arg[2] == 'help' then + return help() +end + +local interface = arg[2] +local k = arg[3] +local alias_sk_path = arg[4] + +if not interface or not k or not alias_sk_path then + return help() +end + +-- XXX +local alias_sk = wh.readsk(alias_sk_path) +if not alias_sk then + printf('cannot load alias private key: %s', alias_sk_path) + return help() +end + +local cmd = string.format('authenticate %s %s', k, alias_sk_path) +local ok, value = pcall(require('ipc').call, interface, cmd) + +if not ok then + printf("error when connecting to WireHub daemon: %s", value) +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return +end + +local via_k, addr, mode, is_nated, relay + +local resp = {} +now = wh.now() +while true do + local r = wh.select({sock}, {}, {}, now+30) + --now = wh.now() + + if not r[sock] then + wh.close(sock) + r={'timeout'} + return -1 -- timeout + end + + local buf = wh.recv(sock, 65535) + if not buf or #buf == 0 then + break + end + + resp[#resp+1] = buf +end +wh.close(sock) + +resp = table.concat(resp) +if string.match(resp, 'authenticated!') then + return 0 +else + resp = string.match(resp, 'failed: (.*)\n') + if resp then + printf('%s', resp) + end + return -1 +end + diff --git a/src/tools/cli.lua b/src/tools/cli.lua new file mode 100644 index 0000000..f25076f --- /dev/null +++ b/src/tools/cli.lua @@ -0,0 +1,131 @@ +function help() + print( + "Usage: wh []\n" .. + "\n" .. + "Available setup subcommands\n" .. + " addconf: Appends a configuration file to a WireHub network\n" .. + " clearconf: Clear the current network configuration\n" .. + " down: Detach a Wireguard interface from a WireHub network (daemon)\n" .. + " genkey: Generates a new private key for a WireHub network\n" .. + " pubkey: Reads a private key from stdin and writes a public key to stdout\n" .. + " set: Change the current network configuration\n" .. + " setconf: Applies a configuration file to a WireHub network\n" .. + " showconf: Shows the current configuration of a given WireHub network\n".. + " up: Create a WireHub network and interface (daemon)\n" .. + " workbit: Print workbits for a given WireGuard public key\n" .. + "\n" .. + "Available network subcommands\n" .. + " auth: Authenticate with an alias' private key\n" .. + " connect: Connect to a WireHub peer\n" .. + " connect-p2p: Connect to a WireHub peer\n" .. + " forget: Forget one WireHub peer\n" .. + " lookup Lookup for a WireHub peer\n" .. + " ping: Ping a WireHub peer\n" .. + " resolve: Resolve a hostname among all WireHub networks\n" .. + "Available status subcommands\n" .. + " show: Shows the current configuration\n" .. + "\n" .. + "Available advanced subcommands\n" .. + " completion: Auto-completion helper\n" .. + " ipc: Send a IPC command to a WireHub daemon\n" .. + " orchid: Print the ORCHID IPv6 of a given node\n" .. + "" + ) +end + +require('wh') +require('helpers') + +SUBCMDS = { + -- private methods + '_completion', + + -- public methods + "addconf", + "authenticate", + "clearconf", + "completion", + "connect", + "down", + "forget", + "genkey", + "help", + "ipc", + "lookup", + "orchid", + "ping", + "pubkey", + "resolve", + "set", + "setconf", + "show", + "showconf", + "up", + "workbit", +} +for _, k in ipairs(SUBCMDS) do SUBCMDS[k] = true end + +wh.ipc_prepare() + +local cmd = arg[1] or 'show' + +if cmd == 'auth' then cmd = 'authenticate' end +if cmd == '_completion' then cmd = 'completion' end + +if not SUBCMDS[cmd] then + printf("Invalid subcommand: `%s'", cmd) + cmd = 'help' +end + +if cmd == 'connect-p2p' or cmd == 'ping' or cmd == 'lookup' then + cmd = 'search' +end + +if cmd == 'addconf' or cmd == 'clearconf' or cmd == 'setconf' then + cmd = 'conf' +end + +-- secret cannot be revealed except in these modes +if cmd ~= 'genkey' and + cmd ~= 'genconf' + then + wh.reveal_secret = nil +end + +if cmd == 'help' or cmd == '--help' then + return help() +end + +if cmd == 'up' then + local r = wh.wg.check() + + if r == 'oldkernel' then + printf( + "==========\n" .. + "$(red)$(bold)Sorry, Linux kernel must be >%s.$(reset)\n" .. + "More info: https://www.wireguard.com/install/#kernel-requirements\n" .. + "==========$(reset)\n", + string.join('.', wh.wg.LINUX_MINVER) + ) + + elseif r == 'notloaded' then + printf( + "==========\n" .. + "$(red)$(bold)WireGuard module is not loaded!$(reset)\n" .. + "\n" .. + " $(bold)You might want to install WireGuard first!$(reset)\n" .. + " https://www.wireguard.com/install/\n" .. + "==========$(reset)\n" + ) + end +end + +disable_globals() + +local retcode = cpcall(require, 'tools.' .. cmd) +if retcode == true then retcode = 0 end + +_do_atexits() +status(nil) +os.exit(retcode or 0) + diff --git a/src/tools/completion.lua b/src/tools/completion.lua new file mode 100644 index 0000000..3a74429 --- /dev/null +++ b/src/tools/completion.lua @@ -0,0 +1,186 @@ +if arg[1] == 'completion' then + function help() + printf( + 'Usage: wh completion {get-bash}\n' .. + '\n' .. + 'To enable auto-completion with `bash-completion`, run:\n' .. + ' wh completion get-bash | sudo tee /usr/share/bash-completion/completions/wh\n' + ) + return + end + + if arg[2] == nil or arg[2] == 'help' then + return help() + end + + if arg[2] == "get-bash" then + print( + '_wh()\n' .. + '{\n' .. + ' local opts cur\n' .. + ' _init_completion || return\n' .. + '\n' .. + ' opts=`wh _completion ${COMP_CWORD} ${COMP_WORDS[@]}`\n' .. + '\n' .. + ' #if [[ $cur == -* ]] ; then\n' .. + ' COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )\n' .. + ' return 0\n' .. + ' #fi\n' .. + '}\n' .. + 'complete -F _wh wh\n' + ) + else + printf("unknown argument: %s", arg[2]) + return -1 + end + + return +end + +assert(arg[1] == '_completion') + +if #arg < 3 then + return +end + +local opt_count = 0 +local function opt(s) + print(s) + opt_count = opt_count + 1 +end + +local function optlist(l) + table.sort(l) + for _, v in ipairs(l) do opt(v) end +end + +local function listpeers(interface) + local ok, value = pcall(require('ipc').call, interface, 'list') + if not ok then + return + end + + local sock = value + if not sock then + return + end + + local trusted = {} + local untrusted = {} + local buf = "" + now = wh.now() + while true do + local r = wh.select({sock}, {}, {}, now+1) + now = wh.now() + + if not r[sock] then + break + end + + buf = buf .. (wh.recv(sock, 65535) .. "") + if #buf == 0 then + break + end + + while true do + local name, trust, i = string.match(buf, '([^%s]+)%s+([^%s]+)\n()') + + if not name or not trust then + break + end + buf = string.sub(buf, i) + + if trust == 'trusted' then + trusted[#trusted+1] = name + else + untrusted[#untrusted+1] = name + end + end + end + wh.close(sock) + + return trusted, untrusted +end + +local cur_idx = tonumber(arg[2]) +local cmd = {} +for i = 3, #arg do cmd[#cmd+1] = arg[i] end + +if cur_idx <= 1 then + local public_subcmds = {} + for _, v in ipairs(SUBCMDS) do + local is_private = string.sub(v, 1, 1) == '_' + if not is_private then + public_subcmds[#public_subcmds+1] = v + end + end + + optlist(public_subcmds) + return +end + +local subcmd = cmd[2] + +if cur_idx == 2 then + if ( + subcmd == 'authenticate' or + subcmd == 'connect' or + subcmd == 'down' or + subcmd == 'forget' or + subcmd == 'ipc' or + subcmd == 'lookup' or + subcmd == 'ping' or + subcmd == 'show' + ) then + local interfaces = wh.ipc_list() + + if subcmd == 'show' and #interfaces > 1 then + opt('all') + end + + optlist(wh.ipc_list()) + return + end + + if ( + subcmd == 'addconf' or + subcmd == 'clearconf' or + subcmd == 'genkey' or + subcmd == 'orchid' or + subcmd == 'set' or + subcmd == 'show' or + subcmd == 'showconf' or + subcmd == 'up' or + subcmd == 'workbit' + ) then + optlist(wh.listconf()) + end +end + +if cur_idx == 3 then + if ( + subcmd == 'connect' or + subcmd == 'forget' or + subcmd == 'lookup' or + subcmd == 'ping' + ) then + local trusted, untrusted = listpeers(cmd[3]) + + if trusted and untrusted then + optlist(trusted) + if cmd[cur_idx+1] or #trusted == 0 then + optlist(untrusted) + end + end + end + + if subcmd == 'show' then + opt('all') + opt('light') + end +end + +if cur_idx == 2 and opt_count ~= 1 then + opt('help') +end + diff --git a/src/tools/conf.lua b/src/tools/conf.lua new file mode 100644 index 0000000..356a7bc --- /dev/null +++ b/src/tools/conf.lua @@ -0,0 +1,105 @@ +local cmd = arg[1] + +function help() + if cmd == 'clearconf' then + printf("Usage: wh %s ", cmd) + else + printf("Usage: wh %s {|'-'}", cmd) + end +end + +if arg[2] == 'help' then + return help() +end + +local name = arg[2] + +if not name then + return help() +end + +if cmd == 'clearconf' then + wh.writeconf(name, nil) + return +end + +local conf_s = wh.readconf(name) +local conf +if conf_s then + conf = wh.fromconf(conf_s) +end + +if not arg[3] then + return help() +end + +local conf_filepath = arg[3] +local conf_fh +if conf_filepath == '-' then + conf_fh = io.stdin +else + conf_fh = io.open(conf_filepath, 'r') +end + +local conf_s = {} +while true do + local c = conf_fh:read() + if c == nil then + break + end + + conf_s[#conf_s+1] = c .. '\n' +end +conf_s = table.concat(conf_s) + +if conf_fh ~= io.stdin then + conf_fh:close() +end + +local updated_conf = wh.fromconf(conf_s) + +if not updated_conf then + printf('Invalid configuration') + return -1 +end + +if cmd == 'setconf' or conf == nil then + conf = updated_conf + conf.name = name +else + assert(cmd == 'addconf') + + conf.workbit = updated_conf.workbit or conf.workbit + + local to_add = {} + for _, up in ipairs(updated_conf.peers) do + local found = false + for _, p in ipairs(conf.peers) do + if p.k == up.k then + p.hostname = up.hostname or p.hostname + if up.is_router ~= nil then p.is_router = up.is_router end + if up.is_gateway ~= nil then p.is_gateway = up.is_gateway end + if up.trust ~= nil then p.trust = up.trust end + p.ip = up.ip or p.ip + if up.bootstrap ~= nil then p.bootstrap = up.bootstrap end + p['allowed-ips'] = up['allowed-ips'] or p['allowed-ips'] + found = true + break + end + end + + if not found then + to_add[#to_add+1] = up + end + end + + for _, p in ipairs(to_add) do + conf.peers[#conf.peers+1] = p + end +end + +local conf = wh.toconf(conf) +wh.fromconf(conf) -- check conf + +wh.writeconf(name, conf) + diff --git a/src/tools/connect.lua b/src/tools/connect.lua new file mode 100644 index 0000000..4552715 --- /dev/null +++ b/src/tools/connect.lua @@ -0,0 +1,127 @@ +function help() + printf('Usage: wh connect ') +end + +if arg[2] == 'help' then + return help() +end + +local interface = arg[2] +local k = arg[3] + +if not interface or not k then + return help() +end + +local ipc_cmd = string.format('connect %s', k) +local ok, value = pcall(require('ipc').call, interface, ipc_cmd) +if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return +end + +local b64k, via_b64k, addr, mode, is_nated, relay + +now = wh.now() +while true do + local r = wh.select({sock}, {}, {}, now+30) + now = wh.now() + + if not r[sock] then + wh.close(sock) + printf("timeout") + return -1 -- timeout + end + + local buf = wh.recv(sock, 65535) + + if not buf or #buf == 0 then + break + end + + b64k, mode, addr, via_b64k = string.match(buf, '([^ ]+) ([^ ]+) ([^ ]+) ([^ ]+)\n') + + if not via_b64k or not addr or not mode then + printf("$(red)bad format: %s$(reset)", buf) + wh.close(sock) + return -1 + end + + addr = wh.address(addr, 0) + + if mode == '(direct)' then + mode = 'direct' + elseif mode == '(nat)' then + mode = 'nat' + is_nated = true + else + relay = wh.fromb64(mode) + mode = 'relay' + end +end +wh.close(sock) + +local found +local m = {} +if mode == 'relay' then + found = true + m[#m+1] = string.format('relay %s', wh.key(relay)) +elseif mode == 'nat' then + found = true + m[#m+1] = string.format('nat %s', addr) +elseif mode == 'direct' then + found = true + m[#m+1] = string.format('direct %s', addr) +else + found = false + m[#m+1] = "not found" +end + +m = table.concat(m) + +if cmd == 'lookup' then + print(m) + return found and 0 or -1 + +elseif cmd == 'ping' then + if found then + printf('ping %s: time=%.2fms', m, (now-time_before_ping)*1000.0) + else + print('offline') + end + +elseif cmd == 'p2p' then + local k = wh.fromb64(b64k) + if mode == 'relay' then + printf('unable to open a p2p connection') + return -1 + end + + local peer = { + public_key=k, + endpoint=addr, + } + + if mode == 'nat' then + peer.persistent_keepalive_interval = wh.NAT_TIMEOUT + end + + + local ok, err = pcall(wh.wg.set, {name=interface, peers={peer}}) + + if not ok then + printf('error when setting up wireguard interface %s: %s', interface, err) + return -1 + end + + printf('connected to %s', m) +end + +return found and 0 or -1 + diff --git a/src/tools/down.lua b/src/tools/down.lua new file mode 100644 index 0000000..0c99a31 --- /dev/null +++ b/src/tools/down.lua @@ -0,0 +1,51 @@ +function help() + print('Usage: wh down ') +end + +local interface = arg[2] +if not interface then + return help() +end + +local ipc=require'ipc' + +local ok, value = pcall(ipc.call, interface, 'down') + +if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return -1 +end + +local ret = -1 + +now = wh.now() +while true do + local r = wh.select({sock}, {}, {}, now+30) + now = wh.now() + + if not r[sock] then + printf("timeout") + break + end + + local buf = wh.recv(sock, 65535) + + if not buf or #buf == 0 then + break + end + + if buf == "OK\n" then + ret = 0 + break + end +end +wh.close(sock) + +return ret + diff --git a/src/tools/find.lua b/src/tools/find.lua new file mode 100644 index 0000000..ab8d253 --- /dev/null +++ b/src/tools/find.lua @@ -0,0 +1,59 @@ +function help() + print('Usage: wh find ') +end + +local interface = arg[2] +local b64k = arg[3] + +if not interface or not b64k then + return help() +end + +local wg = wh.wg.get(interface) + +if not wg then + print("Unable to access interface: No such device") + return +end + +local ok, k = pcall(wh.fromb64, b64k) + +if not ok then + print("Invalid key") + return +end + +local ipc=require'ipc' + +local ok, value = pcall(ipc.call, interface, 'search ' .. b64k) + +if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return +end + +local buf = '' +while true do + local r = wh.select({sock}, {}, {}, 60) + + if #r == 0 then + break + end + + local chunk = wh.recv(sock, 65535) + buf = buf .. chunk + + local newline_idx = string.find(buf, '\n') + if newline_idx then + local line = string.sub(buf, 1, newline_idx-1) + buf = string.sub(buf, newline_idx+1) + + on_new_endpoint(line) + end +end diff --git a/src/tools/forget.lua b/src/tools/forget.lua new file mode 100644 index 0000000..a8b4ccb --- /dev/null +++ b/src/tools/forget.lua @@ -0,0 +1,52 @@ +function help() + printf('Usage: wh forget ') +end + +if arg[2] == 'help' then + return help() +end + +local interface = arg[2] +local k = arg[3] + +if not interface or not k then + return help() +end + +local ipc_cmd = string.format('forget %s', k) +local ok, value = pcall(require('ipc').call, interface, ipc_cmd) +if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return +end + +local b64k, via_b64k, addr, mode, is_nated, relay + +now = wh.now() +while true do + local r = wh.select({sock}, {}, {}, now+30) + now = wh.now() + + if not r[sock] then + wh.close(sock) + printf("timeout") + return -1 -- timeout + end + + local buf = wh.recv(sock, 65535) + + if not buf or #buf == 0 then + break + end + + printf("$(red)bad format: %s$(reset)", buf) + break +end +wh.close(sock) + diff --git a/src/tools/genkey.lua b/src/tools/genkey.lua new file mode 100644 index 0000000..9160962 --- /dev/null +++ b/src/tools/genkey.lua @@ -0,0 +1,36 @@ +function help() + print('Usage: wh genkey [threads ]]') +end + +if arg[2] == 'help' then + return help() +end + +local name = arg[2] + +if not name then + return help() +end + +local opts = parsearg(3, { + threads=tonumber, +}) + +if not opts then + return help() +end + +local conf = wh.fromconf(wh.readconf(name)) + +if not conf then + printf("Unknown network `%s'", name) + return help() +end + +local sign_sk, sign_k, sk, k = wh.genkey( + conf.namespace or 'public', + conf.workbit or 0, + opts.threads or 0 +) + +print(wh.tob64(wh.revealsk(sk), 'wg')) diff --git a/src/tools/ipc.lua b/src/tools/ipc.lua new file mode 100644 index 0000000..7e4f976 --- /dev/null +++ b/src/tools/ipc.lua @@ -0,0 +1,60 @@ +function help() + print('Usage: wh ipc ') +end + +local interface = arg[2] + +if not interface then + return help() +end + +local s = {} +local idx = 3 +while arg[idx] do + if arg[idx] == '-' then + s[#s+1] = io.stdin:read() + s[#s+1] = ' ' + break + else + s[#s+1] = arg[idx] + s[#s+1] = ' ' + idx = idx + 1 + end +end + +if #s == 0 then + return help() +end + +s[#s] = nil +s = table.concat(s) + +local ipc=require'ipc' + +local ok, value = pcall(ipc.call, interface, s) + +if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return +end + +while true do + wh.select({sock}, {}, {}, nil) + + local buf = wh.recv(sock, 65535) + + if not buf or #buf == 0 then + break + end + + io.stdout:write(buf) + io.stdout:flush() +end + +wh.close(sock) diff --git a/src/tools/orchid.lua b/src/tools/orchid.lua new file mode 100644 index 0000000..c7a88d5 --- /dev/null +++ b/src/tools/orchid.lua @@ -0,0 +1,46 @@ +function help() + print('Usage: wh orchid ') +end + +if arg[2] == 'help' then + return help() +end + +local name = arg[2] + +if not name then + return help() +end + +local conf = wh.fromconf(wh.readconf(name)) + +if not conf then + printf("Unknown network `%s'", name) + return help() +end + +local b64k = arg[3] + +if b64k == '-' then + b64k = io.stdin:read() +end + +if b64k == nil then + return help() +end + +local ok, value = pcall(wh.fromb64, b64k) + +if not ok then + printf("Invalid key: %s", value) + return +end + +local k = value + +local addr = wh.orchid(conf.namespace or 'public', k, 0) + +addr = tostring(addr) +addr = string.sub(addr, 2, string.find(addr, ']')-1) + +print(addr) diff --git a/src/tools/pubkey.lua b/src/tools/pubkey.lua new file mode 100644 index 0000000..eb7b0de --- /dev/null +++ b/src/tools/pubkey.lua @@ -0,0 +1,22 @@ +function help() + print('Usage: wh pubkey') +end + +if arg[2] == 'help' then + return help() +end + +local b64k = io.stdin:read() + +local ok, value = pcall(wh.fromb64, b64k, 'wg') + +if not ok then + printf("Invalid key: %s", value) + return +end + +local sk = value + +local k = wh.publickey(sk) + +print(wh.tob64(k)) diff --git a/src/tools/resolve.lua b/src/tools/resolve.lua new file mode 100644 index 0000000..799b661 --- /dev/null +++ b/src/tools/resolve.lua @@ -0,0 +1,89 @@ +function help() + printf('Usage: wh resolve ') +end + +if not arg[2] or arg[2] == 'help' then + return help() +end + +local name = arg[2] + +local function resolve(cmd, name) + local resolv = {} + + for _, interface in ipairs(wh.ipc_list()) do + local ipc_cmd = string.format('%s %s', cmd, name) + local ok, value = pcall(require('ipc').call, interface, ipc_cmd) + if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return + end + + local sock = value + if not sock then + print("Interface not attached to WireHub") + return + end + + now = wh.now() + while true do + local r = wh.select({sock}, {}, {}, now+1) + now = wh.now() + + if not r[sock] then + wh.close(sock) + printf("timeout") + return -1 -- timeout + end + + local buf = wh.recv(sock, 65535) + + if not buf or #buf == 0 then + break + end + + local b64k, hostname, ip = string.match(buf, '([^%s]+)\t([^%s]*)\t([^%s]*)\n') + + if not b64k then + break + end + + if hostname and #hostname == 0 then + hostname = nil + end + + if ip and #ip == 0 then + ip = nil + end + + if hostname or ip then + resolv[#resolv+1] = {interface, b64k, hostname, ip} + end + break + end + wh.close(sock) + end + + return resolv +end + +local is_host = true +local resolv = resolve('gethostbyname', name) + +if #resolv == 0 then + is_host = false + resolv = resolve('gethostbyaddr', name) +end + +if #resolv == 1 then + local r = resolv[1][is_host and 4 or 3] + if r then + print(r) + end +elseif #resolv >= 2 then + print("multiple results") + for _, v in ipairs(resolv) do + printf(" %s: %s", resolv[1], resolv[2]) + end + return -1 +end diff --git a/src/tools/search.lua b/src/tools/search.lua new file mode 100644 index 0000000..9a0f905 --- /dev/null +++ b/src/tools/search.lua @@ -0,0 +1,135 @@ +local cmd = arg[1] + +function help() + printf('Usage: wh %s ', cmd) +end + +if arg[2] == 'help' then + return help() +end + +local interface = arg[2] +local k = arg[3] + +if not interface or not k then + return help() +end + +local time_before_ping = wh.now() + +if cmd == 'connect-p2p' then cmd = 'p2p' end + +local ipc_cmd = string.format('%s %s', cmd, k) +local ok, value = pcall(require('ipc').call, interface, ipc_cmd) +if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return +end + +local sock = value +if not sock then + print("Interface not attached to WireHub") + return +end + +local b64k, via_b64k, addr, mode, is_nated, relay + +now = wh.now() +while true do + local r = wh.select({sock}, {}, {}, now+30) + now = wh.now() + + if not r[sock] then + wh.close(sock) + printf("timeout") + return -1 -- timeout + end + + local buf = wh.recv(sock, 65535) + + if not buf or #buf == 0 then + break + end + + b64k, mode, addr, via_b64k = string.match(buf, '([^ ]+) ([^ ]+) ([^ ]+) ([^ ]+)\n') + + if not via_b64k or not addr or not mode then + printf("$(red)bad format: %s", buf) + wh.close(sock) + return -1 + end + + addr = wh.address(addr, 0) + + if mode == '(direct)' then + mode = 'direct' + elseif mode == '(nat)' then + mode = 'nat' + is_nated = true + else + relay = wh.fromb64(mode) + mode = 'relay' + end +end +wh.close(sock) + +local found +local m = {} +if cmd == 'ping' then + found = via_b64k and via_b64k == b64k +elseif mode == 'relay' then + found = true + m[#m+1] = string.format('relay %s', wh.key(relay)) +elseif mode == 'nat' then + found = true + m[#m+1] = string.format('nat %s', addr) +elseif mode == 'direct' then + found = true + m[#m+1] = string.format('direct %s', addr) +else + found = false + m[#m+1] = "not found" +end + +m = table.concat(m) + +if cmd == 'lookup' then + print(m) + return found and 0 or -1 + +elseif cmd == 'ping' then + if found then + printf('ping %s: time=%.2fms', m, (now-time_before_ping)*1000.0) + else + print('offline') + end + +elseif cmd == 'p2p' then + local k = wh.fromb64(b64k) + if mode == 'relay' then + printf('unable to open a p2p connection') + return -1 + end + + local peer = { + public_key=k, + endpoint=addr, + } + + if mode == 'nat' then + peer.persistent_keepalive_interval = wh.NAT_TIMEOUT + end + + + local ok, err = pcall(wh.wg.set, {name=interface, peers={peer}}) + + if not ok then + printf('error when setting up wireguard interface %s: %s', interface, err) + return -1 + end + + printf('connected to %s', m) +end + +return found and 0 or -1 + diff --git a/src/tools/set.lua b/src/tools/set.lua new file mode 100644 index 0000000..63072f0 --- /dev/null +++ b/src/tools/set.lua @@ -0,0 +1,181 @@ +function help() + print('Usage: wh set ' .. + '[namespace ] ' .. + '[subnet ] ' .. + '[workbit ] ' .. + '[{ peer | ' .. + '[name ] } ' .. + '[alias ] ' .. + '[bootstrap {yes|no}] ' .. + '[allowed-ips /[,/]...] ' .. + '[endpoint :] ' .. + '[gateway {yes|no}] ' .. + '[ip ] ' .. + '[router {yes|no}] ' .. + '[untrusted] ' .. + '[remove] ' .. + ']' + ) +end + +if arg[2] == 'help' then + return help() +end + +local name = arg[2] + +if not name then + return help() +end + +local function tohost(n) + local m = string.match(n, '([%a%d%.%-]+)') + if m ~= n then + n = nil + end + return n +end + +local function tosubnet(x) -- XXX TODO + return tostring(x) +end + +local function split_comma(x, cb) + local r = {} + for subnet in string.gmatch(x, "([^,]+)") do + if not subnet then + return nil + end + r[#r+1] = subnet + end + + return r +end + +local function fromb64_or_false(x) + if x == 'none' or x == 'remove' then + return false + end + + return wh.fromb64(x) +end + +local opts = parsearg(3, { + alias=fromb64_or_false, + ['allowed-ips']=function(x) return split_comma(x, tosubnet) end, + bootstrap=parsebool, + endpoint=function(s) return wh.address(s, wh.DEFAULT_PORT) end, + gateway=parsebool, + name=tohost, + ip=function(s) + -- XXX + if not wh.address(s, 0) then + return nil + end + + return s + end, + namespace=tostring, + peer=wh.fromb64, + remove=true, + router=parsebool, + subnet=tosubnet, + untrusted=true, + workbit=tonumber, +}) + +if not opts then + return help() +end + +if not opts.peer and not opts.name then + for _, o in ipairs{ + 'allowed-ips', + 'bootstrap', + 'endpoint', + 'gateway', + 'ip', + 'remove', + 'router', + 'untrusted', + } do + if opts[o] then + printf('Invalid argument: %s', o) + return help() + end + end +end + +local conf = wh.fromconf(wh.readconf(name)) + +conf = conf or {peers={}} + +local k_map = {} +local host_map = {} +for i, p in ipairs(conf.peers) do + if p.k then k_map[p.k] = i end + if p.hostname then host_map[p.hostname] = i end +end + +conf.name = name +conf.namespace = opts.namespace or conf.namespace or 'public' +conf.workbit = opts.workbit or conf.workbit +conf.subnet = opts.subnet or conf.subnet + +local k_idx = opts.peer and k_map[opts.peer] +local host_idx = opts.name and host_map[opts.name] + +if opts.peer or opts.name then + if opts.peer and opts.name and k_idx ~= host_idx and host_idx then + printf('Host already exists: %s', opts.name) + return help() + end + + local idx = k_idx or host_idx + + if opts.remove then + if idx then + table.remove(conf.peers, idx) + end + else + if not idx then + idx = #conf.peers+1 + conf.peers[idx] = {} + end + local p = conf.peers[idx] + + if opts.peer then + -- check workbit is respected + local wb = wh.workbit(opts.peer, conf.namespace) + + if wb < (conf.workbit or 0) then + printf("Insufficient workbit: %d (minimum is %d)", wb, conf.workbit or 0) + return + end + + p.k = opts.peer + end + + p.addr = opts.endpoint or p.addr + + if opts.alias then + p.alias = opts.alias + elseif opts.alias == false then + p.alias = nil + end + + p['allowed-ips'] = opts['allowed-ips'] or p['allowed-ips'] + if opts.gateway then p.is_gateway = opts.gateway end + p.hostname = opts.name or p.hostname + p.ip = opts.ip and wh.address(opts.ip) or p.ip + if opts.router then p.is_router = opts.router end + if p.trust == nil then p.trust = true end + if opts.untrusted then p.trust = false end + if opts.bootstrap then p.bootstrap = opts.bootstrap end + end +end + +local conf = wh.toconf(conf) +wh.fromconf(conf) -- check conf + +wh.writeconf(name, conf) diff --git a/src/tools/show.lua b/src/tools/show.lua new file mode 100644 index 0000000..571c472 --- /dev/null +++ b/src/tools/show.lua @@ -0,0 +1,86 @@ +function help() + print("Usage: wh show [{ | all} [{light|all}] ] ") +end + +local ipc = require'ipc' + +local all = {} +local interface = arg[2] + +local mode = 'light' +if interface then + mode = arg[3] + if mode == nil then mode = 'light' end + local mode_ok = ({ + light=true, + all=true, + })[mode] + + if not mode_ok then + printf('invalid mode: %s', mode) + return help() + end +end + +if interface == 'all' then interface = nil end + +local function call(interface, cmd) + local ok, value = pcall(ipc.call, interface, cmd) + + if not ok then + printf("%s\nError when connecting to WireHub daemon.", value) + return + end + + local sock = value + if not sock then + return + end + + local buf = {} + while true do + local r = wh.select({sock}, {}, {}, 1) + + if not r[sock] then + r[#r+1] = '\n(daemon timeout)' + break + end + + local chunk = wh.recv(sock, 65535) + if not chunk or #chunk == 0 then + break + end + + buf[#buf+1] = chunk + end + + wh.close(sock) + + return table.concat(buf) +end + +local names +local whs = wh.ipc_list() +for _, v in ipairs(whs) do whs[v] = true end +if interface then + if not whs[interface] then + printf('invalid interface: %s', interface) + return help() + end + + names = {interface} +else + names = whs +end + +table.sort(names) + +for _, name in ipairs(names) do + local cmd = string.format('describe %s', mode) + local info = call(name, cmd) + + if info then + printf("interface $(bold)%s$(reset), %s", name, info) + end +end + diff --git a/src/tools/showconf.lua b/src/tools/showconf.lua new file mode 100644 index 0000000..aec5956 --- /dev/null +++ b/src/tools/showconf.lua @@ -0,0 +1,29 @@ +function help() + print('Usage: wh showconf ') +end + +if arg[2] == 'help' then + return help() +end + +local name = arg[2] + +if not name then + return help() +end + +local conf_s = wh.readconf(name) + +if not conf_s then + return +end + +local conf = wh.fromconf(conf_s) + +if not conf then + print('Invalid configuration') + return +end + +print(wh.toconf(conf)) + diff --git a/src/tools/up.lua b/src/tools/up.lua new file mode 100644 index 0000000..33c7eeb --- /dev/null +++ b/src/tools/up.lua @@ -0,0 +1,335 @@ +local cmd = arg[1] + +function help() + printf('Usage: wh up {interface | [private-key ] [listen-port ]} [mode {unknown | direct | nat}]') +end + +-- XXX move this after configuration is checked +do + local FG = os.getenv('FG') + local FG = FG == 'y' or FG == '1' + if not FG then + wh.daemon() + end +end + +do + local log_path = os.getenv('WH_LOGPATH') + if log_path then + local log_fh = io.open(log_path, "a") + + function log(s) + log_fh:write(s .. '\n') + log_fh:flush() + end + + atexit(log_fh.close, log_fh) + end +end +-- XXX ---------------------------------------- + +local name = arg[2] + +if not name then + return help() +end + +local private_key_path +local wg +local opts = parsearg(3, { + interface = function(s) + wg = wh.wg.get(s) + return s + end, + ["private-key"] = function(path) + private_key_path = path + return wh.readsk(path) + end, + ["listen-port"] = tonumber, + mode = function(s) + if s ~= 'unknown' and + s ~= 'direct' and + s ~= 'nat' then + s = nil + end + return s + end, +}) + +if not opts then + return help() +end + +local conf = wh.fromconf(wh.readconf(name)) + +if not conf then + printf("Unknown network `%s'", name) + return help() +end + +if not opts.interface and not opts['private-key'] then + printf('no key specified. generates an ephemeron one. this might be long...') + local _, _, sk, k = wh.genkey( + conf.namespace, + conf.workbit or 0, + 0 + ) + + opts['private-key'] = sk +end + +-- now is a global +now = wh.now() +local start_time = now + +-- + +status("starting...") + +local private_key +local listen_port +local n + +-- + +if opts.interface then + if not conf.subnet then + printf('subnetwork not defined: %s', name) + return help() + end + + if wg then + if not wg.private_key or not wg.public_key then + printf("Interface %s does not have a private key", opts.interface) + return + end + + if not wg.listen_port then + wg.listen_port = wh.DEFAULT_PORT + -- XXX + execf('wg set %s listen-port %d', opts.interface, wg.listen_port) + end + + local wb = wh.workbit(wg.public_key, conf.namespace) + if wb < (conf.workbit or 0) then + printf("Insufficient workbit: %d (minimum is %d)", wb, conf.workbit or 0) + return + end + + private_key = wg.private_key + listen_port = wg.listen_port + else + error("ephemeron wireguard interface not implemented") + --execf("ip link add dev %s type wireguard", opts.interface) + --execf("wg set %s private-key %s listen-port 0", opts.interface, skpath) + --execf("ip link set %s up", opts.interface) + end +else + private_key = opts['private-key'] + listen_port = opts['listen-port'] or wh.DEFAULT_PORT + + if listen_port == 0 then + listen_port = randomrange(1024, 65535) + end + + assert(private_key) +end + +-- + +local n_log = tonumber(os.getenv('LOG')) or 0 + +n = wh.new{ + name=name, + sk=private_key, + port=listen_port, + port_echo=listen_port+1, -- XXX ? + namespace=conf.namespace, + workbit=conf.workbit, + mode=opts.mode, + log=n_log, + ns={ + require('ns_keybase'), + }, +} + +atexit(n.close, n) + +-- + +local ipc_conn +local handlers = require('handlers_ipc')(n) +ipc_conn = require('ipc').bind(opts.interface or wh.key(n.k), handlers) +atexit(ipc_conn.close, ipc_conn) + +-- + +for _, pconf in ipairs(conf.peers) do + -- do not bootstrap with self + local p + if pconf.k then + p = n.kad:touch(pconf.k) + p.addr = pconf.addr + elseif pconf.alias then + p = n.kad:touch(pconf.alias) + p.alias = true + end + + if p then + p.trust = pconf.trust + p.hostname = pconf.hostname + p.ip = pconf.ip + p.is_gateway = pconf.is_gateway + p.is_router = pconf.is_router + p.bootstrap = pconf.bootstrap + + if false then + local m = {} + m[#m+1] = string.format("add %s %s", + p.alias and 'alias' or 'peer', + p.hostname or wh.tob64(p.k) + ) + + if p.is_router then m[#m+1] = " (router)" end + if p.is_gateway then m[#m+1] = " (gateway)" end + + printf(table.concat(m)) + end + + + if p.bootstrap then + printf("bootstrap with $(yellow)%s$(reset) (%s)", wh.tob64(p.k), p.addr) + end + end +end + +--[[ +local s +if n.mode == 'unknown' then + n:detect_nat(nil, function(mode) + -- mode=blocked, cone, direct, offline + printf("$(magenta)NAT TYPE: %s", mode) + + n.is_nated = mode ~= 'direct' + + s = n:search(n.k, 'lookup') + end) +else + s = n:search(n.k, 'lookup') +end +--]] + +-- log + +do + local msg = {"wirehub listening as $(yellow)", wh.tob64(wh.publickey(private_key)), "$(reset)"} + if DEVICE then msg[#msg+1] = string.format(" for $(yellow)%s$(reset)", DEVICE) end + msg[#msg+1] = string.format(" on $(yellow)%d$(reset) (port echo %d)", n.port, n.port_echo) + msg[#msg+1] = string.format(" (mode: $(yellow)%s$(reset))", n.mode) + printf(table.concat(msg)) +end + +-- main loop + +if opts.interface then + n.lo = require('lo'){ + n = n, + auto_connect = true, + interface = opts.interface .. '-rl' + } +end + +local wgsync +if n.lo and true then + wgsync = require('wgsync').new{ + n = n, + interface = opts.interface, + subnet = conf.subnet, + } + + atexit(wgsync.close, wgsync) +end + +local self = {} + +local LOADING_CHARS = {'-', '\\', '|', '/'} +local LOADING_CHARS = {'▄▄', '█▄', '█ ', '█▀', '▀▀', '▀█', ' █', '▄█'} +local lc_idx = 1 + +now = wh.now() +while n.running do + local socks = {} + local timeout + do + local deadlines = {} + deadlines[#deadlines+1] = n:update(socks) + + if ipc_conn then + deadlines[#deadlines+1] = ipc_conn:update(socks) + end + + if wgsync then + deadlines[#deadlines+1] = wgsync:update(socks) + end + + -- + + local deadline = min(deadlines) + + if deadline ~= nil then + timeout = deadline-now + if timeout < 0 then timeout = 0 end + end + end + + do + if self.ip ~= n.p.ip then + self.ip = n.p.ip + + local ip_subnet = ( + self.ip:addr() .. + string.sub(conf.subnet, string.find(conf.subnet, '/'), -1) + ) + + printf('$(green)new ip: %s$(reset)', ip_subnet) + execf('ip addr add %s dev %s', ip_subnet, opts.interface) + end + + if self.hostname ~= n.p.hostname then + + end + end + + status( + '%s waiting (fds: %d, timeout: %s)', + LOADING_CHARS[lc_idx], + #socks, + timeout and string.format('%.1fs', timeout) or '(none)' + ) + + n.kad:clear_touched() + + local r + do + -- Not sure why, but one pcall is not enough to catch the "interrupted" + -- launched by lua if the user press CTRL+C + pcall(pcall, function() r = wh.select(socks, {}, {}, timeout) end) + if not r then break end + now = wh.now() + end + + do + lc_idx = (lc_idx % (#LOADING_CHARS)) + 1 + status('%s', LOADING_CHARS[lc_idx]) + end + + do + n:on_readable(r) + + if ipc_conn then + ipc_conn:on_readable(r) + end + end +end + +status('exiting...') diff --git a/src/tools/workbit.lua b/src/tools/workbit.lua new file mode 100644 index 0000000..ce39f4d --- /dev/null +++ b/src/tools/workbit.lua @@ -0,0 +1,40 @@ +function help() + print('Usage: wh workbit ') +end + +if arg[2] == 'help' then + return help() +end + +local name = arg[2] + +if not name then + return help() +end + +local conf_s = wh.readconf(name) + +if not conf_s then + return +end + +local conf = wh.fromconf(conf_s) + +if not conf then + print('Invalid configuration') + return +end + +local k = io.stdin:read() +local ok, value = pcall(wh.fromb64, k) + +if not ok then + printf("Invalid name: %s", value) + return +end + +k = value + +local wb = wh.workbit(k, conf.namespace) + +print(wb) diff --git a/src/wgsync.lua b/src/wgsync.lua new file mode 100644 index 0000000..c15ebee --- /dev/null +++ b/src/wgsync.lua @@ -0,0 +1,133 @@ +-- XXX handles sync from wg to wh + +local REFRESH_EVERY = wh.NAT_TIMEOUT / 2 + +local M = {} + +local MT = { + __index = {} +} + +function MT.__index.update(sy, socks) + local deadlines = {} + + local deadline = (sy.last_sync or 0) + REFRESH_EVERY + if deadline <= now then + local wg = wh.wg.get(sy.interface) + for _, wg_p in pairs(wg.peers) do + local k = wg_p.public_key + local p = sy.n.kad:get(k) + + if p then + p.wg_connected = wg_p.last_handshake_time > 0 + + if (p.last_seen or 0) < wg_p.last_handshake_time then + p.last_seen = wg_p.last_handshake_time + end + + if wg_p.rx_bytes ~= (sy.p_rx[k] or 0) then + p.last_seen = now + sy.p_rx[k] = wg_p.rx_bytes + end + end + end + + sy.last_sync = now + deadline = (sy.last_sync or 0) + REFRESH_EVERY + end + deadlines[#deadlines+1] = deadline + + for k, p in pairs(sy.n.kad.touched) do + local comment = nil + if k ~= sy.n.p.k and (not p or p.trust) then + local peer + + -- destroy tunnel + if sy.n.lo and p and p.tunnel and p.addr and not p.relay then + sy.n.lo:free_tunnel(p) + assert(not p.tunnel) + end + + if p and p.trust and not p.alias and p.ip then + peer = { + public_key = p.k, + replace_allowedips=true, + allowedips={}, + } + + if p.ip then + -- XXX check subnet + + local slash_idx = string.find(sy.subnet, '/') + local cidr = string.sub(sy.subnet, slash_idx+1) + + -- XXX IPv6 Orchid + + peer.allowedips[#peer.allowedips+1] = {p.ip, 32} + end + + p.endpoint = nil + if sy.n.p.ip then + if p.tunnel then + p.endpoint = 'lo' + elseif p.addr and not p.relay then + p.endpoint = p.addr + elseif sy.n.lo then + sy.n.lo:touch_tunnel(p) + p.endpoint = 'lo' + else + p.endpoint = nil + end + end + + if p.endpoint ~= p._old_endpoint then + if p.endpoint == 'lo' then + peer.endpoint = p.tunnel.lo_addr + elseif p.endpoint then + peer.endpoint = p.endpoint + else + comment = "replace" + wh.wg.set{name=sy.interface, peers={public_key=p.k, remove_me=true}} + end + + p._old_endpoint = p.endpoint + end + + if p.endpoint ~= nil and p.endpoint ~= 'lo' and p.is_nated then + peer.persistent_keepalive_interval = wh.NAT_TIMEOUT + else + peer.persistent_keepalive_interval = 0 + end + + elseif not p then + comment = "remove" + peer = { + public_key = p.k, + remove_me = true, + } + end + + if peer then + comment = comment or "upsert" + --printf("$(orange)%s %s$(reset)", comment, dump(peer)) + + wh.wg.set{name=sy.interface, peers={peer}} + end + end + end + + + return min(deadlines) +end + +function MT.__index.close(sy) +end + +function M.new(sy) + assert(sy.n and sy.interface) + sy.p_rx = {} + return setmetatable(sy, MT) +end + +return M + diff --git a/src/wh.lua b/src/wh.lua new file mode 100644 index 0000000..e35f4d6 --- /dev/null +++ b/src/wh.lua @@ -0,0 +1,66 @@ +--DEBUG = true + +-- variable's nomenclature: +-- a for Authentication session +-- M for Module +-- n for Node +-- t for kademila Tree +-- s for Search session +-- d for nat Detecting session +-- p for Peer address +-- sc for Search & Connect +-- +-- Nomenclature: +-- by default, a key is public (a key <=> a public key) +-- a secret key <=> a private key + +-- wh is a global +assert(wh == nil) +_G['wh'] = require('whcore') + +local VERSION = {0, 1, 0} + +-- check version +do + local major, minor, revision = wh.version() + + if major ~= VERSION[1] or minor ~= VERSION[2] or revision ~= VERSION[3] then + error(string.format("version mismatch: version is %d.%d.%d, core's is %d.%d.%d", + major, minor, revision, + VERSION[1], VERSION[2], VERSION[3] + )) + end + + wh.version = setmetatable(VERSION, { + __tostring = function(v) + return string.format('%d.%d.%d', table.unpack(VERSION)) + end + }) +end + +-- constants +wh.AUTH_RETRY = 4 +wh.CONNECTIVITY_CHECK_EVERY = 5*60 +wh.DEFAULT_PORT = 62096 +wh.FRAGMENT_MAX = 4 +wh.FRAGMENT_MTU = 1024 -- XXX +wh.FRAGMENT_TIMEOUT = 4 +wh.KADEMILIA_K = 20 +wh.KEEPALIVE_TIMEOUT = 25 +wh.MAX_PUNCH_RETRY = 10 +wh.MAX_PUNCH_TIMEOUT = .5 +wh.NAT_TIMEOUT = 25 +wh.PING_BACKOFF = .5 +wh.PING_RETRY = 4 +wh.SEARCH_TIMEOUT = 5 +wh.UPNP_REFRESH_EVERY = 10*60 + +-- sanity check +assert(wh.FRAGMENT_MTU >= 1024, "65536/MTU <= 64") + +-- additional extensions +require('key') -- add method wh.key +require('conf') -- add wh.fromconf & wh.toconf +wh.new = require('node').new + +_G['now'] = 0 diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..0bdfd49 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +keys diff --git a/tests/generate-keys.sh b/tests/generate-keys.sh new file mode 100755 index 0000000..667e41b --- /dev/null +++ b/tests/generate-keys.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +NAME=test + +wh clearconf $NAME +wh set $NAME workbit 8 subnet 10.0.42.1/24 +wh set $NAME name bootstrap peer P17zMwXJFbBdJEn05RFIMADw9TX5_m2xgf31OgNKX3w untrusted bootstrap yes endpoint bootstrap.wirehub.io + +CWD="$(dirname "$0")" +KPATH="$CWD/keys" + +rm -f $KPATH/*.{sk,k} +mkdir -p $KPATH + +for i in {1..9} +do + echo "generating key $i..." + wh genkey $NAME | tee $KPATH/$i.sk | wh pubkey > $KPATH/$i.k + wh set $NAME name $i.$NAME ip 10.0.42.$i peer `cat $KPATH/$i.k` +done + +wh showconf $NAME > $KPATH/config +