In [120]:
import torch
from torch.utils.data import Dataset
import pandas as pd
from torchvision.datasets import ImageFolder
from tqdm import tqdm, trange
import sklearn


In [127]:
class GeneralDataset(Dataset):
    def __init__(self, root,
                 main_transform=None, pair_transform=None, comp_transform=None, 
                 invert=False, **kwargs):
        super(GeneralDataset).__init__()
        self.root = root
        self.main_transform = main_transform
        self.pair_transform = pair_transform
        self.comp_transform = comp_transform
        self.invert = invert
        self._dataset = ImageFolder(root=self.root)
        self.label_image_dict = self._create_label_image_dict()
        self.dataframe = self._build_pair_set()

    def _build_pair_set(self):
        simm_df = self._create_similar_pair()
        diff_df = self._create_different_pair()
        diff_df = self._balance_different_pair(simm_df, diff_df)
        dataframe = self._combine_simm_diff_pair(simm_df, diff_df)
        return dataframe

    def _create_label_image_dict(self):
        img_label_list = sorted(self._dataset.imgs, key=lambda x: x[1],  reverse=False)
        label_image_dict = {lbl: [] for img, lbl in img_label_list}
        for impath, label in img_label_list:
            label_image_dict[label].append(impath)
        return label_image_dict

    def _create_similar_pair(self):
        # create pair same pair
        simm_pair = {'main_image': [], 'main_label_idx':[], 'main_label_name': [], 
                    'comp_image': [], 'comp_label_idx':[], 'comp_label_name': [], 
                    'label': [], 'status':[]}
        for key, list_value in self.label_image_dict.items():
            for idx, main_img in enumerate(list_value):
                for jdx, comp_image in enumerate(list_value):
                    if idx!=jdx:
                        simm_pair['main_image'].append(main_img)
                        simm_pair['main_label_name'].append(self._dataset.classes[key])
                        simm_pair['main_label_idx'].append(key)
                        simm_pair['comp_image'].append(comp_image)
                        simm_pair['comp_label_name'].append(self._dataset.classes[key])
                        simm_pair['comp_label_idx'].append(key)
                        simm_pair['label'] = int(key != key)
                        simm_pair['status'] = 'similar'
        simm_df = pd.DataFrame(simm_pair) 
        return simm_df

    def _create_different_pair(self):
        diff_pair = {'main_image': [], 'main_label_idx':[], 'main_label_name': [], 
             'comp_image': [], 'comp_label_idx':[], 'comp_label_name': [], 
             'label': [], 'status':[]}
        for main_key, main_list_value in tqdm(self.label_image_dict.items()):
            for diff_key, diff_list_value in self.label_image_dict.items():
                if main_key!=diff_key:
                    for idx, main_img in enumerate(main_list_value):
                        for jdx, comp_image in enumerate(diff_list_value):
                            diff_pair['main_image'].append(main_img)
                            diff_pair['main_label_name'].append(self._dataset.classes[main_key])
                            diff_pair['main_label_idx'].append(main_key)
                            diff_pair['comp_image'].append(comp_image)
                            diff_pair['comp_label_name'].append(self._dataset.classes[diff_key])
                            diff_pair['comp_label_idx'].append(diff_key)
                            diff_pair['label'] = int(main_key != diff_key)
                            diff_pair['status'] = 'different'
        diff_df = pd.DataFrame(diff_pair)
        return diff_df

    def _balance_different_pair(self, simm_df, diff_df, random_state=1261):
        diff_df_list = []
        for idx, name in enumerate(self._dataset.classes):
            label_name = self._dataset.classes[idx]
            simm_df_by_idx = simm_df[simm_df['main_label_name'] == label_name]
            len_simm_idx = len(simm_df_by_idx)

            label_name = self._dataset.classes[idx]
            diff_df_by_idx = diff_df[diff_df['main_label_name'] == label_name]
            len_diff_idx = len(diff_df_by_idx)

            balance_ratio = len_simm_idx / len_diff_idx
            diff_df_ratio_idx = diff_df_by_idx.sample(frac=balance_ratio, random_state=random_state).reset_index(drop=True)
            diff_df_list.append(diff_df_ratio_idx)

        diff_df = pd.concat(diff_df_list)
        return diff_df
    
    def _combine_simm_diff_pair(self, simm_df, diff_df, shuffle=True, random_state=1261):
        main_df = pd.concat([simm_df, diff_df])
        main_df = main_df.reset_index(drop=True)
        if shuffle:
            main_df = sklearn.utils.shuffle(main_df, random_state=1261)
            main_df = main_df.reset_index(drop=True)
        return main_df
    
    def _load_image(self, path: str, to_rgb=True):
        image = Image.open(path)
        if to_rgb:
            image = image.convert("RGB")
        else:
            image = image.convert("L")
        return image
    
    def _preprocess_label(self, label):
        label_numpy = np.array([label],dtype=np.float32)
        return torch.from_numpy(label_numpy)
                    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        record = self.dataframe.iloc[idx]
        main_path, comp_path = record['main_image'], record['comp_image']
        main_image, comp_image = self._load_image(main_path), self._load_image(comp_path)
        label = self._preprocess_label(record['label'])
        
        if self.invert:
            main_image = ImageOps.invert(main_image)
            comp_image = ImageOps.invert(comp_image)


        if self.pair_transform:
            main_image, comp_image = self.pair_transform(main_image, comp_image)
        
        if self.main_transform:
            main_image = self.main_transform(main_image)
            
        if self.comp_transform:
            comp_image = self.comp_transform(comp_image)
        
        return main_image, comp_image, label
        
        
        
        return record
        
    

