In [31]:
from __future__ import print_function
import sys

# TODO: just use list indices as normal and use pop()? (instead of inserting at 0, etc)

class InMemoryDatabase(object):
    def __init__(self):
        # TODO: underscores?
        self.data = {}
        self.value_index = {}
    
    # TODO: underscores?
    def increment_value_index(self, value):
        value_count = self.value_index.get(value, 0)
        self.value_index[value] = value_count + 1
        
    def decrement_value_index(self, value):
        self.value_index[value] -= 1
        if not self.value_index[value]:
            del self.value_index[value]
            
    def is_set(self, var):
        return var in self.data
    
    def get(self, var):
        return None if not self.is_set(var) else self.data[var][0]
    
    def _add_or_replace(self, var, value, new=True):
        if new and not self.is_set(var):
            self.data[var] = []
        else:
            current_value = self.get(var)
            if current_value == value:
                # Setting the value to its existing value should have no effect
                return None
            else:
                # The current value is being changed, so the value index needs to be changed
                self.decrement_value_index(current_value)
        if new:
            self.data[var].insert(0, value)
        else:
            self.data[var][0] = value
        self.increment_value_index(value)
        
    def add(self, var, value):
        self._add_or_replace(var, value, new=True)
    
    def change(self, var, value):
        self._add_or_replace(var, value, new=False)
        
    def remove(self, var):
        value = self.data[var][0]
        self.data[var].pop(0)
        if not self.data[var]:
            del self.data[var]
        else:
            self.increment_value_index(self.get(var))
        self.decrement_value_index(value)
    
    def flatten(self):
        for var in self.data.keys():
            self.data[var] = [self.data[var][0]]
            
    def num_equal_to(self, value):
        return self.value_index.get(value, 0)
    
    def __repr__(self):
        return 'Data: {}\nValue Index: {}'.format(self.data, self.value_index)

    
class DbSession(object):
    def __init__(self):
        # TODO: underscores?
        self.database = InMemoryDatabase()
        self.transaction_stack = []
        self.reset_transaction_state()
    
    def reset_transaction_state(self):
        # Transaction stack will always have a 'base' transaction - cannot be rolled back/commited
        if not self.transaction_stack:
            self.current_trans = set()
        else:
            self.current_trans = self.transaction_stack[0]
        self.transaction_stack = [self.current_trans]
        
    def pop_transaction(self):
        self.transaction_stack.pop(0)
        self.current_trans = self.transaction_stack[0]
        
    def begin(self):
        self.current_trans = set()
        self.transaction_stack.insert(0, self.current_trans)
        
    def has_open_transaction(self):
        return len(self.transaction_stack) > 1
        
    def rollback(self):
        if not self.has_open_transaction():
            print('NO TRANSACTION')
        else:
            for var in list(self.current_trans):
                self.database.remove(var)
            self.pop_transaction()
        
    def commit(self):
        if not self.has_open_transaction():
            print('NO TRANSACTION')
        else:
            self.database.flatten()
            self.reset_transaction_state()

    def set_var(self, var, value):
        if var in self.current_trans:
            self.database.change(var, value)
        else:
            self.database.add(var, value)
            self.current_trans.add(var)
    
    def unset_var(self, var):
        self.set_var(var, None)
            
    def get_var(self, var):
        print(self.database.get(var) or 'NULL')
    
    def num_equal_to(self, value):
        print(self.database.num_equal_to(value))

    def __repr__(self):
        return '{}\nTransaction Stack: {}'.format(self.database, self.transaction_stack)


class CommandInterpreter(object):
    ### DB commands
    # SET name value – Set the variable name to the value value. Neither variable names nor values will contain spaces.
    # GET name – Print out the value of the variable name, or NULL if that variable is not set.
    # UNSET name – Unset the variable name, making it just like that variable was never set.
    # NUMEQUALTO value – Print out the number of variables that are currently set to value. If no variables equal that value, print 0.
    # END
    ### Transaction commands
    # BEGIN – Open a new transaction block. Transaction blocks can be nested; a BEGIN can be issued inside of an existing block.
    # ROLLBACK – Undo all of the commands issued in the most recent transaction block, and close the block. Print nothing if successful, or print NO TRANSACTION if no transaction is in progress.=
    # COMMIT

    def __init__(self, database):
        self.database = database

    def execute(self, command_input):
        pass

In [34]:
db = DbSession()
db.set_var('a', 50)
db.begin()
db.get_var('a')
db.set_var('a', 60)
db.begin()
db.unset_var('a')
db.get_var('a')
print(db)
db.rollback()
db.get_var('a')
print('')
print(db)
db.commit()
db.get_var('a')
print('')
print(db)


50
NULL
Data: {'a': [None, 60, 50]}
Value Index: {None: 1}
Transaction Stack: [set(['a']), set(['a']), set(['a'])]
60

Data: {'a': [60, 50]}
Value Index: {60: 1}
Transaction Stack: [set(['a']), set(['a'])]
60

Data: {'a': [60]}
Value Index: {60: 1}
Transaction Stack: [set(['a'])]


In [19]:
db = DbSession()
db.begin()
db.set_var('a', 10)
db.get_var('a')
db.begin()
db.set_var('a', 20)
db.get_var('a')
db.rollback()
db.get_var('a')
db.rollback()
db.get_var('a')

10
20
10
NULL


In [150]:
db = DbSession()
db.begin()
db.set_var('a', 1)
db.set_var('b', 2)
db.set_var('c', 3)
db.unset_var('c')
# print(db)
db.begin()
db.get_var('b')
db.get_var('e')
db.set_var('b', 1)
db.set_var('b', 3)
db.set_var('b', 4)
db.num_equal_to(5)
db.set_var('b', 5)
db.num_equal_to(5)
db.set_var('c', 3)
db.set_var('c', 5)
db.num_equal_to(5)
# print(db)
db.begin()
db.set_var('b', 2)
# print(db)
print('PRE-ROLLBACK:\n{}'.format(db))
db.rollback()
print('\nAFTER ROLLBACK 1:\n{}'.format(db))
db.rollback()
print('\nAFTER ROLLBACK 2:\n{}'.format(db))
db.num_equal_to(5)
db.rollback()
print('\nAFTER ROLLBACK 3:\n{}'.format(db))
db.num_equal_to(5)
db.rollback()
# print(db)

2
NULL
0
1
2
PRE-ROLLBACK:
Data: {'a': [1], 'c': [5], 'b': [2, 5, 2]}
Value Index: {1: 1, 2: 1, 5: 1}
Transaction Stack: [set(['b']), set(['c', 'b']), set(['a', 'b'])]

AFTER ROLLBACK 1:
Data: {'a': [1], 'c': [5], 'b': [5, 2]}
Value Index: {1: 1, 5: 2}
Transaction Stack: [set(['c', 'b']), set(['a', 'b'])]

AFTER ROLLBACK 2:
Data: {'a': [1], 'c': [5], 'b': [5, 2]}
Value Index: {1: 1, 5: 2}
Transaction Stack: [set(['a', 'b'])]
2

AFTER ROLLBACK 3:
Data: {'a': [1], 'b': [2]}
Value Index: {1: 1, 2: 1}
Transaction Stack: []
0
NO TRANSACTION
