Skip to content

Commit

Permalink
Restructure auth token usage and storage
Browse files Browse the repository at this point in the history
  • Loading branch information
fkantelberg committed Jun 5, 2023
1 parent 233cc15 commit 1db7477
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 47 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ repos:
- id: flake8
additional_dependencies: [flake8-bugbear]

- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.8.0
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort

- repo: https://github.com/pre-commit/mirrors-pylint
rev: v2.7.4
- repo: https://github.com/pylint-dev/pylint
rev: v2.17.4
hooks:
- id: pylint
2 changes: 1 addition & 1 deletion src/socket_proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _format_action_invocation(self, action: argparse.Action) -> str:


def basic_group(parser: argparse.ArgumentParser, server: bool = False) -> None:
group = parser.add_argument_group("Security")
group = parser.add_argument_group("Basic")
group.add_argument(
"--config",
default=None,
Expand Down
18 changes: 17 additions & 1 deletion src/socket_proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ipaddress
import logging
import os
from datetime import datetime
from typing import TypeVar, Union

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -46,11 +47,26 @@ class ReachedClientLimit(Exception):


class AuthType(enum.IntEnum):
"""Helper for authentication tokens"""
"""Helper for authentication token types"""

TOTP = 0x01
HOTP = 0x02

def __str__(self) -> str:
return {AuthType.TOTP: "totp", AuthType.HOTP: "hotp"}.get(self.value)


class AuthToken:
"""Helper for authentication tokens"""

def __init__(self, dt: datetime = None):
if not dt:
self.creation = datetime.now()
elif isinstance(dt, str):
self.creation = datetime.fromisoformat(dt)
else:
self.creation = dt


class InternetType(enum.IntEnum):
"""Helper for IP addresses and identification"""
Expand Down
77 changes: 48 additions & 29 deletions src/socket_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import uuid
from asyncio import StreamReader, StreamWriter
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, List, Tuple, Union

Expand Down Expand Up @@ -37,7 +38,6 @@ def __init__(
self.max_tunnels = base.config.max_tunnels
self.http_ssl = base.config.http_ssl
self.tunnels = {}
self.tokens = {}
self.sc = utils.generate_ssl_context(
cert=cert,
key=key,
Expand All @@ -47,8 +47,11 @@ def __init__(
)
self.http_proxy = self.server = None

# Authentication
self.authentication = authentication
self.tokens = defaultdict(dict)
self.auth_timeout = auth_timeout

self.event = event.EventSystem(
event.EventType.Server,
url=base.config.hook_url,
Expand Down Expand Up @@ -80,41 +83,42 @@ def _load_persisted_state(self, file: str = None) -> None:
return

# Restore the tokens
for tkn, dt in state.get("tokens", {}).items():
self.tokens[tkn] = datetime.fromisoformat(dt) if dt else None
self.tokens.clear()

# Stay compatible
for token, dt in state.get("tokens", {}).items():
ttype = base.AuthType.TOTP if dt else base.AuthType.HOTP
self.tokens[ttype][token] = base.AuthToken(dt)

for auth_type in base.AuthType:
for token, creation in state.get(f"tokens_{auth_type}", {}).items():
self.tokens[auth_type][token] = base.AuthToken(creation)

def _persist_state(self, file: str = None) -> None:
def _save_persisted_state(self, file: str = None) -> None:
"""Persist the internal state of the proxy server like tokens"""
file = file or base.config.persist_state
if not file:
return

state = self.get_persistant_state_dict()
with open(file, "w+", encoding="utf-8") as fp:
json.dump(
{
"tokens": {
tkn: dt.isoformat(" ") if dt else dt
for tkn, dt in self.tokens.items()
}
},
fp,
)
json.dump(state, fp)

async def idle(self) -> None:
"""This methods will get called regularly to apply timeouts"""
dt = datetime.now() - timedelta(seconds=self.auth_timeout)
changes = False
for token, t in list(self.tokens.items()):
if t and token and t < dt:
self.tokens.pop(token, None)
for token, t in list(self.tokens[base.AuthType.TOTP].items()):
if t.creation < dt:
self.tokens[base.AuthType.TOTP].pop(token, None)
changes = True
_logger.info(f"Invalidated token {token}")
await self.event.send(msg="token_invalidate", token=token)

if self.authentication and not self.tokens:
if self.authentication and not self.tokens[base.AuthType.TOTP]:
self.generate_token()
elif changes:
self._persist_state()
self._save_persisted_state()

# Flush the event queue
await self.event.flush()
Expand All @@ -125,11 +129,12 @@ def generate_token(self, hotp: bool = False) -> str:
return None

token = str(uuid.uuid4())
self.tokens[token] = None if hotp else datetime.now()
ttype = "hotp" if hotp else "totp"
_logger.info(f"Generated authentication token {token} [{ttype}]")
auth_type = base.AuthType.HOTP if hotp else base.AuthType.TOTP
self.tokens[auth_type][token] = base.AuthToken()

_logger.info(f"Generated authentication token {token} [{auth_type}]")
self.event.send_nowait(msg="token_generate", token=token, hotp=bool(hotp))
self._persist_state()
self._save_persisted_state()
return token

async def _api_handle(self, path: Tuple[str], request: api.Request) -> Any:
Expand All @@ -143,12 +148,13 @@ async def _api_handle(self, path: Tuple[str], request: api.Request) -> Any:
def _verify_auth_token(self, pkg: package.AuthPackage) -> bool:
"""Verify an authentication package"""
if pkg.token_type == base.AuthType.TOTP:
return pkg.token in {tk for tk, dt in self.tokens.items() if dt}
return pkg.token in self.tokens[base.AuthType.TOTP]

if pkg.token_type == base.AuthType.HOTP:
for token, dt in self.tokens.items():
if dt is None and utils.hotp_verify(token, pkg.token):
return True
return any(
utils.hotp_verify(token, pkg.token)
for token in self.tokens[base.AuthType.HOTP]
)

return False

Expand Down Expand Up @@ -282,6 +288,15 @@ async def loop(self) -> None:
async with self.server:
await self.server.serve_forever()

def get_persistant_state_dict(self) -> dict:
"""Generate a dictionary with all persistance information"""
return {
f"tokens_{t}": {
token: t.creation.isoformat(" ") for token, t in self.tokens[t].items()
}
for t in base.AuthType
}

def get_state_dict(self) -> dict:
"""Generate a dictionary which shows the current state of the server"""
tunnels = {}
Expand All @@ -294,15 +309,19 @@ def get_state_dict(self) -> dict:
"port": self.http_port,
}

state = self.get_persistant_state_dict()
# Stay compatible
state["tokens"] = {
**state["tokens_totp"],
**dict.fromkeys(state["tokens_hotp"], None),
}
return {
**state,
"http": http if self.http_domain else {},
"tcp": {
"host": self.host,
"port": self.port,
},
"tokens": {
t: dt.isoformat(" ") if dt else dt for t, dt in self.tokens.items()
},
"tunnels": tunnels,
}

Expand Down
2 changes: 2 additions & 0 deletions src/socket_proxy/tunnel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
cert: str = None,
key: str = None,
auth_token: str = None,
client_id: str = None,
**kwargs,
):
super().__init__(api_type=api.APIType.Client, **kwargs)
Expand All @@ -32,6 +33,7 @@ def __init__(
self.addr = []
self.last_ping = self.last_pong = None
self.auth_token = auth_token
self.client_id = client_id

self.ping_enabled = base.config.ping

Expand Down
2 changes: 1 addition & 1 deletion src/socket_proxy/tunnel_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_dimension(self) -> None:
"""Get the dimensions of the current window"""
self.height, self.width = self.scr.getmaxyx()

# pylint: disable=W0613,R0201
# pylint: disable=W0613
def fmt_port(self, ip_type: InternetType, ip: IPvXAddress, port: int) -> str:
"""Format an address"""
return f"{ip}:{port}" if ip else str(port)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ async def test_authenticated_tunnel_api():
response = await srv._api_index(req_mock)
token = response.text.strip().replace('"', "")
assert response.status == 200
assert srv.tokens[token]
assert srv.tokens[base.AuthType.TOTP][token]
await asyncio.sleep(0.1)

req_mock = mock.AsyncMock(path="/api/token/hotp")
response = await srv._api_index(req_mock)
token = response.text.strip().replace('"', "")
assert response.status == 200
assert srv.tokens[token] is None
assert token in srv.tokens[base.AuthType.HOTP]
await asyncio.sleep(0.1)

srv.authentication = False
Expand Down
21 changes: 12 additions & 9 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,28 +113,31 @@ async def test_server_state_persistent():
srv.generate_token()
srv.generate_token(True)

srv._persist_state(fp.name)
srv._save_persisted_state(fp.name)

fp.seek(0)
data = srv.tokens
srv.tokens = {}
data = srv.tokens[base.AuthType.TOTP]
srv.tokens[base.AuthType.TOTP] = {}
srv._load_persisted_state(fp.name)
assert data == srv.tokens
assert srv.tokens
assert list(data) == list(srv.tokens[base.AuthType.TOTP])
assert srv.tokens[base.AuthType.TOTP]


@pytest.mark.asyncio
async def test_proxy_token_cleanup():
(port,) = unused_ports(1)
async with server(port) as srv:
srv.tokens["old-token"] = datetime(1970, 1, 1)
srv.authentication = False
srv.tokens[base.AuthType.TOTP]["old-token"] = base.AuthToken(
datetime(1970, 1, 1)
)
await srv.idle()
assert "old-token" not in srv.tokens
assert not srv.tokens
assert "old-token" not in srv.tokens[base.AuthType.TOTP]
assert not srv.tokens[base.AuthType.TOTP]

srv.authentication = True
await srv.idle()
assert srv.tokens
assert srv.tokens[base.AuthType.TOTP]


@pytest.mark.asyncio
Expand Down

0 comments on commit 1db7477

Please sign in to comment.