# Chapter 17: Socket Fundamentals

**Networking and Protocols**

Sockets are the foundational building block for network communication in Python.
The `socket` module provides a low-level interface to the BSD socket API, enabling
programs to send and receive data over TCP, UDP, and other protocols. Understanding
sockets is essential even when using higher-level libraries, because they underpin
everything from HTTP clients to database drivers.

## Socket Basics: Address Families and Socket Types

A socket is created with two key parameters:

- **Address family** (`AF_INET` for IPv4, `AF_INET6` for IPv6, `AF_UNIX` for local IPC)
- **Socket type** (`SOCK_STREAM` for TCP, `SOCK_DGRAM` for UDP)

TCP (`SOCK_STREAM`) provides a reliable, ordered byte stream with connection setup.
UDP (`SOCK_DGRAM`) provides connectionless, unreliable datagrams -- faster but
with no delivery guarantees.

In [None]:
import socket


# Inspect the key constants
print("=== Address Families ===")
print(f"AF_INET  (IPv4):  {socket.AF_INET}")
print(f"AF_INET6 (IPv6):  {socket.AF_INET6}")
print(f"AF_UNIX  (local): {socket.AF_UNIX}")

print("\n=== Socket Types ===")
print(f"SOCK_STREAM (TCP): {socket.SOCK_STREAM}")
print(f"SOCK_DGRAM  (UDP): {socket.SOCK_DGRAM}")

# Create a TCP socket and inspect its properties
tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print(f"\nTCP socket: {tcp_sock}")
print(f"  family:  {tcp_sock.family}")
print(f"  type:    {tcp_sock.type}")
print(f"  fileno:  {tcp_sock.fileno()}")
tcp_sock.close()

# Create a UDP socket and inspect
udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
print(f"\nUDP socket: {udp_sock}")
print(f"  family:  {udp_sock.family}")
print(f"  type:    {udp_sock.type}")
udp_sock.close()

## Socket as a Context Manager

Sockets implement the context manager protocol. Using `with` ensures the socket
is closed automatically, even if an exception occurs. This prevents file descriptor
leaks, which can exhaust operating system resources in long-running servers.

In [None]:
import socket


# Preferred pattern: socket as context manager
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
    print(f"Socket open, fileno={sock.fileno()}")
    # The socket is usable within this block

# After the block, the socket is closed automatically
print(f"Socket closed, fileno={sock.fileno()}")
# fileno() returns -1 on a closed socket on most platforms


# Equivalent manual pattern (less safe):
sock2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
    print(f"\nManual socket open, fileno={sock2.fileno()}")
finally:
    sock2.close()
    print(f"Manual socket closed, fileno={sock2.fileno()}")

## DNS Resolution with `getaddrinfo()`

`socket.getaddrinfo()` resolves a hostname into one or more address tuples
that can be used to create and connect sockets. It is protocol-agnostic and
handles both IPv4 and IPv6. This is the recommended way to resolve addresses
rather than using `gethostbyname()`, which only returns IPv4.

In [None]:
import socket


def resolve_address(
    host: str,
    port: int,
    family: socket.AddressFamily = socket.AF_UNSPEC,
    type_: socket.SocketKind = socket.SOCK_STREAM,
) -> list[tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]]:
    """Resolve a host:port to a list of address info tuples."""
    results = socket.getaddrinfo(host, port, family, type_)
    return results


# Resolve localhost for TCP
print("=== Resolving 'localhost' port 80 (TCP) ===")
for info in resolve_address("localhost", 80):
    family, socktype, proto, canonname, sockaddr = info
    print(f"  family={family.name}, type={socktype.name}, addr={sockaddr}")

# Resolve for UDP only
print("\n=== Resolving 'localhost' port 53 (UDP) ===")
for info in resolve_address("localhost", 53, type_=socket.SOCK_DGRAM):
    family, socktype, proto, canonname, sockaddr = info
    print(f"  family={family.name}, type={socktype.name}, addr={sockaddr}")

