# 1. 라이브러리 및 기본 설정

In [None]:
import os
import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import lightning as L
import warnings
from copy import deepcopy
from PIL import Image, ImageOps
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from torchvision.io import read_image
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from timm.models.vision_transformer import Block, Attention, VisionTransformer
from torchvision.transforms import v2 as transforms
from tqdm import tqdm
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from multiprocessing import Pool

In [None]:
config = {}
config['seed']=40
config['batch_size']=96

# 2. 데이터 준비 및 전처리

## 2.1 Image Augmentation

** 주의) 한번만 하는걸 추천, 어차피 JigsawDataset에서 매번 호출할때마다 재배열하는 코드있음!!

** 주의) 한번만 하는걸 추천, 어차피 JigsawDataset에서 매번 호출할때마다 재배열하는 코드있음!!

** 주의) 한번만 하는걸 추천, 어차피 JigsawDataset에서 매번 호출할때마다 재배열하는 코드있음!!

** 주의) 한번만 하는걸 추천, 어차피 JigsawDataset에서 매번 호출할때마다 재배열하는 코드있음!!

### 2.1.1 원본이미지 추출

In [None]:
def process_image(index, train_df):
    sample_df = train_df.iloc[index]

    train_path = sample_df['img_path'].split('/')[-1]
    train_img = Image.open('./train/' + train_path)

    width, height = train_img.size
    cell_width = width // 4
    cell_height = height // 4

    target_positions = list(sample_df)[2:]

    origin_img = Image.new("RGB", (width, height))

    # 각 타일을 올바른 위치로 이동
    for target_pos in range(1, 17):
        # 타일의 현재 위치 찾기
        current_pos = target_positions.index(target_pos) + 1
        current_row, current_col = divmod(current_pos - 1, 4)

        # 타일의 목표 위치
        target_row, target_col = divmod(target_pos - 1, 4)

        # 타일을 추출
        tile = train_img.crop((current_col * cell_width, current_row * cell_height, (current_col + 1) * cell_width, (current_row + 1) * cell_height))
        # 타일을 올바른 위치에 붙여넣기
        origin_img.paste(tile, (target_col * cell_width, target_row * cell_height))

    # 재구성된 이미지 저장
    origin_name = f'ORIGIN_{index:05}.jpg'
    origin_path = './origin/' + origin_name
    origin_img.save(origin_path)

    return {'ID': origin_name, 'img_path': origin_path}

def main(train_df):
    num_processes = 16
    pool = Pool(num_processes)

    # tqdm
    results = list(tqdm(pool.starmap(process_image, [(i, train_df) for i in range(len(train_df))]), total=len(train_df)))
    
    pool.close()
    pool.join()

    origin_df = pd.DataFrame(results)
    origin_df.to_csv('./origin.csv', index=False)
    print('./origin.csv 저장완료')

if __name__ == '__main__':
    train_df = pd.read_csv('./train.csv')  # 데이터 경로에 맞게 수정하세요
    main(train_df)

### 2.1.2 45도 rotate

In [None]:
def rotate_and_shuffle_image_left_45(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)
    orig_width, orig_height = img.size

    # 이미지 좌로 45도 회전
    img = img.rotate(45, expand=True)
    rotated_width, rotated_height = img.size

    # 회전한 이미지에서 중앙의 일정 영역을 크롭 (예: 300x300 픽셀)
    crop_width, crop_height = 312, 312  # 크롭할 영역의 크기
    left = (rotated_width - crop_width) // 2
    top = (rotated_height - crop_height) // 2
    right = (rotated_width + crop_width) // 2
    bottom = (rotated_height + crop_height) // 2
    img_cropped = img.crop((left, top, right, bottom))

    # 크롭한 이미지를 원본 크기로 리사이징
    img = img_cropped.resize((orig_width, orig_height), Image.LANCZOS)

    # 확대된 이미지에서 원본 크기에 맞는 중앙 부분 잘라내기
    new_width, new_height = img.size
    left = (new_width - orig_width) / 2
    top = (new_height - orig_height) / 2
    right = (new_width + orig_width) / 2
    bottom = (new_height + orig_height) / 2
    img = img.crop((left, top, right, bottom))

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * new_width // 4, row * new_height // 4, (col + 1) * new_width // 4, (row + 1) * new_height // 4))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (new_width, new_height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * new_width // 4, row * new_height // 4))

    # 재배치된 이미지 저장
    aug_name = f'augment_2_left_45_{index:05}.jpg'
    aug_path = './augment_2_left_45/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main():
    origin_df = pd.read_csv('./origin.csv')  # 데이터 경로에 맞게 수정하세요
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image_left_45, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_left_45.csv', index=False)
    print('./augment_2_left_45.csv 저장완료')

if __name__ == '__main__':
    main()


### 2.1.3 90도 rotate

In [None]:
def rotate_and_shuffle_image(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)

    # 이미지 좌로 90도 회전
    img = img.rotate(90, expand=True)

    width, height = img.size
    cell_width = width // 4
    cell_height = height // 4

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * cell_width, row * cell_height, (col + 1) * cell_width, (row + 1) * cell_height))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (width, height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * cell_width, row * cell_height))

    # 재배치된 이미지 저장
    aug_name = f'augment_2_left_90_{index:05}.jpg'
    aug_path = './augment_2_left_90/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main():
    origin_df = pd.read_csv('./origin.csv')  # 데이터 경로에 맞게 수정하세요
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_left_90.csv', index=False)
    print('./augment_2_left_90.csv 저장완료')

if __name__ == '__main__':
    main()

### 2.1.4 135도 rotate

