In [1]:
from pathlib import Path
import os, sys
sys.path.append(str(Path(os.getcwd()).parent.parent))

## Param reloading

We can reload parameters from a dictionary that's either flat or nested. A dataclass can be exported as a nested dict.

In [2]:
from dataclasses import dataclass
from params import Params
from params.params import flatten_dict, asdict

@dataclass
class A(Params):
    value: int
        
@dataclass
class B(Params):
    value: int
    a: A
        
@dataclass
class C(Params):
    value: int
    b: B


print(C.from_flat({"value": 1, "b.value": 2, "b.a.value": 3}))
print(C.from_flat(flatten_dict({
    "value": 1,
    "b": {
        "value": 2,
        "a": {
            "value": 3
        }
    }
})))
some_C = C(value=1, b=B(value=2, a=A(value=3)))
print(asdict(some_C))

C(value=1, b=B(value=2, a=A(value=3)))
C(value=1, b=B(value=2, a=A(value=3)))
{'value': 1, 'b': {'value': 2, 'a': {'value': 3}}}


## Oh NO ! I changed my Params class !

Sometimes we want to add new fields to our classes. This is rather benign and is handled automagically when reloading old classes. However, what happens when we decide that `value` just isn't the right name anymore in class `B` ? We have to migrate

In [3]:
from params import MissingArg
# Let's redefine a new B class with a new param and reload some_C

@dataclass
class B(Params):
    value: int
    new_value:int
    a: A
        
@dataclass
class C(Params):
    value: int
    b: B

try:
    C.from_flat({"value": 1, "b.value": 2, "b.a.value": 3})
except MissingArg as e:
    print(e)

Parameter b.new_value unspecified and has no default


In [4]:
@dataclass
class B(Params):
    value: int
    a: A
    new_value:int = 0
        
@dataclass
class C(Params):
    value: int
    b: B


print(C.from_flat({"value": 1, "b.value": 2, "b.a.value": 3}))

C(value=1, b=B(value=2, a=A(value=3), new_value=0))


We can add migrations for a class by appending to the global variable `migrations`. The order in which these are appended is also the order in which they will be executed.

TODO: A good task to familiarize with this would be to add names to migrations and make sure they are added only once.

In [5]:
from params.migrate import migrations, migrate

@dataclass
class B(Params):
    new_name_for_value: int
    a: A
        
@dataclass
class C(Params):
    value: int
    b: B

        
from functools import partial
from evariste.trainer.migrations import rename_prefix

In [6]:
from params.migrate import Schema, FlatDict, Migration, warn_migration, migrate
from dataclasses import fields, is_dataclass


def rename_field(flat_dict: FlatDict, old: str, new: str) -> FlatDict:
    warn_migration(f"Changing {old} to {new} in flatdict")
    new_dict = dict(flat_dict)  # copy
    if old in flat_dict:
        value = flat_dict[old]
        new_dict.pop(old)
        new_dict[new] = value
    return new_dict

migrations[B].append(
    partial(rename_field, old="value", new="new_name_for_value")
)  # only do this once...

In [8]:
print(C.from_flat(migrate(C, {"value": 1, "b.value": 2, "b.a.value": 3})))

[93m[MIGRATION] - Changing value to new_name_for_value in flatdict[0m
C(value=1, b=B(new_name_for_value=2, a=A(value=3)))


## Reloading old models

For historical reason, the first "migration" that is applied is `trainer_args_from_old`. Maybe one day we'll be able to get rid of it. 

In the meantime, if you want to reload a checkpoint, you should use : `evariste.model.utils.reload_ckpt` which returns TrainerArgs, Dictionary and the reloaded modules.

This utility function will handle params migration for you.