# Resolve an external hostname (may return IPv4 and IPv6)
print("\n=== Resolving 'example.com' port 443 (TCP) ===")
try:
    for info in resolve_address("example.com", 443):
        family, socktype, proto, canonname, sockaddr = info
        print(f"  family={family.name}, type={socktype.name}, addr={sockaddr}")
except socket.gaierror as e:
    print(f"  DNS resolution failed: {e}")

## Socket Options: `SO_REUSEADDR` and Timeouts

Socket options control low-level behavior:

- **`SO_REUSEADDR`**: Allows a server socket to bind to an address that is in
  the `TIME_WAIT` state (e.g., after a recent restart). Without this, restarting
  a server immediately may fail with "Address already in use".
- **Timeouts**: `settimeout()` controls how long blocking operations (connect,
  recv, accept) wait before raising `socket.timeout`. A value of `None` means
  blocking indefinitely; `0` means non-blocking.

In [None]:
import socket


# Demonstrate SO_REUSEADDR
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
    # Check default value
    reuse_before = server.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
    print(f"SO_REUSEADDR before: {reuse_before}")

    # Enable SO_REUSEADDR -- always do this for server sockets
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    reuse_after = server.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
    print(f"SO_REUSEADDR after:  {reuse_after}")

# Demonstrate timeouts
print("\n=== Socket Timeouts ===")
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
    # Default: blocking (None)
    print(f"Default timeout: {sock.gettimeout()}")

    # Set a 2-second timeout
    sock.settimeout(2.0)
    print(f"After settimeout(2.0): {sock.gettimeout()}")

    # Non-blocking mode
    sock.setblocking(False)
    print(f"After setblocking(False): timeout={sock.gettimeout()}")

    # Back to blocking
    sock.setblocking(True)
    print(f"After setblocking(True): timeout={sock.gettimeout()}")

# Demonstrate timeout exception
print("\n=== Timeout on connect ===")
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
    sock.settimeout(0.001)  # Very short timeout
    try:
        # Attempt connection to a non-routable address (will timeout)
        sock.connect(("192.0.2.1", 80))
    except socket.timeout:
        print("  Connection timed out (socket.timeout raised)")
    except OSError as e:
        print(f"  OS error: {e}")

## TCP Client-Server Echo Example

A classic TCP pattern: the **echo server** accepts connections and sends back
whatever data it receives. We use `threading` to run the server in the background
so client and server can coexist in a single notebook.

TCP flow:
1. Server: `bind()` -> `listen()` -> `accept()` (blocks until client connects)
2. Client: `connect()` -> `sendall()` -> `recv()`
3. Server: `recv()` -> `sendall()` (echoes data back)

In [None]:
import socket
import threading


HOST: str = "127.0.0.1"
PORT: int = 0  # Let the OS assign a free port


def echo_server(server_sock: socket.socket, ready: threading.Event) -> None:
    """Accept one connection and echo data back."""
    server_sock.listen(1)
    ready.set()  # Signal that the server is ready to accept

    conn, addr = server_sock.accept()
    with conn:
        print(f"[Server] Connection from {addr}")
        while True:
            data: bytes = conn.recv(1024)
            if not data:
                print("[Server] Client disconnected")
                break
            print(f"[Server] Received: {data!r}")
            conn.sendall(data)  # Echo back


def echo_client(port: int, messages: list[str]) -> list[str]:
    """Connect to the echo server and send messages."""
    responses: list[str] = []
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.connect((HOST, port))
        print(f"[Client] Connected to {HOST}:{port}")

        for msg in messages:
            sock.sendall(msg.encode("utf-8"))
            data = sock.recv(1024)
            response = data.decode("utf-8")
            print(f"[Client] Sent: {msg!r}, Received: {response!r}")
            responses.append(response)

    return responses


# Set up the server socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
    server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server_sock.bind((HOST, PORT))
    actual_port: int = server_sock.getsockname()[1]
    print(f"Server bound to {HOST}:{actual_port}")

    # Start server in a background thread
    ready = threading.Event()
    server_thread = threading.Thread(
        target=echo_server, args=(server_sock, ready), daemon=True
    )
    server_thread.start()
    ready.wait()  # Wait until server is listening

    # Run the client
    messages = ["Hello, server!", "How are you?", "Goodbye!"]
    responses = echo_client(actual_port, messages)

    print(f"\nAll messages echoed correctly: {messages == responses}")