In [None]:
def rotate_and_shuffle_image_left_135(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)
    orig_width, orig_height = img.size

    # 이미지 좌로 45도 회전
    img = img.rotate(135, expand=True)
    rotated_width, rotated_height = img.size

    # 회전한 이미지에서 중앙의 일정 영역을 크롭 (예: 300x300 픽셀)
    crop_width, crop_height = 312, 312  # 크롭할 영역의 크기
    left = (rotated_width - crop_width) // 2
    top = (rotated_height - crop_height) // 2
    right = (rotated_width + crop_width) // 2
    bottom = (rotated_height + crop_height) // 2
    img_cropped = img.crop((left, top, right, bottom))

    # 크롭한 이미지를 원본 크기로 리사이징
    img = img_cropped.resize((orig_width, orig_height), Image.LANCZOS)

    # 확대된 이미지에서 원본 크기에 맞는 중앙 부분 잘라내기
    new_width, new_height = img.size
    left = (new_width - orig_width) / 2
    top = (new_height - orig_height) / 2
    right = (new_width + orig_width) / 2
    bottom = (new_height + orig_height) / 2
    img = img.crop((left, top, right, bottom))

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * new_width // 4, row * new_height // 4, (col + 1) * new_width // 4, (row + 1) * new_height // 4))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (new_width, new_height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * new_width // 4, row * new_height // 4))

    # 재배치된 이미지 저장
    aug_name = f'augment_2_left_135_{index:05}.jpg'
    aug_path = './augment_2_left_135/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main():
    origin_df = pd.read_csv('./origin.csv')  # 데이터 경로에 맞게 수정하세요
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image_left_135, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_left_135.csv', index=False)
    print('./augment_2_left_135.csv 저장완료')

if __name__ == '__main__':
    main()

### 2.1.5 180도 rotate

In [None]:
def rotate_and_shuffle_image_180(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)

    # 이미지 180도 회전
    img = img.rotate(180, expand=True)  # 180도 회전


    width, height = img.size
    cell_width = width // 4
    cell_height = height // 4

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * cell_width, row * cell_height, (col + 1) * cell_width, (row + 1) * cell_height))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (width, height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * cell_width, row * cell_height))
    
    aug_name = f'augment_2_180_{index:05}.jpg'
    aug_path = './augment_2_180/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main_180():
    origin_df = pd.read_csv('./origin.csv')
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image_180, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_180.csv', index=False)
    print('./augment_2_180.csv 저장완료')

if __name__ == '__main__':
    main_180()

### 2.1.6 225도 rotate

In [None]:
def rotate_and_shuffle_image_right_135(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)
    orig_width, orig_height = img.size

    # 이미지 좌로 45도 회전
    img = img.rotate(-135, expand=True)
    rotated_width, rotated_height = img.size

    # 회전한 이미지에서 중앙의 일정 영역을 크롭 (예: 300x300 픽셀)
    crop_width, crop_height = 312, 312  # 크롭할 영역의 크기
    left = (rotated_width - crop_width) // 2
    top = (rotated_height - crop_height) // 2
    right = (rotated_width + crop_width) // 2
    bottom = (rotated_height + crop_height) // 2
    img_cropped = img.crop((left, top, right, bottom))

    # 크롭한 이미지를 원본 크기로 리사이징
    img = img_cropped.resize((orig_width, orig_height), Image.LANCZOS)

    # 확대된 이미지에서 원본 크기에 맞는 중앙 부분 잘라내기
    new_width, new_height = img.size
    left = (new_width - orig_width) / 2
    top = (new_height - orig_height) / 2
    right = (new_width + orig_width) / 2
    bottom = (new_height + orig_height) / 2
    img = img.crop((left, top, right, bottom))

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * new_width // 4, row * new_height // 4, (col + 1) * new_width // 4, (row + 1) * new_height // 4))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (new_width, new_height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * new_width // 4, row * new_height // 4))

    # 재배치된 이미지 저장
    aug_name = f'augment_2_right_135_{index:05}.jpg'
    aug_path = './augment_2_right_135/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main():
    origin_df = pd.read_csv('./origin.csv')  # 데이터 경로에 맞게 수정하세요
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image_right_135, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_right_135.csv', index=False)
    print('./augment_2_right_135.csv 저장완료')

if __name__ == '__main__':
    main()

### 2.1.7 270도 rotate

In [None]:
def rotate_and_shuffle_image_right_90(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)

    # 이미지 오른쪽으로 90도 회전
    img = img.rotate(-90, expand=True)  # 오른쪽으로 90도 회전

    width, height = img.size
    cell_width = width // 4
    cell_height = height // 4

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * cell_width, row * cell_height, (col + 1) * cell_width, (row + 1) * cell_height))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (width, height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * cell_width, row * cell_height))
    
    aug_name = f'augment_2_right_90_{index:05}.jpg'
    aug_path = './augment_2_right_90/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main_right_90():
    origin_df = pd.read_csv('./origin.csv')
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image_right_90, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_right_90.csv', index=False)
    print('./augment_2_right_90.csv 저장완료')

if __name__ == '__main__':
    main_right_90()

### 2.1.9 315도 rotate

In [None]:
from PIL import Image, ImageOps
import random
from multiprocessing import Pool
import pandas as pd
from tqdm import tqdm

