<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/split_cifar10(S_CIFAR10).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Download cifar10 dataset

In [2]:
! wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
! tar xvzf cifar-10-binary.tar.gz

--2022-04-28 16:38:54--  https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170052171 (162M) [application/x-gzip]
Saving to: ‘cifar-10-binary.tar.gz’


2022-04-28 16:39:01 (28.0 MB/s) - ‘cifar-10-binary.tar.gz’ saved [170052171/170052171]

cifar-10-batches-bin/
cifar-10-batches-bin/data_batch_1.bin
cifar-10-batches-bin/batches.meta.txt
cifar-10-batches-bin/data_batch_3.bin
cifar-10-batches-bin/data_batch_4.bin
cifar-10-batches-bin/test_batch.bin
cifar-10-batches-bin/readme.html
cifar-10-batches-bin/data_batch_5.bin
cifar-10-batches-bin/data_batch_2.bin


split class into folders

In [3]:
import os
import numpy as np
from skimage import io

path = 'cifar-10-batches-bin'

# Labels.
label_strings = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
collection =[]

# Split train dataset.
for i in range(1, 6):
    fpath = os.path.join(path, 'data_batch_' + str(i) + '.bin')
    raw = np.fromfile(fpath, dtype='uint8')
    collection.append(np.reshape(raw, (10000, 3073)))
records = np.concatenate(collection, axis=0)
labels = records[:, 0]
images = np.reshape(records[:, 1:], (50000, 3, 32, 32,))
# Gather different classes to corresponding files.
#os.makedirs('data/train')
#os.makedirs('data/train', mode = 0o777, exist_ok = False) 
os.makedirs('data/train', exist_ok = True)
for i in range(10):
    index = (labels == i)
    class_labels = labels[index]
    class_images = images[index]
    class_records = np.concatenate(
        [np.expand_dims(class_labels, axis=1), np.reshape(class_images, (5000, -1))],
        axis=1)
    raw_records = class_records.flatten()
    class_file_name = 'data/train/' + label_strings[i] + '.bin'
    raw_records.tofile(class_file_name)


# Split test dataset.
fpath = os.path.join(path, 'test_batch.bin')
raw = np.fromfile(fpath, dtype='uint8')
records = np.reshape(raw, (10000, 3073))
labels = records[:, 0]
images = np.reshape(records[:, 1:], (10000, 3, 32, 32,))
# Gather different classes to corresponding files.
os.makedirs('data/test', exist_ok = True)
for i in range(10):
    index = (labels == i)
    class_labels = labels[index]
    class_images = images[index]
    class_records = np.concatenate(
        [np.expand_dims(class_labels, axis=1), np.reshape(class_images, (1000, -1))],
        axis=1)
    raw_records = class_records.flatten()
    class_file_name = 'data/test/' + label_strings[i] + '.bin'
    raw_records.tofile(class_file_name)

Training script

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import models
import torchvision.transforms as transforms
import os
import argparse
import copy
import random
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def seed_everything(seed=12):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
parser = argparse.ArgumentParser(description='BalancedLSF Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--lr_schedule', default=0, type=int, help='lr scheduler')
parser.add_argument('--batch_size', default=1024, type=int, help='batch size')
parser.add_argument('--test_batch_size', default=2048, type=int, help='batch size')
parser.add_argument('--num_epoch', default=50, type=int, help='epoch number')
parser.add_argument('--num_classes', type=int, default=10, help='number classes')
args = parser.parse_args(args=[])

def train(model, trainloader, criterion, optimizer):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total