Ideia central do sampler

In [None]:
# import numpy as np

# accumulation = 512
# batch_size = n_classes = 32 # isso é o mesmo que n_classes
# accumulation_steps = accumulation // batch_size

# # Seed
# np.random.seed(42)

# label_set = set(dataset.labels)
# labels_to_indices = {label: np.where(np.array(dataset.labels) == label)[0] for label in label_set}

# # Selecionar n_classes labels aleatórias
# selected_labels = np.random.choice(list(label_set), batch_size, replace=False)

# # Selecionar accumulation_steps imagens para cada label, retornar um dict com os indices
# indices = {label: np.random.choice(labels_to_indices[label], accumulation_steps, replace=False) for label in selected_labels}

# # Organizar em um array com outros batch_size arrays, de forma que cada array tenha accumulation_steps elementos, sendo 1 de cada label
# indices_array = np.array([[indices[label][i] for label in selected_labels] for i in range(accumulation_steps)])
# len(indices_array)

In [6]:
import pandas as pd
from torch.utils.data import Dataset, Sampler, DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
import numpy as np
import sys
import os
import torch
from PIL import Image
from tqdm.notebook import tqdm

transform = Compose(
    [
    Resize((112, 112)),
    ToTensor(), 
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

In [7]:
class TripletDataset(Dataset):
    def __init__(self, images_df, transform=None, dtype=torch.bfloat16):
        self.labels = images_df['id'].values
        self.image_paths = images_df['path'].values
        self.transform = transform
        self.dtype = dtype

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        image = image.to(self.dtype)
        label = torch.tensor(self.labels[idx], dtype=torch.int16)
        
        return image, label, idx

class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, accumulation, batch_size):
        self.dataset = dataset
        self.accumulation = accumulation
        self.batch_size = batch_size
        self.accumulation_steps = accumulation // batch_size
        self.label_set = list(set(dataset.labels))
        self.labels_to_indices = {label: np.where(np.array(dataset.labels) == label)[0] for label in self.label_set}
        self.num_batches = len(dataset) // batch_size
        self.num_accumulation_batches = self.num_batches // self.accumulation_steps
        
        self.counter = 0

    def __iter__(self):
        pbar = tqdm(total=self.num_accumulation_batches, desc="Accumulated Batches")
        for _ in range(self.num_accumulation_batches):
            selected_labels = np.random.choice(self.label_set, self.batch_size, replace=False)
            indices = []
            for _ in range(self.accumulation_steps):
                batch_indices = []
                for label in selected_labels:
                    idx = np.random.choice(self.labels_to_indices[label])
                    batch_indices.append(idx)
                indices.extend(batch_indices)
                self.counter += 1
                yield batch_indices
            pbar.update(1)
        pbar.close()

    def __len__(self):
        return self.num_accumulation_batches

In [8]:
df = pd.read_csv('../data/CASIA/casia_train.csv')

batch_size = n_classes = 32
accumulation = 512

df['path'] = df['path'].apply(lambda x: '../data/CASIA/casia-faces/' + x)
dataset = TripletDataset(df, transform=transform)
sampler = BalancedBatchSampler(dataset=dataset, batch_size=batch_size, accumulation=accumulation)
dataloader = DataLoader(dataset, batch_sampler=sampler, pin_memory=False)
len(dataloader)  * accumulation, len(dataloader)

(76800, 150)

In [9]:
from tqdm.notebook import tqdm
import time

# Itera pelo dataloader, salvando a quantidade de cada classe em um dicionário
classes = {}
indices = []
counter = 0
# Contar quantos de cada classe num batch
#progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), unit='batch')
for i, batch in enumerate(dataloader):
    counter += 1
    imgs, labels, idx = batch
    if i == 0:
        print(type(imgs), type(labels), type(idx))
    indices.append(idx)
    
    for label in labels:
        if label in classes:
            classes[label.item()] += 1
        else:
            classes[label.item()] = 1
    
    #progress_bar.update(1)
    #if (i+1) % (accumulation // batch_size) == 0: break
#progress_bar.close()

counter, i

Accumulated Batches:   0%|          | 0/150 [00:00<?, ?it/s]

<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>


KeyboardInterrupt: 

In [None]:
indices = np.array(indices).flatten()
len(indices)#, indices

76800

In [None]:
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt

def denormalize(tensor, mean, std):
    if tensor.ndim == 3:  # Check if the tensor is (C, H, W)
        mean = torch.tensor(mean).view(-1, 1, 1)
        std = torch.tensor(std).view(-1, 1, 1)
        tensor = tensor * std + mean  # Apply denormalization
    elif tensor.ndim == 4:  # Check if the tensor is (B, C, H, W)
        mean = torch.tensor(mean).view(1, -1, 1, 1)
        std = torch.tensor(std).view(1, -1, 1, 1)
        tensor = tensor * std + mean  # Apply denormalization for batch of images
    tensor = torch.clamp(tensor, 0, 1)
    return tensor

def plot_images_side_by_side(dataset, indices, images_per_row=10, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    num_images = len(indices)
    num_rows = (num_images + images_per_row - 1) // images_per_row
    
    plt.figure(figsize=(images_per_row * 2, num_rows * 2))
    
    for idx, image_index in enumerate(indices, start=1):
        image, label, index = dataset[image_index]  # Assuming dataset returns a tuple (image, label)
        image = denormalize(image, mean, std)  # Denormalize the image
        if isinstance(image, torch.Tensor):
            image = image.permute(1, 2, 0)  # Change (C, H, W) to (H, W, C)
            image = image.numpy()  # Convert to numpy array
        
        plt.subplot(num_rows, images_per_row, idx)
        plt.imshow(image)
        plt.axis('off')
        plt.title(f'Index: {image_index}\nLabel: {label}', fontsize=10)  # Display index and label
    
    plt.tight_layout()
    plt.show()

In [None]:
len(indices)

76800

In [None]:
plot_images_side_by_side(dataset, indices, images_per_row=batch_size)

KeyboardInterrupt: 

Error in callback <function _draw_all_if_interactive at 0x7fdb2800c700> (for post_execute), with arguments args (),kwargs {}:


ValueError: Image size of 6400x480000 pixels is too large. It must be less than 2^16 in each direction.

ValueError: Image size of 6400x480000 pixels is too large. It must be less than 2^16 in each direction.

<Figure size 6400x480000 with 3315 Axes>

In [None]:
# Converter o dicionário para um dataframe
res_df = pd.DataFrame.from_dict(classes, orient='index', columns=['count'])

In [None]:
# Adicionar uma linha 'Total' com a soma
df_res = pd.concat([res_df.sort_index(), pd.DataFrame(res_df.sum(), columns=['Total']).T]).T

In [None]:
df_res.shape[1]-1

In [None]:
df_res

In [None]:
df_res.T