def rotate_and_shuffle_image_right_45(index, img_path):
    # 원본 이미지 불러오기
    img = Image.open(img_path)
    orig_width, orig_height = img.size

    # 이미지 좌로 45도 회전
    img = img.rotate(-45, expand=True)
    rotated_width, rotated_height = img.size

    # 회전한 이미지에서 중앙의 일정 영역을 크롭 (예: 300x300 픽셀)
    crop_width, crop_height = 312, 312  # 크롭할 영역의 크기
    left = (rotated_width - crop_width) // 2
    top = (rotated_height - crop_height) // 2
    right = (rotated_width + crop_width) // 2
    bottom = (rotated_height + crop_height) // 2
    img_cropped = img.crop((left, top, right, bottom))

    # 크롭한 이미지를 원본 크기로 리사이징
    img = img_cropped.resize((orig_width, orig_height), Image.LANCZOS)

    # 확대된 이미지에서 원본 크기에 맞는 중앙 부분 잘라내기
    new_width, new_height = img.size
    left = (new_width - orig_width) / 2
    top = (new_height - orig_height) / 2
    right = (new_width + orig_width) / 2
    bottom = (new_height + orig_height) / 2
    img = img.crop((left, top, right, bottom))

    # 타일 추출 및 재배치
    tiles = []
    for i in range(16):
        row, col = divmod(i, 4)
        tile = img.crop((col * new_width // 4, row * new_height // 4, (col + 1) * new_width // 4, (row + 1) * new_height // 4))
        tiles.append(tile)

    # 타일 임의로 재배치
    shuffled_indices = random.sample(range(16), 16)

    # 재배치된 이미지 생성
    shuffled_img = Image.new('RGB', (new_width, new_height))
    for i, idx in enumerate(shuffled_indices):
        row, col = divmod(i, 4)
        shuffled_img.paste(tiles[idx], (col * new_width // 4, row * new_height // 4))

    # 재배치된 이미지 저장
    aug_name = f'augment_2_right_45_{index:05}.jpg'
    aug_path = './augment_2_right_45/' + aug_name
    shuffled_img.save(aug_path)

    # 데이터프레임을 위한 정보 생성
    data = {'ID': aug_name, 'img_path': aug_path}
    for i, idx in enumerate(shuffled_indices, 1):
        data[str(i)] = idx + 1  # 인덱스를 1부터 시작하도록 조정

    return data

def main():
    origin_df = pd.read_csv('./origin.csv')  # 데이터 경로에 맞게 수정하세요
    num_processes = 16  # 코어 수

    with Pool(num_processes) as pool:
        args = [(index, row['img_path']) for index, row in origin_df.iterrows()]
        results = list(tqdm(pool.starmap(rotate_and_shuffle_image_right_45, args), total=len(origin_df)))

    aug_df = pd.DataFrame(results)
    aug_df.to_csv('./augment_2_right_45.csv', index=False)
    print('./augment_2_right_45.csv 저장완료')

if __name__ == '__main__':
    main()

## 2.2 Data Preprocessing

In [None]:
def alternate_rows(df1, df2,df3,df4,df5, df6,df7,df8):
    # 새로운 데이터 프레임을 위한 빈 리스트 초기화
    merged_rows = []

    # 두 데이터 프레임의 각 행을 번갈아가며 추가
    for i in range(len(df1)):
        merged_rows.append(df1.iloc[i])
        merged_rows.append(df2.iloc[i])
        merged_rows.append(df3.iloc[i])
        merged_rows.append(df4.iloc[i])
        merged_rows.append(df5.iloc[i])
        merged_rows.append(df6.iloc[i])
        merged_rows.append(df7.iloc[i])
        merged_rows.append(df8.iloc[i])

    # 리스트를 데이터 프레임으로 변환
    merged_df = pd.DataFrame(merged_rows).reset_index(drop=True)
    return merged_df

In [None]:
# Read and split datasets
train_df = pd.read_csv('./train.csv')
aug_df_left_45 =  pd.read_csv('./augment_2_left_45.csv')
aug_df_left_90 = pd.read_csv('./augment_2_left_90.csv')
aug_df_left_135 = pd.read_csv('./augment_2_left_135.csv')
aug_df_180 = pd.read_csv('./augment_2_180.csv')
aug_df_right_45 = pd.read_csv('./augment_2_right_45.csv')
aug_df_right_90 = pd.read_csv('./augment_2_right_90.csv')
aug_df_right_135 = pd.read_csv('./augment_2_right_135.csv')
train_df = alternate_rows(train_df,
                          aug_df_left_45,
                          aug_df_left_90,
                          aug_df_left_135,
                          aug_df_180,
                          aug_df_right_45,
                          aug_df_right_90,
                          aug_df_right_135
                         )

train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=config['seed'])

test_df = pd.read_csv('./test.csv')

In [None]:
'''
train_df = alternate_rows(train_df, aug_df_left_90, aug_df_180, aug_df_right_90)
train_df2 = pd.read_csv('./train.csv')
train_df2 = train_df2.sample(frac=1).reset_index(drop=True)

val_df_left_90 = aug_df_left_90.sample(n)
val_df_180 = aug_df_180.sample(n)
val_df_right_90 = aug_df_right_90.sample(n + remainder)
val_df2 = pd.concat([val_df_left_90, val_df_180, val_df_right_90], ignore_index=True)
'''

# 3. 커스텀 데이터셋 및 데이터 로더 정의

In [None]:
# 퍼즐 이미지 데이터셋을 위한 커스텀 데이터셋 클래스입니다.
class JigsawDataset(Dataset):
    def __init__(self, df, data_path, mode='train'):
        self.df = df
        self.mode = mode

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.mode == 'train':
            row = self.df.iloc[idx]
            image = read_image(row['img_path'])
            shuffle_order = row[[str(i) for i in range(1, 17)]].values-1
            image_src = self.reset_image(image, shuffle_order)
            image_reshuffle, reshuffle_order = self.shuffle_image(image_src)
            adjacency_matrix = self.get_adjacency_matrix(reshuffle_order)
            data = {
                'image_src':image_src,
                'image_reshuffle':image_reshuffle,
                'order':reshuffle_order,
                'adjacency_matrix':adjacency_matrix,
                'score': self.get_score(range(16), reshuffle_order),
            }
            return data
            
        elif self.mode == 'val':
            row = self.df.iloc[idx]
            image = read_image(row['img_path']).numpy()
            shuffle_order = row[[str(i) for i in range(1, 17)]].values-1
            adjacency_matrix = self.get_adjacency_matrix(shuffle_order.tolist())
            data = {
                'image':image,
                'order':shuffle_order,
                'adjacency_matrix':adjacency_matrix,
            }
            return data
            
        elif self.mode == 'inference':
            row = self.df.iloc[idx]
            image = read_image(row['img_path']).numpy()
            data = {
                'image':image
            }
            return data

    def reset_image(self, image, shuffle_order):     # 이미지를 초기 상태(원본 순서)로 재배열
        c, h, w = image.shape
        block_h, block_w = h//4, w//4
        image_src = [[0 for _ in range(4)] for _ in range(4)]
        for idx, order in enumerate(shuffle_order):
            h_idx, w_idx = divmod(order,4)
            h_idx_shuffle, w_idx_shuffle = divmod(idx, 4)
            image_src[h_idx][w_idx] = image[:, block_h * h_idx_shuffle : block_h * (h_idx_shuffle+1), block_w * w_idx_shuffle : block_w * (w_idx_shuffle+1)]
        image_src = np.concatenate([np.concatenate(image_row, -1) for image_row in image_src], -2)
        return image_src

    def shuffle_image(self, image):     # 이미지를 랜덤하게 섞음
        c, h, w = image.shape
        block_h, block_w = h//4, w//4
        shuffle_order = list(range(0, 16))
        random.shuffle(shuffle_order)
        image_shuffle = [[0 for _ in range(4)] for _ in range(4)]
        for idx, order in enumerate(shuffle_order):
            h_idx, w_idx = divmod(order,4)
            h_idx_shuffle, w_idx_shuffle = divmod(idx, 4)
            image_shuffle[h_idx_shuffle][w_idx_shuffle] = image[:, block_h * h_idx : block_h * (h_idx+1), block_w * w_idx : block_w * (w_idx+1)]
        image_shuffle = np.concatenate([np.concatenate(image_row, -1) for image_row in image_shuffle], -2)
        return image_shuffle, shuffle_order

    def get_adjacency_matrix(self, order): # 패치에 대하여 연결된 패치 찾기  # 인접 행렬을 생성, 퍼즐 조각 간의 연결 상태
        order_matrix = [order[4*i:4*(i+1)]for i in range(4)]
        adj_matrix = np.zeros((16,16), dtype=int)
        for i in range(4):
            for j in range(4):
                o = order_matrix[i][j]
                i_o, j_o = divmod(o,4)
                for i_add,j_add in [(-1,0), (1,0), (0,1), (0,-1)]:
                    i_compare, j_compare = i_o+i_add, j_o+j_add
                    if i_compare<0 or i_compare>=4 or j_compare<0 or j_compare>=4 : continue
                    o_compare = order[i_compare*4+j_compare]
                    i_, j_ = i*4+j, order.index(i_compare*4+j_compare)
                    if (i_add,j_add) == (-1,0):
                        adj_matrix[i_][j_] = 1 # 상
                        adj_matrix[j_][i_] = 2 # 하
                    elif (i_add,j_add) == (-1,0):
                        adj_matrix[i_][j_] = 2
                        adj_matrix[j_][i_] = 1
                    elif  (i_add,j_add) == (0,-1):
                        adj_matrix[i_][j_] = 3 # 좌
                        adj_matrix[j_][i_] = 4 # 우
                    elif (i_add,j_add) == (0,1):
                        adj_matrix[i_][j_] = 4
                        adj_matrix[j_][i_] = 3
        return adj_matrix

    def get_score(self, order_true, order_pred): # regression task? 현재 아키텍처와 맞지 않을듯 # 평가산식 점수 계산
        puzzle_a = np.array(order_true, dtype=int).reshape(4, 4)
        puzzle_s = np.array(order_pred, dtype=int).reshape(4, 4)

        accuracies = {}
        accuracies['1x1'] = np.mean(puzzle_a == puzzle_s)

        combinations_2x2 = [(i, j) for i in range(3) for j in range(3)]
        combinations_3x3 = [(i, j) for i in range(2) for j in range(2)]

        for size in range(2, 5):  # Loop through sizes 2, 3, 4
            correct_count = 0  # Initialize counter for correct full sub-puzzles
            total_subpuzzles = 0
            combinations = combinations_2x2 if size == 2 else combinations_3x3 if size == 3 else [(0, 0)]
            for start_row, start_col in combinations:
                rows = slice(start_row, start_row + size)
                cols = slice(start_col, start_col + size)
                if np.array_equal(puzzle_a[rows, cols], puzzle_s[rows, cols]):
                    correct_count += 1
                total_subpuzzles += 1

            accuracies[f'{size}x{size}'] = correct_count / total_subpuzzles

        score = (accuracies['1x1'] + accuracies['2x2'] + accuracies['3x3'] + accuracies['4x4']) / 4.
        return score

In [None]:
# 데이터 로딩 과정에서 배치 데이터를 처리하기 위한 레이트 함수입니다.
class JigsawCollateFn:
    def __init__(self, transform, mode):
        self.mode = mode
        self.transform = transform

    def __call__(self, batch):     # 배치 데이터를 받아서 필요한 처리를 수행하고 텐서로 반환
        if self.mode=='train':
            pixel_values = torch.stack([self.transform(Image.fromarray(data['image_reshuffle'].astype(np.uint8).transpose(1,2,0))) for data in batch])
            order = torch.tensor([data['order'] for data in batch], dtype=torch.long)
            adjacency_matrices = [data['adjacency_matrix'] for data in batch]
            adjacency_matrx = torch.tensor(np.array(adjacency_matrices), dtype=torch.long)
        
            return {
                'pixel_values':pixel_values,
                'order':order,
                'adjacency_matrx':adjacency_matrx
            }
        elif self.mode=='val':
            pixel_values = torch.stack([self.transform(Image.fromarray(data['image'].astype(np.uint8).transpose(1,2,0))) for data in batch])
            order = torch.tensor([data['order'] for data in batch], dtype=torch.long)
            adjacency_matrices = [data['adjacency_matrix'] for data in batch]
            adjacency_matrx = torch.tensor(np.array(adjacency_matrices), dtype=torch.long)
            return {
                'pixel_values':pixel_values,
                'order':order,
                'adjacency_matrx':adjacency_matrx
            }
        elif self.mode=='inference':
            pixel_values = torch.stack([self.transform(Image.fromarray(data['image'].astype(np.uint8).transpose(1,2,0))) for data in batch])
            return {
                'pixel_values':pixel_values,
            }


In [None]:
transform = transforms.Compose([
    transforms.Resize(size=(256,256), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) #Imagenet, normalize 차용

train_dataset = JigsawDataset(df=train_df, data_path='', mode='train')
val_dataset = JigsawDataset(df=val_df, data_path='', mode='val')
pred_dataset = JigsawDataset(df=test_df,data_path='',mode='inference')

train_dataloader = DataLoader(train_dataset, collate_fn=JigsawCollateFn(transform, 'train'), batch_size=config['batch_size'],
                              num_workers = 16, prefetch_factor=3)
val_dataloader = DataLoader(val_dataset, collate_fn=JigsawCollateFn(transform, 'val'), batch_size=config['batch_size'],
                            num_workers = 16 , prefetch_factor=3)
pred_dataloader = DataLoader(pred_dataset, collate_fn=JigsawCollateFn(transform, 'inference'), batch_size=config['batch_size'],
                             num_workers = 16 , prefetch_factor=3)

# 4. Model architecture 정의

## 4.1. Vision Transformer (ViT)

In [None]:
def attention_forward(self, x, attn_bias=None):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    q, k = self.q_norm(q), self.k_norm(k)

    q = q * self.scale
    attn = q @ k.transpose(-2, -1)
    if attn_bias is not None:
        attn + attn_bias
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    x = attn @ v

    x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x
Attention.forward = attention_forward

In [None]:
def block_forward(self, x_and_attn_bias):
    x, attn_bias = x_and_attn_bias
    x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_bias)))
    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
    return (x, attn_bias)
Block.forward = block_forward

In [None]:
def vision_transformer_forward_features(self, x, embed_bias=None, attn_bias=None):
    x = self.patch_embed(x)
    x = self._pos_embed(x)
    if embed_bias is not None:
        x = x + embed_bias
    x = self.patch_drop(x)
    x = self.norm_pre(x)
    x, _ = self.blocks((x,attn_bias))
    x = self.norm(x)
    return x
VisionTransformer.forward_features = vision_transformer_forward_features

In [None]:
def vision_transformer_forward(self, x, embed_bias=None, attn_bias=None):
    x = self.forward_features(x, embed_bias, attn_bias)
    return x
VisionTransformer.forward = vision_transformer_forward

In [None]:
model = timm.create_model('vit_medium_patch16_gap_256', pretrained=True, num_classes=0)

model_config = {
    'image_size':256,
    'patch_size':16,
    'hidden_size':512,
    'num_attention_heads':8,
}

transform_config = timm.data.resolve_data_config(model.pretrained_cfg)
transform_config.pop('crop_pct')
transform_config.pop('crop_mode')

transform = timm.data.create_transform(
    **transform_config
)

## 4.2. JigsawElectra 모델 클래스

In [None]:
# Jigsaw 퍼즐 문제를 해결하기 위한 커스텀 모델 아키텍처입니다.
class JigsawElectra(nn.Module):
    """
    모델설명
    1st Stage:
    In the initial stage, a transformer architecture is employed to discern optimal patch arrangements for each puzzle segment.
    This involves intricate spatial relationships, where the model dynamically identifies neighboring patches in cardinal directions(i.e., up, down, left, right).
    The foundation of this stage lies in the incorporation of attention matrices at the final layer, providing nuanced insights into patch interdependencies.
    - 초기 단계에서는 트랜스포머 아키텍처를 사용하여 각 퍼즐 조각에 대한 최적의 패치 배치를 파악합니다.
    - 모델은 동적으로 인접 패치를 카디널 방향(위, 아래, 왼쪽, 오른쪽)으로 식별하여 복잡한 공간 관계를 다룹니다.
    - 이 단계의 기반이 되는 것은 마지막 레이어에 있는 attention matrices 통해 패치 간의 상호 의존성에 대한 세밀한 통찰력을 제공하는 것입니다.

    
    2nd Stage:
    Subsequently, the second stage capitalizes on the predicted matrices from the initial stage to derive piece-type embeddings and connect-type embedding.
    These embeddings encapsulate diverse spatial configurations, such as cross shapes, left corners and right, and else.
    The innovation lies in the integration of piece-type embeddings as positional embedding biases, enhancing the model's contextual awareness.
    Furthermore, connect matrix embeddings serve as attention biases, enabling the model to capture intricate inter-piece relationships.
    The final objective of this stage is to predict an optimal reordering sequence, leveraging the acquired embeddings.
    - 이어서, 두 번째 단계에서는 초기 단계에서 예측된 매트릭스를 활용하여 조각 타입 임베딩(piece-type embeddings)과 연결 타입 임베딩(connect-type embedding)을 도출합니다.
    - 이러한 임베딩들은 십자형, 왼쪽 코너와 오른쪽 코너 등과 같은 다양한 공간 구성을 포함합니다.
    - 이 단계의 혁신은 조각 타입 임베딩을 위치 임베딩 바이어스로 통합하여 모델의 문맥 인식을 향상시키는 것입니다.
    - 또한, 연결 매트릭스 임베딩은 주의 바이어스로서 작용하여 모델이 조각 간의 복잡한 관계를 포착할 수 있게 합니다.
    - 이 단계의 최종 목표는 획득한 임베딩을 활용하여 최적의 재배열 순서를 예측하는 것입니다.

    
    The backbone model shares weights excluding head layers. And losses are jointly computed for gradient updates, aiming for efficient learning and high performance.
    - 백본 모델은 헤드 레이어를 제외하고 가중치를 공유합니다., 손실은 효율적인 학습과 높은 성능을 목표로 공동으로 계산됩니다.
    """
    # 초기화 메소드, 모델과 설정을 초기화합니다.
    def __init__(self, model, config):
        super(JigsawElectra, self).__init__()
        for k,v in config.items():
            setattr(self,k,v)
        self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
        self.num_patch_per_block = int(self.image_size/4/self.patch_size)
        self.model = model
        
        self.pos_emb = nn.Parameter(torch.randn(16, self.hidden_size))
        self.piece_type_emb = nn.Embedding(10, self.hidden_size, padding_idx=0)
        self.piece_type_emb.weight.data[0,:]=0
        self.piece_type_emb.weight.data = self.piece_type_emb.weight.data*0.1
        self.connect_type_emb = nn.Embedding(5, self.num_attention_heads, padding_idx=0)
        self.connect_type_emb.weight.data[0,:]=0
        self.connect_type_emb.weight.data = self.connect_type_emb.weight.data*0.1
        
        self.local_linear1 = nn.LazyLinear(self.hidden_size)
        self.local_linear2 = nn.LazyLinear(self.hidden_size)
        self.local_conv = nn.Conv2d(self.num_attention_heads, self.num_attention_heads, int(self.image_size/16), int(self.image_size/16))
        self.local_clf = nn.Sequential(
            nn.LazyLinear(self.num_attention_heads),
            nn.Tanh(),
            nn.LazyLinear(5),
        )

        self.global_conv = nn.Conv1d(self.hidden_size, self.hidden_size, int(self.image_size/16), int(self.image_size/16))
        self.global_clf = nn.Sequential(
            nn.LazyLinear(self.hidden_size),
            nn.Tanh(),
            nn.LazyLinear(16),
        )

    def _transpose(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        x = x.permute(0, 2, 1, 3)
        b, h, l, d = x.shape
        x = torch.cat(x.reshape(b, h, -1, self.num_patch_per_block, d).split(self.num_patch_per_block, 2), 3).reshape(b, h, l, d)
        return x

    # local 단계
    def local_forward(self, x, label=None):
        pos_emb = self.pos_emb.reshape(4,4,-1)
        pos_emb = pos_emb.unsqueeze(-2).repeat(1,1,self.num_patch_per_block,1).reshape(4,-1,self.hidden_size)
        pos_emb = pos_emb.unsqueeze(1).repeat(1,self.num_patch_per_block, 1, 1).reshape(-1, 4*self.num_patch_per_block, self.hidden_size)
        pos_emb = pos_emb.reshape(-1, self.hidden_size)
        
        x = self.model(x, embed_bias=pos_emb)
        x1 = self._transpose(self.local_linear1(x))
        x2 = self._transpose(self.local_linear2(x))
        x = torch.matmul(x1,x2.transpose(-1, -2)).transpose(-1,-2)
        x = self.local_conv(x)
        x = x.permute(0,2,3,1)
        x = self.local_clf(x)
        probs = nn.Softmax(dim=-1)(x)
        loss = None
        if label is not None:
            loss = nn.CrossEntropyLoss()(x.reshape(-1, 5), label.reshape(-1))
        return x, probs, loss

    # global 단계, 전체 이미지에 대한 이해를 바탕으로 최종 퍼즐 순서를 예측합니다.
    def global_forward(self, x, piece_type=None, connect_type=None, label=None):
        pos_emb = self.pos_emb.reshape(4,4,-1)
        pos_emb = pos_emb.unsqueeze(-2).repeat(1,1,self.num_patch_per_block,1).reshape(4,-1,self.hidden_size)
        pos_emb = pos_emb.unsqueeze(1).repeat(1,self.num_patch_per_block, 1, 1).reshape(-1, 4*self.num_patch_per_block, self.hidden_size)
        pos_emb = pos_emb.reshape(-1, self.hidden_size)
        
        if piece_type is not None:
            b = piece_type.shape[0]
            piece_emb = self.piece_type_emb(piece_type).reshape(b, 4, 4, -1)
            piece_emb = piece_emb.unsqueeze(-2).repeat(1,1,1,self.num_patch_per_block,1).reshape(b, 4,-1,self.hidden_size)
            piece_emb = piece_emb.unsqueeze(2).repeat(1,1,self.num_patch_per_block, 1, 1).reshape(b,-1, 4*self.num_patch_per_block, self.hidden_size)
            piece_emb = piece_emb.reshape(b,-1, self.hidden_size)
            pos_emb = piece_emb+pos_emb
            
        attn_bias = None
        if connect_type is not None:
            b = connect_type.shape[0]
            attn_bias = self.connect_type_emb(connect_type) # B 16,16,8
            attn_bias = attn_bias.unsqueeze(-2).repeat(1,1,1,int(self.image_size/16),1).reshape(b,16,-1,self.num_attention_heads)
            attn_bias = attn_bias.unsqueeze(2).repeat(1,1,int(self.image_size/16), 1, 1).reshape(b,-1, self.image_size, self.num_attention_heads)
            attn_bias = attn_bias.permute(0,3,1,2)
            
        x = self.model(
            x,
            embed_bias=pos_emb,
            attn_bias=attn_bias,
        )
        x = self._transpose(x)
        b, h, l, d = x.shape
        x = x.permute(0,1,3,2).reshape(b,h*d,l)
        x = self.global_conv(x)
        x = x.permute(0,2,1)
        x = self.global_clf(x)
        probs = nn.Softmax(dim=-1)(x)
        
        loss = None
        if label is not None:
            loss = nn.CrossEntropyLoss()(x.reshape(-1, 16), label.reshape(-1))
        return x, probs, loss 

# 5. PyTorch Lightning 모듈 정의

In [None]:
# PyTorch Lightning을 사용한 트레이닝, 검증, 추론을 위한 클래스입니다.
class LitJigsawElectra(L.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.config = config
        self.jigsaw_electra = JigsawElectra(model, config)
        self.inference_iter = 1
        self.validation_step_outputs = []

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=1e-6)     # 학습률 조정
        return opt

    # 트레이닝 스텝
    def training_step(self, batch):
        x_local, x_local_probs, loss_local = self.jigsaw_electra.local_forward(batch['pixel_values'], batch['adjacency_matrx'])        
        connect_type = x_local_probs.argmax(-1).detach()
        piece_type = self.connect_to_piece(connect_type).detach()
        x_global, x_global_probs, loss_global = self.jigsaw_electra.global_forward(batch['pixel_values'], piece_type=piece_type, connect_type=connect_type, label=batch['order'])
        loss = loss_local*0.2 + loss_global
        self.log("train_loss_local", loss_local, on_step=True, on_epoch=False)
        self.log("train_loss_global", loss_global, on_step=True, on_epoch=False)
        return loss

    # 검증 스텝
    def validation_step(self, batch):
        x_local, x_local_probs, loss_local = self.jigsaw_electra.local_forward(batch['pixel_values'], batch['adjacency_matrx'])
        self.log("val_loss_local", loss_local)
        connect_type = x_local_probs.argmax(-1).detach()
        piece_type = self.connect_to_piece(connect_type).detach()
        local_accuracy = torch.mean(1*(connect_type == batch['adjacency_matrx']), dtype=torch.float32)
        self.log("val_acc_local", local_accuracy)
        x_global, x_global_probs, loss_global = self.jigsaw_electra.global_forward(batch['pixel_values'], piece_type=piece_type, connect_type=connect_type, label=batch['order'])
        self.log("val_loss_global", loss_global)
        self.validation_step_outputs.append((x_global_probs, batch['order']))
        return

    # 추론 스텝
    def predict_step(self, batch):
        pixel_values = batch['pixel_values']
        label = batch.get('order', None)
        for i in range(self.inference_iter):
            x_local, x_local_probs, _ = self.jigsaw_electra.local_forward(pixel_values)        
            connect_type = x_local_probs.argmax(-1).detach()
            piece_type = self.connect_to_piece(connect_type).detach()
            x_global, x_global_probs, _ = self.jigsaw_electra.global_forward(batch['pixel_values'], piece_type=piece_type, connect_type=connect_type)
            reorder = self._probs_to_order(x_global_probs)
            pixel_values = self._reorder_image(pixel_values, reorder)
        return x_global_probs, reorder, label

    # 연결 타입(connect types)을 기반으로 조각 타입(piece types)을 결정합니다.
    def connect_to_piece(self, connect_types):
        device = connect_types.device
        connect_types = connect_types.detach().cpu()
        piece_types = []
        for connect_type in connect_types:
            piece_type = []
            for connect_type_row in connect_type:
                connect_bins = torch.bincount(connect_type_row)
                if torch.equal(connect_bins[1:5], torch.LongTensor([0,1,0,1])): #  ┌
                    piece_type.append(1)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([0,1,1,1])): # ㅜ
                    piece_type.append(2)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([0,1,1,0])): # ㄱ
                    piece_type.append(3)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([1,1,0,1])): # ㅏ
                    piece_type.append(4)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([1,1,1,0])): # ㅓ
                    piece_type.append(5)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([1,0,0,1])): # ㄴ
                    piece_type.append(6)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([1,0,1,1])): # ㅗ
                    piece_type.append(7)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([1,0,1,0])): # ┘
                    piece_type.append(8)
                elif torch.equal(connect_bins[1:5], torch.LongTensor([1,1,1,1])): # +
                    piece_type.append(9)
                else: # unknown
                    piece_type.append(0)
            piece_types.append(piece_type)
        piece_types = torch.LongTensor(piece_types).to(device)
        return piece_types

    # 전체 검증 성능을 계산하고 로깅
    def on_validation_epoch_end(self):
        order_pred = []
        order_true = []
        for probs, order in self.validation_step_outputs:
            order_pred.append(self._probs_to_order(probs))
            order_true.append(order)
        order_pred = torch.cat(order_pred).detach().cpu().numpy()
        order_true = torch.cat(order_true).detach().cpu().numpy()
        
        score, accuracies = self._get_score(order_true, order_pred)
    
        # 다양한 퍼즐 크기에 대한 정확도를 로깅합니다
        self.log("val_score_1x1", accuracies['1x1'])
        self.log("val_score_2x2", accuracies['2x2'])
        self.log("val_score_3x3", accuracies['3x3'])
        self.log("val_score_4x4", accuracies['4x4'])
        self.log("val_score", score)
        
        self.validation_step_outputs.clear()
        return

    # 주어진 실제 순서(order_true)와 예측 순서(order_pred)에 기반하여 점수를 계산
    def _get_score(self, order_true, order_pred):
        combinations_2x2 = [(i, j) for i in range(3) for j in range(3)]
        combinations_3x3 = [(i, j) for i in range(2) for j in range(2)]
        accuracies = {}
        accuracies['1x1'] = np.mean(order_true == order_pred)
        
        for size in range(2, 5): 
            correct_count = 0  
            total_subpuzzles = 0
            for i in range(len(order_true)):
                puzzle_a = order_true[i].reshape(4, 4)
                puzzle_s = order_pred[i].reshape(4, 4)
                combinations = combinations_2x2 if size == 2 else combinations_3x3 if size == 3 else [(0, 0)]
                for start_row, start_col in combinations:
                    rows = slice(start_row, start_row + size)
                    cols = slice(start_col, start_col + size)
                    if np.array_equal(puzzle_a[rows, cols], puzzle_s[rows, cols]):
                        correct_count += 1
                    total_subpuzzles += 1
            accuracies[f'{size}x{size}'] = correct_count / total_subpuzzles
        score = (accuracies['1x1'] + accuracies['2x2'] + accuracies['3x3'] + accuracies['4x4']) / 4.
        return score, accuracies

    # 확률(probs)을 기반으로 최적의 퍼즐 조각 순서를 결정
    def _probs_to_order(self, probs): # Greedily arrange the jigsaw puzzle pieces based on maximum probability.
        order = []
        for prob in probs:
            prob = prob.reshape(16,16).clone()
            indices = [-1 for _ in range(16)]
            for _ in range(16):
                i, j = divmod(int(prob.argmax()),16)
                indices[i]=j
                prob[i, :] = float('-inf')
                prob[:, j] = float('-inf')
            order.append(indices)
        order = torch.LongTensor(order)
        return order

    # 주어진 순서에 따라 이미지를 재배열
    def _reorder_image(self, images, reorders):
        device = images.device
        images_reordered = []
        for image, reorder in zip(images, reorders):
            image = image.cpu().numpy()
            reorder = reorder.cpu().numpy()
            c, h, w = image.shape
            block_h, block_w = h // 4, w // 4
            image_src = np.zeros((c, h, w), dtype=image.dtype)
            for idx, order in enumerate(reorder):
                h_idx, w_idx = divmod(order, 4)
                h_idx_shuffle, w_idx_shuffle = divmod(idx, 4)
                image_src[:, block_h * h_idx : block_h * (h_idx + 1), block_w * w_idx : block_w * (w_idx + 1)] = image[:, block_h * h_idx_shuffle : block_h * (h_idx_shuffle + 1), block_w * w_idx_shuffle : block_w * (w_idx_shuffle + 1)]
            images_reordered.append(image_src)
        
        images_reordered = np.stack(images_reordered)
        images_reordered = torch.from_numpy(images_reordered).to(device)
        return images_reordered
        

