Splay 트리 구현

In [None]:
class Node:
    def __init__(self, key):
        self.key = key
        self.parent = None
        self.left = None
        self.right = None

class BST:
    def __init__(self):
        self.root = None

    def search(self, key):
        v = self.root
        while v:
            if key == v.key:
                return v
            elif key < v.key:
                v = v.left
            else:
                v = v.right
        return None

    def insert(self, key):
        new_node = Node(key)
        if self.root is None:
            self.root = new_node
            return new_node
        
        v = self.root
        while True:
            if key < v.key:
                if v.left is None:
                    v.left = new_node
                    new_node.parent = v
                    break
                v = v.left
            else:
                if v.right is None:
                    v.right = new_node
                    new_node.parent = v
                    break
                v = v.right
        return new_node

    def preorder(self, v):
        if v:
            print(v.key, end=' ')
            self.preorder(v.left)
            self.preorder(v.right)

    def inorder(self, v):
        if v:
            self.inorder(v.left)
            print(v.key, end=' ')
            self.inorder(v.right)

    def postorder(self, v):
        if v:
            self.postorder(v.left)
            self.postorder(v.right)
            print(v.key, end=' ')

class SplayTree(BST):
    def rotate(self, x):
        p = x.parent
        g = p.parent
        
        if x == p.left: # Right Rotate
            p.left = x.right
            if x.right:
                x.right.parent = p
            x.right = p
        else: # Left Rotate
            p.right = x.left
            if x.left:
                x.left.parent = p
            x.left = p
        
        p.parent = x
        x.parent = g
        
        if g is None:
            self.root = x
        elif g.left == p:
            g.left = x
        else:
            g.right = x

    def splay(self, x):
        # x가 루트가 될 때까지 회전 반복
        while x.parent is not None:
            p = x.parent
            g = p.parent
            
            # 1. Zig Case (부모가 루트)
            if g is None:
                self.rotate(x)
            
            # 2. Zig-Zig Case (일직선) -> "부모 먼저"가 아니라 "나 먼저" 2번으로 수정
            elif (g.left == p and p.left == x) or (g.right == p and p.right == x):
                self.rotate(x) # 부모(p)를 먼저 올리는 것이 아니라, 나(x)를 올림
                self.rotate(x) # 한 번 더 나(x)를 올림
            
            # 3. Zig-Zag Case (꺾임) -> 원래대로 나 먼저 2번
            else:
                self.rotate(x)
                self.rotate(x)

    def search(self, key):
        v = super().search(key)
        if v:
            self.splay(v)
        return v

    def insert(self, key):
        v = super().insert(key)
        self.splay(v)
        return v
    
    def delete(self, x):
        # 메인 루프에서 v = T.search() 후 T.delete(v)를 호출하므로
        # x가 None인 경우(search 실패)를 방어해야 함
        if x is None: return

        # 이미 search를 했으므로 x는 루트에 있겠지만, 확실히 하기 위해 splay
        self.splay(x)
        
        L = x.left
        R = x.right
        
        if L:
            L.parent = None
            m = L
            while m.right:
                m = m.right
            self.splay(m) # L의 max 노드 m을 L의 루트로 만듦
            m.right = R
            if R:
                R.parent = m
            self.root = m
        else:
            if R:
                R.parent = None
            self.root = R


T = SplayTree()
while True:
    try:
        line = input()
        if not line: break # 빈 줄 입력 시 종료
        cmd = line.split()
    except EOFError:
        break

    if cmd[0] == 'insert':
        v = T.insert(int(cmd[1]))
        print("+ {0} is inserted".format(v.key))
    elif cmd[0] == 'delete':
        v = T.search(int(cmd[1]))
        T.delete(v)
        print("- {0} is deleted".format(int(cmd[1])))
    elif cmd[0] == 'search':
        v = T.search(int(cmd[1]))
        if v == None:
            print("* {0} is not found!".format(cmd[1]))
        else:
            print("* {0} is found!".format(cmd[1]))
    elif cmd[0] == 'preorder':
        T.preorder(T.root)
        print()
    elif cmd[0] == 'postorder':
        T.postorder(T.root)
        print()
    elif cmd[0] == 'inorder':
        T.inorder(T.root)
        print()
    elif cmd[0] == 'exit':
        break
    else:
        print("* not allowed command. enter a proper command!")