## Handling Partial Sends and Receives with `sendall()`

TCP is a **stream protocol** -- it does not preserve message boundaries. A single
`send()` may not transmit all bytes at once, and a single `recv()` may return
partial data. Key rules:

- **`sendall(data)`**: Repeatedly calls `send()` until all bytes are sent. Always
  prefer this over raw `send()`.
- **`recv(bufsize)`**: Returns up to `bufsize` bytes. You must loop to receive a
  complete message, using a framing protocol (length prefix, delimiter, etc.).

In [None]:
import socket
import struct
import threading


def recv_exactly(sock: socket.socket, num_bytes: int) -> bytes:
    """Receive exactly num_bytes from the socket.

    Handles partial receives by looping until all bytes are collected.
    Raises ConnectionError if the connection is closed prematurely.
    """
    chunks: list[bytes] = []
    bytes_received: int = 0
    while bytes_received < num_bytes:
        chunk = sock.recv(num_bytes - bytes_received)
        if not chunk:
            raise ConnectionError(
                f"Connection closed after {bytes_received}/{num_bytes} bytes"
            )
        chunks.append(chunk)
        bytes_received += len(chunk)
    return b"".join(chunks)


def send_message(sock: socket.socket, message: bytes) -> None:
    """Send a length-prefixed message (4-byte big-endian length header)."""
    header: bytes = struct.pack("!I", len(message))  # 4-byte unsigned int
    sock.sendall(header + message)


def recv_message(sock: socket.socket) -> bytes:
    """Receive a length-prefixed message."""
    header: bytes = recv_exactly(sock, 4)
    (length,) = struct.unpack("!I", header)
    return recv_exactly(sock, length)


# Demonstrate length-prefixed messaging
def framing_server(server_sock: socket.socket, ready: threading.Event) -> None:
    server_sock.listen(1)
    ready.set()
    conn, _ = server_sock.accept()
    with conn:
        # Receive and echo back 3 messages
        for _ in range(3):
            msg = recv_message(conn)
            print(f"[Server] Received {len(msg)} bytes: {msg!r}")
            send_message(conn, msg)


with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind(("127.0.0.1", 0))
    port = srv.getsockname()[1]

    ready = threading.Event()
    t = threading.Thread(target=framing_server, args=(srv, ready), daemon=True)
    t.start()
    ready.wait()

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client:
        client.connect(("127.0.0.1", port))

        # Send messages of varying sizes
        test_messages: list[bytes] = [
            b"Short",
            b"A medium-length message for testing",
            b"X" * 500,  # Larger message
        ]

        for original in test_messages:
            send_message(client, original)
            echoed = recv_message(client)
            assert echoed == original, f"Mismatch: {original!r} != {echoed!r}"
            print(f"[Client] Sent {len(original)} bytes, echoed correctly")

    t.join(timeout=2)
    print("\nAll length-prefixed messages exchanged correctly!")

## UDP Send/Receive Example

UDP (`SOCK_DGRAM`) is connectionless -- each `sendto()` and `recvfrom()` is
an independent datagram. There is no `connect()`, `listen()`, or `accept()`.
Datagrams may arrive out of order, be duplicated, or be lost entirely.

UDP is suitable for:
- DNS lookups
- Real-time audio/video streaming
- Game state updates
- Any scenario where low latency matters more than reliability

In [None]:
import socket
import threading


def udp_echo_server(server_sock: socket.socket, ready: threading.Event) -> None:
    """Simple UDP echo server that handles a fixed number of datagrams."""
    ready.set()
    for _ in range(3):  # Handle 3 datagrams then stop
        data, client_addr = server_sock.recvfrom(1024)
        print(f"[UDP Server] Received {data!r} from {client_addr}")
        server_sock.sendto(data.upper(), client_addr)  # Echo back uppercased


