<a href="https://colab.research.google.com/github/lilloo04/DSA/blob/main/Algorithm_HW_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random, copy
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# 기본 설정
BAR_COLOR = '#6666ff'
HIGHLIGHT_COLOR = '#ff6666'
DONE_COLOR = '#66cc66'
NUM_ELEMENTS = 20

# 정렬 알고리즘

def selection_sort(arr):
    a = arr.copy(); c=s=0
    for i in range(len(a)):
        min_idx = i
        for j in range(i+1, len(a)):
            c += 1; yield a, min_idx, j, c, s
            if a[j] < a[min_idx]:
                min_idx = j
        if i != min_idx:
            a[i], a[min_idx] = a[min_idx], a[i]; s+=1
        yield a, i, min_idx, c, s

def bubble_sort(arr):
    a = arr.copy(); c=s=0
    for i in range(len(a)):
        for j in range(len(a) - i - 1):
            c += 1; yield a, j, j+1, c, s
            if a[j] > a[j+1]:
                a[j], a[j+1] = a[j+1], a[j]; s+=1
                yield a, j, j+1, c, s

def insertion_sort(arr):
    a = arr.copy(); c=s=0
    for i in range(1, len(a)):
        key = a[i]; j = i - 1
        c += 1; yield a, i, j, c, s
        while j >= 0 and a[j] > key:
            c+=1; a[j+1] = a[j]; s+=1; j-=1
            yield a, i, j, c, s
        a[j+1] = key
        yield a, i, j, c, s

def quick_sort(arr):
    a = arr.copy(); stack=[(0,len(a)-1)]; c=s=0
    while stack:
        l,h=stack.pop()
        if l<h:
            pivot,i=a[h],l
            for j in range(l,h):
                c+=1; yield a,j,h,c,s
                if a[j]<pivot:
                    a[i],a[j]=a[j],a[i]; s+=1; i+=1
                    yield a,i,j,c,s
            a[i],a[h]=a[h],a[i]; s+=1
            yield a,i,h,c,s
            stack+=[(l,i-1),(i+1,h)]

def merge_sort(arr):
    a = arr.copy(); width=1; n=len(a); c=s=0
    while width < n:
        for i in range(0, n, 2*width):
            l, r = i, min(i+width,n); e = min(i+2*width,n)
            merged = sorted(a[l:e]); c+=len(merged)
            for j,v in enumerate(merged):
                a[l+j] = v; s+=1
                yield a, l+j, -1, c, s
        width *= 2

