# Hexagonal Coordinate System

This code implements a basic hexagonal coordinate system, as well as a grid. Much of this is based off the explanations on [Red Blob Games](https://www.redblobgames.com/grids/hexagons/implementation.html). 

In [1]:
import sys
from typing import NamedTuple
from math import sin, cos, pi, sqrt

We start by defining some basic data structures for points, hexes, vertices and edges.

In [2]:
class Point(NamedTuple):
    x: int
    y: int

class _Hex(NamedTuple):
    q: int
    r: int
    s: int
    def __str__(self):
        return f'({self.q}, {self.r}, {self.s})'
    def __eq__(self, other):
        return self.q == other.q and self.r == other.r

def Hex(q: int, r: int, s=None) -> _Hex:
    if s is None:
        s = -q-r
    else:
        # Ensures each hex is represented by a singular, unique coordinate
        if q + r + s != 0:
            raise ValueError("q + r + s must equal 0")
    return _Hex(q, r, s)

class _Vertex(NamedTuple):
    q: int
    r: int
    t: int

def Vertex(q: int, r: int, t:int) -> _Vertex:
    if t not in [0,1]:
        raise ValueError("t must be 0 or 1")
    return _Vertex(q,r,t)

class _Edge(NamedTuple):
    a: Hex
    b: Hex
    def __eq__(self, other):
        eq = (self.a == other.a and self.b == other.b) or (self.a == other.b and self.b == other.a)
        return eq

def Edge(a: Hex, b: Hex):
    if hex_distance(a,b) > 1:
        raise ValueError("hexes must be adjacent")
    return _Edge(a,b)
    

Next, we define some functions that are useful for these structures.

In [3]:
def hex_add(a: Hex, b: Hex) -> Hex:
    return Hex(a.q + b.q, a.r + b.r, a.s + b.s)

def point_add(a: Point, b: Point) -> Point:
    return Point(a.x + b.x, a.y + b.y)

def hex_subtract(a: Hex, b: Hex) -> Hex:
    return Hex(a.q - b.q, a.r - b.r, a.s - b.s)

def point_subtract(a: Point, b: Point) -> Point:
    return Point(a.x - b.x, a.y - b.y)

def hex_scale(a: Hex, k: int) -> Hex:
    return Hex(a.q * k, a.r * k, a.s * k)

def point_scale(p: Point, k: float) -> Point:
    return Point(p.x * k, p.y * k)

def hex_rotate_left(a: Hex) -> Hex:
    return Hex(-a.s, -a.q, -a.r)

def hex_rotate_right(a: Hex) -> Hex:
    return Hex(-a.r, -a.s, -a.q)

def point_rotate(p: Point, a: float, o = Point(0,0)) -> Point:
    # rotate a point p an angle a around another point o
    x_off = p.x - o.x
    y_off = p.y - o.y
    x = o.x + x_off * cos(a) - y_off * sin(a)
    y = o.y + x_off * sin(a) + y_off * cos(a)
    return Point(x,y)

def hex_length(a: Hex) -> int:
    return (abs(a.q) + abs(a.r) + abs(a.s)) // 2

def hex_distance(a: Hex, b: Hex) -> int:
    return hex_length(hex_subtract(a, b))
    

In [4]:
def hex_direction(direction: int) -> Hex:
    if direction not in range(6): 
        raise ValueError('direction must be between 0 and 5')
    hex_directions = [
        Hex(1, 0, -1), Hex(1, -1, 0), Hex(0, -1, 1), 
        Hex(-1, 0, 1), Hex(-1, 1, 0), Hex(0, 1, -1)
    ]
    return hex_directions[direction]

def hex_neighbour(hex_: Hex, direction: int) -> Hex:
    if direction not in range(6): 
        raise ValueError('direction must be between 0 and 5')
    return hex_add(hex_, hex_direction(direction))

def hex_neighbourhood(hex_: Hex) -> list[Hex]:
    neighbourhood = []
    for direction in range(6):
        neighbourhood.append(hex_neighbour(hex_, direction))
    return neighbourhood

def vertex_to_hex_trip(v: Vertex) -> list[Hex]:
    # Each vertex sits on a corner shared by three hexes.
    # This function returns those three hexes.
    h = Hex(v.q, v.r) # This is the hex that we define the vertex by
    if v.t == 0:
        hex_trip = [h, hex_neighbour(h,4), hex_neighbour(h,5)]
    else:
        hex_trip = [h, hex_neighbour(h,1), hex_neighbour(h,2)]
    return hex_trip

def vertex_neighbourhood(v: Vertex) -> list[Vertex]:
    neighbourhood = []
    h = Hex(v.q, v.r) # This is the hex that we define the vertex by
    hex_trip = vertex_to_hex_trip(v)
    for hex_ in hex_trip:
        if hex_ == h:
            if v.t == 0:
                neighbourhood.append(Vertex(v.q-1,v.r+2,1))
            else:
                neighbourhood.append(Vertex(v.q+1,v.r-2,0))
        else:
            neighbourhood.append(Vertex(hex_.q,hex_.r,(v.t+1)%2))
    return neighbourhood                                     
            
def vertex_distance(u: Vertex, v: Vertex) -> int:
    # Returns the length of the shortest path (edgewise) between vertices
    if u.t == v.t:
        # If the vertices are the same orientation, then the distance
        # is just |Δq|+|Δr|+|Δs|, which is the same as 2 * hex_dist
        dist = 2*hex_distance(Hex(u.q,u.r), Hex(v.q, v.r))
    else:
        # Otherwise, we find the min distance to a neighbour, and add 1
        # Using the neighbours takes us back to the case where u.t==v.t
        neighbour_dists = []
        neighbours_v = vertex_neighbourhood(v)
        for n in neighbours_v:
            neighbour_dists.append(2*hex_distance(Hex(u.q,u.r), Hex(n.q,n.r)))
        dist = min(neighbour_dists)+1
    return dist 

Of course, hexes, vertices and edges are closely related. It's helpful to be able to map from one type to another.

In [5]:
def hex_to_vertices(a: Hex) -> list[Vertex]:
    vertices = [
        Vertex(a.q, a.r+1, 1),
        Vertex(a.q+1, a.r-1, 0),
        Vertex(a.q, a.r, 1),
        Vertex(a.q, a.r-1, 0),
        Vertex(a.q-1, a.r+1, 1),
        Vertex(a.q, a.r, 0)
    ]
    return vertices        

def hex_to_vertex(a: Hex, direction: int):
    return hex_to_vertices(a)[direction]

def vertex_to_hexes(v: Vertex) -> list[Hex]:
    if v.t == 1:
        hexes = [Hex(v.q, v.r), Hex(v.q, v.r-1), Hex(v.q+1, v.r-1)]
    else:
        hexes = [Hex(v.q, v.r), Hex(v.q, v.r+1), Hex(v.q-1, v.r+1)]
    return hexes

def hex_to_edges(a: Hex) -> list[Edge]:
    neighbours = hex_neighbourhood(a)
    return [Edge(a, n) for n in neighbours]

def hex_to_edge(a: Hex, direction: int):
    return hex_to_edges(a)[direction]

def vertex_pair_to_edge(u: Vertex, v: Vertex) -> Edge:
    u_hexes = vertex_to_hexes(u)
    v_hexes = vertex_to_hexes(v)
    # Get the intersection of their hexes
    h = list(set(u_hexes) & set(v_hexes))
    if len(h) != 2:
        raise ValueError("u and v must be adjacent vertices")
    return Edge(h[0], h[1])

def edge_to_vertex_pair(e: Edge) -> list[Vertex]:
    a_vertices = hex_to_vertices(e.a)
    b_vertices = hex_to_vertices(e.b)
    return list(set(a_vertices) & set(b_vertices))
    

Now, we define some additional structures that specify properties of the grid.

In [6]:
class Orientation(NamedTuple):
    f0: float
    f1: float
    f2: float
    f3: float
    b0: float
    b1: float
    b2: float
    b3: float
    start_angle: float

class Layout(NamedTuple):
    orientation: Orientation
    size: int
    origin: Point

In [7]:
layout_pointy = Orientation(sqrt(3.0), sqrt(3.0) / 2.0, 0.0, 3.0 / 2.0, sqrt(3.0) / 3.0, -1.0 / 3.0, 0.0, 2.0 / 3.0, 0.5)
layout_flat = Orientation(3.0 / 2.0, 0.0, sqrt(3.0) / 2.0, sqrt(3.0), 2.0 / 3.0, 0.0, -1.0 / 3.0, sqrt(3.0) / 3.0, 0.0)
layout_grid = Layout(layout_pointy, 35, Point(0,0))

To display a hexagonal grid, we need to be able to convert hexes, vertices and edges to points on a 2D plane.

In [8]:
def hex_to_center_pixel(layout: Layout, h: Hex) -> Point:
    # Returns a coordinates for a pixel at the center of the hex
    M = layout.orientation
    size = layout.size
    origin = layout.origin
    x = (M.f0 * h.q + M.f1 * h.r) * size
    y = (M.f2 * h.q + M.f3 * h.r) * size
    return Point(x + origin.x, y + origin.y)

def center_pixel_to_hex(layout: Layout, p: Point) -> Hex:
    M = layout.orientation
    size = layout.size
    origin = layout.origin
    pt = Point((p.x - origin.x) / size, (p.y - origin.y) / size)
    q = M.b0 * pt.x + M.b1 * pt.y
    r = M.b2 * pt.x + M.b3 * pt.y
    return Hex(q, r, -q - r)

def hex_vertex_pixel_offset(layout: Layout, direction: int):
    M = layout.orientation
    size = layout.size
    angle = 2.0 * pi * (M.start_angle - direction) / 6.0
    return Point(size * cos(angle), size * sin(angle))

def hex_to_vertex_pixels(layout: Layout, h: Hex):
    vertices = []
    center = hex_to_center_pixel(layout, h)
    for i in range(0, 6):
        offset = hex_vertex_pixel_offset(layout, i)
        vertices.append(Point(center.x + offset.x, center.y + offset.y))
    return vertices

In [9]:
def vertex_to_vertex_pixels(layout: Layout, v: Vertex) -> Point:
    hex_ = Hex(v.q, v.r, -v.q-v.r)
    hex_center_pixel = hex_to_center_pixel(layout, hex_)
    if v.t == 1: 
        direction = 2
    else:
        direction = 5
    return point_add(hex_center_pixel, hex_vertex_pixel_offset(layout, direction))

def edge_to_edge_pixels(layout: Layout, e: Edge) -> list[Point]:
    vertices = edge_to_vertex_pair(e)
    return [vertex_to_vertex_pixels(layout, v) for v in vertices]

def edge_to_shorter_edge_pixels(layout: Layout, e: Edge) -> list[Point]:
    # Slightly reduce edge length
    vertices = edge_to_edge_pixels(layout, e)
    direction = point_subtract(vertices[0], vertices[1])
    adjust = point_scale(direction, layout.size * 0.001)
    if round(adjust.x) == 0 or vertices[0].y > vertices[1].y:
        adjusted_vertices = [point_subtract(vertices[0], adjust), point_add(vertices[1],  adjust)]
    else:
        adjusted_vertices = [point_subtract(vertices[0], adjust), point_add(vertices[1],  adjust)]
    return adjusted_vertices

We need to be able to generate a grid of hexes. It's also helpful to be able to get all vertices and edges in a grid.

In [10]:
def create_hexagonal_grid(radius: int) -> list:
    i = 0
    grid = []
    for q in range(-radius, radius + 1):
        r1 = max(-radius, -q - radius)
        r2 = min(radius, -q + radius)
        for r in range(r1, r2 + 1):
            grid.append(Hex(q, r, -q - r))
            i += 1
    return grid

def hex_grid_index(h: Hex, grid: list[Hex]) -> int:
    return grid.index(h)

def vertex_list(grid: list[Hex]) -> list[Vertex]:
    vertices = []
    for hex_ in grid:
        hex_vertices = hex_to_vertices(hex_)
        for v in hex_vertices:
            if v not in vertices:
                vertices.append(v)
    return vertices

def edge_list(grid: list[Hex]) -> list[Edge]:
    edges = []
    for hex_ in hex_grid:
        neighbours = hex_neighbourhood(hex_)
        for n in neighbours:
            e = Edge(hex_, n)
            if e not in edges:
                edges.append(e)
    return edges

Unit to be placed here!

In [11]:
if __name__ == "__main__":
    print("Insert unit tests here")
    vertex_neighbourhood(Vertex(-2,2,1))
    vertex_distance(Vertex(-1,1,0), Vertex(-2,2,1))
    hex_grid = create_hexagonal_grid(2)
    vertex_list(hex_grid)

Insert unit tests here
