In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
import csv
import os

# --- ENERGY FUNCTIONS ---
def brightness(img):
    return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

def edgeness(gray):
    gray = np.float32(gray)
    grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    energy = np.abs(grad_x) + np.abs(grad_y)
    return energy

# --- SEAM REMOVAL ---
def remove_seam(img, seam):
    height, width, channels = img.shape
    mask = np.ones((height, width), dtype=bool)
    for i, j in seam:
        mask[i, j] = False
    new_width = width - 1
    return img[mask].reshape((height, new_width, channels))

# --- GENETIC STRUCTURES ---
class Individual:
    def __init__(self, pivot_row, pivot_col, deltas):
        self.pivot_row = pivot_row
        self.pivot_col = pivot_col
        self.deltas = deltas

def decode_individual(ind, height, width):
    pr, pc = ind.pivot_row, ind.pivot_col
    seam = [(pr, pc)]
    deltas = ind.deltas

    for i in range(pr + 1, height):
        delta = deltas[i - 1] if i - 1 < len(deltas) else 0
        pc = np.clip(pc + delta, 0, width - 1)
        seam.append((i, pc))

    pc = ind.pivot_col
    for i in range(pr - 1, -1, -1):
        delta = deltas[i] if i < len(deltas) else 0
        pc = np.clip(pc + delta, 0, width - 1)
        seam.insert(0, (i, pc))

    return seam

def fitness(ind, energy_map):
    seam = decode_individual(ind, *energy_map.shape)
    rows, cols = zip(*seam)
    rows = np.array(rows)
    cols = np.array(cols)
    if np.any(rows < 0) or np.any(rows >= energy_map.shape[0]) or np.any(cols < 0) or np.any(cols >= energy_map.shape[1]):
        return float('inf')
    return np.sum(energy_map[rows, cols]) / len(seam)

def int_to_ternary_array(n, digits):
    arr = np.zeros(digits, dtype=np.int8)
    for i in range(digits - 1, -1, -1):
        rem = n % 3
        n //= 3
        if rem == 2:
            arr[i] = -1
            n += 1
        else:
            arr[i] = rem
    return arr

def ternary_array_to_int(arr):
    powers = 3 ** np.arange(len(arr))[::-1]
    return int(np.dot(arr, powers))

def crossover(p1, p2, width, height):
    point = np.random.randint(1, len(p1.deltas))
    child_deltas = np.concatenate([p1.deltas[:point], p2.deltas[point:]])

    digits = int(np.ceil(np.log(width) / np.log(3)))
    t1 = int_to_ternary_array(p1.pivot_col, digits)
    t2 = int_to_ternary_array(p2.pivot_col, digits)
    pivot_point = np.random.randint(1, digits)
    child_ternary = np.concatenate([t1[:pivot_point], t2[pivot_point:]])
    new_col = np.clip(ternary_array_to_int(child_ternary), 0, width - 1)

    new_row = np.random.randint(0, height)
    return Individual(new_row, new_col, child_deltas)

def mutate(ind, mutation_rate=0.1, width=None):
    new_deltas = ind.deltas.copy()
    n = len(new_deltas)  # n = height - 1

    candidate_positions = list(range(n))
    if 0 <= ind.pivot_row < n:
        candidate_positions.remove(ind.pivot_row)

    k = max(1, int(mutation_rate * n))
    k = min(k, len(candidate_positions))

    random_positions = np.random.choice(candidate_positions, size=k, replace=False)
    if 0 <= ind.pivot_row < n:
        mutation_positions = np.append(random_positions, ind.pivot_row)
    else:
        mutation_positions = random_positions

    new_deltas[mutation_positions] = np.random.choice([-1, 0, 1], size=len(mutation_positions))

    if width is None or width <= 0:
        raise ValueError("width must be a positive integer")

    new_col = np.random.randint(0, width)

    return Individual(ind.pivot_row, new_col, new_deltas)

def generate_random_individual(height, width):
    row = np.random.randint(0, height)
    col = np.random.randint(0, width)
    deltas = np.random.choice([-1, 0, 1], size=height - 1)
    return Individual(row, col, deltas)