In [128]:
root = '../dataset/train/'
dataset = GeneralDataset(root=root)

100%|██████████| 40/40 [00:00<00:00, 840.24it/s]


In [131]:
dataset[0]

main_image         ../dataset/train/s10/2.pgm
main_label_idx                              1
main_label_name                           s10
comp_image         ../dataset/train/s10/4.pgm
comp_label_idx                              1
comp_label_name                           s10
label                                       0
status                                similar
Name: 0, dtype: object

In [66]:

#create label image dictionary
img_label_list = sorted(_dataset.imgs, key=lambda x: x[1],  reverse=False)
label_img_dict = {lbl: [] for img, lbl in img_label_list}
for impath, label in img_label_list:
    label_img_dict[label].append(impath)

In [69]:
# create pair same pair
simm_pair = {'main_image': [], 'main_label_idx':[], 'main_label_name': [], 
             'comp_image': [], 'comp_label_idx':[], 'comp_label_name': [], 
             'label': [], 'status':[]}
for key, list_value in label_img_dict.items():
    for idx, main_img in enumerate(list_value):
        for jdx, comp_image in enumerate(list_value):
            if idx!=jdx:
                simm_pair['main_image'].append(main_img)
                simm_pair['main_label_name'].append(_dataset.classes[key])
                simm_pair['main_label_idx'].append(key)
                simm_pair['comp_image'].append(comp_image)
                simm_pair['comp_label_name'].append(_dataset.classes[key])
                simm_pair['comp_label_idx'].append(key)
                simm_pair['label'] = int(key != key)
                simm_pair['status'] = 'similar'
simm_df = pd.DataFrame(simm_pair) 



In [70]:
simm_df

Unnamed: 0,main_image,main_label_idx,main_label_name,comp_image,comp_label_idx,comp_label_name,label,status
0,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/2.pgm,0,s1,0,similar
1,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/3.pgm,0,s1,0,similar
2,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/4.pgm,0,s1,0,similar
3,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/6.pgm,0,s1,0,similar
4,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/8.pgm,0,s1,0,similar
...,...,...,...,...,...,...,...,...
1675,../dataset/train/s9/9.pgm,39,s9,../dataset/train/s9/10.pgm,39,s9,0,similar
1676,../dataset/train/s9/9.pgm,39,s9,../dataset/train/s9/3.pgm,39,s9,0,similar
1677,../dataset/train/s9/9.pgm,39,s9,../dataset/train/s9/4.pgm,39,s9,0,similar
1678,../dataset/train/s9/9.pgm,39,s9,../dataset/train/s9/7.pgm,39,s9,0,similar


In [91]:
from tqdm import tqdm, trange

diff_pair = {'main_image': [], 'main_label_idx':[], 'main_label_name': [], 
             'comp_image': [], 'comp_label_idx':[], 'comp_label_name': [], 
             'label': [], 'status':[]}

for main_key, main_list_value in tqdm(label_img_dict.items()):
    for diff_key, diff_list_value in label_img_dict.items():
        if main_key!=diff_key:
            for idx, main_img in enumerate(main_list_value):
                for jdx, comp_image in enumerate(diff_list_value):
                    diff_pair['main_image'].append(main_img)
                    diff_pair['main_label_name'].append(_dataset.classes[main_key])
                    diff_pair['main_label_idx'].append(main_key)
                    diff_pair['comp_image'].append(comp_image)
                    diff_pair['comp_label_name'].append(_dataset.classes[diff_key])
                    diff_pair['comp_label_idx'].append(diff_key)
                    diff_pair['label'] = int(main_key != diff_key)
                    diff_pair['status'] = 'different'
