# Image Segmentation

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models
import torchmetrics

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(device)

## Data

In [None]:
data_dir = '../../dataset/'
train_dir = data_dir + 'train/'
val_dir = data_dir + 'val/'
unlabel_dir = data_dir + 'unlabeled/'

In [None]:
img = Image.open(data_dir + "train/video_0/image_21.png")
img

In [None]:
masks = np.load(data_dir + "train/video_0/mask.npy")
print(masks.shape)
plt.imshow(masks[21])

In [None]:
mask = np.load(data_dir + "train/video_0/mask.npy")[0]
obj_ids = np.unique(mask)  # instances are encoded as different colors
obj_ids = obj_ids[1: ]     # remove background (1st id)
print(obj_ids)

num_objs = len(obj_ids)

### Dataset

In [None]:
class UnlabeledDataset(Dataset):
    def __init__(self, root='../../dataset/unlabeled/', transform=None):
        self.root = root
        self.transform = transform
        self.vid_list = sorted(os.listdir(root))
        self.img_list = ['image_' + str(i) + '.png' for i in range(22)]
    
    def __len__(self):
        return len(self.vid_list) * 22

    def __getitem__(self, idx):
        vid_idx = idx // 22
        img_idx = idx % 22
        # load images (unlabeled)
        img_path = os.path.join(self.root, self.vid_list[vid_idx], self.img_list[img_idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img