In [8]:
import os
import torch
import pickle
import numpy as np
from tqdm import tqdm
from datetime import datetime
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Subset
from collections import Counter, defaultdict
from torchvision import transforms

In [2]:
def unpickle(file):
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")

    return dict
def load_data(data_path):
    data = unpickle(data_path)
    # print(data.keys())
    _data = np.array(data["data"])
    _labels = np.array(data["labels"])
    print("data loaded.")
    return _data, _labels

In [5]:
data, labels = load_data("HASYv2")
labels = labels.squeeze(1)

data = np.transpose(data, (3, 2, 0, 1))
data = torch.from_numpy(data).float()
labels = torch.from_numpy(labels).long()

data loaded.


In [6]:
print(data.shape)
print(labels.shape)

torch.Size([168233, 3, 32, 32])
torch.Size([168233])


In [9]:
label_counts = Counter(labels.tolist())

print("原始类别数：", len(label_counts))
print("最小样本数：", min(label_counts.values()))

原始类别数： 369
最小样本数： 51


In [10]:
augment = transforms.Compose([
    transforms.RandomAffine(
        degrees=10,
        translate=(0.05, 0.05),
        scale=(0.95, 1.05)
    ),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))
])

In [11]:
label_to_indices = defaultdict(list)
for idx, y in enumerate(labels.tolist()):
    label_to_indices[y].append(idx)

In [17]:
device = data.device
aug_data = []
aug_labels = []

for label, indices in tqdm(label_to_indices.items()):
    count = len(indices)
    if count >= 1000:
        continue

    need = 1000 - count
    idx_pool = torch.tensor(indices)

    for _ in range(need):
        src_idx = idx_pool[torch.randint(0, count, (1,))].item()
        img = data[src_idx].cpu()   # 强制 CPU，避免 transforms 问题
        img_aug = augment(img)

        aug_data.append(img_aug.unsqueeze(0))
        aug_labels.append(label)

if len(aug_data) > 0:
    aug_data = torch.cat(aug_data, dim=0)
    aug_labels = torch.tensor(aug_labels, dtype=torch.long)

    data_final = torch.cat([data, aug_data], dim=0)
    labels_final = torch.cat([labels, aug_labels], dim=0)
else:
    data_final = data
    labels_final = labels

print("增强后数据形状：", data_final.shape)
print("增强后标签形状：", labels_final.shape)

100%|██████████| 369/369 [05:07<00:00,  1.20it/s]


增强后数据形状： torch.Size([398626, 3, 32, 32])
增强后标签形状： torch.Size([398626])


In [18]:
torch.save(
    {
        "data": data_final,
        "labels": labels_final
    },
    "HASYv2_balanced_500.pt"
)