diff_df = pd.DataFrame(diff_pair)


100%|██████████| 40/40 [00:00<00:00, 525.21it/s]


In [92]:
diff_df[diff_df['main_label_name'] == 's1']

Unnamed: 0,main_image,main_label_idx,main_label_name,comp_image,comp_label_idx,comp_label_name,label,status
0,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s10/1.pgm,1,s10,1,different
1,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s10/10.pgm,1,s10,1,different
2,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s10/2.pgm,1,s10,1,different
3,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s10/4.pgm,1,s10,1,different
4,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s10/5.pgm,1,s10,1,different
...,...,...,...,...,...,...,...,...
1906,../dataset/train/s1/9.pgm,0,s1,../dataset/train/s9/3.pgm,39,s9,1,different
1907,../dataset/train/s1/9.pgm,0,s1,../dataset/train/s9/4.pgm,39,s9,1,different
1908,../dataset/train/s1/9.pgm,0,s1,../dataset/train/s9/7.pgm,39,s9,1,different
1909,../dataset/train/s1/9.pgm,0,s1,../dataset/train/s9/8.pgm,39,s9,1,different


In [103]:
diff_df_list = []
for idx, name in enumerate(_dataset.classes):
    label_name = _dataset.classes[idx]
    simm_df_by_idx = simm_df[simm_df['main_label_name'] == label_name]
    len_simm_idx = len(simm_df_by_idx)

    label_name = _dataset.classes[idx]
    diff_df_by_idx = diff_df[diff_df['main_label_name'] == label_name]
    len_diff_idx = len(diff_df_by_idx)

    balance_ratio = len_simm_idx / len_diff_idx
    diff_df_ratio_idx = diff_df_by_idx.sample(frac=balance_ratio, random_state=1261)
    diff_df_list.append(diff_df_ratio_idx)


diff_df = pd.concat(diff_df_list)




In [104]:
diff_df

Unnamed: 0,main_image,main_label_idx,main_label_name,comp_image,comp_label_idx,comp_label_name,label,status
1670,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s5/5.pgm,35,s5,1,different
638,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s22/10.pgm,14,s22,1,different
361,../dataset/train/s1/3.pgm,0,s1,../dataset/train/s17/7.pgm,8,s17,1,different
1775,../dataset/train/s1/2.pgm,0,s1,../dataset/train/s7/7.pgm,37,s7,1,different
1427,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s37/9.pgm,30,s37,1,different
...,...,...,...,...,...,...,...,...
75738,../dataset/train/s9/7.pgm,39,s9,../dataset/train/s31/8.pgm,24,s31,1,different
75555,../dataset/train/s9/9.pgm,39,s9,../dataset/train/s28/5.pgm,20,s28,1,different
75339,../dataset/train/s9/4.pgm,39,s9,../dataset/train/s24/8.pgm,16,s24,1,different
75419,../dataset/train/s9/10.pgm,39,s9,../dataset/train/s26/3.pgm,18,s26,1,different


In [94]:
idx = 0


1911

In [105]:
main_df = pd.concat([simm_df, diff_df])

In [106]:
main_df

Unnamed: 0,main_image,main_label_idx,main_label_name,comp_image,comp_label_idx,comp_label_name,label,status
0,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/2.pgm,0,s1,0,similar
1,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/3.pgm,0,s1,0,similar
2,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/4.pgm,0,s1,0,similar
3,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/6.pgm,0,s1,0,similar
4,../dataset/train/s1/10.pgm,0,s1,../dataset/train/s1/8.pgm,0,s1,0,similar
...,...,...,...,...,...,...,...,...
75738,../dataset/train/s9/7.pgm,39,s9,../dataset/train/s31/8.pgm,24,s31,1,different
75555,../dataset/train/s9/9.pgm,39,s9,../dataset/train/s28/5.pgm,20,s28,1,different
75339,../dataset/train/s9/4.pgm,39,s9,../dataset/train/s24/8.pgm,16,s24,1,different
75419,../dataset/train/s9/10.pgm,39,s9,../dataset/train/s26/3.pgm,18,s26,1,different
