# Linked List with Pivot

Given a linked list of numbers and a pivot k, partition the linked list so that all nodes less than k come before nodes greater than or equal to k.

For example, given the linked list 5 -> 1 -> 8 -> 0 -> 3 and k = 3, the solution could be 1 -> 0 -> 5 -> 8 -> 3.

In [3]:
# First you need to define some helpful classes... 

class Node: 
    def __init__(self, val, next=None): 
        self.val = val
        self.next = next
        
class LinkedList: 
    def __init__(self): 
        self.head = None
        self.tail = None
        
    # need to be able to insert elements at the front of the list
    def insert(self, data): 
        if not self.head:
            self.head = self.tail = Node(data)
        else: 
            tmp = Node(data)
            tmp.next = self.head
            self.head = tmp
    
    # need to be able to append elements at the back...
    # This is O(1) because we keep track of the tail pointer..
    def append(self, data): 
        if not self.head: 
            self.head = self.tail = Node(data) 
        else: 
            tmp = Node(data)
            self.tail.next = tmp
            self.tail = self.tail.next

In [7]:
# initialize three linked lists to hold elements ...
# smaller than, equal to, and larger than the pivot... 

def partition(head, pivot): 
    low = LinkedList()
    middle = LinkedList()
    high = LinkedList()
    
    while head: 
        if head.val < pivot: 
            low.append(head.val)
        elif head.val == pivot: 
            middle.append(head.val) 
        else: 
            high.append(head.val) 
        head = head.next
        
    m = middle.head
    while m: 
        low.append(m.val)
        m = m.next
        
    h = high.head
    while h: 
        low.append(h.val)
        h = h.next
        
    # return the concatenation of our lists in the order low, middle, high
    return low

## Notice that k doesn't need to be in the middle of hte list. 
Hence, we can solve this simpler. We can traverse the input list and insert elements whose value is less than k into our new linked list. THen append everything else. Gets a simpler solution. 



In [6]:
def partition(head, pivot): 
    new = LinkedList()
    
    while head: 
        if head.val < pivot: 
            new.insert(head.val) 
        else: 
            new.append(head.val)
        head = head.next
        
    return new