In [3]:
import random
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from PIL import Image
import itertools

class SiameseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.image_folder = datasets.ImageFolder(root=root_dir)
        self.transform = transform
        self.image_pairs = list(itertools.combinations_with_replacement(range(len(self.image_folder)), 2))
        self.targets = [int(self.image_folder.targets[idx1] == self.image_folder.targets[idx2]) for idx1, idx2 in self.image_pairs]

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, index):
        idx1, idx2 = self.image_pairs[index]
        img1,_ = self.image_folder[idx1]
        img2,_ = self.image_folder[idx2]
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img1, img2, self.targets[index]

  Referenced from: <2D1B8D5C-7891-3680-9CF9-F771AE880676> /opt/homebrew/Caskroom/miniconda/base/envs/oneshot-face/lib/python3.9/site-packages/torchvision/image.so
  warn(


In [28]:
import torch
from torch.utils.data import Sampler
import numpy as np
from collections import Counter

class RandomUnderSampler(Sampler):
    def __init__ (self, targets, seed=None, shuffle=False):
        self.targets = np.array(targets)
        self.class_counts = Counter(self.targets)
        self.classes = self.class_counts.keys()
        self.indices = {cls : np.where(self.targets==cls)[0] for cls in self.classes}
        self.seed = seed
        self.min_count = min(self.class_counts.values())
        self.shuffle = shuffle

    def __iter__(self):
        sampled_indices = []
        for cls, indices in self.indices.items():
            if self.seed is not None:
                np.random.seed(self.seed)
            sampled_indices.extend(np.random.choice(indices, self.min_count))
        if self.shuffle:
            np.random.shuffle(sampled_indices)
        return iter(sampled_indices)
    
    def __len__(self):
        return self.min_count * len(self.classes)
        
        

In [5]:
ds = SiameseDataset('data/raw', transform=transforms.ToTensor())

In [37]:
sampler = RandomUnderSampler(ds.targets)

In [7]:
dl = torch.utils.data.DataLoader(ds, sampler=sampler, batch_size=510980)

In [66]:
targets = []
indices = []
for idx in sampler:
    targets.append(ds.targets[idx])
    indices.append(idx)
print(np.bincount(targets))
print(indices[0])

[255490 255490]
47327957


In [1]:
for batch in dl:
    img1, img2, label = batch
    print(label)

NameError: name 'dl' is not defined

In [9]:
import numpy as np

In [10]:
class_count = np.bincount(ds.targets)

In [11]:
class_count

array([87307271,   255490])

In [12]:
from collections import Counter

In [13]:
class_count = Counter(ds.targets)

In [36]:
class_count

Counter({0: 87307271, 1: 255490})

In [70]:
class_count.keys()

dict_keys([1, 0])

In [21]:
minority_class_count = min(class_count.values())
minority_class_count

255490

In [37]:
for i in class_count.items():
    print(i)

(1, 255490)
(0, 87307271)


In [34]:
print([item[0] for item in class_count.items()])

[1, 0]


In [None]:
for cls, i

In [25]:
indices = {cls: np.where(ds.targets == cls) for cls in class_count}
indices

  indices = {cls: np.where(ds.targets == cls) for cls in class_count}


{1: (array([], dtype=int64),), 0: (array([], dtype=int64),)}

In [None]:
indices_per_class = {cls: np.where(self.labels == cls)[0] for cls in unique_classes}

In [56]:
np.unique(ds.targets)

array([0, 1])

In [7]:
len(ds.targets)

87562761

In [65]:
targets = np.array(ds.targets)
targets

array([1, 0, 0, ..., 1, 0, 1])

In [67]:
np.where(targets==0)[0]

array([       1,        2,        3, ..., 87562756, 87562757, 87562759])

In [69]:
classes = np.unique(ds.targets)
classes

array([0, 1])

In [72]:
indices_per_class = {cls: np.where(targets == cls)[0] for cls in classes}
indices_per_class

{0: array([       1,        2,        3, ..., 87562756, 87562757, 87562759]),
 1: array([       0,    13233,    26465, ..., 87562755, 87562758, 87562760])}

In [73]:
sampled_indices = []

In [75]:
min_count  = min(class_count.values())
min_count

255490

In [76]:
for cls, indices in indices_per_class.items():
    sampled_indices.extend(np.random.choice(indices,min_count))

In [78]:
len(sampled_indices)

510980