# 6. Trainer 설정 및 실행

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_score',
    mode='max',
    dirpath='./checkpoint3_aug/',
    filename='jigsawelectra-vitgap-{epoch:02d}-{val_score:.6f}',
    save_top_k=3,
    save_weights_only=True
)

In [None]:
earlystopping_callback = EarlyStopping(monitor="val_score", mode="max", patience=15)

In [None]:
lit_jigsaw_electra = LitJigsawElectra(model, model_config)

In [None]:
lit_jigsaw_electra = LitJigsawElectra.load_from_checkpoint('./checkpoint3_aug/jigsawelectra-vitgap-epoch=11-val_score=0.994698.ckpt',model=model, config=model_config)
lit_jigsaw_electra.inference_iter=1 #1e-6 사용 점수 0.9920312059 02-05

In [None]:
# TensorBoard Logger 설정
logger = TensorBoardLogger("lightning_logs", name="24-02-05")

L.seed_everything(config['seed'])
trainer = L.Trainer(max_epochs=150, precision='bf16-mixed', callbacks=[checkpoint_callback, earlystopping_callback], logger=logger)

In [None]:
# 경고 메시지를 무시하고 싶은 경우
warnings.filterwarnings('ignore', category=UserWarning, message="Creating a tensor from a list of numpy.ndarrays is extremely slow.*")

