In [26]:
from tqdm import tqdm
import torch
import gc
import numpy as np
from datasets import load_dataset
from torchvision import transforms as T
from imblearn.over_sampling import SMOTE
from torch.utils.data import DataLoader, Dataset

In [2]:
# dataset_folder_name = r"D:\Study\M.E. IoTWT\5. Term 4\Deep Learning\Project\short_imbalanced_dataset"
dataset_folder_name = 'mnt/local/data/kalexu97/short_imbalanced_dataset'
dataset = load_dataset("imagefolder", data_dir = dataset_folder_name) #, split = ['train[:75%]', 'train[75%:]'])

In [3]:
dataset['train']

Dataset({
    features: ['image', 'label'],
    num_rows: 1816
})

In [4]:
initial_transform = T.Compose([
    T.Resize(512), 
    T.RandomHorizontalFlip(p = 0.5),
    T.RandomVerticalFlip(p = 0.5), 
    T.CenterCrop(size = (480, 480)), 
    T.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
])

In [5]:
def func_transform(examples):
    examples['pixel_values'] = [initial_transform(image.convert('L')) for image in examples['image']]
    del examples['image']
    return examples

In [6]:
resized_data = dataset.with_transform(func_transform)

In [7]:
for sample in resized_data['train']:
    print(np.array(sample['pixel_values']))
    break

[[4 5 6 ... 4 5 4]
 [5 5 6 ... 6 6 5]
 [6 7 7 ... 6 6 5]
 ...
 [6 8 8 ... 4 4 3]
 [6 7 7 ... 3 4 3]
 [5 7 7 ... 2 2 2]]


In [8]:
labels = dataset['train'].features['label'].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [13]:
def return_image_label(dataset):
    images_majority = []
    labels_majority = []
    images_minority = []
    labels_minority = []
    for sample in tqdm(dataset):
        image, label = sample['pixel_values'], sample['label']
        if label == 0:
            images_majority.append(np.array(image))
            labels_majority.append(label)
        else:
            images_minority.append(np.array(image))
            labels_minority.append(label)
        del image
        del label
    gc.collect()
    return images_majority, images_minority, labels_majority, labels_minority

def samp_strategy(y_majority, y_minority):
    up_sampling_class_size = int(len(y_majority) * 0.4)
    minority_sampling_strategy = {class_label: up_sampling_class_size for class_label in set(y_minority)}
    majority_sampling_strategy = 0.7
    return minority_sampling_strategy, majority_sampling_strategy

In [35]:
# Define transformations
final_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean = 0.5, std = 0.5)
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean = 0.5, std = 0.5)
])

In [15]:
class MyDataset(Dataset):
    def __init__(self, dataset, transform = None, oversampling = False):
        self.dataset = dataset
        self.transform = transform
        self.oversampling = oversampling

        self.total_len = len(dataset)

        if oversampling:
            self.oversample_dataset()

    def __len__(self):
        return self.total_len
    
    def __iter__(self):
        for sample in self.dataset:
            image, label = sample['pixel_values'], sample['label']
            yield image, label

    def __getitem__(self, idx):
        gc.collect()

        sample = self.dataset[idx]
        image = sample['pixel_values']
        label = sample['label']
            
        return image, label

    def oversample_dataset(self):
        X_majoirty, X_minority, y_majority, y_minority = return_image_label(self.dataset)
        gc.collect()

        X_majority_array = np.stack(X_majoirty)
        X_minority_array = np.stack(X_minority)
        X_majority_reshaped = X_majority_array.reshape(X_majority_array.shape[0], -1)
        X_minority_reshaped = X_minority_array.reshape(X_minority_array.shape[0], -1)

        gc.collect()
        minority_strategy, majority_strategy = samp_strategy(y_majority, y_minority)
        smote = SMOTE(sampling_strategy = minority_strategy)
        X_minority_resampled, y_minority_resampled = smote.fit_resample(X_minority_reshaped, y_minority)
        X_resampled = np.concatenate((X_minority_resampled, X_majority_reshaped))
        y_resampled = np.concatenate((y_minority_resampled, y_majority))
        
        X_resampled = X_resampled.reshape(X_resampled.shape[0], 480, 480)
        gc.collect()
        if self.transform:
            X_augmented = [self.transform(image) for image in X_resampled]
            self.dataset = [{'pixel_values' : X_augmented[i], 'label' : y_resampled[i]} for i in range(len(X_augmented))]
        gc.collect()

In [16]:
oversampled_dataset = MyDataset(resized_data['train'], transform = final_transform, oversampling = True)

100%|██████████| 1816/1816 [02:44<00:00, 11.02it/s]


In [None]:
for sample in dataset['test']:
    sample['image'] = test_transform(sample['image'])

In [9]:
# train_loader = DataLoader(oversampled_dataset, batch_size = 32, shuffle = True, num_workers = 2)
# test_loader = DataLoader(test_dataset, batch_size = 32, num_workers = 2)