In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# Set up the transforms for the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the data to have zero mean and unit variance
])

# Load the training and test datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders for the training and test sets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
import matplotlib.pyplot as plt

In [None]:
import pickle
import numpy as np

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# Set the path to the directory containing the CIFAR10 dataset
data_dir = './data/cifar-10-batches-py/'

# Load the training data
train_data = None
train_labels = []

for batch in range(1, 6):
    file = data_dir + 'data_batch_' + str(batch)
    batch_data = unpickle(file)
    if train_data is None:
        train_data = batch_data[b'data']
    else:
        train_data = np.vstack((train_data, batch_data[b'data']))
    train_labels += batch_data[b'labels']

# Load the test data
test_data = None
test_labels = []

file = data_dir + 'test_batch'
test_batch = unpickle(file)
test_data = test_batch[b'data']
test_labels = test_batch[b'labels']

# Convert the data to the format expected by PyTorch
train_data = train_data.reshape((50000, 3, 32, 32))
train_data = np.transpose(train_data, (0, 2, 3, 1))
train_labels = np.array(train_labels)

test_data = test_data.reshape((10000, 3, 32, 32))
test_data = np.transpose(test_data, (0, 2, 3, 1))
test_labels = np.array(test_labels)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

class_names = {
    0: 'airplane',
    1: 'automobile',
    2: 'bird',
    3: 'cat',
    4: 'deer',
    5: 'dog',
    6: 'frog',
    7: 'horse',
    8: 'ship',
    9: 'truck'
}

# Get the first image and label from the training set
image, label = trainset[0]

# Convert the image to a numpy array
image = np.transpose(image, (1, 2, 0))

# Display the image
plt.imshow(image)
plt.show()

# Print the class name
class_name = class_names[label]
print(class_name)
