In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [2]:
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from tqdm.autonotebook import tqdm

from sklearn.model_selection import train_test_split

from glob import glob
import torch
import torch.nn as nn
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
import torchvision
import tifffile as tiff
from torchvision import transforms
from tqdm.auto import tqdm
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics.classification import BinaryF1Score
from torchmetrics.classification import BinaryPrecision
from torchmetrics.classification import BinaryRecall


import joblib
import torch.nn.functional as F

seed = 20240308

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Currently using "{device}" device.')

Currently using "cuda" device.


In [4]:
import random
def seed_everything(seed=42):
    random.seed(seed)               # Python의 내장 난수 생성기
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)            # Numpy의 난수 생성기
    torch.manual_seed(seed)         # PyTorch의 난수 생성기
    torch.cuda.manual_seed(seed)    # CUDA의 난수 생성기
    torch.cuda.manual_seed_all(seed) # 멀티-GPU 환경에서 CUDA 모든 난수 생성기

    # PyTorch가 가능한 한 결정적으로 동작하도록 하는 몇 가지 설정
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(seed)

In [5]:
batch_size = 6
image_size = 256
num_classes = 1
epochs = 40

In [6]:
path_images = '/home/hyj/ChanHyung/Image_segementation/Forest_Fire_Segmentation/dataset/train_img/'
path_masks = '/home/hyj/ChanHyung/Image_segementation/Forest_Fire_Segmentation/dataset/train_mask/'

images_paths = glob(path_images + '*.tif')
masks_paths = glob(path_masks + '*.tif')

images_paths = sorted([str(p) for p in images_paths])
masks_paths = sorted([str(p) for p in masks_paths])

df_train = pd.DataFrame({'images': images_paths, 'masks': masks_paths})

df_train.head(2)

Unnamed: 0,images,masks
0,/home/hyj/ChanHyung/Image_segementation/Forest...,/home/hyj/ChanHyung/Image_segementation/Forest...
1,/home/hyj/ChanHyung/Image_segementation/Forest...,/home/hyj/ChanHyung/Image_segementation/Forest...


In [7]:
path_images = '/home/hyj/ChanHyung/Image_segementation/Forest_Fire_Segmentation/dataset/test_img/'
path_masks = '/home/hyj/ChanHyung/Image_segementation/Forest_Fire_Segmentation/dataset/test_mask/'

images_paths = glob(path_images + '*.tif')
masks_paths = glob(path_masks + '*.tif')

images_paths = sorted([str(p) for p in images_paths])
masks_paths = sorted([str(p) for p in masks_paths])

df_test = pd.DataFrame({'images': images_paths, 'masks': masks_paths})

df_test.head(2)

Unnamed: 0,images,masks
0,/home/hyj/ChanHyung/Image_segementation/Forest...,/home/hyj/ChanHyung/Image_segementation/Forest...
1,/home/hyj/ChanHyung/Image_segementation/Forest...,/home/hyj/ChanHyung/Image_segementation/Forest...


In [9]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, train_mode=True, transform=None):
        self.dataframe = dataframe
        self.train_mode = train_mode
        self.transform = transform

        self.MAX_PIXEL_VALUE = 65535 

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx, 0]
        image = tiff.imread(image_path).astype('float32')

        # AFI 채널 계산
        AFI_channel = image[:, :, 6] / np.maximum(image[:, :, 1], 1)  # 0으로 나누는 것을 방지
        AFI_channel = np.expand_dims(AFI_channel, axis=2)

        # 원하는 채널 선택 및 AFI 채널 추가
        selected_channels = image[:, :, [0,1,2,3,4,5,6,8,9]]  # Blue, SWIR1, SWIR2 선택
        image_with_AFI = np.concatenate((selected_channels, AFI_channel), axis=2)  # AFI 채널 추가

        # 이미지 정규화 및 텐서 변환
        image_with_AFI = torch.from_numpy(image_with_AFI).permute(2, 0, 1) / self.MAX_PIXEL_VALUE

        if self.train_mode or len(self.dataframe.columns) == 2:
            mask_path = self.dataframe.iloc[idx, 1]
            mask = tiff.imread(mask_path).astype('float32')
            mask = torch.from_numpy(mask).unsqueeze(0)  # 채널 차원 추가

            return image_with_AFI, mask
        else:
            return image_with_AFI

