In [1]:
import typing
import abc
from dataclasses import dataclass, asdict
import json

In [25]:

@dataclass
class ModelConfig(metaclass=abc.ABCMeta):

	@classmethod
	@abc.abstractclassmethod
	def load(cls, obj: dict) -> "ModelConfig":
		raise NotImplementedError

	@abc.abstractmethod
	def serialize(self) -> dict:
		raise NotImplementedError



T_config = typing.TypeVar("T_config", bound=ModelConfig)
class ConfiguredModel(
		typing.Generic[T_config], 
		metaclass=abc.ABCMeta,
	):
	config_class: type = None

	def __init__(self, config: T_config):
		super().__init__()
		if self.config_class is None:
			raise NotImplementedError("you need to set `config_class` for your model")
		if not isinstance(config, self.config_class):
			raise TypeError(f"config must be an instance of {self.config_class = }, got {type(config) = }")
		self.config = config
		self.data = "test data"

	def save(self) -> None:
		print(json.dumps(dict(
				config = self.config.serialize(),
				data = self.data,
			),
			indent=4,
		))
			

	@classmethod
	def load(cls, data: str) -> "ConfiguredModel":
		obj = json.loads(data)
		# get the config class from the type annotation
		# print(typing.get_type_hints(cls))
		# config_class = typing.get_type_hints(cls)["config"]
		# print(config_class, type(config_class))
		print(f"{cls.config_class = } {type(cls.config_class) = }")
		# print(f"{cls.config.__class__ = } {type(cls.config.__class__) = }")


		# config = config_class(**obj["config"])
		# model: "ConfiguredModel" = cls(config)
		# model.data = obj["data"]
		# return model

In [26]:
@dataclass
class MyConfig(ModelConfig):
	"""basic test GPT config"""
	n_layers: int
	n_vocab: int

	def serialize(self) -> dict:
		return asdict(self)
	
	@classmethod
	def load(cls, obj: dict) -> "MyGPTConfig":
		return cls(**obj)

class MyGPT(ConfiguredModel[MyConfig]):

	def __init__(self, config: MyConfig):
		super().__init__(config)
		self.transformer = lambda x: (sum(x), config.n_layers, config.n_vocab)

	def forward(self, x):
		return self.transformer(x)

In [27]:
config: MyConfig = MyConfig(
	n_layers=2,
	n_vocab=128,
)

model: MyGPT = MyGPT(config)

model.save()

model.load("""{
    "config": {
        "n_layers": 2,
        "n_vocab": 128
    },
    "data": "test data"
}""")

{
    "config": {
        "n_layers": 2,
        "n_vocab": 128
    },
    "data": "test data"
}
cls.config_class = ~T_config type(cls.config_class) = <class 'typing.TypeVar'>