# Create UDP server socket
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as server_sock:
    server_sock.bind(("127.0.0.1", 0))
    port = server_sock.getsockname()[1]
    print(f"UDP server listening on 127.0.0.1:{port}")

    ready = threading.Event()
    t = threading.Thread(
        target=udp_echo_server, args=(server_sock, ready), daemon=True
    )
    t.start()
    ready.wait()

    # UDP client
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client_sock:
        client_sock.settimeout(2.0)  # Timeout since UDP has no guarantees

        messages = [b"hello udp", b"networking is fun", b"datagrams!"]
        for msg in messages:
            client_sock.sendto(msg, ("127.0.0.1", port))
            try:
                data, server_addr = client_sock.recvfrom(1024)
                print(f"[UDP Client] Sent: {msg!r}, Got: {data!r}")
            except socket.timeout:
                print(f"[UDP Client] No response for {msg!r} (lost datagram)")

    t.join(timeout=2)
    print("\nUDP echo complete!")

## Multi-Client TCP Server with Threading

A production-like TCP server handles multiple clients concurrently. Each accepted
connection is dispatched to a new thread. This pattern is the foundation for
many network services, though for high concurrency, `asyncio` (covered in
notebook 03) or `selectors` may be preferred.

In [None]:
import socket
import threading
import time


def handle_client(
    conn: socket.socket,
    addr: tuple[str, int],
    client_id: int,
) -> None:
    """Handle a single client connection in its own thread."""
    with conn:
        print(f"[Server] Client #{client_id} connected from {addr}")
        while True:
            data = conn.recv(1024)
            if not data:
                break
            response = f"[echo #{client_id}] {data.decode()}".encode()
            conn.sendall(response)
        print(f"[Server] Client #{client_id} disconnected")


def multi_client_server(
    server_sock: socket.socket,
    ready: threading.Event,
    max_clients: int = 3,
) -> None:
    """Accept up to max_clients connections, each in a new thread."""
    server_sock.listen(5)
    server_sock.settimeout(5.0)  # Don't block forever
    ready.set()

    handlers: list[threading.Thread] = []
    for i in range(max_clients):
        try:
            conn, addr = server_sock.accept()
            t = threading.Thread(
                target=handle_client, args=(conn, addr, i + 1), daemon=True
            )
            t.start()
            handlers.append(t)
        except socket.timeout:
            break

    for t in handlers:
        t.join(timeout=5)