transform = transforms.Compose([
    transforms.ToTensor(),
])


In [10]:
train_dataset = CustomDataset(df_train)
valid_dataset = CustomDataset(df_test, train_mode=False)
# test_dataset = CustomDataset(test, train_mode=False)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [11]:
import torch
import torch.nn as nn


class InceptionConv(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels

        self.double_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.double_conv2 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=5, padding=2, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=5, padding=2, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.double_conv3 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )

        self.double_conv4 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv_Inc = nn.Sequential(
            nn.Conv2d(4 * out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.relu = nn.ReLU()

    def forward(self, x):
        outputs = [self.double_conv1(x), self.double_conv2(x), self.double_conv3(x), self.double_conv4(x)]
        output2 = self.conv_Inc(torch.cat(outputs, 1))
        xx = output2 + self.conv_skip(x)
        xx_o = self.relu(xx)
        return xx_o


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, up_sample_mode):
        super(UpSample, self).__init__()
        if up_sample_mode == 'conv_transpose':
            self.up_sample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        elif up_sample_mode == 'bilinear':
            self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            raise ValueError("Unsupported `up_sample_mode` (can take one of `conv_transpose` or `bilinear`)")

    def forward(self, down_input):
        x = self.up_sample(down_input)
        return x


class UnetGridGatingSignal(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(UnetGridGatingSignal, self).__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(
                input_dim, output_dim, kernel_size=1, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        xx = self.conv_block(x)
        return xx


class AttentionBlock(nn.Module):
    """Attention block with learnable parameters"""

    def __init__(self, F_g, F_l, n_coefficients):
        """
        :param F_g: number of feature maps (channels) in previous layer
        :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
        :param n_coefficients: number of learnable multi-dimensional attention coefficients
        """
        super(AttentionBlock, self).__init__()

        self.W_gate = nn.Sequential(
            nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(n_coefficients)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=2, padding=0),
            nn.BatchNorm2d(n_coefficients)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.UpSampling2D = nn.Upsample(scale_factor=2)

        self.conv = nn.Sequential(
            nn.Conv2d(n_coefficients, n_coefficients, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(n_coefficients)
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, gate, skip_connection):
        """
        :param gate: gating signal from previous layer
        :param skip_connection: activation from corresponding encoder layer
        :return: output activations
        """

        g1 = self.W_gate(gate)
        x1 = self.W_x(skip_connection)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        upsample_sigmoid_xg = self.UpSampling2D(psi)
        out = skip_connection * upsample_sigmoid_xg.expand_as(skip_connection)
        return out


class base_Unet(nn.Module):
    def __init__(self, img_ch=10, output_ch=1, filters=[16, 32, 64, 128, 256, 512]):
        super(base_Unet, self).__init__()

        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=None)
        self.Dropout = nn.Dropout(p=0.1, inplace=True)

        self.conv_1 = InceptionConv(img_ch, filters[0])
        self.conv_2 = InceptionConv(filters[0], filters[1])
        self.conv_3 = InceptionConv(filters[1], filters[2])
        self.conv_4 = InceptionConv(filters[2], filters[3])
        self.conv_5 = InceptionConv(filters[3], filters[4])
        self.conv_6 = InceptionConv(filters[4], filters[5])

        self.gating = UnetGridGatingSignal(filters[5], filters[4], 1, 0)

        self.upsample1 = UpSample(filters[5], filters[5], up_sample_mode='bilinear')
        self.Att1 = AttentionBlock(F_g=filters[5], F_l=filters[4], n_coefficients=filters[4])
        self.up_conv1 = InceptionConv(filters[5] + filters[4], filters[4])

        self.upsample2 = UpSample(filters[4], filters[4], up_sample_mode='bilinear')
        self.Att2 = AttentionBlock(F_g=filters[4], F_l=filters[3], n_coefficients=filters[3])
        self.up_conv2 = InceptionConv(filters[4] + filters[3], filters[3])

        self.upsample3 = UpSample(filters[3], filters[3], up_sample_mode='bilinear')
        self.Att3 = AttentionBlock(F_g=filters[3], F_l=filters[2], n_coefficients=filters[2])
        self.up_conv3 = InceptionConv(filters[3] + filters[2], filters[2])

        self.upsample4 = UpSample(filters[2], filters[2], up_sample_mode='bilinear')
        self.Att4 = AttentionBlock(F_g=filters[2], F_l=filters[1], n_coefficients=filters[1])
        self.up_conv4 = InceptionConv(filters[2] + filters[1], filters[1])

        self.upsample5 = UpSample(filters[1], filters[1], up_sample_mode='bilinear')
        self.Att5 = AttentionBlock(F_g=filters[1], F_l=filters[0], n_coefficients=filters[0])
        self.up_conv5 = InceptionConv(filters[1] + filters[0], filters[0])

        self.output_layer = nn.Conv2d(filters[0], output_ch, 1, 1)

    def forward(self, x):
        # Encode
        x1 = self.conv_1(x)
        e1 = self.MaxPool(x1)

        x2 = self.conv_2(e1)
        e2 = self.MaxPool(x2)

        x3 = self.conv_3(e2)
        e3 = self.MaxPool(x3)

        x4 = self.conv_4(e3)
        e4 = self.MaxPool(x4)

        x5 = self.conv_5(e4)
        e5 = self.MaxPool(x5)

        x6 = self.conv_6(e5)

        # Decode
        x66 = self.upsample1(x6)
        g_conv5 = self.Att1(x6, x5)
        x7 = torch.cat((g_conv5, x66), dim=1)
        x8 = self.up_conv1(x7)

        x88 = self.upsample2(x8)
        g_conv4 = self.Att2(x8, x4)
        x9 = torch.cat((g_conv4, x88), dim=1)
        x10 = self.up_conv2(x9)

        x1010 = self.upsample3(x10)
        g_conv3 = self.Att3(x10, x3)
        x11 = torch.cat((g_conv3, x1010), dim=1)
        x12 = self.up_conv3(x11)

        x1212 = self.upsample4(x12)
        g_conv2 = self.Att4(x12, x2)
        x13 = torch.cat((g_conv2, x1212), dim=1)
        x14 = self.up_conv4(x13)

        x1414 = self.upsample5(x14)
        g_conv1 = self.Att5(x14, x1)
        x15 = torch.cat((g_conv1, x1414), dim=1)
        x16 = self.up_conv5(x15)

        output = self.output_layer(x16)

        return output

In [12]:
model = base_Unet().to(device)

In [13]:
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1 - intersection / union
    if p > 1:  # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def hinge(pred, label):
    signs = 2 * label - 1
    errors = 1 - pred * signs
    return errors


def lovasz_hinge_flat(logits, labels, ignore_index):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore_index: label to ignore
    """
    logits = logits.contiguous().view(-1)
    labels = labels.contiguous().view(-1)
    if ignore_index is not None:
        mask = labels != ignore_index
        logits = logits[mask]
        labels = labels[mask]
    errors = hinge(logits, labels)
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.elu(errors_sorted) + 1, grad)
    return loss
    
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss

class LovaszLoss(nn.Module):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore_index: label to ignore
    """
    def __init__(self, ignore_index=None):
        super().__init__()
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        return lovasz_hinge_flat(logits, labels, self.ignore_index)

In [14]:
def train(model, epoch, dataloader, optimizer, criterion, f1_score_metric, miou_metric):
    model.train()
    train_loss = []
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs} Training")
    for imgs, labels in progress_bar:
        imgs, labels = imgs.float().to(device), labels.to(device)
        optimizer.zero_grad()

        output = model(imgs)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())

        # F1 Score와 MIoU 계산
        f1_score = f1_score_metric(output, labels)
        miou = miou_metric(output, labels)

        # tqdm 바 업데이트 (손실, F1 Score, MIoU 포함)
        progress_bar.set_postfix(loss=loss.item(), F1_score=f1_score.item(), mIoU=miou.item())

    # epoch당 훈련 메트릭 평균 계산
    avg_train_loss = sum(train_loss) / len(train_loss)
    train_f1 = f1_score_metric.compute()
    train_miou = miou_metric.compute()
    
    miou_metric.reset()
    torch.cuda.empty_cache()

    return {"train_loss": avg_train_loss, "train_f1":train_f1, "train_miou":train_miou}

In [15]:
def validation(model, dataloader, criterion, f1_score_metric, miou_metric):
    model.eval()
    val_loss = []
    saved_outputs, saved_labels = [], []

    f1_score_metric.reset()
    miou_metric.reset()

    progress_bar = tqdm(dataloader, desc="Valid")
    with torch.no_grad():
        for imgs, labels in progress_bar:
            imgs, labels = imgs.float().to(device), labels.to(device)

            output = model(imgs)
            loss = criterion(output, labels)


            f1_score = f1_score_metric(output, labels)
            miou = miou_metric(output, labels)

            saved_outputs.append(output)
            saved_labels.append(labels)

            progress_bar.set_postfix(loss=loss.item(), F1_score=f1_score.item(), mIoU=miou.item())
            val_loss.append(loss.item())

    saved_outputs = torch.cat(saved_outputs, dim=0)
    saved_labels = torch.cat(saved_labels, dim=0)
    avg_val_loss = sum(val_loss) / len(val_loss)
    val_f1 = f1_score_metric.compute()
    val_miou = miou_metric.compute()
    miou_metric.reset()
    torch.cuda.empty_cache()

    return {"saved_outputs":saved_outputs, "saved_labels":saved_labels, "val_loss":avg_val_loss, "val_f1":val_f1, "val_miou":val_miou}

    

In [16]:
def visual(epoch, k, model, train_CFG, val_CFG, best_val_miou, threshold_min=-10, threshold_max=1, interv=0.1):
    # epoch 결과 출력
    print(f"Epoch {epoch}/{epochs} - Train Loss: {train_CFG['train_loss']:.4f}, Train F1: {train_CFG['train_f1']:.4f}, Train MIoU: {train_CFG['train_miou']:.4f}, "
            f"Val Loss: {val_CFG['val_loss']:.4f}, Val F1: {val_CFG['val_f1']:.4f}, Val MIoU: {val_CFG['val_miou']:.4f}")

    
    miou_values = []
    thresholds = np.arange(threshold_min, threshold_max, interv)
    miou_tmp = BinaryJaccardIndex().to(device)

    for threshold in tqdm(thresholds):
        # output을 이진 마스크로 변환
        miou_tmp.reset()
        saved_outputs_binary = val_CFG['saved_outputs'] > threshold
        miou = miou_tmp(saved_outputs_binary.cuda(), val_CFG['saved_labels'])
        miou_values.append(miou.item())

    # 가장 높은 mIoU 값 찾기
    max_miou = max(miou_values)
    max_threshold = thresholds[miou_values.index(max_miou)]

    # 최적 모델 저장 로직 (예시)
    if max_miou > best_val_miou:
        best_val_miou = max_miou
        torch.save(model, f"./Fold_Model/{k}_{epoch}_{max_miou:.4f}_{max_threshold:.4f}.pth")

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, miou_values, marker='o', linestyle='-', color='b')
    plt.title(f'max_threshold: {max_threshold}, max_MIoU: {max_miou}')
    plt.xlabel('Threshold')
    plt.ylabel('Mean Intersection over Union (MIoU)')
    plt.grid(True)
    plt.show()

    torch.cuda.empty_cache()
    return {"best_val_miou": best_val_miou}

In [17]:
criterion = LovaszLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)  # if unfreeze=True -> 1e-4, 1e-5, so not to ruin good init w
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.001, last_epoch=-1)

In [18]:
# torchmetrics를 사용하여 F1 Score와 MIOU 계산을 위한 준비
f1_score_metric = BinaryF1Score().to(device)
miou_metric = BinaryJaccardIndex(threshold=0.55).to(device)
precision_metric = BinaryPrecision().to(device)
recall_metric = BinaryRecall().to(device)
miou_tmp = BinaryJaccardIndex().to(device)
model.to(device)

best_val_miou = 0.8158
best_val_loss = 999999999999
best_model = None

In [None]:
best_val_miou = 0.8158
best_val_loss = 999999999999
best_model = None
for epoch in range(1, epochs + 1):
    model.train()
    train_loss = []
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} Training")
    for imgs, labels in progress_bar:
        imgs, labels = imgs.float().to(device), labels.to(device)
        optimizer.zero_grad()

        output = model(imgs)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())

        # F1 Score와 MIoU 계산
        f1_score = f1_score_metric(output, labels)
        miou = miou_metric(output, labels)

        # tqdm 바 업데이트 (손실, F1 Score, MIoU 포함)
        progress_bar.set_postfix(loss=loss.item(), F1_score=f1_score.item(), mIoU=miou.item())

    # epoch당 훈련 메트릭 평균 계산
    avg_train_loss = sum(train_loss) / len(train_loss)
    train_f1 = f1_score_metric.compute()
    train_miou = miou_metric.compute()


    model.eval()
    val_loss = []
    saved_outputs = []
    saved_labels = []
    f1_score_metric.reset()
    miou_metric.reset()
    progress_bar = tqdm(valid_dataloader, desc="Valid")

    with torch.no_grad():
        for imgs, labels in progress_bar:
            imgs, labels = imgs.float().to(device), labels.to(device)

            output = model(imgs)
            loss = criterion(output, labels)


            # F1 Score와 MIoU 계산 (검증 단계)
            f1_score = f1_score_metric(output, labels)
            miou = miou_metric(output, labels)

            saved_outputs.append(output)
            saved_labels.append(labels)

            # tqdm 바 업데이트 (검증 단계 손실, F1 Score, MIoU 포함)
            progress_bar.set_postfix(loss=loss.item(), F1_score=f1_score.item(), mIoU=miou.item())
            val_loss.append(loss.item())

    saved_outputs = torch.cat(saved_outputs, dim=0)
    saved_labels = torch.cat(saved_labels, dim=0)
    avg_val_loss = sum(val_loss) / len(val_loss)
    val_f1 = f1_score_metric.compute()
    val_miou = miou_metric.compute()

    # epoch 결과 출력
    print(f"Epoch {epoch}/{epochs} - Train Loss: {avg_train_loss:.4f}, Train F1: {train_f1:.4f}, Train MIoU: {train_miou:.4f}, "
            f"Val Loss: {avg_val_loss:.4f}, Val F1: {val_f1:.4f}, Val MIoU: {val_miou:.4f}")

    scheduler.step()


    f1_score_metric.reset()
    miou_metric.reset()

    miou_values = []
    thresholds = np.arange(-10, 1, 0.1)
    miou_tmp = BinaryJaccardIndex().to(device)

    for threshold in tqdm(thresholds):
        # output을 이진 마스크로 변환
        miou_tmp.reset()
        saved_outputs_binary = saved_outputs > threshold
        miou = miou_tmp(saved_outputs_binary.cuda(), saved_labels)
        miou_values.append(miou.item())

    # 가장 높은 mIoU 값 찾기
    max_miou = max(miou_values)
    max_threshold = thresholds[miou_values.index(max_miou)]
    # 최적 모델 저장 로직 (예시)
    if max_miou > best_val_miou:
        best_val_miou = max_miou
        torch.save(model, f"./model_save/{epoch}_jihwan1_{max_miou:.4f}_{max_threshold:.4f}.pth")

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, miou_values, marker='o', linestyle='-', color='b')
    plt.title(f'max_threshold: {max_threshold}, max_MIoU: {max_miou}')
    plt.xlabel('Threshold')
    plt.ylabel('Mean Intersection over Union (MIoU)')
    plt.grid(True)
    plt.show()

    torch.cuda.empty_cache()