trainer.fit(lit_jigsaw_electra, train_dataloader, val_dataloader)

# 7. 모델 평가 및 예측

## 7.1 Model evaluate

In [None]:
lit_jigsaw_electra = LitJigsawElectra.load_from_checkpoint('./checkpoint3_aug/jigsawelectra-vitgap-epoch=11-val_score=0.994698.ckpt',model=model, config=model_config)
lit_jigsaw_electra.inference_iter=1 #1e-6 사용 점수 0.9920312059 02-05

In [None]:
trainer = L.Trainer()

In [None]:
# 경고 메시지를 무시하고 싶은 경우
warnings.filterwarnings('ignore', category=UserWarning, message="Creating a tensor from a list of numpy.ndarrays is extremely slow.*")

val_preds = trainer.predict(lit_jigsaw_electra, val_dataloader)

In [None]:
val_order_pred = torch.cat([order_pred for pixel_values, order_pred, order_true in val_preds]).cpu().numpy()
val_order_true = torch.cat([order_true for pixel_values, order_pred, order_true in val_preds]).cpu().numpy()

In [None]:
lit_jigsaw_electra._get_score(val_order_true, val_order_pred) 
# inference_iter=1 늘린다고 좋아지지 않음. pretrained image clf 를 이용하여 선별적으로 iterative하게 하면 더 좋아질지도.

## 7.2 Inference

In [None]:
preds = trainer.predict(lit_jigsaw_electra, pred_dataloader)

In [None]:
order_pred = torch.cat([order_pred for pixel_values, order_pred, _ in preds]).cpu().numpy()

In [None]:
submission = pd.read_csv('./sample_submission.csv')

In [None]:
submission.iloc[:,1:] = order_pred + 1

In [None]:
submission.to_csv('./submission/24-02-05-1_submission1.csv', index=False)