# test

In [None]:
import os
import json
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large
import albumentations as A
from albumentations.pytorch import ToTensorV2

# 테스트용 데이터셋 클래스 정의
class TestDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images_dir = os.path.join(root_dir, 'images')
        self.image_files = sorted(os.listdir(self.images_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, img_name

# 테스트 데이터 변환 정의
test_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# 테스트 데이터셋 및 데이터로더 생성
test_dataset = TestDataset(root_dir='test', transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

# 모델 로드 및 설정
model = deeplabv3_mobilenet_v3_large(pretrained=False, num_classes=25)
model.load_state_dict(torch.load('best_autonomous_driving_segmentation_model.pth'))
model.to(device)
model.eval()

# 예측 함수
def predict(model, dataloader, device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, img_names in dataloader:
            images = images.to(device)
            outputs = model(images)['out']
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            for i in range(len(img_names)):
                pred_mask = preds[i]
                img_name = img_names[i]
                predictions.append((img_name, pred_mask))

    return predictions

# 예측 수행
predictions = predict(model, test_loader, device)

# 예측 결과 저장
output_dir = 'predictions'
os.makedirs(output_dir, exist_ok=True)

for img_name, pred_mask in predictions:
    # 결과 마스크를 이미지로 저장
    pred_mask_resized = cv2.resize(pred_mask, (256, 256), interpolation=cv2.INTER_NEAREST)
    output_path = os.path.join(output_dir, img_name.replace('.jpg', '_pred.png'))
    cv2.imwrite(output_path, pred_mask_resized)

print("Prediction complete. Predicted masks are saved in the 'predictions' directory.")


## modified name with including 'split' string