In [78]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.backends.cudnn as cudnn

import torchvision
from torchvision import transforms as T
from torchvision.datasets import CIFAR10

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np

from torchmetrics import Accuracy
from tqdm import tqdm

In [79]:
transform_train = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
])

In [80]:
train_set = CIFAR10(root='cifar10',
                    train=True,
                    download=True,
                    transform=transform_train)

Files already downloaded and verified


In [81]:
test_set = CIFAR10(root='cifar10',
                   train=False,
                   download=True,
                   transform=transform_train)

Files already downloaded and verified


In [82]:
train_set

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: cifar10
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
           )

In [83]:
train_set[12][0]

tensor([[[ 0.7501,  0.7113,  0.7113,  ...,  0.6144,  0.5368,  0.5368],
         [ 0.8276,  0.7888,  0.7888,  ...,  0.7501,  0.6725,  0.6338],
         [ 0.9051,  0.8664,  0.8858,  ...,  0.8276,  0.7694,  0.7307],
         ...,
         [-0.0447, -0.1610, -0.2580,  ..., -0.1610, -0.5875, -0.1029],
         [-0.1029, -0.1804, -0.2580,  ...,  0.0134, -0.1029, -0.1029],
         [-0.0835, -0.1029, -0.1804,  ..., -0.0253, -0.0253, -0.1029]],

        [[ 1.8101,  1.7511,  1.7511,  ...,  1.7118,  1.6724,  1.5741],
         [ 1.8691,  1.8101,  1.8101,  ...,  1.7511,  1.6921,  1.6134],
         [ 1.9085,  1.8495,  1.8691,  ...,  1.7511,  1.6724,  1.6528],
         ...,
         [ 0.7284,  0.6104,  0.5121,  ...,  0.2564, -0.1566,  0.5121],
         [ 0.6498,  0.5514,  0.4728,  ...,  0.5121,  0.4138,  0.5711],
         [ 0.6104,  0.5711,  0.4924,  ...,  0.5711,  0.6104,  0.6104]],

        [[ 2.5391,  2.4611,  2.4611,  ...,  2.4025,  2.3635,  2.3440],
         [ 2.5586,  2.5001,  2.5001,  ...,  2

In [84]:
train_set[12]

(tensor([[[ 0.7501,  0.7113,  0.7113,  ...,  0.6144,  0.5368,  0.5368],
          [ 0.8276,  0.7888,  0.7888,  ...,  0.7501,  0.6725,  0.6338],
          [ 0.9051,  0.8664,  0.8858,  ...,  0.8276,  0.7694,  0.7307],
          ...,
          [-0.0447, -0.1610, -0.2580,  ..., -0.1610, -0.5875, -0.1029],
          [-0.1029, -0.1804, -0.2580,  ...,  0.0134, -0.1029, -0.1029],
          [-0.0835, -0.1029, -0.1804,  ..., -0.0253, -0.0253, -0.1029]],
 
         [[ 1.8101,  1.7511,  1.7511,  ...,  1.7118,  1.6724,  1.5741],
          [ 1.8691,  1.8101,  1.8101,  ...,  1.7511,  1.6921,  1.6134],
          [ 1.9085,  1.8495,  1.8691,  ...,  1.7511,  1.6724,  1.6528],
          ...,
          [ 0.7284,  0.6104,  0.5121,  ...,  0.2564, -0.1566,  0.5121],
          [ 0.6498,  0.5514,  0.4728,  ...,  0.5121,  0.4138,  0.5711],
          [ 0.6104,  0.5711,  0.4924,  ...,  0.5711,  0.6104,  0.6104]],
 
         [[ 2.5391,  2.4611,  2.4611,  ...,  2.4025,  2.3635,  2.3440],
          [ 2.5586,  2.5001,

In [85]:
train_set.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [86]:
train_set.class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [87]:
test_set

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: cifar10
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
           )

In [88]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False)

In [90]:
x, y = next(iter(train_loader))
print(x.shape, y.shape)

torch.Size([64, 3, 32, 32]) torch.Size([64])
