In [102]:

class CheckPointHandler:
    def __init__(self, **pars):
        self.version_num = 0
        self.history = [{}]
        if len(pars) > 0:
            self.add_vars(**pars)
            
    def make_check_point(self):    
        # Erase checkpoints after this checkpoint (as we don't remember every branch.)
        while self.version_num+1 < len(self.history):
            self.history.pop()
        # Make a dictionary (based on the variables that exist in current instance)
        base_version = self.__make_version()
        # Append the version to end of our history (current version is the last as we have deleted the remaining part's of history)
        self.history.append(base_version)
        # Change the version and update the class parameters
        self.__change_version(self.version_num, self.version_num+1)
    
    def add_vars(self, **pars):
        # Erase checkpoints after this checkpoint (as we don't remember every branch.)
        while self.version_num+1 < len(self.history):
            self.history.pop()
        # Make a dictionary (based on the variables that exist in current instance)
        base_version = self.__make_version()
        # Add mentioned variables to the dictionary
        new_version = self.__expand_version(base_version, pars)
        # Append the version to end of our history (current version is the last as we have deleted the remaining part's of history)
        self.history.append(new_version)
        # Change the version and update the class parameters
        self.__change_version(self.version_num, self.version_num+1)
        
    def undo(self):
        if self.version_num == 0:
            raise Exception("First checkpoint!")
        self.__change_version(self.version_num, self.version_num-1)
    
    def redo(self):
        if self.version_num+1 == len(self.history):
            raise Exception("Last checkpoint!")
        self.__change_version(self.version_num, self.version_num+1)
    
    def print_versions(self):
        for i in range(len(self.history)):
            if i == self.version_num: print('+', end='')
            print('version {}: {}'.format(i, self.history[i]))
    
    def __make_version(self):
        base_version = {}
        
        for key in vars(self):
            if key not in {'version_num', 'history'} :
                att = getattr(self, key)
                if hasattr(att, 'copy'): att = att.copy()
                base_version[key] = att
                
        return base_version
    
    def __expand_version(self, base_version, pars):
        new_version = base_version
        
        for key in pars:
            value = pars[key]
            if hasattr(value, 'copy'): value = value.copy()
            new_version[key] = value
            
        return new_version
        
            
    def __change_version(self, base_version_num, version_num):
        for key in self.history[base_version_num]:
            if key not in self.history[version_num]:
                delattr(self, key)
                
        for key in self.history[version_num]:
            value = self.history[version_num][key]
            setattr(self, key, value)
            
        self.version_num = version_num


In [103]:
cph = CheckPointHandler(a=10)
cph.a = 20
cph.make_check_point()
cph.print_versions()

version 0: {}
version 1: {'a': 10}
+version 2: {'a': 20}


In [105]:
cph.add_vars(a=1,b=20)
cph.add_vars(c=-1,d=-2)
cph.print_versions()

version 0: {}
version 1: {'a': 10}
version 2: {'a': 20}
version 3: {'a': 1, 'b': 20}
+version 4: {'a': 1, 'b': 20, 'c': -1, 'd': -2}


In [106]:
cph.undo()
cph.print_versions()

version 0: {}
version 1: {'a': 10}
version 2: {'a': 20}
+version 3: {'a': 1, 'b': 20}
version 4: {'a': 1, 'b': 20, 'c': -1, 'd': -2}


In [107]:
print(cph.a, cph.b)

1 20


In [108]:
print(cph.c, cph.d)

AttributeError: 'CheckPointHandler' object has no attribute 'c'

In [109]:
cph.redo()
cph.print_versions()

version 0: {}
version 1: {'a': 10}
version 2: {'a': 20}
version 3: {'a': 1, 'b': 20}
+version 4: {'a': 1, 'b': 20, 'c': -1, 'd': -2}


In [110]:
print(cph.a, cph.b, cph.c, cph.d)

1 20 -1 -2
