In [6]:
from PIL import Image
import os

# === Config ===
image_path = "input_data/X_test/"  # Đường dẫn đến ảnh gốc
ROWS = 3
COLS = 5

def split_image(image_path, rows, cols):
    """Đọc ảnh và cắt thành rows x cols, trả về danh sách các mảnh."""
    img = Image.open(image_path)
    width, height = img.size

    piece_width = width // cols
    piece_height = height // rows

    pieces = []
    for r in range(rows):
        for c in range(cols):
            left = c * piece_width
            upper = r * piece_height
            right = left + piece_width
            lower = upper + piece_height
            piece = img.crop((left, upper, right, lower))
            pieces.append(piece)

    return pieces

def load_all_pieces(image_dir, rows, cols):
    """
    Duyệt toàn bộ ảnh trong thư mục và trả về mảng 3 chiều:
    all_pieces[image_index][row][col] = PIL.Image
    """
    all_pieces = []
    image_files = sorted(
        [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    )

    print(f"Found {len(image_files)} images in {image_dir}")

    for filename in image_files:
        image_path = os.path.join(image_dir, filename)
        print(f"- Processing {filename}...")
        pieces = split_image(image_path, rows, cols)
        all_pieces.append(pieces)

    print(f"\nTotal images processed: {len(all_pieces)}")
    return all_pieces

all_pieces = load_all_pieces(image_path, ROWS, COLS)
# if DEBUG:
#     all_pieces = all_pieces[:20]  # Giới hạn để debug nhanh

# Ví dụ: xem thông tin
print(f"Total images: {len(all_pieces)}")  # số lượng ảnh
if all_pieces:
    print(f"Rows per image: {len(all_pieces[0])}")


Found 100 images in input_data/X_test/
- Processing Alfred_Sisley_115_shuffled.jpg...
- Processing Alfred_Sisley_188_shuffled.jpg...
- Processing Alfred_Sisley_205_shuffled.jpg...
- Processing Alfred_Sisley_232_shuffled.jpg...
- Processing Alfred_Sisley_6_shuffled.jpg...
- Processing Amedeo_Modigliani_112_shuffled.jpg...
- Processing Amedeo_Modigliani_143_shuffled.jpg...
- Processing Amedeo_Modigliani_2_shuffled.jpg...
- Processing Amedeo_Modigliani_65_shuffled.jpg...
- Processing Amedeo_Modigliani_73_shuffled.jpg...
- Processing Andrei_Rublev_32_shuffled.jpg...
- Processing Andrei_Rublev_50_shuffled.jpg...
- Processing Andrei_Rublev_89_shuffled.jpg...
- Processing Andy_Warhol_11_shuffled.jpg...
- Processing Andy_Warhol_161_shuffled.jpg...
- Processing Andy_Warhol_171_shuffled.jpg...
- Processing Andy_Warhol_70_shuffled.jpg...
- Processing Andy_Warhol_88_shuffled.jpg...
- Processing Camille_Pissarro_3_shuffled.jpg...
- Processing Caravaggio_19_shuffled.jpg...
- Processing Caravaggio_22

In [7]:
import numpy as np
from PIL import Image
import os
import csv
import random
from tqdm import tqdm
import pandas as pd
# =========================
# COST: MSE theo viền kề
# =========================
def _to_np_rgb(pil_img):
    return np.asarray(pil_img.convert("RGB"), dtype=np.float32)

def _border_arrays(pieces):
    arrs = [_to_np_rgb(p) for p in pieces]
    tops    = [a[0, :, :]  for a in arrs]  # (W,3)
    bottoms = [a[-1, :, :] for a in arrs]  # (W,3)
    lefts   = [a[:, 0, :]  for a in arrs]  # (H,3)
    rights  = [a[:, -1, :] for a in arrs]  # (H,3)
    return tops, bottoms, lefts, rights

def _mse_rgb_edge(edge1, edge2):
    L = min(edge1.shape[0], edge2.shape[0])
    if L <= 0:
        return 0.0
    diff = edge1[:L].astype(np.float32) - edge2[:L].astype(np.float32)
    return float(np.mean(diff * diff))

def compute_cost_matrix_mse(pieces):
    """
    - H[i,j]: cost đặt j bên phải i  -> so sánh right(i) vs left(j)
    - V[i,j]: cost đặt j bên dưới i -> so sánh bottom(i) vs top(j)
    """
    n = len(pieces)
    H = np.zeros((n, n), dtype=np.float64)
    V = np.zeros((n, n), dtype=np.float64)

    tops, bottoms, lefts, rights = _border_arrays(pieces)
    for i in range(n):
        r_i = rights[i]
        b_i = bottoms[i]
        for j in range(n):
            if i == j:
                continue
            H[i, j] = _mse_rgb_edge(r_i, lefts[j])
            V[i, j] = _mse_rgb_edge(b_i, tops[j])
    return H, V
# =========================
# MST-BASED SOLVER
# =========================
class DSU:
    def __init__(self, n):
        self.p = list(range(n))
        self.r = [0]*n
    def find(self, x):
        while self.p[x] != x:
            self.p[x] = self.p[self.p[x]]
            x = self.p[x]
        return x
    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb: return False
        if self.r[ra] < self.r[rb]:
            self.p[ra] = rb
        elif self.r[ra] > self.r[rb]:
            self.p[rb] = ra
        else:
            self.p[rb] = ra
            self.r[ra] += 1
        return True

def invert_perm_c2o_to_o2c(c2o):
    """c2o[c]=o  ->  o2c[o]=c"""
    n = len(c2o)
    o2c = [0]*n
    for c, o in enumerate(c2o):
        o2c[o] = c
    return o2c

def invert_perm_o2c_to_c2o(o2c):
    """o2c[o]=c  ->  c2o[c]=o"""
    n = len(o2c)
    c2o = [0]*n
    for o, c in enumerate(o2c):
        c2o[c] = o
    return c2o

class MSTPuzzleSolver:
    """
    Lắp ghép bằng MST trên đồ thị mảnh ảnh.
    - Cạnh (i,j) chọn hướng có tổng sai số đôi chiều nhỏ hơn:
        ngang: H[i,j]+H[j,i] => offset (1,0) (j ở bên phải i)
        dọc:   V[i,j]+V[j,i] => offset (0,1) (j ở bên dưới i)
    - Sau khi có MST, BFS để gán (x,y) tương đối cho từng mảnh.
    - Chuẩn hoá (x,y) -> (cột, hàng) [0..C-1], [0..R-1], tạo ánh xạ cell->piece.
    """
    def __init__(self, rows, cols):
        self.rows = rows
        self.cols = cols

    def _build_edges(self, H, V):
        n = H.shape[0]
        edges = []  # (w, i, j, dx, dy)
        for i in range(n):
            for j in range(i+1, n):
                w_h = H[i, j] + H[j, i]
                w_v = V[i, j] + V[j, i]
                if w_h <= w_v:
                    # j nằm bên phải i: offset (dx=+1, dy=0)
                    edges.append((float(w_h), i, j, 1, 0))
                else:
                    # j nằm bên dưới i: offset (dx=0, dy=+1)
                    edges.append((float(w_v), i, j, 0, 1))
        edges.sort(key=lambda x: x[0])
        return edges

    def _mst(self, n, edges):
        dsu = DSU(n)
        used = []
        for w, i, j, dx, dy in edges:
            if dsu.union(i, j):
                used.append((i, j, dx, dy))
                if len(used) == n-1:
                    break
        return used  # danh sách cạnh trong MST

    def _layout_from_mst(self, n, mst_edges):
        # adjacency với offset hai chiều
        adj = [[] for _ in range(n)]
        for i, j, dx, dy in mst_edges:
            adj[i].append((j, dx, dy))
            adj[j].append((i, -dx, -dy))

        # BFS: đặt piece 0 ở (0,0)
        from collections import deque
        pos = {0: (0, 0)}
        q = deque([0])
        while q:
            u = q.popleft()
            ux, uy = pos[u]
            for v, dx, dy in adj[u]:
                if v not in pos:
                    pos[v] = (ux + dx, uy + dy)
                    q.append(v)
        # tất cả mảnh phải nằm trong MST nên đều có pos
        return pos

    def _normalize_to_grid(self, pos_map):
        # map to 0..C-1 and 0..R-1 by compressing coordinates
        xs = sorted({x for x, y in pos_map.values()})
        ys = sorted({y for x, y in pos_map.values()})
        # tạo mapping compact
        x_to_c = {x: idx for idx, x in enumerate(xs)}
        y_to_r = {y: idx for idx, y in enumerate(ys)}
        # nếu số mức khác cols/rows, sẽ xử lý dự phòng phía dưới
        return x_to_c, y_to_r

    def run(self, pieces):
        """
        Trả về (c2o, total_cost_approx)
        c2o[c] = o: piece index tại cell c theo row-major (r*cols + c)
        """
        n = len(pieces)
        assert n == self.rows * self.cols, "Số mảnh không khớp rows*cols"

        H, V = compute_cost_matrix_mse(pieces)
        edges = self._build_edges(H, V)
        mst_edges = self._mst(n, edges)
        pos = self._layout_from_mst(n, mst_edges)

        # Chuẩn hoá tọa độ
        x_to_c, y_to_r = self._normalize_to_grid(pos)
        uniq_cols = len(x_to_c)
        uniq_rows = len(y_to_r)

        # Dựng lưới cell -> piece
        grid = [[None for _ in range(self.cols)] for __ in range(self.rows)]
        placed = set()

        def safe_put(r, c, piece):
            if 0 <= r < self.rows and 0 <= c < self.cols and grid[r][c] is None:
                grid[r][c] = piece
                placed.add(piece)
                return True
            return False

        # Trường hợp lý tưởng: uniq_cols==self.cols và uniq_rows==self.rows
        if uniq_cols == self.cols and uniq_rows == self.rows:
            for piece, (x, y) in pos.items():
                r = y_to_r[y]
                c = x_to_c[x]
                safe_put(r, c, piece)
        else:
            # Dự phòng: nén toạ độ theo thứ tự (y,x), lấp row-major
            order = sorted([(y, x, p) for p, (x, y) in pos.items()])
            idx = 0
            for r in range(self.rows):
                for c in range(self.cols):
                    if idx < len(order):
                        _, _, p = order[idx]
                        safe_put(r, c, p)
                        idx += 1

        # Lấp ô còn trống bằng mảnh chưa đặt (hiếm khi xảy ra)
        remaining = [p for p in range(n) if p not in placed]
        it = iter(remaining)
        for r in range(self.rows):
            for c in range(self.cols):
                if grid[r][c] is None:
                    try:
                        grid[r][c] = next(it)
                    except StopIteration:
                        grid[r][c] = 0  # không xảy ra, nhưng để chắc chắn

        # Tạo chromosome c->o
        c2o = []
        for r in range(self.rows):
            for c in range(self.cols):
                c2o.append(int(grid[r][c]))

        # Ước lượng tổng cost theo H,V (tuỳ chọn)
        def approx_cost(c2o_arr):
            tot = 0.0
            g = np.array(c2o_arr, dtype=int).reshape(self.rows, self.cols)
            for r in range(self.rows):
                for c in range(self.cols):
                    cur = g[r, c]
                    if c < self.cols - 1:
                        tot += H[cur, g[r, c+1]]
                    if r < self.rows - 1:
                        tot += V[cur, g[r+1, c]]
            return float(tot)

        return c2o, approx_cost(c2o)

    def assemble_image(self, pieces, order):
        n = self.rows * self.cols
        if len(order) != n or set(order) != set(range(n)):
            # đảm bảo là hoán vị hợp lệ
            used, out = set(), []
            for g in order:
                if isinstance(g, (int, np.integer)) and 0 <= g < n and g not in used:
                    out.append(int(g)); used.add(int(g))
            out.extend([x for x in range(n) if x not in used])
            order = out[:n]
        grid = np.array(order, dtype=int).reshape(self.rows, self.cols)
        w, h = pieces[0].size
        out = Image.new("RGB", (w*self.cols, h*self.rows))
        for r in range(self.rows):
            for c in range(self.cols):
                out.paste(pieces[grid[r, c]], (c*w, r*h))
        return out


class PuzzleRunnerLocal:
    def __init__(self, solver, pieces_list, image_dir, output_dir, output_img_dir, y_true_csv):
        self.solver = solver
        self.pieces_list = pieces_list
        self.image_dir = image_dir
        self.output_dir = output_dir
        self.output_img_dir = output_img_dir
        self.y_true_csv = y_true_csv
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.output_img_dir, exist_ok=True)

    def run_all(self):
        results = []
        image_files = sorted([f for f in os.listdir(self.image_dir)
                              if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

        for idx, pieces in enumerate(tqdm(self.pieces_list, desc="Local search for images")):
            best_order, _ = self.solver.run(pieces)  # c->o (cell->piece)
            image_name = image_files[idx]

            # Lưu ảnh lắp theo c->o
            img = self.solver.assemble_image(pieces, best_order)
            img.save(os.path.join(self.output_img_dir, f"{image_name}_solved.png"))

            # Ghi trực tiếp c->o ra CSV (rất tự nhiên: piece_at_r_c)
            results.append([image_name] + best_order[:])

        # Ghi output.csv (c->o)
        output_csv = os.path.join(self.output_dir, "output.csv")
        with open(output_csv, "w", newline="") as f:
            writer = csv.writer(f)
            header = ["image_filename"] + [
                f"piece_at_{r}_{c}" for r in range(self.solver.rows) for c in range(self.solver.cols)
            ]
            writer.writerow(header)
            writer.writerows(results)

        print(f"Saved output to {output_csv}")
        print(f"Solved images saved to {self.output_img_dir}")
        return output_csv

    def evaluate(self, output_csv):
        df_pred = pd.read_csv(output_csv)
        df_true = pd.read_csv(self.y_true_csv)

        correct_count = 0
        ppa_scores = []

        true_map = {row['image_filename']: row.values[1:].astype(int)
                    for _, row in df_true.iterrows()}

        for _, row in df_pred.iterrows():
            fname = row['image_filename']
            if fname not in true_map:
                continue
            pred = row.values[1:].astype(int)  # chúng ta xuất c->o
            gt   = true_map[fname]            # chưa chắc c->o hay o->c

            # 1) GT là c->o: so trực tiếp
            match_direct = (pred == gt).sum()

            # 2) GT là o->c: đảo về c->o rồi so
            if set(gt) == set(range(len(gt))):
                gt_as_c2o = np.array(invert_perm_o2c_to_c2o(gt), dtype=int)
                match_inv = (pred == gt_as_c2o).sum()
            else:
                match_inv = -1

            match_best = max(match_direct, match_inv)
            ppa_scores.append(match_best / len(gt))
            if match_best == len(gt):
                correct_count += 1

        total = len(df_true)
        acc = (correct_count / total) * 100 if total else 0.0
        mean_ppa = float(np.mean(ppa_scores)) if ppa_scores else 0.0

        print(f"\nTotal images: {total}")
        print(f"Correctly solved: {correct_count}/{total} ({acc:.2f}%)")
        print(f"Average PPA: {mean_ppa:.4f}")

In [8]:
mst_solver = MSTPuzzleSolver(rows=ROWS, cols=COLS)

runner = PuzzleRunnerLocal(
    solver=mst_solver,
    pieces_list=all_pieces,
    image_dir=image_path,
    output_dir="output_data",
    output_img_dir="output_data/output_images",
    y_true_csv="input_data/Y_test.csv"
)

out_csv = runner.run_all()
runner.evaluate(out_csv)


Local search for images: 100%|██████████| 100/100 [00:02<00:00, 39.95it/s]

Saved output to output_data\output.csv
Solved images saved to output_data/output_images

Total images: 100
Correctly solved: 0/100 (0.00%)
Average PPA: 0.0993



