In [None]:
import warnings
import matplotlib as mpl
mpl.set_loglevel("error")
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

import sys
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from dotenv import dotenv_values
import ipywidgets as widgets
from IPython.display import display, clear_output
from torch.utils.data import DataLoader

# 하이드라와 주피터 노트북은 아규먼트 관련 충돌이 발생하므로 초기화 해줌
sys.argv = ['']
# 환경변수 읽기
if (python_path := dotenv_values().get('PYTHONPATH')) and python_path not in sys.path: sys.path.append(python_path)

from src.dataset.CvImageDatasetFastEx import get_datasets
#from src.dataset.CvImageDataset import get_datasets
from src.models.CustomModelEx import CustomModelEx

button = widgets.Button(description="다음 배치")
output = widgets.Output()

# model config
model_name = 'tf_efficientnet_b4' # 'resnet50' 'efficientnet_b4', ...

# training config
EPOCHS = 50
BATCH_SIZE = 16
num_workers = 0
num_classes = 17
learning_rate = 1e-4
drop_out = 0.4
do_test = True


# 시드 고정
def random_seed(seed_num=42):

    """ SEED = seed_num
    os.environ['PYTHONHASHSEED'] = str(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True """
    
    # seed_everything 은 위의 내용 제어 + 밑에내용
    pl.seed_everything(seed_num)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

# 데이터 준비 함수
def prepare_data(batch_size=32, num_workers=4):
    
       
    random_seed(42)

    model = CustomModelEx(
            model_name= model_name,
            num_classes=num_classes,
            learning_rate=learning_rate,
            drop_rate = drop_out
        )


   # 데이터셋 생성
    _, val_dataset, _ = get_datasets(model)


    val_loader = DataLoader(
        val_dataset,  # 별도의 검증 데이터셋
        batch_size=batch_size,
        shuffle=False,  # 검증 시에는 셔플하지 않음
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )

    return val_loader


# 데이터 로더 준비
val_loader = prepare_data(batch_size=BATCH_SIZE, num_workers=num_workers)


# 페이지 인덱스 전역 변수
page = 0

# 전체 배치 개수 계산
num_batches = len(val_loader)

# 버튼 및 출력 위젯 생성
prev_button = widgets.Button(description="이전 배치")
next_button = widgets.Button(description="다음 배치")
output = widgets.Output()

def show_images_from_loader_by_page(page_idx):
    """특정 배치(페이지)만 시각화"""
    for batch_idx, (images, _) in enumerate(val_loader):
        if batch_idx == page_idx:
            #batch_size = images.size(0)
            batch_size = 10
            cols = min(batch_size, 5)
            rows = (batch_size + cols - 1) // cols
            plt.figure(figsize=(3 * cols, 3 * rows))
            for i in range(batch_size):
                img = images[i].permute(1, 2, 0).numpy()
                plt.subplot(rows, cols, i + 1)
                plt.imshow(img)
                plt.axis('off')
            plt.suptitle(f"view image Page ({page_idx+1}/{num_batches})")
            plt.tight_layout()
            #plt.show()
            break
    else:
        print(f"페이지 {page_idx}에 해당하는 배치가 없습니다.")

def on_prev_clicked(b):
    global page
    if page > 0:
        page -= 1
        with output:
            clear_output(wait=True)
            show_images_from_loader_by_page(page)

def on_next_clicked(b):
    global page
    if page < num_batches - 1:
        page += 1
        with output:
            clear_output(wait=True)
            show_images_from_loader_by_page(page)


def main():

    # 버튼 이벤트 등록
    prev_button.on_click(on_prev_clicked)
    next_button.on_click(on_next_clicked)

    # 초기 화면 출력
    with output:
        show_images_from_loader_by_page(page)

    # 버튼과 출력 위젯 표시
    display(widgets.HBox([prev_button, next_button]), output)
  
if __name__ == "__main__":
    main()