In [1]:
import datasets

In [2]:
dataset = datasets.load_dataset('ethz/food101')

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00008.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

train-00001-of-00008.parquet:   0%|          | 0.00/464M [00:00<?, ?B/s]

train-00002-of-00008.parquet:   0%|          | 0.00/472M [00:00<?, ?B/s]

train-00003-of-00008.parquet:   0%|          | 0.00/464M [00:00<?, ?B/s]

train-00004-of-00008.parquet:   0%|          | 0.00/475M [00:00<?, ?B/s]

train-00005-of-00008.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

train-00006-of-00008.parquet:   0%|          | 0.00/478M [00:00<?, ?B/s]

train-00007-of-00008.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

validation-00000-of-00003.parquet:   0%|          | 0.00/423M [00:00<?, ?B/s]

validation-00001-of-00003.parquet:   0%|          | 0.00/413M [00:00<?, ?B/s]

validation-00002-of-00003.parquet:   0%|          | 0.00/426M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/75750 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25250 [00:00<?, ? examples/s]

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 75750
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 25250
    })
})

In [8]:
train_data = dataset['train']

In [None]:
from pathlib import Path
import pandas as pd
import datasets
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from PIL import ImageFile
import numpy as np

ImageFile.LOAD_TRUNCATED_IMAGES = True

def save_example(index, mode, global_path, all_data):
    try:
        item = all_data[index]
        image, label = item['image'], item['label']
        image = image.convert("RGB")
        filename = f'{index}_{mode}.png'
        image_path = global_path / filename
        image.save(image_path, format='png')
        return image_path, label, filename
    except Exception as e:
        print(f"[Error] image with index {index} is missed: {e}")
        return None

def get_data(target_dir='food101', mode='train', num_workers=8):
    all_data = datasets.load_dataset('ethz/food101')[mode]

    global_path = Path(target_dir) / 'data'
    global_path.mkdir(parents=True, exist_ok=True)

    fpath, targets, names = [], [], []

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        from functools import partial
        save_func = partial(save_example, mode=mode, global_path=global_path, all_data=all_data)
        futures = [executor.submit(save_func, i) for i in range(len(all_data))]

        for future in tqdm(futures, desc="Сохранение изображений"):
            res = future.result()
            if res is None:
                continue
            path, label, name = res
            fpath.append(path)
            targets.append(label)
            names.append(name)

    df = pd.DataFrame({'fpath': fpath, 'target': targets, 'name': names})
    df = df.sort_values(by='target')

    map_file_path = Path(target_dir) / f"{mode}_map.csv"
    df.to_csv(map_file_path, index=False)
    print(f"Downloading is done!")
    return df


In [6]:
get_data('check_dir', num_workers=20)

Сохранение изображений:  82%|████████▏ | 62340/75750 [04:15<00:51, 260.90it/s]

[Ошибка] Пропущен элемент 62336: 'utf-8' codec can't decode byte 0x83 in position 34: invalid start byte


Сохранение изображений: 100%|██████████| 75750/75750 [05:08<00:00, 245.83it/s]


Сохранено 75749 изображений в check_dir/train_map.csv


In [7]:
df = pd.read_csv('/home/valetov/EF25_NIPS/check_dir/train_map.csv')

In [None]:
def uniform_split(df, num_samples, random_state):
        split_samples = []
        split_ids = []
        for class_id in range(101):
            class_samples = df[df["target"] == class_id].sample(
                num_samples, random_state=random_state
            )
            split_samples.append(class_samples)
            split_ids.extend(class_samples.index)
        return split_samples, split_ids

In [None]:
def set_uniform_split(df, target_dir='food101', amount_of_clients=10):
    target_sizes = np.unique(df['target'], return_counts=True)[1]
    clients = []
    
    for size in target_sizes:
        for_each_client = size // amount_of_clients
        diff =  size % amount_of_clients
        
        client_dist = [[i+1] * for_each_client for i in range(10)]
        clients.append(np.concatenate(client_dist))
        if diff != 0:
            clients.append([amount_of_clients+1]*diff) if diff !=0 else 0
        
    df['clients'] = np.concatenate(clients)
    path_to_save = Path(target_dir) / f'{amount_of_clients}_clients_trin_map_file.csv'
    
    df.to_csv(path_to_save, index=False)
    