def run_client(port: int, client_name: str, messages: list[str]) -> None:
    """A simple client that sends messages and prints responses."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.connect(("127.0.0.1", port))
        for msg in messages:
            sock.sendall(msg.encode())
            response = sock.recv(1024).decode()
            print(f"  [{client_name}] Sent: {msg!r}, Got: {response!r}")
            time.sleep(0.05)  # Small delay between messages


# Launch server
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind(("127.0.0.1", 0))
    port = srv.getsockname()[1]

    ready = threading.Event()
    server_t = threading.Thread(
        target=multi_client_server, args=(srv, ready, 3), daemon=True
    )
    server_t.start()
    ready.wait()

    # Launch 3 clients concurrently
    client_threads: list[threading.Thread] = []
    for i in range(1, 4):
        t = threading.Thread(
            target=run_client,
            args=(port, f"Client-{i}", [f"msg-{j}" for j in range(1, 3)]),
        )
        client_threads.append(t)
        t.start()

    for t in client_threads:
        t.join()

    server_t.join(timeout=3)
    print("\nMulti-client echo server demo complete!")

## TCP vs UDP Comparison

| Feature | TCP (`SOCK_STREAM`) | UDP (`SOCK_DGRAM`) |
|---|---|---|
| **Connection** | Connection-oriented (3-way handshake) | Connectionless |
| **Reliability** | Guaranteed delivery, ordering, retransmission | No guarantees |
| **Data boundaries** | Byte stream (no message boundaries) | Preserves datagram boundaries |
| **Flow control** | Yes (TCP windowing) | None |
| **Overhead** | Higher (headers, handshake, ACKs) | Lower (minimal headers) |
| **Use cases** | HTTP, SSH, email, file transfer | DNS, streaming, gaming |

In [None]:
import socket
from dataclasses import dataclass


@dataclass(frozen=True)
class ProtocolInfo:
    name: str
    socket_type: socket.SocketKind
    connection_oriented: bool
    reliable: bool
    preserves_boundaries: bool
    typical_uses: list[str]


protocols: list[ProtocolInfo] = [
    ProtocolInfo(
        name="TCP",
        socket_type=socket.SOCK_STREAM,
        connection_oriented=True,
        reliable=True,
        preserves_boundaries=False,
        typical_uses=["HTTP/HTTPS", "SSH", "SMTP", "FTP", "Database connections"],
    ),
    ProtocolInfo(
        name="UDP",
        socket_type=socket.SOCK_DGRAM,
        connection_oriented=False,
        reliable=False,
        preserves_boundaries=True,
        typical_uses=["DNS", "DHCP", "Video streaming", "Online gaming", "VoIP"],
    ),
]

for proto in protocols:
    print(f"=== {proto.name} ({proto.socket_type.name}) ===")
    print(f"  Connection-oriented: {proto.connection_oriented}")
    print(f"  Reliable delivery:   {proto.reliable}")
    print(f"  Message boundaries:  {proto.preserves_boundaries}")
    print(f"  Typical uses: {', '.join(proto.typical_uses)}")
    print()

## Best Practices and Common Pitfalls

The following helper demonstrates several best practices when creating server
sockets: using `SO_REUSEADDR`, proper error handling, port 0 for testing,
and the context manager pattern.

In [None]:
import socket
from contextlib import contextmanager
from typing import Generator


@contextmanager
def create_server_socket(
    host: str = "127.0.0.1",
    port: int = 0,
    backlog: int = 5,
    timeout: float | None = None,
) -> Generator[socket.socket, None, None]:
    """Create a properly configured TCP server socket.

    Best practices:
    - Uses SO_REUSEADDR to avoid 'Address already in use' errors
    - Port 0 lets the OS assign a free port (great for testing)
    - Context manager ensures the socket is always closed
    - Optional timeout prevents indefinite blocking
    """
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        server.bind((host, port))
        server.listen(backlog)
        if timeout is not None:
            server.settimeout(timeout)
        actual_addr = server.getsockname()
        print(f"Server listening on {actual_addr[0]}:{actual_addr[1]}")
        yield server
    finally:
        server.close()
        print("Server socket closed")


# Usage demonstration
with create_server_socket(timeout=1.0) as srv:
    host, port = srv.getsockname()
    print(f"  Assigned port: {port}")
    print(f"  Timeout: {srv.gettimeout()}s")
    print(f"  SO_REUSEADDR: {srv.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)}")

# Common pitfalls to avoid:
print("\n=== Common Pitfalls ===")
pitfalls: dict[str, str] = {
    "Using send() instead of sendall()": (
        "send() may not transmit all bytes. Always use sendall() for complete delivery."
    ),
    "Assuming recv() returns full message": (
        "TCP is a stream protocol. Use length-prefix or delimiter-based framing."
    ),
    "Forgetting SO_REUSEADDR on servers": (
        "Restarting a server may fail with 'Address in use' without this option."
    ),
    "Not setting timeouts": (
        "Blocking sockets without timeouts can hang indefinitely."
    ),
    "Not closing sockets": (
        "File descriptor leaks. Always use context managers."
    ),
}

for pitfall, explanation in pitfalls.items():
    print(f"  - {pitfall}")
    print(f"    Fix: {explanation}")

## Summary

This notebook covered the fundamentals of socket programming in Python:

1. **Socket creation** with `AF_INET`/`SOCK_STREAM` (TCP) and `SOCK_DGRAM` (UDP)
2. **Context managers** for safe socket lifecycle management
3. **DNS resolution** with `getaddrinfo()` for protocol-agnostic address lookup
4. **Socket options** like `SO_REUSEADDR` and timeout configuration
5. **TCP client-server** echo pattern with `bind()`, `listen()`, `accept()`, `connect()`
6. **Partial send/receive handling** with `sendall()` and length-prefixed framing
7. **UDP datagrams** with `sendto()` and `recvfrom()`
8. **Multi-client servers** using threading for concurrent connections

The next notebook covers HTTP and URL handling with `urllib`, building on these
socket fundamentals with higher-level abstractions.