# Item 49: Register Class Existence with `__init_subclass__`

In [1]:
# Say we want to implement our own serialized representation of a Python object using JSON
import json

class Serializable:
    def __init__(self, *args):
        self.args = args

    def serialize(self):
        return json.dumps({'args': self.args})

In [2]:
# This class makes it easy to serialize a simple, immutable data structures like Point2D to a string
class Point2D(Serializable):
    def __init__(self, x, y):
        super().__init__(x, y)
        self.x = x
        self.y = y

    def __repr__(self):
        return f'Point2D({self.x}, {self.y})'

point = Point2D(5, 3)
print('Object:    ', point)
print('Serialized:', point.serialize())

Object:     Point2D(5, 3)
Serialized: {"args": [5, 3]}


In [4]:
# We now need to deserialize this JSON string and construct the Point2D object it represents.
class Deserializable(Serializable):
    @classmethod
    def deserialize(cls, json_data):
        params = json.loads(json_data)
        return cls(*params['args'])

In [5]:
# Using Deserializable makes it easy to serialize and deserialize simple, immutable objects in a generic
# way
class BetterPoint2D(Deserializable):
    def __init__(self, x, y):
        super().__init__(x, y)
        self.x = x
        self.y = y

    def __repr__(self):
        return f'Point2D({self.x}, {self.y})'

before = BetterPoint2D(5, 3)
print('Before:    ', before)
data = before.serialize()
print('Serialized:', data)
after = BetterPoint2D.deserialize(data)
print('After:     ', after)

Before:     Point2D(5, 3)
Serialized: {"args": [5, 3]}
After:      Point2D(5, 3)


The problem with this approach is that it only works if we know the intended type of the serialized data ahead of time (e.g., `Point2D`, `BetterPoint2D`). Ideally, we would have a large number of classe serializing to JSON and one common function that could deserialize any if them back to a corresponding Pyhton `object`.

In [6]:
# We do the above by including a serialized object's class name in the JSON data
class BetterSerializable:
    def __init__(self, *args):
        self.args = args

    def serialize(self):
        return json.dumps({
            'class': self.__class__.__name__,
            'args': self.args
        })

    def __repr__(self):
        name = self.__class__.__name__
        args_str = ', '.join(str(x) for x in self.args)
        return f'{name}({args_str})'

In [7]:
# Then we can maintain a mapping of class names back to constructors for those objects
registry = {}

def register_class(target_class):
    registry[target_class.__name__] = target_class

def deserialize(data):
    params = json.loads(data)
    name = params['class']
    target_class = registry[name]
    return target_class(*params['args'])

In [8]:
# To enusre that deserialize alwas works properly, we mus call register_class for every class we may
# want to deserialize in the future
class EvenBetterPoint2D(BetterSerializable):
    def __init__(self, x, y):
        super().__init__(x, y)
        self.x = x
        self.y = y

register_class(EvenBetterPoint2D)

In [9]:
# Now we can serialize an arbitrary JSON string without having to know which class it contains
before = EvenBetterPoint2D(5, 3)
print('Before:    ', before)
data = before.serialize()
print('Serialized:', data)
after = deserialize(data)
print('After:     ', after)


Before:     EvenBetterPoint2D(5, 3)
Serialized: {"class": "EvenBetterPoint2D", "args": [5, 3]}
After:      EvenBetterPoint2D(5, 3)


In [11]:
# The problem with the previous approach is that its possible to forget to call register_class
class Point3D(BetterSerializable):
    def __init__(self, x, y, z):
        super().__init__(x, y, z)
        self.x = x
        self.y = y
        self.z = z

# Forgot to call register_class! Whoops!

In [13]:
# This causes the code to break at runtime, when we finally try to deserialize an instance of a class
# we forgot to register
point = Point3D(5, 9, -4)
data = point.serialize()
deserialize(data)

KeyError: 'Point3D'

In this intance, if we forget to call `register_call` after the class statement body, we don't get all of its features. Metaclasses, however, enables us to ensure that `register_call` is called in all classes.

In [14]:
# Here, we use a metaclass to register the new type immediately after the class's body
class Meta(type):
    def __new__(meta, name, bases, class_dict):
        cls = type.__new__(meta, name, bases, class_dict)
        register_class(cls)
        return cls

class RegisteredSerializable(BetterSerializable, metaclass=Meta):
    pass

In [15]:
# When we define a subclass of RegisteredSerializable, we can be confident that the call to register_class
# happened and deserialize always works as expected
class Vector3D(RegisteredSerializable):
    def __init__(self, x, y, z):
        super().__init__(x, y, z)
        self.x, self.y, self.z = x, y, z

before = Vector3D(10, -7, 3)
print('Before:    ', before)
data = before.serialize()
print('Serialized:', data)
print('After:     ', deserialize(data))

Before:     Vector3D(10, -7, 3)
Serialized: {"class": "Vector3D", "args": [10, -7, 3]}
After:      Vector3D(10, -7, 3)


An even better approach is to use the `__init_subclass__` special method, which reduces the visual noise of applying custom logic when a class is defined.

In [17]:
class BetterRegisteredSerializable(BetterSerializable):
    def __init_subclass__(cls):
        super().__init_subclass__()
        register_class(cls)

class Vector1D(BetterRegisteredSerializable):
    def __init__(self, magnitude):
        super().__init__(magnitude)
        self.magnitude = magnitude


before = Vector1D(6)
print('Before:    ', before)
data = before.serialize()
print('Serialized:', data)
print('After:     ', deserialize(data))

Before:     Vector1D(6)
Serialized: {"class": "Vector1D", "args": [6]}
After:      Vector1D(6)
