In [1]:
from libs.config import get_config
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import os

In [2]:
from logging import getLogger


__all__ = ["get_dataset"]
logger = getLogger(__name__)

In [3]:
config = get_config('/mnt/sda1/Summarization/SurgSum/result/model_name=resnet50d/config.yaml')

## Create DataFrames

In [4]:
csv_path = Path(config.dataset_dir) / "csv"
dirs = list(csv_path.glob('*csv'))

In [12]:
all_df = pd.DataFrame()
for i, path in enumerate(tqdm(sorted(dirs))):
    file_name = path.stem
    # if file_name != 'video01':
    #     continue
    tmp = pd.read_csv(path)

    tmp['video_idx'] = int(file_name[-2:])  
    img_path = os.path.join(config.dataset_dir, 'video_split', file_name)
    tmp['file_name'] = sorted(os.listdir(img_path)[:len(tmp)])
    
    if config.val_vid_idx == int(file_name[-2:]):
        tmp['stage'] = 'val'
        factor = int(30 / config.fps_sampling_test)
    else:
        tmp['stage'] = 'train'
        factor = int(30 / config.fps_sampling)
    tmp = tmp.iloc[::factor]
    print(f'{file_name} after subsampling: {len(tmp)}')
    all_df = pd.concat([all_df, tmp], axis=0)


 60%|██████    | 3/5 [00:00<00:00, 10.42it/s]

video00 after subsampling: 4395
video01 after subsampling: 1903
video02 after subsampling: 3912


100%|██████████| 5/5 [00:00<00:00,  8.53it/s]

video03 after subsampling: 4896
video05 after subsampling: 4206





In [13]:
all_df.head()

Unnamed: 0,Frame,time,field,phase,summary,video_idx,file_name,stage
0,0,0:00:00.00,False,irrelevant,0.0,0,video00_000001.png,train
30,30,0:00:01.00,False,irrelevant,-1.0,0,video00_000031.png,train
60,60,0:00:02.00,False,irrelevant,-1.0,0,video00_000061.png,train
90,90,0:00:03.00,False,irrelevant,-1.0,0,video00_000091.png,train
120,120,0:00:04.00,False,irrelevant,-1.0,0,video00_000121.png,train


## dataset

In [54]:
from logging import getLogger
from typing import Any, Dict, Optional

import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

__all__ = ["get_dataloader"]

logger = getLogger(__name__)


class ExtractorDataset(torch.utils.data.Dataset):
    def __init__(self, df, config, stage="train"):
        self.stage = stage
        self.df = df
        self.df = self.df[self.df["stage"] == self.stage]
        self.config = config
        self.class_labels = self.get_labels()

    def __getitem__(self, index):
        row = self.df.iloc[index]
        video_name = "video" + str(row.video_idx).zfill(2)
        data_path = os.path.join(self.config.dataset_dir, "video_split", video_name, row.file_name)
        img = Image.open(data_path)
        img = np.array(img)
        if self.stage == 'train':
            img = self.transform()(image=img)["image"]
        label = torch.tensor(self.class_labels[row.phase])

        return img.float(), label.float()
        


    def __len__(self):
        return len(self.df)
    
    def get_labels(self):
        class_labels = {}
        for i,label in enumerate(self.df.phase.unique()):
            class_labels[label] = i
            print(label, i)
        return class_labels

    def transform(self):
        transforms = [
                A.Normalize(mean=(0,0,0), std=(1,1,1)),
        ]
        
        if self.stage == 'train':
            if self.config.aug_ver == 1:
                transforms += [
                A.RandomResizedCrop(always_apply=False, p=1.0, height=self.img_size, width=self.img_size, scale=(0.7, 1.2), ratio=(0.75, 1.3), interpolation=1),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                ]
            elif self.config.aug_ver == 2:
                transforms += [
                    A.HorizontalFlip(p=0.3),
                    A.VerticalFlip(p=0.3),
                ]
            
        transforms.append(ToTensorV2(p=1))

        return A.Compose(transforms)


In [55]:
ds = ExtractorDataset(all_df, config, 'train')

irrelevant 0
design 1
anesthesia 2
incision 3
hemostasis 4
dissection 5
closure 6
others 7


In [56]:
img, label = ds.__getitem__(0)

In [57]:
img.shape

torch.Size([3, 250, 250])

In [58]:
label

tensor(0.)

In [35]:
all_df.phase.unique()

array(['irrelevant', 'design', 'anesthesia', 'incision', 'hemostasis',
       'dissection', 'closure', 'others'], dtype=object)

In [37]:
class_labels = {}
for i,label in enumerate(all_df.phase.unique()):
    class_labels[label] = i
    print(label, i)
class_labels

irrelevant 0
design 1
anesthesia 2
incision 3
hemostasis 4
dissection 5
closure 6
others 7


{'irrelevant': 0,
 'design': 1,
 'anesthesia': 2,
 'incision': 3,
 'hemostasis': 4,
 'dissection': 5,
 'closure': 6,
 'others': 7}

In [None]:
class_labels = {
    'irrelevant':0, 
    'design':1, 
    'anesthesia':2, 
    'incision':3, 
    'hemostasis':4,
    'dissection':5, 
    'closure':6, 
    'others':7
    }