In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import time
from collections import Counter

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device Set is {device}")

Device Set is cpu


In [9]:
DATA_DIR = "/mnt/data"
BATCH_SIZE = 32
EPOCH = 10
LR = 1e-3
NUM_CLASSES = 10
MODEL_PATH = '/mnt/tmp/simple_cfar10.pth'
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
MEAN = (0.4194,0.4852,0.4465)
STD = (0.2470,0.2435,0.2616)

In [10]:
raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=T.ToTensor())
test_raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=T.ToTensor())

img, lbl = raw[0]
print(f"Train samples:{len(raw)}")
print(f"Test Samples :{len(test_raw)}")
print(f"Image Shape : {img.shape}")

100%|██████████| 170M/170M [00:42<00:00, 4.02MB/s] 
  entry = pickle.load(f, encoding="latin1")


Train samples:50000
Test Samples :10000
Image Shape : torch.Size([3, 32, 32])


In [11]:
counts = Counter(raw.targets)
for k,v in counts.items():
    print(f"{CLASSES[k]} : {v}")

frog : 5000
truck : 5000
deer : 5000
automobile : 5000
bird : 5000
horse : 5000
ship : 5000
cat : 5000
dog : 5000
airplane : 5000


In [12]:
train_tfm = T.Compose([
      T.RandomCrop(32, padding=4, padding_mode="reflect"),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    T.RandomRotation(10),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
    T.RandomErasing(p=0.15, scale=(0.02, 0.2)),
])

val_tfm = T.Compose([
        T.ToTensor(),
    T.Normalize(MEAN, STD),
])

train_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tfm)
val_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=val_tfm)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

imgs, lbls = next(iter(train_loader))

print(f"Train")


Train