def genetic_seam_search(energy_map, population_size=20, generations=100,
                        early_stop_patience=20, min_improvement=1e-3,
                        crossover_rate=0.8):  
    h, w = energy_map.shape
    pop = [generate_random_individual(h, w) for _ in range(population_size)]
    
    best_fitness = float('inf')
    patience_counter = 0 

    for gen in range(generations):
        fitnesses = np.array([fitness(ind, energy_map) for ind in pop])
        idx = np.argsort(fitnesses)
        survivors = [pop[i] for i in idx[:population_size // 2]]

        current_best_fitness = fitnesses[idx[0]]
        
        if best_fitness - current_best_fitness < min_improvement:
            patience_counter += 1
        else:
            best_fitness = current_best_fitness
            patience_counter = 0

        if patience_counter >= early_stop_patience:
            break

        fit_vals = fitnesses[idx[:population_size // 2]]
        max_fit = fit_vals.max()
        weights = max_fit - fit_vals + 1e-6
        probs = weights / weights.sum()

        new_pop = survivors[:]
        while len(new_pop) < population_size:
            p1, p2 = np.random.choice(survivors, size=2, replace=False, p=probs)
            if np.random.rand() < crossover_rate:  
                child = crossover(p1, p2, w, h)
            else:
                child = p1  
            child = mutate(child, width=w)
            new_pop.append(child)

        pop = new_pop

    best_fitness = min([fitness(ind, energy_map) for ind in pop])
    best_inds = [ind for ind in pop if fitness(ind, energy_map) == best_fitness]
    return (decode_individual(ind, h, w) for ind in best_inds)



# --- SHRINK OPERATIONS ---
def shrink_image_with_genetic_seam(img, target_width):
    current_width = img.shape[1]
    orig_energy = edgeness(brightness(img))
    total_energy = np.sum(orig_energy)
    lost_energy = 0
    times = []
    total_seams = 0  

    while current_width > target_width:
        gray = brightness(img)
        energy = edgeness(gray)
        t0 = time.time()

        seams_generator = genetic_seam_search(energy)
        seam = next(seams_generator)  
        t1 = time.time()

        elapsed = t1 - t0
        times.append(elapsed)

        rows, cols = zip(*seam)
        lost_energy += np.sum(energy[rows, cols])

        img = remove_seam(img, seam)
        current_width = img.shape[1]
        total_seams += 1

    avg_time_per_seam = (sum(times) / total_seams) if total_seams > 0 else 0
    return img, avg_time_per_seam, sum(times), lost_energy / total_energy



def shrink_image_height_with_genetic_seam(img, target_height):
    rot = np.rot90(img, k=1)
    new_width = int(rot.shape[1] * target_height / rot.shape[0])
    shrunk, avg_time, total_time, loss = shrink_image_with_genetic_seam(rot, new_width)
    return np.rot90(shrunk, k=-1), avg_time, total_time, loss

def shrink_image_alternating_seams(img, target_width, target_height):
    current_height, current_width = img.shape[:2]
    total_energy = np.sum(edgeness(brightness(img)))
    lost_energy = 0
    times = []
    total_seams = 0
    remove_vertical = True  

    while current_width > target_width or current_height > target_height:
        gray = brightness(img)
        energy = edgeness(gray)
        t0 = time.time()

        if remove_vertical and current_width > target_width:
           
            seam = next(genetic_seam_search(energy))  
            rows, cols = zip(*seam)
            lost_energy += np.sum(energy[rows, cols])
            img = remove_seam(img, seam)
            current_width -= 1

        elif not remove_vertical and current_height > target_height:
           
            rotated = np.rot90(img, k=1)
            rotated_gray = brightness(rotated)
            rotated_energy = edgeness(rotated_gray)

            seam = next(genetic_seam_search(rotated_energy))
            rows, cols = zip(*seam)
            lost_energy += np.sum(rotated_energy[rows, cols])
            rotated = remove_seam(rotated, seam)
            img = np.rot90(rotated, k=-1)
            current_height -= 1

        else:
           
            remove_vertical = not remove_vertical
            continue

        t1 = time.time()
        times.append(t1 - t0)
        total_seams += 1
        remove_vertical = not remove_vertical  

    avg_time = sum(times) / total_seams if total_seams > 0 else 0
    return img, avg_time, sum(times), lost_energy / total_energy


# --- LOGGING ---
def save_results_to_csv(img_path, original_shape, resized_shape, avg_time, total_time, loss_ratio, total_elapsed_time):
    csv_path = "test_ketqua_case_3.csv"
    header = [
        "Tên ảnh", 
        "Kích thước gốc", 
        "Kích thước mới", 
        "Avg time/seam (s)", 
        "Tổng thời gian tìm seam(s)", 
        "Tỉ lệ năng lượng mất (%)", 
        "Tổng thời gian xử lý (s)"
    ]

    if not os.path.exists(csv_path):
        with open(csv_path, 'w', newline='', encoding='utf-8') as file:
            writer = csv.writer(file)
            writer.writerow(header)

    with open(csv_path, 'a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow([
            img_path,
            f"{original_shape[0]}x{original_shape[1]}",
            f"{resized_shape[0]}x{resized_shape[1]}",
            f"{avg_time:.4f}",
            f"{total_time:.4f}",
            f"{loss_ratio * 100:.2f}",
            f"{total_elapsed_time:.4f}"
        ])

# --- MAIN ---
if __name__ == '__main__':
    base_input_folder = r"D:\\data\\data_case_3"
    output_folder = r"D:\\data\\result3\\result_test_case_3"
    reduce_height = False
    xenke = True
    os.makedirs(output_folder, exist_ok=True)

    folders_to_process = ["128x128", "256x256", "512x512", "1024x1024"]

    processed_images = set()
    csv_path = "test_ketqua_case_3.csv"
    if os.path.exists(csv_path):
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            next(reader)  
            for row in reader:
                if row:
                    processed_images.add(row[0])

    for folder_name in folders_to_process:
        input_folder = os.path.join(base_input_folder, folder_name)

        for root, dirs, files in os.walk(input_folder):
            for fname in files:
                if not fname.lower().endswith(('.jpg', '.png', '.jpeg')):
                    continue

                img_path = os.path.join(root, fname)
                if img_path in processed_images:
                    print(f"Đã xử lý, bỏ qua: {img_path}")
                    continue

                rel_path = os.path.relpath(root, base_input_folder)
                out_dir = os.path.join(output_folder, rel_path)
                os.makedirs(out_dir, exist_ok=True)
                out_path = os.path.join(out_dir, fname)

                print(f"\n>>> Đang xử lý: {img_path}")

                start = time.time()

                img = cv2.imread(img_path)
                if img is None:
                    print(f"Lỗi khi đọc ảnh: {img_path}")
                    continue

                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                original_shape = img.shape
                target_width = int(original_shape[1] * 0.9)
                target_height = int(original_shape[0] * 0.9)

                if xenke:
                    result, avg_time, total_time, loss_ratio = shrink_image_alternating_seams(img, target_width, target_height)
                elif reduce_height:
                    result, avg_time, total_time, loss_ratio = shrink_image_height_with_genetic_seam(img, target_height)
                else:
                    result, avg_time, total_time, loss_ratio = shrink_image_with_genetic_seam(img, target_width)

                total_elapsed_time = time.time() - start
                resized_shape = result.shape

                print(f"Kích thước gốc: {original_shape}")
                print(f"Kích thước mới: {resized_shape}")
                print(f"Thời gian TB mỗi seam: {avg_time:.4f}s")
                print(f"Tổng thời gian seam: {total_time:.4f}s")
                print(f"Tổng thời gian xử lý: {total_elapsed_time:.4f}s")
                print(f"Tỉ lệ năng lượng mất: {loss_ratio:.2%}")

                # Hiển thị ảnh
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(img)
                plt.title("Ảnh gốc")
                plt.axis('off')

                plt.subplot(1, 2, 2)
                plt.imshow(result)
                plt.title("Ảnh đã thu nhỏ")
                plt.axis('off')
                plt.tight_layout()
                plt.show()

                
                cv2.imwrite(out_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
                save_results_to_csv(img_path, original_shape, resized_shape,
                                    avg_time, total_time, loss_ratio, total_elapsed_time)

