# Trading Backend

Abstract class whose API is implemented.

In [63]:
from abc import ABC, abstractmethod  
      
# abstract class to represent a stock trading platform
class AbstractTradingBackend(ABC):
    
    # constructor
    @abstractmethod
    def __init__(self):
        pass           
        
    # adds transaction_record to the set of completed transactions
    @abstractmethod
    def log_transaction(self, transactionRecord):
        pass

    # returns a list with all transactions of a given stock_name,
    # sorted by increasing trade value. 
    # stock_name : str
    @abstractmethod
    def sorted_transactions(self, stock_name): 
        sorted_list = []
        return sorted_list    
    
    # returns a list of transactions of a given stock_name with minimum trade value
    # stock_name : str
    @abstractmethod
    def min_transactions(self, stock_name): 
        min_list = []
        return min_list    
    
    # returns a list of transactions of a given stock_name with maximum trade value
    # stock_name : str
    @abstractmethod
    def max_transactions(self, stock_name): 
        max_list = []
        return max_list    

    # returns a list of transactions of a given stock_name, 
    # with the largest trade value below a given thresholdValue.  
    # stock_name : str
    # thresholdValue : double
    @abstractmethod
    def floor_transactions(self, stock_name, threshold_value): 
        floor_list = []
        return floor_list    

    # returns a list of transactions of a given stock_name, 
    # with the smallest trade value above a given thresholdValue.  
    # stock_name : str
    # thresholdValue : double
    @abstractmethod
    def ceiling_transactions(self, stock_name, threshold_value): 
        ceiling_list = []
        return ceiling_list    

        
    # returns a list of transactions of a given stock_name,  
    # whose trade value is within the range [fromValue, toValue].
    # stock_name : str
    # fromValue : double
    # toValue : double
    @abstractmethod
    def range_transactions(self, stock_name, from_value, to_value): 
        range_list = []
        return range_list    

Left-leaning red-black binary search tree implementation to efficiently log and query large amounts of data.

In [64]:
# ADD AUXILIARY DATA STRUCTURE DEFINITIONS AND HELPER CODE HERE

# constant for storing company names that the stock trading website specialises in
# makes the program easily extendable
# by adding a new company here, StockTradingPlatform() will assign to it a separate TransactionTree.
# transactions can be logged for a given company as soon as it's added to STOCK_NAMES
STOCK_NAMES = ["Barclays",
               "HSBA",
               "Lloyds Banking Group",
               "NatWest Group", 
               "Standard Chartered",
               "3i",
               "Abrdn",
               "Hargreaves Lansdown", 
               "London Stock Exchange Group",
               "Pershing Square Holdings", 
               "Schroders",
               "St. James's Place plc."]

# data type storing individual transactions
class Transaction():
    def __init__(self, transaction):
        # self.name = STOCK_NAME.index(transaction[0])
        # maps the name string into integer.
        # would save space, however time taken to load transactions into
        # tree increased with string to integer mapping (return STOCK_NAME[name] is O(n)).
        # idea could still prove useful but its implementation depends on the client's 
        # requirements (ie. which is more important; time or memory),
        # which wasn't part of the coursework brief

        self.name = transaction[0]
        self.price = transaction[1]
        self.quantity = transaction[2]
        self.time = transaction[3]

    # getter methods to improve encapsulation
    def get_name(self):
        # mapping the name string into integer
        # return STOCK_NAME[name]
        return self.name

    def get_price(self):
        return self.price

    def get_quantity(self):
        return self.quantity
    
    def get_time(self):
        return self.time

    # when print() is called with a Transaction make sure output formatting is correct 
    def __repr__(self):
        return "[" + self.name + ", " + str(self.price) + ", " + str(self.quantity) + ", " + self.time + "]"

# node representation in a binary search tree
class Node():
    def __init__(self, key, value, colour):
        self.right = None
        self.left = None
        self.key = key

        # allows values to be stored in a list if key already exist.
        # simple list implementation is sufficient here as there are no
        # requirements to support any operations on the transactions themselves.
        # additionally, native lists exist without significant overhead and are efficient to use
        self.value = [value]
        self.colour = colour