In [48]:
df = df.sort_values('target')
df = df.reset_index(drop=True)

In [54]:
def set_pathology_split(df, std, target_dir='food101', amount_of_clients=10, random_state=42):
    mean = len(df) // amount_of_clients
    rng = np.random.default_rng(seed=42)
    amount_for_each_client = [rng.integers(mean - int(mean*std), mean + int(mean*std)) for _ in range(amount_of_clients-1)]
    amount_for_each_client.append(len(df) - sum(amount_for_each_client))
    
    for i in range(100):
        df = df.sample(frac=1).reset_index(drop=True)
        
    clients = np.concatenate([[i+1] * amount_for_each_client[i] for i in range(amount_of_clients)])
    df['client'] = clients
    
    path_to_save = Path(target_dir) / f'{amount_of_clients}_clients_pathology_split_train_map_file.csv'
    df.to_csv(path_to_save, index=False)
    


In [78]:
def flexible_split(df, total_clients=10, head_classes=20, head_clients=2, random_state=42):
    rng = np.random.default_rng(random_state)
    num_classes = df['target'].nunique()
    clients = [[] for _ in range(total_clients)]

    for class_id in range(num_classes):
        class_df = df[df["target"] == class_id]
        indices = class_df.sample(frac=1, random_state=random_state).index.tolist()

        if len(indices) < total_clients * 2:
            raise ValueError(f"Not enough samples in class {class_id}: minimum {total_clients * 2} is needed")

        for i in range(total_clients):
            clients[i].extend(indices[i * 2 : (i + 1) * 2])

        remaining = indices[total_clients * 2:]

        if class_id < head_classes:
            for i, idx in enumerate(remaining):
                client_id = i % head_clients
                clients[client_id].append(idx)
        else:
            for i, idx in enumerate(remaining):
                client_id = head_clients + (i % (total_clients - head_clients))
                clients[client_id].append(idx)

    return clients


In [96]:
def assign_groups_to_df(df, clients):
    df = df.copy()
    client_column = pd.Series(index=df.index, dtype=int)

    for client_id, indices in enumerate(clients):
        client_column.loc[indices] = client_id

    df['client'] = client_column.astype(np.int64)
    return df


In [97]:
groups = flexible_split(df)  # из предыдущего шага
df_with_groups = assign_groups_to_df(df, groups)


In [98]:
df_group = df_with_groups[df_with_groups['client'] == 1.0]

In [104]:
df = pd.read_csv('/home/valetov/EF25_NIPS/food101/image_data/food101_hetero_map_file.csv')

In [105]:
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

# Ваш DataFrame
# df = pd.DataFrame({'path': [...список путей к изображениям...]})

# Инициализируем сумматоры
# Предполагаем, что все изображения — RGB (3 канала)
sum_channels   = np.zeros(3, dtype=np.float64)
sum_sq_channels = np.zeros(3, dtype=np.float64)
n_pixels_total = 0

for img_path in tqdm(df['fpath'], desc="Processing images"):
    img = Image.open(img_path).convert('RGB')
    arr = np.asarray(img, dtype=np.float32) / 255.0  # масштабируем в [0,1]
    h, w, c = arr.shape

    # Суммы по каналам
    sum_channels    += arr.sum(axis=(0,1))
    sum_sq_channels += (arr ** 2).sum(axis=(0,1))
    n_pixels_total  += h * w

# Среднее по каждому каналу
mean = sum_channels / n_pixels_total

# Дисперсия: E[X^2] − (E[X])^2
var  = sum_sq_channels / n_pixels_total - mean**2
std  = np.sqrt(var)

print("Mean (R,G,B):", mean)
print("Std  (R,G,B):", std)


Processing images: 100%|██████████| 71912/71912 [16:55<00:00, 70.81it/s]

Mean (R,G,B): [0.54936988 0.44513756 0.34351461]
Std  (R,G,B): [0.27317306 0.27592136 0.27996224]