def heap_sort(arr):
    a = arr.copy(); n=len(a); c=s=0
    def heapify(n,i):
        nonlocal c,s
        largest=i; l=2*i+1; r=2*i+2
        if l<n: c+=1; largest=l if a[l]>a[largest] else largest
        if r<n: c+=1; largest=r if a[r]>a[largest] else largest
        if largest!=i:
            a[i],a[largest]=a[largest],a[i]; s+=1
            yield a,i,largest,c,s
            yield from heapify(n,largest)
    for i in range(n//2-1,-1,-1):
        yield from heapify(n,i)
    for i in range(n-1,0,-1):
        a[i],a[0]=a[0],a[i]; s+=1
        yield a,i,0,c,s
        yield from heapify(i,0)

def shell_sort(arr):
    a = arr.copy(); n=len(a); gap=n//2; c=s=0
    while gap:
        for i in range(gap,n):
            temp=a[i]; j=i
            c+=1; yield a,i,j-gap,c,s
            while j>=gap and a[j-gap]>temp:
                c+=1; a[j]=a[j-gap]; j-=gap; s+=1
                yield a,i,j,c,s
            a[j]=temp
        gap//=2

def counting_sort(arr):
    a = arr.copy(); max_val=max(a); c=s=0
    count=[0]*(max_val+1)
    for num in a: count[num]+=1; s+=1; yield a,num,-1,c,s
    i=0
    for val,cnt in enumerate(count):
        for _ in range(cnt):
            a[i]=val; s+=1; yield a,i,-1,c,s; i+=1

def radix_sort(arr):
    a = arr.copy(); max_val=max(a); exp=1; c=s=0
    while max_val//exp>0:
        count=[0]*10; output=[0]*len(a)
        for num in a: count[(num//exp)%10]+=1; s+=1
        for i in range(1,10): count[i]+=count[i-1]
        for num in reversed(a):
            idx=(num//exp)%10; output[count[idx]-1]=num; count[idx]-=1; s+=1
        for i in range(len(a)):
            a[i]=output[i]; yield a,i,-1,c,s
        exp*=10

# 상태 변수
sorts = {
    'Selection': selection_sort,
    'Bubble': bubble_sort,
    'Insertion': insertion_sort,
    'Quick': quick_sort,
    'Merge': merge_sort,
    'Heap': heap_sort,
    'Shell': shell_sort,
    'Counting': counting_sort,
    'Radix': radix_sort
}
data_states, generators, highlights, counters = {}, {}, {}, {}
step_counter = 0
done_order = []
original_data = []

# 위젯 설정
reset_btn = widgets.Button(description="🔄 초기화", layout=widgets.Layout(width="120px"))
step_btn = widgets.Button(description="▶ 다음 단계", layout=widgets.Layout(width="120px"))
input_box = widgets.Text(placeholder='예: 5, 12, 30, 1, 7', description='배열 입력:')
error_label = widgets.HTML(value='', layout=widgets.Layout(margin='0 0 10px 0'))
output = widgets.Output()

# 시각화 함수
def draw_all():
    fig, axs = plt.subplots(len(sorts), 1, figsize=(10, 2.5 * len(sorts)))
    if len(sorts) == 1: axs = [axs]
    for ax, name in zip(axs, sorts.keys()):
        arr = data_states[name]; i,j = highlights[name]
        colors = [DONE_COLOR if name in done_order else HIGHLIGHT_COLOR if x==i or x==j else BAR_COLOR for x in range(len(arr))]
        ax.bar(range(len(arr)), arr, color=colors)
        c,s = counters[name]
        ax.set_title(f"{name} | Comparison: {c}  Exchange: {s}", loc='left')
        ax.set_ylim(0, max(arr)+10)
        ax.set_xticks([]); ax.set_yticks([])
    fig.subplots_adjust(top=0.92, hspace=0.6)
    fig.suptitle(f"sorting step: {step_counter}", fontsize=18, fontweight='bold')
    plt.show()
    print("Completed sorting order:", " → ".join(done_order))

# 초기화 함수
def reset_all(_=None):
    global step_counter, done_order, original_data
    original_data = [random.randint(10, 100) for _ in range(NUM_ELEMENTS)]
    step_counter = 0; done_order = []
    for name in sorts:
        data_states[name] = original_data.copy()
        generators[name] = sorts[name](original_data.copy())
        highlights[name] = (-1, -1)
        counters[name] = (0, 0)
    with output:
        clear_output(wait=True)
        draw_all()

# step 진행 함수
def step_all(_=None):
    global step_counter
    for name in sorts:
        if name in done_order: continue
        try:
            state, i, j, c, s = next(generators[name])
            data_states[name] = state.copy()
            highlights[name] = (i, j)
            counters[name] = (c, s)
        except StopIteration:
            done_order.append(name)
    step_counter += 1
    with output:
        clear_output(wait=True)
        draw_all()

# 사용자 배열 입력 처리
def on_array_input(change):
    text = change['new']
    try:
        nums = [int(x.strip()) for x in text.split(',') if x.strip()]
        if len(nums) < 2:
            raise ValueError()
        global original_data, step_counter, done_order
        original_data = nums.copy(); step_counter = 0; done_order = []
        for name in sorts:
            data_states[name] = nums.copy()
            generators[name] = sorts[name](nums.copy())
            highlights[name] = (-1, -1)
            counters[name] = (0, 0)
        with output:
            clear_output(wait=True)
            draw_all()
        error_label.value = ''
    except:
        error_label.value = "<span style='color:red;'>⚠️ 숫자만 쉼표로 구분해 입력하세요 (예: 3, 14, 15, 92)</span>"

# 이벤트 연결
reset_btn.on_click(reset_all)
step_btn.on_click(step_all)
input_box.observe(on_array_input, names='value')

# UI 구성 및 실행
display(widgets.VBox([
    widgets.HBox([reset_btn, step_btn]),
    input_box,
    error_label
]))
display(output)
reset_all()


VBox(children=(HBox(children=(Button(description='🔄 초기화', layout=Layout(width='120px'), style=ButtonStyle()), …

Output()