In [2]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset
import cv2
import json
import numpy as np
from PIL import Image
import os
import glob
import random
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from sklearn.model_selection import train_test_split
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import DeepLabV3
from torchvision.models.optical_flow import raft_small
import datetime
import ptlflow
from ptlflow.utils import flow_utils
from ptlflow.utils.io_adapter import IOAdapter

  from .autonotebook import tqdm as notebook_tqdm


[!!alt_cuda_corr is not compiled!!]


In [153]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
raft_model = raft_small(pretrained=True, progress=False).to(device)
raft_model = raft_model.eval()
'''
flow_model = ptlflow.get_model('fastflownet', ckpt_path='things')

gamma_values = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8]  # Gamma values
rotate_values = [-90, 0, 90, 180]
def transform(img):
    return img

def gamma_correction(img, gamma):
    return TF.adjust_gamma(img, gamma)

flip_types = ["horizontal", "vertical", "both", "none"]

def random_flip(img, flip_type):
    if flip_type == "horizontal":
        return TF.hflip(img)
    elif flip_type == "vertical":
        return TF.vflip(img)
    elif flip_type == "both":
        return TF.hflip(TF.vflip(img))
    return img  # No flip

def correct_affine(prev_frame, curr_frame):
    """
    특징점을 기반으로 Affine Transform을 사용하여 시점 보정 (이동 + 회전 포함)
    """
    orb = cv2.ORB_create()  # ORB 특징점 검출기 사용
    prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
    curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)

    # 특징점 검출
    kp1, des1 = orb.detectAndCompute(prev_gray, None)
    kp2, des2 = orb.detectAndCompute(curr_gray, None)

    # 특징점이 검출되지 않았을 경우 대비
    if des1 is None or des2 is None:
        return curr_frame  # 원본 프레임 반환 (보정 없이 유지)

    # 특징점 매칭 (BFMatcher 사용)
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    matches = bf.match(des1, des2)

    # 매칭이 너무 적으면 보정 불가능
    if len(matches) < 10:
        return curr_frame  # 원본 프레임 반환

    # 좋은 매칭 선택
    matches = sorted(matches, key=lambda x: x.distance)[:50]
    src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
    dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)

    # Affine Transform Matrix 계산
    M, _ = cv2.estimateAffine2D(dst_pts, src_pts)

    # 변환 행렬이 None이면 원본 반환
    if M is None:
        return curr_frame

    # Affine 변환 적용
    h, w = prev_frame.shape[:2]
    corrected_frame = cv2.warpAffine(curr_frame, M, (w, h))

    return corrected_frame

# {'Instrument', 'Care', 'Bubble', 'unkown', 'unknown', 'Fat', 'SoftTIssue', 'Dura', 'BF', 'SoftTissue', 'vessel', 'Vessel', 'Bone', 'LF', 'SofrTissue'}
num_classes = 11 # Background, BF, Vessel, Instrument, Care, Bubble, Fat, Bone, LF, Dura, SoftTissue

# 클래스별 라벨 매핑
class_map = {"BF": 1, "Vessel": 9, "vessel": 9, "Instrument": 3, "Care": 4, "Bubble": 5, "Fat": 6, "Bone": 7, "LF": 8, "Dura": 2, "SoftTissue": 10, "SofrTissue": 10, "SoftTIssue": 10}

# 데이터셋 클래스 정의
class BleedingDataset(Dataset):
    def __init__(self, image_files, x_offset=420, transform=None, augmentation=False):
        self.image_paths = image_files
        self.json_paths = [f.replace('.jpeg', '.json').replace('.png', '.json') for f in self.image_paths]
        self.transform = transform
        self.augmentation = augmentation
        self.transform_image = transform_image
        self.transform_mask = transform_mask
        self.toTensor = transforms.ToTensor()
        self.image1 = None
        self.image2 = None
        self.image1_resized = None
        self.image2_resized = None
        self.x_offset = x_offset
    
    def __len__(self):
        return len(self.image_paths) - 2

    def __getitem__(self, idx):
        x_offset = self.x_offset
        # 원본 이미지 로드
        if idx == 0:
            image_path1 = self.image_paths[idx]
            
            # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
            file_bytes1 = np.fromfile(image_path1, dtype=np.uint8)  
            image1 = cv2.imdecode(file_bytes1, cv2.IMREAD_COLOR)  # BGR로 로드
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)  # RGB 변환

            # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
            image1 = image1[0:1080, x_offset:x_offset + 1080]  # crop(x1, y1, x2, y2)
            image1_resized = cv2.resize(image1, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
            self.image1_resized = image1_resized
            image1 = torch.from_numpy(image1_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

            self.image1 = image1


            image_path2 = self.image_paths[idx+1]
            
            # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
            file_bytes2 = np.fromfile(image_path2, dtype=np.uint8)  
            image2 = cv2.imdecode(file_bytes2, cv2.IMREAD_COLOR)  # BGR로 로드
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)  # RGB 변환

            # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
            image2 = image2[0:1080, x_offset:x_offset +1080]  # crop(x1, y1, x2, y2)
            image2_resized = cv2.resize(image2, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
            self.image2_resized = image2_resized
            image2 = torch.from_numpy(image2_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

            self.image2 = image2

        else:
            self.image1 = self.image2
            image1 = self.image1
            self.image2 = self.image3
            image2 = self.image2

        
        image_path3 = self.image_paths[idx+2]
            
        # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
        file_bytes3 = np.fromfile(image_path3, dtype=np.uint8)  
        image3 = cv2.imdecode(file_bytes3, cv2.IMREAD_COLOR)  # BGR로 로드
        image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)  # RGB 변환

        # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
        image3 = image3[0:1080, x_offset:x_offset+1080]  # crop(x1, y1, x2, y2)
        image3_resized = cv2.resize(image3, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
        self.image3_resized = image3_resized
        image3 = torch.from_numpy(image3_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

        self.image3 = image3
        
        # JSON 파일 로드
        json_path = self.json_paths[idx+1]
        with open(json_path, 'r') as f:
            data = json.load(f)

        # 빈 마스크 생성
        mask = np.zeros((data["imageHeight"], data["imageWidth"]), dtype=np.uint8)

        # 출혈(BF) 영역 폴리곤 마스크 생성
        for shape in data["shapes"]:
            label = shape["label"]
            if label in class_map:
                points = np.array(shape["points"], dtype=np.int32)
                cv2.fillPoly(mask, [points], class_map[label])

        
        # PIL 이미지 변환 후 Tensor 변환
        mask = Image.fromarray(mask)
        mask = mask.crop((x_offset, 0, x_offset+1080, 1080))
        mask = self.transform_mask(mask)

        # if self.augmentation:
        gamma = random.choice(gamma_values)
        image1_gamma = gamma_correction(image1, gamma)
        image2_gamma = gamma_correction(image2, gamma)
        image3_gamma = gamma_correction(image3, gamma)

        rotate_degree = random.choice(rotate_values)
        image1_rotate = TF.rotate(image1_gamma, angle=rotate_degree)
        image2_rotate = TF.rotate(image2_gamma, angle=rotate_degree)
        image3_rotate = TF.rotate(image3_gamma, angle=rotate_degree)

        seq = torch.stack([image1_rotate, image2_rotate, image3_rotate], dim=0)
        
        mask = TF.rotate(mask, angle=rotate_degree)
            
        return seq, mask

# 데이터셋 클래스 정의
class BleedingDatasetTest(Dataset):
    def __init__(self, image_files, x_offset=0, transform=None, augmentation=False):
        self.image_paths = image_files
        # self.yolo_image_paths = yolo_image_files
        self.transform = transform
        self.augmentation = augmentation
        self.transform_image = transform_image
        self.toTensor = transforms.ToTensor()
        self.image1 = None
        self.image2 = None
        self.yolo_image1 = None
        self.yolo_image2 = None
        self.image1_resized = None
        self.image2_resized = None
        self.x_offset = x_offset

    def __len__(self):
        return len(self.image_paths) - 2

    def __getitem__(self, idx):
        x_offset = self.x_offset
        if idx == 0:
            image_path1 = self.image_paths[idx]
            
            # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
            file_bytes1 = np.fromfile(image_path1, dtype=np.uint8)  
            image1 = cv2.imdecode(file_bytes1, cv2.IMREAD_COLOR)  # BGR로 로드
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)  # RGB 변환

            # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
            image1 = image1[0:1080, x_offset:x_offset + 1080]  # crop(x1, y1, x2, y2)
            image1_resized = cv2.resize(image1, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
            self.image1_resized = image1_resized
            image1 = torch.from_numpy(image1_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

            self.image1 = image1


            image_path2 = self.image_paths[idx+1]
            
            # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
            file_bytes2 = np.fromfile(image_path2, dtype=np.uint8)  
            image2 = cv2.imdecode(file_bytes2, cv2.IMREAD_COLOR)  # BGR로 로드
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)  # RGB 변환

            # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
            image2 = image2[0:1080, x_offset:x_offset + 1080]  # crop(x1, y1, x2, y2)
            image2_resized = cv2.resize(image2, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
            self.image2_resized = image2_resized
            image2 = torch.from_numpy(image2_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

            self.image2 = image2

        else:
            self.image1 = self.image2
            image1 = self.image1
            self.image2 = self.image3
            image2 = self.image2

        
        image_path3 = self.image_paths[idx+2]
            
        # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
        file_bytes3 = np.fromfile(image_path3, dtype=np.uint8)  
        image3 = cv2.imdecode(file_bytes3, cv2.IMREAD_COLOR)  # BGR로 로드
        image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)  # RGB 변환

        # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
        image3 = image3[0:1080, x_offset:x_offset + 1080]  # crop(x1, y1, x2, y2)
        image3_resized = cv2.resize(image3, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
        self.image3_resized = image3_resized
        image3 = torch.from_numpy(image3_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

        self.image3 = image3

        seq = torch.stack([image1, image2, image3], dim=0)
        
        return seq

# 데이터셋 클래스 정의
class BleedingDatasetTestVideo(Dataset):
    def __init__(self, image_files, x_offset=0, transform=None, augmentation=False):
        self.image_paths = image_files
        self.transform = transform
        self.augmentation = augmentation
        self.transform_image = transform_image
        self.toTensor = transforms.ToTensor()
        self.image1 = None
        self.image2 = None
        self.image1_resized = None
        self.image2_resized = None
        self.x_offset = x_offset

    def __len__(self):
        return len(self.image_paths) - 2

    def __getitem__(self, idx):
        # 원본 이미지 로드
        x_offset = self.x_offset
        if idx == 0:
            image_path1 = self.image_paths[idx]
            
            # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
            file_bytes1 = np.fromfile(image_path1, dtype=np.uint8)  
            image1 = cv2.imdecode(file_bytes1, cv2.IMREAD_COLOR)  # BGR로 로드
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)  # RGB 변환

            # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
            image1 = image1[0:2160, x_offset:x_offset + 2160]  # crop(x1, y1, x2, y2)
            image1_resized = cv2.resize(image1, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
            self.image1_resized = image1_resized
            image1 = torch.from_numpy(image1_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

            self.image1 = image1


            image_path2 = self.image_paths[idx+1]
            
            # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
            file_bytes2 = np.fromfile(image_path2, dtype=np.uint8)  
            image2 = cv2.imdecode(file_bytes2, cv2.IMREAD_COLOR)  # BGR로 로드
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)  # RGB 변환

            # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
            image2 = image2[0:2160, x_offset:x_offset + 2160]  # crop(x1, y1, x2, y2)
            image2_resized = cv2.resize(image2, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
            self.image2_resized = image2_resized
            image2 = torch.from_numpy(image2_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

            self.image2 = image2

        else:
            self.image1 = self.image2
            image1 = self.image1
            self.image2 = self.image3
            image2 = self.image2

        
        image_path3 = self.image_paths[idx+2]
            
        # 🚀 빠른 파일 로딩: np.fromfile() + cv2.imdecode()
        file_bytes3 = np.fromfile(image_path3, dtype=np.uint8)  
        image3 = cv2.imdecode(file_bytes3, cv2.IMREAD_COLOR)  # BGR로 로드
        image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)  # RGB 변환

        # 🚀 빠른 크롭 (Pillow보다 OpenCV가 빠름)
        image3 = image3[0:2160, x_offset:x_offset + 2160]  # crop(x1, y1, x2, y2)
        image3_resized = cv2.resize(image3, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
        self.image3_resized = image3_resized
        image3 = torch.from_numpy(image3_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

        self.image3 = image3

        seq = torch.stack([image1, image2, image3], dim=0)
        
        return seq

class TemporalSelfAttention(nn.Module):
    def __init__(self, in_channels=3, heads=3):
        super(TemporalSelfAttention, self).__init__()
        self.heads = heads
        self.scale = (in_channels // heads) ** 0.5
        
        self.qkv_proj = nn.Conv3d(in_channels, in_channels * 3, kernel_size=1)
        self.output_proj = nn.Conv3d(in_channels, in_channels, kernel_size=1)
        
        # 게이트 학습 모듈 추가
        self.gate_conv = nn.Sequential(
            nn.Conv3d(in_channels, in_channels // 2, kernel_size=1),
            nn.ReLU(),
            nn.Conv3d(in_channels // 2, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        입력: x [B, C, T, H, W]  
        출력: [B, C, H, W] — 시간 축 요약된 feature
        """
        B, C, T, H, W = x.shape
        qkv = self.qkv_proj(x)  # [B, 3*C, T, H, W]
        qkv = qkv.view(B, 3, self.heads, C // self.heads, T, H, W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]  # 각각 [B, heads, C//heads, T, H, W]

        # Attention score 계산
        attn_scores = (q * k).sum(dim=2) / self.scale  # [B, heads, T, H, W]
        attn_weights = F.softmax(attn_scores, dim=2)  # 시간축(T)에 대해 softmax

        '''
        # Value에 attention weight 곱해서 시간축 요약
        attn_output = (attn_weights.unsqueeze(2) * v).sum(dim=3)  # [B, heads, C//heads, H, W]

        # heads 합치기
        out = attn_output.reshape(B, C, H, W)
        out = self.output_proj(out.unsqueeze(2)).squeeze(2)  # Conv3D→2D처럼 사용

        return out
        '''
        weighted_v = (attn_weights.unsqueeze(2) * v).sum(dim=3)  # [B, heads, C//heads, H, W]
        out = weighted_v.reshape(B, C, H, W).unsqueeze(2)

        out = self.output_proj(out)  # [B, C, 1, H, W]

        # 🎯 추가된 게이트로 시간정보의 신뢰도 조절
        gate = self.gate_conv(x)  # [B, 1, T, H, W]
        gate_score = gate[:, :, 1]  # 현재 프레임 위치 gate만 사용 (T=3에서 가운데)

        gated_out = out.squeeze(2) * gate_score  # 게이트 적용 후 output 반환

        return gated_out
        
class CustomBackBone(nn.Module):
    def __init__(self):
        super(CustomBackBone, self).__init__()
        
        # feature extractor
        self.rgb_conv = torchvision.models.mobilenet_v3_small(pretrained=True).features # (B, 576, H/32, W/32)
        self.flow_conv = torchvision.models.mobilenet_v3_small(pretrained=True).features # (B, 576, H/32, W/32)
        self.temporal_attention = TemporalSelfAttention(in_channels=3, heads=3)
        
        # fusion
        self.fusion = nn.Conv2d(576*2, 576, kernel_size=1)
        self.out_channels = 576
        
    
    def forward(self, x):
        x_seq = x
        # x_seq: [B, T(3), C, H, W] → input은 연속된 RGB 프레임
        B, T, C, H, W = x_seq.shape
        rgb_feature = self.rgb_conv(x_seq[:, 1, :, :, :])
        x_seq = x_seq.permute(0, 2, 1, 3, 4)  # → [B, C, T, H, W]
        x_attended = self.temporal_attention(x_seq)  # → [B, C, H, W]
        flow_feature = self.flow_conv(x_attended)  # → [B, 576, H/32, W/32]
        stack_feature = self.fusion(torch.cat([rgb_feature, flow_feature], dim=1))
        
        return {"out": stack_feature + rgb_feature}

# 데이터 변환 정의
transform_image = transforms.Compose([
    transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),  # 일반 이미지용,
    transforms.ToTensor(),
    #lambda x: x.long(),
])

transform_mask = transforms.Compose([
    transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.NEAREST),  # mask 용,
    transforms.ToTensor(),
    lambda x: x * 255,  # 다시 255를 곱하여 (0,255) 범위로 변환
    lambda x: x.long(),
])

def apply_transform_cv2(image_np):
    """OpenCV를 이용하여 numpy.ndarray를 변환"""
    image_resized = cv2.resize(image_np, (512, 512), interpolation=cv2.INTER_LINEAR)  # BILINEAR 보간
    image_tensor = torch.from_numpy(image_resized).permute(2, 0, 1).float() / 255.0  # (H, W, C) → (C, H, W)

    return image_tensor

# Dice Loss 정의
def dice_loss(pred, target, smooth=1e-6):
    pred = F.softmax(pred, dim=1)  # 여러 클래스 예측 확률로 변환
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2)  # One-hot encoding
    intersection = (pred * target_onehot).sum(dim=(2,3))
    union = pred.sum(dim=(2,3)) + target_onehot.sum(dim=(2,3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()  # 다중 클래스 dice loss

def focal_loss(pred, target, gamma=2.0):
    alpha = [1.0, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
    pred = F.softmax(pred, dim=1)  # 확률 분포
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
    pt = (pred * target_onehot).sum(dim=1)
    log_pt = torch.log(pt + 1e-6)  # log(0) 방지
    # Focal Loss 적용
    focal_loss = -((1 - pt) ** gamma) * log_pt
    return focal_loss.mean()

    '''
    ce_loss = -(target_onehot * torch.log(pred + 1e-6))  # Cross Entropy 기반
    focal_loss = (1 - pred) ** gamma * ce_loss
    return focal_loss.mean()
    '''

def iou_loss(pred, target, smooth=1e-6):
    pred = F.softmax(pred, dim=1)
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2)
    intersection = (pred * target_onehot).sum(dim=(2,3))
    union = pred.sum(dim=(2,3)) + target_onehot.sum(dim=(2,3)) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return 1 - iou.mean()

def fp_penalty_loss(pred, target):
    pred_soft = F.softmax(pred, dim=1)
    pred_label = pred_soft.argmax(dim=1)  # [B, H, W]
    fp_mask = (pred_label != target) & (pred_label != 0) & (target == 0)

    fp_confidence = pred_soft.max(dim=1)[0]  # confidence score
    penalty = (fp_mask * fp_confidence).mean()
    return penalty

def loss_fn(pred, target):
    return ((dice_loss(pred, target) + focal_loss(pred, target) + iou_loss(pred, target)) / 3) + 0.5 * fp_penalty_loss(pred, target)


[32m2025-03-25 12:31:06.417[0m | [1mINFO    [0m | [36mptlflow[0m:[36mrestore_model[0m:[36m280[0m - [1mRestored model state from checkpoint: things[0m


In [154]:
image_dir = "0014_spine_endoscope_data/"
image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.jpeg', '.png', '.jpg'))])  # 이미지 파일 리스트

# train 데이터셋 및 DataLoader 생성
train_dataset = BleedingDataset(image_files, transform=transform, augmentation=True)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)

In [155]:
# 🔹 Custom Backbone을 사용한 DeepLabV3 모델 정의
custom_backbone = CustomBackBone()

# 🔹 DeepLabV3 모델에 Custom Backbone 연결
model = DeepLabV3(
    backbone=custom_backbone,
    classifier=DeepLabHead(custom_backbone.out_channels, num_classes)  # classifier의 입력 크기를 backbone에 맞춤
)

# 출력 채널 변경 (COCO 클래스 → predict 클래스)
# model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

# model.load_state_dict(torch.load("deeplabv3_bleeding_self_attention_best.pth"))

# GPU 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 옵티마이저 설정
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for images, masks in train_loader:
        if len(images) <= 1: # batch number
            continue
        images, masks = images.to(device), masks.squeeze(1).to(device)
        # stacked_images = torch.cat(images, dim=1)
        optimizer.zero_grad()
        outputs = model(images)["out"]  # DeepLabV3의 출력 가져오기
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

# 모델 저장
torch.save(model.state_dict(), "deeplabv3_bleeding_self_attention_new.pth")
print("모델 저장 완료!")


Epoch [1/100], Loss: 0.9095
Epoch [2/100], Loss: 0.9008
Epoch [3/100], Loss: 0.8864
Epoch [4/100], Loss: 0.8754
Epoch [5/100], Loss: 0.8809
Epoch [6/100], Loss: 0.8712
Epoch [7/100], Loss: 0.8557
Epoch [8/100], Loss: 0.8502
Epoch [9/100], Loss: 0.8342
Epoch [10/100], Loss: 0.8216
Epoch [11/100], Loss: 0.8181
Epoch [12/100], Loss: 0.8032
Epoch [13/100], Loss: 0.8120
Epoch [14/100], Loss: 0.8014
Epoch [15/100], Loss: 0.7917
Epoch [16/100], Loss: 0.7841
Epoch [17/100], Loss: 0.7706
Epoch [18/100], Loss: 0.7764
Epoch [19/100], Loss: 0.7600
Epoch [20/100], Loss: 0.7655
Epoch [21/100], Loss: 0.7484
Epoch [22/100], Loss: 0.7380
Epoch [23/100], Loss: 0.7461
Epoch [24/100], Loss: 0.7453
Epoch [25/100], Loss: 0.7283
Epoch [26/100], Loss: 0.7140
Epoch [27/100], Loss: 0.7102
Epoch [28/100], Loss: 0.6952
Epoch [29/100], Loss: 0.6968
Epoch [30/100], Loss: 0.6847
Epoch [31/100], Loss: 0.6732
Epoch [32/100], Loss: 0.6651
Epoch [33/100], Loss: 0.6593
Epoch [34/100], Loss: 0.6480
Epoch [35/100], Loss: 0

In [157]:
# video test

def get_bounding_boxes(probability_map, threshold=0.4):
    """
    확률값이 threshold 이상인 픽셀들을 Bounding Box로 감싸기
    """
    # 확률 맵을 0~255로 정규화 후 이진화 (Threshold 적용)
    binary_mask = (probability_map > threshold).astype(np.uint8) * 255

    # Contour(외곽선) 찾기
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Bounding Boxes 저장
    bounding_boxes = [cv2.boundingRect(cnt) for cnt in contours]

    return bounding_boxes

def extract_frames(video_path, output_folder, fps=5):
    # 비디오 캡처 객체 생성
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video file.")
        return
    
    # 원본 비디오의 FPS 및 총 프레임 수 가져오기
    video_fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # 프레임을 저장할 폴더 생성
    os.makedirs(output_folder, exist_ok=True)
    
    frame_interval = video_fps // fps  # 몇 프레임마다 저장할지 계산
    frame_count = 0
    saved_count = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_count % frame_interval == 0:
            frame_filename = os.path.join(output_folder, f"frame_{saved_count:05d}.png")
            cv2.imwrite(frame_filename, frame)
            saved_count += 1
        
        frame_count += 1
    
    cap.release()
    print(f"Extracted {saved_count} frames and saved to {output_folder}")


video_file = "video5.mp4"  # MP4 파일 경로
# video_image_dir = "bleeding_test_1/"  # 저장할 image folder
# video_image_dir = "test_riwo/video2/"  # 저장할 image folder
video_image_dir = "test_riwo/video5/"  # 저장할 image folder


yolo_video_image_dir = "test_riwo/video2_yolo/"  # 저장할 image folder

# extract_frames(video_file, video_image_dir)


# output_dir = "bleeding_test_result_1"
# output_dir = "test_riwo/video2_result_new/"
output_dir = "test_riwo/video5_result_full/"

video_image_dirs = [
    "test_riwo/video5_important_image/", "test_riwo/video5_important_image2/",
    #"test_riwo/video2/", "test_riwo/video5/", # "video1_image/", "video2_image/", "video3_image/", "video4_image/", "video5_image/",
    "bleeding_test_1/", "bleeding_test_2/", "bleeding_test_3/"
]
output_dirs = [
    "test_riwo/video5_important_output1/", "test_riwo/video5_important_output2/",
    #"test_riwo/video2_output/", "test_riwo/video5_output/", # "video1_output/", "video2_output/", "video3_output/", "video4_output/", "video5_output/",
    "bleeding_test_1_output/", "bleeding_test_2_output/", "bleeding_test_3_output/"
]
offsets = [
    0, 0, 
    #280, 0, # 280, 280, 280, 0, 0,
    840, 840, 840
]
resols = [
    "low", "low",
    #"low", "low", # "low", "low", "low", "low", "low", 
    "high", "high", "high"
]

def test_video(video_image_dir, output_dir, x_offset, resol):
    output_dir = os.path.join("test_new/", output_dir)
    
    video_image_files = sorted([os.path.join(video_image_dir, f) for f in os.listdir(video_image_dir) if f.endswith(('.jpeg', '.png', '.jpg'))])  # 이미지 파일 리스트
    # yolo_video_image_files = sorted([os.path.join(yolo_video_image_dir, f) for f in os.listdir(yolo_video_image_dir) if f.endswith(('.jpeg', '.png', '.jpg'))])  # 이미지 파일 리스트
    
    # test 데이터셋 및 DataLoader 생성
    if resol == "low":
        video_test_dataset = BleedingDatasetTest(video_image_files, x_offset=x_offset, transform=transform, augmentation=False)
    else:
        video_test_dataset = BleedingDatasetTestVideo(video_image_files, x_offset=x_offset, transform=transform, augmentation=False)

    video_test_dataloader = DataLoader(video_test_dataset, batch_size=16, shuffle=False)
    
    
    # 모델 로드 (ResNet50 기반)
    # model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
    
    # 출력 채널 변경 (COCO 클래스 → predict 클래스)
    # model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
    
    # 🔹 Custom Backbone을 사용한 DeepLabV3 모델 정의
    custom_backbone = CustomBackBone()
    
    # 🔹 DeepLabV3 모델에 Custom Backbone 연결
    model = DeepLabV3(
        backbone=custom_backbone,
        classifier=DeepLabHead(custom_backbone.out_channels, num_classes)  # classifier의 입력 크기를 backbone에 맞춤
    )
    
    model.load_state_dict(torch.load("deeplabv3_bleeding_self_attention_new.pth"))
    
    # GPU 사용 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    os.makedirs(output_dir, exist_ok=True)
    
    total_num = 0
    start_time = datetime.datetime.now()
    
    for images in video_test_dataloader:
        
        images = images.to(device)
        with torch.no_grad():
            preds = model(images)["out"]  # DeepLabV3의 출력 가져오기
            images = images[:, 1, :, :, :]
        
        for image, pred in zip(images, preds):
            total_num += 1
            
            # 📌 후처리 (Threshold)
            pred = F.softmax(pred, dim=0)
            max_probs, pred_mask = torch.max(pred, dim=0)  # (512, 512)
    
            # threshold = 0.4
            # pred_mask[max_probs < threshold] = 0
    
            # NumPy 변환
            pred_mask = pred_mask.cpu().numpy()
            
            # 원본 이미지, 마스크 변환
            original_image = image.cpu().numpy().transpose(1,2,0)
            original_image = (original_image * 255).astype(np.uint8)  # 정규화 해제
           
            # ✅ 컬러맵 적용 (GT = Green, Pred = Red, Overlap = Yellow)
            overlay = np.array(original_image)
            overlay = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)
    
            green = [0, 255, 0]  # Prediction - Green
            red = [0, 0, 255]  # Prediction - Red
            blue = [255, 0, 0]  # Prediction - Blue
    
            yellow = [0, 255, 255]  # Overlapping - Yellow
            BF_color = [0, 237, 204] # BF color in YOLO
    
            '''
            mask_layer = np.zeros((512, 512, 3), dtype=np.uint8)
            mask_layer[pred_mask == 1] = green  # Prediction
            '''
    
            # 결과 마스크 생성
            # thresholded_mask_layer = np.zeros((512, 512, 3), dtype=np.uint8)
            # thresholded_mask_layer = np.zeros_like(overlay, dtype=np.uint8)
            # thresholded_mask_layer = cv2.applyColorMap(np.uint8(pred[1].cpu() * 255), cv2.COLORMAP_JET)
            
            # 확률이 높은 영역에 대한 Bounding Box 찾기
            bounding_boxes = get_bounding_boxes(pred[1].cpu().numpy(), threshold=0.5)
            
            for (x, y, w, h) in bounding_boxes:
                cv2.rectangle(overlay, (x, y), (x + w, y + h), BF_color, 2)
                # 텍스트 위치 설정 (Bounding Box의 오른쪽 위)
                text_x = x + w - 30  # 오른쪽 끝 - 30px
                text_y = y + 15  # 위쪽 여백 고려
                
                # 텍스트 바탕 박스 (검은색)
                cv2.rectangle(overlay, (text_x - 2, text_y - 12), 
                              (text_x + 22, text_y + 3), (0, 0, 0), -1)
        
                # 'BF' 텍스트 추가 (노란색)
                cv2.putText(overlay, "BF", (text_x, text_y), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1, cv2.LINE_AA)
            '''
            # 설정값
            grid_size = 4  # 4x4 그리드
            cell_size = 512 // grid_size  # 각 셀의 크기
            binary_pred_mask = (pred_mask == 1).astype(np.uint8)
    
            for i in range(grid_size):
                for j in range(grid_size):
                    # 각 셀의 위치 계산
                    x_start, x_end = i * cell_size, (i + 1) * cell_size
                    y_start, y_end = j * cell_size, (j + 1) * cell_size
                    
                    # 셀 내부의 마스크 값이 1인 픽셀의 비율 계산
                    cell = binary_pred_mask[x_start:x_end, y_start:y_end]
                    mask_ratio = np.mean(cell)
                    
                    # 마스크 비율 threshold
                    threshold = 0.001
    
                    if (i == 0 or i == grid_size - 1) and (j == 0 or j == grid_size - 1): # corner
                        threshold = 0.001
                    elif (i == 0 or i == grid_size - 1) or (j == 0 or j == grid_size - 1): # side
                        threshold = 0.001
                    
                    if mask_ratio > threshold:
                        thresholded_mask_layer[x_start:x_end, y_start:y_end] = green
            
            thresholded_mask_layer[pred_mask==1] = BF_color
            '''
    
            # 512x512 -> 1080x1080 업샘플링 (cubic interpolation 사용)
            # image_1080 = cv2.resize(thresholded_mask_layer, (1080, 1080), interpolation=cv2.INTER_CUBIC)
    
            '''
            # 1920x1080 빈 이미지 생성 (검은색 배경)
            image_1920x1080 = np.zeros((1080, 1920, 3), dtype=np.uint8)
            
            # 1080x1080 이미지를 1920x1080 이미지의 가로 280~1360 위치에 배치
            x_offset = 0  # 왼쪽 시작점
            y_offset = 0    # 세로 정렬 (상단부터 채우기)
            
            # 이미지 삽입
            image_1920x1080[y_offset:y_offset + 1080, x_offset:x_offset + 1080] = image_1080
            '''
            
            # ✅ 최종 합성
            # blended = cv2.addWeighted(overlay, 0.8, thresholded_mask_layer, 0.4, 0)
            blended = overlay
    
            # blended = cv2.resize(blended, (1920, 1080), interpolation=cv2.INTER_LANCZOS4)
    
            # ✅ 저장
            filename = f"output_{total_num:04d}.png"
            output_path = os.path.join(output_dir, filename)
            cv2.imwrite(output_path, blended)
    
    end_time = datetime.datetime.now()
    print("image num: " + str(len(video_test_dataloader) * 16))
    print("total time(ms):", end='')
    print(end_time - start_time)
    
    os.system(f"cd {output_dir} && rm output.mp4")
    os.system(f"cd {output_dir} && ffmpeg -framerate 25 -i output_%04d.png -c:v libx264 -pix_fmt yuv420p output.mp4 & cd ..")
    print("video test completed")

for video_image_dir, output_dir, offset, resol in zip(video_image_dirs, output_dirs, offsets, resols):
    test_video(video_image_dir, output_dir, offset, resol)

image num: 208
total time(ms):0:00:18.540117
video test completed


ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena

image num: 208
total time(ms):0:00:12.164772
video test completed


rm: cannot remove 'output.mp4': No such file or directory
ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx26

image num: 736
total time(ms):0:03:24.493396
video test completed


rm: cannot remove 'output.mp4': No such file or directory
ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx26

image num: 160
total time(ms):0:00:39.278794
video test completed


rm: cannot remove 'output.mp4': No such file or directory
ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx26

image num: 1536
total time(ms):0:06:26.785671
video test completed


rm: cannot remove 'output.mp4': No such file or directory
ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx26