In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

In [None]:
from torchsummary import summary
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import cv2
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
model = smp.UnetPlusPlus(encoder_name='resnet101', classes=6, activation='softmax').to(device)

In [None]:
# print(model)

In [None]:
torch.manual_seed(42)
np.random.seed(42)

In [None]:
train_anger_dir='/content/drive/MyDrive/train/anger'
train_happy_dir='/content/drive/MyDrive/train/happy'
train_panic_dir='/content/drive/MyDrive/train/panic'
train_sadness_dir='/content/drive/MyDrive/train/sadness'
val_anger_dir='/content/drive/MyDrive/val/anger'
val_happy_dir='/content/drive/MyDrive/val/happy'
val_panic_dir='/content/drive/MyDrive/val/panic'
val_sadness_dir='/content/drive/MyDrive/val/sadness'

In [None]:
anger_npz_file = np.load('/content/drive/MyDrive/segmentation/train/train_anger.npz')
happy_npz_file = np.load('/content/drive/MyDrive/segmentation/train/train_happy.npz')
panic_npz_file = np.load('/content/drive/MyDrive/segmentation/train/train_panic.npz')
sadness_npz_file = np.load('/content/drive/MyDrive/segmentation/train/train_sadness.npz')
val_anger_npz_file = np.load('/content/drive/MyDrive/segmentation/val/val_anger.npz')
val_happy_npz_file = np.load('/content/drive/MyDrive/segmentation/val/val_happy.npz')
val_panic_npz_file = np.load('/content/drive/MyDrive/segmentation/val/val_panic.npz')
val_sadness_npz_file = np.load('/content/drive/MyDrive/segmentation/val/val_sadness.npz')

In [None]:
class CustomSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_npz_file, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_npz_file = mask_npz_file
        self.transform = transform
        self.mask_transform = mask_transform
        self.image_files = os.listdir(self.image_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_file)
        image=cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        mask = self.mask_npz_file[image_file]

        # 마스크를 PIL 이미지로 변환
        mask = Image.fromarray(mask)

        image = self.transform(image)
        mask = self.mask_transform(mask)

        return image, mask.squeeze(0)

In [None]:
resize_size = (512, 512)
transform = transforms.Compose([
    transforms.Resize(resize_size),  # 이미지 크기 조정
    transforms.ToTensor(),           # 이미지를 텐서로 변환
])
mask_transform = transforms.Compose([
    transforms.Resize(resize_size),
    transforms.Lambda(lambda x: torch.tensor(np.array(x)))
])


In [None]:
# 데이터셋 생성
train_anger_dataset = CustomSegmentationDataset(train_anger_dir, anger_npz_file, transform=transform, mask_transform=mask_transform)
train_happy_dataset = CustomSegmentationDataset(train_happy_dir, happy_npz_file, transform=transform, mask_transform=mask_transform)
train_panic_dataset = CustomSegmentationDataset(train_panic_dir, panic_npz_file, transform=transform, mask_transform=mask_transform)
train_sadness_dataset = CustomSegmentationDataset(train_sadness_dir, sadness_npz_file, transform=transform, mask_transform=mask_transform)

val_anger_dataset = CustomSegmentationDataset(val_anger_dir, val_anger_npz_file, transform=transform, mask_transform=mask_transform)
val_happy_dataset = CustomSegmentationDataset(val_happy_dir, val_happy_npz_file, transform=transform, mask_transform=mask_transform)
val_panic_dataset = CustomSegmentationDataset(val_panic_dir, val_panic_npz_file, transform=transform, mask_transform=mask_transform)
val_sadness_dataset = CustomSegmentationDataset(val_sadness_dir, val_sadness_npz_file, transform=transform, mask_transform=mask_transform)

# 데이터 로더 설정
batch_size = 4
train_datasets = ConcatDataset([
    train_anger_dataset,
    train_happy_dataset,
    train_panic_dataset,
    train_sadness_dataset
])

# 검증 데이터셋 결합
val_datasets = ConcatDataset([
    val_anger_dataset,
    val_happy_dataset,
    val_panic_dataset,
    val_sadness_dataset
])

# 데이터 로더 설정
train_loader = DataLoader(train_datasets,
                          batch_size=batch_size,
                          shuffle=True,
                          pin_memory=True,
                          # prefetch_factor=1,
                          num_workers=4)
val_loader = DataLoader(val_datasets, batch_size=batch_size, shuffle=False, pin_memory=True,
                        # prefetch_factor=1,
                        num_workers=1)




In [None]:
def compute_iou(preds, labels, num_classes):
    iou = []
    preds = torch.argmax(preds, dim=1)  # Get the class with highest probability
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).sum().float().item()
        union = (pred_inds | target_inds).sum().float().item()
        if union == 0:
            iou.append(float('nan'))  # If no ground truth, do not count
        else:
            iou.append(intersection / union)
    return np.nanmean(iou)


In [None]:
num_epochs = 20
num_classes = 6
# criterion_ce = nn.CrossEntropyLoss()
criterion_dice = smp.losses.DiceLoss(mode='multiclass')
optimizer = optim.Adam(model.parameters(), lr=0.001)

save_dir = 'saved_models'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

train_losses = []
val_losses = []
train_ious = []
val_ious = []
best_val_loss = float('inf')

# 모델 학습
for epoch in range(num_epochs):
    model.train()
    train_running_loss = 0.0
    train_running_iou = 0.0

    for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        inputs = inputs.to(device)
        labels = labels.long().to(device)  # Ensure labels are of type Long for CrossEntropyLoss

        optimizer.zero_grad()
        outputs = model(inputs)
        # loss_ce = criterion_ce(outputs, labels)
        loss_dice = criterion_dice(outputs, labels)
        # loss = loss_ce + loss_dice
        loss = loss_dice
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item() * inputs.size(0)
        train_running_iou += compute_iou(outputs, labels, num_classes) * inputs.size(0)

    train_epoch_loss = train_running_loss / len(train_loader.dataset)
    train_epoch_iou = train_running_iou / len(train_loader.dataset)

    model.eval()
    val_running_loss = 0.0
    val_running_iou = 0.0

    with torch.no_grad():
        for val_inputs, val_labels in tqdm(val_loader,  desc=f'Validation'):
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.long().to(device)

            val_outputs = model(val_inputs)
            # val_loss_ce = criterion_ce(val_outputs, val_labels)
            val_loss_dice = criterion_dice(val_outputs, val_labels)
            # val_loss = val_loss_ce + val_loss_dice

            val_loss = val_loss_dice
            val_running_loss += val_loss.item() * val_inputs.size(0)
            val_running_iou += compute_iou(val_outputs, val_labels, num_classes) * val_inputs.size(0)

    val_epoch_loss = val_running_loss / len(val_loader.dataset)
    val_epoch_iou = val_running_iou / len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_epoch_loss:.4f}, Train IoU: {train_epoch_iou:.4f}, Validation Loss: {val_epoch_loss:.4f}, Validation IoU: {val_epoch_iou:.4f}")

    model_name = f'Unet++_model_epoch{epoch+1}.pth'
    model_path = os.path.join(save_dir, model_name)
    torch.save(model.state_dict(), model_path)

    train_losses.append(train_epoch_loss)
    val_losses.append(val_epoch_loss)
    train_ious.append(train_epoch_iou)
    val_ious.append(val_epoch_iou)

  return self._call_impl(*args, **kwargs)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 1/20: 100%|██████████| 1487/1487 [33:24<00:00,  1.35s/it]
Validation: 100%|██████████| 300/300 [05:35<00:00,  1.12s/it]


Epoch [1/20], Train Loss: 0.7394, Train IoU: 0.6731, Validation Loss: 0.7400, Validation IoU: 0.7093


Epoch 2/20: 100%|██████████| 1487/1487 [33:03<00:00,  1.33s/it]
Validation: 100%|██████████| 300/300 [05:52<00:00,  1.17s/it]


Epoch [2/20], Train Loss: 0.7344, Train IoU: 0.7329, Validation Loss: 0.7375, Validation IoU: 0.7419


Epoch 3/20: 100%|██████████| 1487/1487 [33:10<00:00,  1.34s/it]
Validation: 100%|██████████| 300/300 [08:01<00:00,  1.61s/it]


Epoch [3/20], Train Loss: 0.7341, Train IoU: 0.7398, Validation Loss: 0.7388, Validation IoU: 0.7289


Epoch 4/20: 100%|██████████| 1487/1487 [33:14<00:00,  1.34s/it]
Validation: 100%|██████████| 300/300 [05:34<00:00,  1.11s/it]


Epoch [4/20], Train Loss: 0.7333, Train IoU: 0.7490, Validation Loss: 0.7365, Validation IoU: 0.7497


Epoch 5/20: 100%|██████████| 1487/1487 [32:54<00:00,  1.33s/it]
Validation: 100%|██████████| 300/300 [05:38<00:00,  1.13s/it]


Epoch [5/20], Train Loss: 0.7342, Train IoU: 0.7490, Validation Loss: 0.7359, Validation IoU: 0.7575


Epoch 6/20: 100%|██████████| 1487/1487 [32:55<00:00,  1.33s/it]
Validation: 100%|██████████| 300/300 [05:33<00:00,  1.11s/it]


Epoch [6/20], Train Loss: 0.7318, Train IoU: 0.7558, Validation Loss: 0.7380, Validation IoU: 0.7276


Epoch 7/20: 100%|██████████| 1487/1487 [36:05<00:00,  1.46s/it]
Validation: 100%|██████████| 300/300 [06:32<00:00,  1.31s/it]


Epoch [7/20], Train Loss: 0.7337, Train IoU: 0.7578, Validation Loss: 0.7364, Validation IoU: 0.7394


Epoch 8/20: 100%|██████████| 1487/1487 [32:53<00:00,  1.33s/it]
Validation: 100%|██████████| 300/300 [05:34<00:00,  1.12s/it]


Epoch [8/20], Train Loss: 0.7302, Train IoU: 0.7637, Validation Loss: 0.7364, Validation IoU: 0.7474


Epoch 9/20:  96%|█████████▌| 1422/1487 [31:28<01:26,  1.33s/it]


error: Caught error in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataset.py", line 348, in __getitem__
    return self.datasets[dataset_idx][sample_idx]
  File "<ipython-input-10-3f34aa09d625>", line 16, in __getitem__
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cv2.error: OpenCV(4.8.0) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'



In [None]:
#  for val_inputs, val_labels in tqdm(val_loader,  desc=f'Validation'):
#             val_inputs = val_inputs.to(device)
#             val_labels = val_labels.long().to(device)