# Left-leaning red-black binary search tree implementation to store transactions of a company
# implementation migrated from https://algs4.cs.princeton.edu/33balanced/RedBlackBST.java.html
class TransactionTree():
    def __init__(self):
        self.RED = True
        self.BLACK = False
        self.root = None
        
        # allows for more efficient bounds checking for
        # floor, ceiling, range functions
        self.min_key_value = None
        self.max_key_value = None

        # for alternative implementation of min and max functions
        # self.min = []
        # self.max = []

    def is_red(self, node):
        if node is None:
            return False

        return node.colour is self.RED

    def is_empty(self):
        return self.root is None

    # helper function for debugging
    def get(self, key):
        return self.search(self.root, key)

    # helper function for debugging
    def search(self, node, key):
        if node is not None:
            if key < node.key:
                return self.search(node.left, key)
            elif key > node.key:
                return self.search(node.right, key)
            else:
                return node.value
                
        return "Key doesn't exist"

    # returns every value (list of transactions) in the tree 
    # in sorted order (sorting is done by keys)
    def sorted(self):
        if self.is_empty():
            return "TransactionTree is empty"

        sorted_list = []
        self.traverse(sorted_list, self.root)
        return sorted_list

    # recursively traverses the array, first left side, then right
    def traverse(self, sorted_list, node):
        if node is None:
            return
        
        if node.left is not None:
            self.traverse(sorted_list, node.left)

        sorted_list.append(node.value)

        if node.right is not None:
            self.traverse(sorted_list, node.right)

    # alternative implementation of min and max functions
    # getting the min/max transaction(s) would be faster (O(1)) this way.
    # however check_min() and check_max() would have to be run every time
    # a new transaction is recorded. if minTransactions/maxTransactions
    # functions were called almost after every insertion using check_min()
    # and check_max() would be sensible. otherwise max() and min() functions
    # (see below) work just fine

    # def check_min(self, key, value):
    #   if key == self.min_key_value:
    #     self.min.append(value)

    # def check_max(self, key, value):
    #   if key == self.min_key_value:
    #     self.min.append(value)

    # check for the current minimum key in the tree
    def check_min_value(self, key):
        if self.min_key_value is None:
            self.min_key_value = key
        elif key < self.min_key_value:
            self.min_key_value = key
            # self.min = []
          
    # check for the current maximum key in the tree
    def check_max_value(self, key):
        if self.max_key_value is None:
            self.max_key_value = key
        elif key > self.max_key_value:
            self.max_key_value = key
            # self.max = []

    # inserts a key-value pair in the tree
    def put(self, key, value):
        self.check_min_value(key)
        self.check_max_value(key)
        # self.check_min(key, value)
        # self.check_max(key, value)

        self.root = self.insert(self.root, key, value)
        self.root.colour = self.BLACK

    def insert(self, node, key, value):
        if node is None:
            return Node(key, value, self.RED)

        if key < node.key:
            node.left = self.insert(node.left, key, value)
        elif key > node.key:
            node.right = self.insert(node.right, key, value)
        else:
            # adds to the list of values if key already exist, doesn't override
            node.value.append(value)

        if self.is_red(node.right) and not self.is_red(node.left):
            node = self.rotate_left(node)
        if self.is_red(node.left) and self.is_red(node.left.left):
            node = self.rotate_right(node)
        if self.is_red(node.left) and self.is_red(node.right):
            self.flip_colours(node)

        return node

    # helper function for balancing red-black tree
    def rotate_left(self, node):
        x = node.right
        node.right = x.left
        x.left = node
        x.colour = node.colour
        node.colour = self.RED
        
        return x

    def rotate_right(self, node):
        x = node.left
        node.left = x.right
        x.right = node
        x.colour = node.colour
        node.colour = self.RED

        return x

    def flip_colours(self, node):
        node.colour = not node.colour
        node.left.colour = not node.left.colour
        node.right.colour = not node.right.colour
    
    # finds the value (list of transactions) corresponding to the minimum key in tree
    def min(self):
        if self.is_empty():
            return "TransactionTree is empty"

        return self.get_min(self.root)

    def get_min(self, node):
        if node.left is None:
            return node.value
        else:
            return self.get_min(node.left)
    # finds the value (list of transactions) corresponding to the maximum key in tree
    def max(self):
        if self.is_empty():
            return "TransactionTree is empty"

        return self.get_max(self.root)

    def get_max(self, node):
        if node.right is None:
            return node.value
        else:
            return self.get_max(node.right)    

    # finds the value (list of transaction) corresponding to
    # the maximum key smaller than or equal to threshold
    def floor(self, key):
        if self.is_empty():
            return "TransactionTree is empty"
        return self.get_floor(self.root, key)
    
    def get_floor(self, node, key):
        if node is None:
            return None

        # largest key <= threshold
        if key < node.key:
            return self.get_floor(node.left, key)
        else:
            threshold = self.get_floor(node.right, key)

        if threshold is not None:
            return threshold
        else:
            return node.value

    # finds the value (list of transaction) corresponding to
    # the minumum key larger than or equal to threshold
    def ceiling(self, key):
        if self.is_empty():
            return "TransactionTree is empty"
        return self.get_ceiling(self.root, key)
    
    def get_ceiling(self, node, key):
        if node is None:
            return None

        # smallest key >= threshold
        if key > node.key:
            return self.get_ceiling(node.right, key)
        else:
            threshold = self.get_ceiling(node.left, key)

        if threshold is not None:
            return threshold
        else:
            return node.value

    # finds the list of values (list of list of transactions) corresponding to
    # keys in a given range [low, high]
    def range(self, low, high):
        if self.is_empty():
            return "TransactionTree is empty"

        range_list = []
        self.get_range(range_list, self.root, low, high)
        return range_list

    def get_range(self, range_list, node, low, high):
        if node is None:
            return
        
        if low < node.key:
            self.get_range(range_list, node.left, low, high)
        if low <= node.key and high >= node.key:
            range_list.append(node.value)
        if high > node.key:
            self.get_range(range_list, node.right, low, high)

# no extensive testing done
# assert statements are used to check the validity of the input
class InputValidator():
    def __init__(self):
        # make sure to change date_format if format in TransactionDataGenerator changes
        self.date_format = '%d/%m/%Y %H:%M:%S'
        
    def validate_log(self, transactionRecord):
        assert(type(transactionRecord) == list)
        assert(len(transactionRecord) == 4)
        assert(type(transactionRecord[0]) == str)
        assert(type(transactionRecord[1]) == float)
        assert(transactionRecord[1] > 0)
        assert(type(transactionRecord[2]) == int)
        assert(transactionRecord[2] > 0)
        assert(type(transactionRecord[3]) == str)

        # uncomment only if "from datetime import datetime" is allowed at the top of this section
        # assert(datetime.strptime(transactionRecord[3], self.date_format))

    def validate_sorted(self, stock_name):
        assert(type(stock_name) == str)

    def validate_min(self, stock_name):
        assert(type(stock_name) == str)

    def validate_max(self, stock_name):
        assert(type(stock_name) == str)

    def validate_floor(self, thresholdValue, minKey):
        assert(thresholdValue > minKey)

    def validate_ceiling(self, thresholdValue, maxKey):
        assert(thresholdValue < maxKey)

    def validate_range(self, fromValue, toValue, minKey, maxKey):
        assert(fromValue <= toValue)
        assert(fromValue <= maxKey)
        assert(toValue >= minKey)

Implement API from abstract class.

In [65]:
class StockTradingPlatform(AbstractTradingBackend):
    def __init__(self):
        self.validator = InputValidator()
        self.company_transactions = {}
        for stock_name in STOCK_NAMES:
            self.company_transactions[stock_name] = TransactionTree()

    def log_transaction(self, transaction_record):
        try:
            self.validator.validate_log(transaction_record)
            transaction = Transaction(transaction_record)
            total_value = transaction.get_price()*transaction.get_quantity()
            self.company_transactions[transaction.get_name()].put(total_value, transaction)
        except AssertionError:
            print("Incorrect transaction format")
        except KeyError:
            print("Company not in dictionary")
        except ValueError:
            print("Incorrect date format")

    def sorted_transactions(self, stock_name): 
        sortedList = []
        try:
            self.validator.validate_sorted(stock_name)
            sortedList = self.company_transactions[stock_name].sorted()
        except KeyError:
            print("Company not is dictionary")
        except TypeError:
            print("Argument must be a string")
        
        return sortedList    
    
    def min_transactions(self, stock_name): 
        minList = []
        try:
            self.validator.validate_min(stock_name)
            minList = self.company_transactions[stock_name].min()     
        except KeyError:
            print("Company not is dictionary")
        except AssertionError:
            print("Argument must be a string")
        
        return minList    
    
    def max_transactions(self, stock_name): 
        maxList = []
        try:
            self.validator.validate_max(stock_name)
            maxList = self.company_transactions[stock_name].max()     
        except KeyError:
            print("Company not is dictionary")
        except AssertionError:
            print("Argument must be a string")
        
        return maxList    

    def floor_transactions(self, stock_name, threshold_value): 
        floorList = []       
        try:
            minKey = self.company_transactions[stock_name].min_key_value
            self.validator.validate_floor(threshold_value, minKey)
            floorList = self.company_transactions[stock_name].floor(threshold_value)
        except AssertionError:
            print("Threshold is too small")    
        except KeyError:
            print("Company not is dictionary")
        except TypeError:
            print("TransactionTree is empty")    
        
        return floorList    

    def ceiling_transactions(self, stock_name, threshold_value): 
        ceilingList = []        
        try:
            maxKey = self.company_transactions[stock_name].max_key_value
            self.validator.validate_ceiling(threshold_value, maxKey)
            ceilingList = self.company_transactions[stock_name].ceiling(threshold_value)
        except AssertionError:
            print("Threshold is too large")   
        except KeyError:
            print("Company not is dictionary")
        except TypeError:
            print("TransactionTree is empty")    
        
        return ceilingList    
    
    def range_transactions(self, stock_name, from_value, to_value): 
        rangeList = []
        try:
            minKey = self.company_transactions[stock_name].min_key_value
            maxKey = self.company_transactions[stock_name].max_key_value
            self.validator.validate_range(from_value, to_value, minKey, maxKey)
            rangeList = self.company_transactions[stock_name].range(from_value, to_value)  
        except AssertionError:
            print("Incorrect fromValue, toValue pair")   
        except KeyError:
            print("Company not is dictionary")
        except TypeError:
            print("TransactionTree is empty")    
        
        return rangeList