In [247]:
from dl import Module, Variable
from dl.modules import Convolution, ReLU, Linear, Flatten, MaxPool
from dl.functions import cross_entropy_loss

# Downloading CIFAR-10
import numpy as np
import os
import requests
import tarfile
import pickle

# Training
from dl.data import train_val_split, iterate_batches
from dl.optimizers import SGD

## Download and extract CIFAR-10.

In [248]:
root = './data'
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename = 'cifar-10-python.tar.gz'
archive_path = os.path.join(root, filename)
extract_path = os.path.join(root, 'cifar-10-batches-py')

os.makedirs(root, exist_ok=True)

# Download compressed file containing dataset.
if not os.path.exists(archive_path):
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(archive_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

# Extract dataset from compressed file.
if not os.path.exists(extract_path):
    with tarfile.open(archive_path, 'r:gz') as tar:
        tar.extractall(path=os.path.dirname(extract_path))

print(f"\nCIFAR-10 is ready at: {extract_path}")


CIFAR-10 is ready at: ./data/cifar-10-batches-py


## Load CIFAR-10 into ram.

In [249]:
def load_cifar_batch(batch_path):
    with open(batch_path, 'rb') as fo:
        batch = pickle.load(fo, encoding='bytes')
        X = batch[b'data'] # shape (10000, 3072)
        X = X.reshape(-1, 3, 32, 32).astype(np.float32)
        y = np.array(batch[b'labels']) # list of 10000 ints

    return X, y

# Iterate over all 5 batch files.
xs = []
ys = []
for i in range(1, 6): 
    batch_path = os.path.join(extract_path, f'data_batch_{i}')
    with open(batch_path, 'rb') as fo:
        X, y = load_cifar_batch(batch_path)

    xs.append(X)
    ys.append(y)

X_train_full = np.concatenate(xs)  # shape (50000, 3, 32, 32)
y_train_full = np.concatenate(ys)
X_test, y_test = load_cifar_batch(os.path.join(extract_path, 'test_batch'))

# Normalize
X_train_full = (X_train_full / 255.0 - 0.5) / 0.5  # normalize to [-1, 1]
X_test = (X_test / 255.0 - 0.5) / 0.5

In [250]:
# In order to use these data in DLminibox, they must be converted to type Variable.
X_train_full = Variable(X_train_full)
y_train_full = Variable(y_train_full)
X_test = Variable(X_test)
y_test = Variable(y_test)

# Set aside a validation set.
X_train, y_train, X_val, y_val = train_val_split(X_train_full, y_train_full, ratio=0.1, seed=42)

## Define the CNN architecture.

In [251]:
class CNN(Module):
    def __init__(self):
        super().__init__()
        
        # Conv Block 1
        self.conv1 = Convolution(C_in=3, C_out=32, K=3, stride=1, padding=1)
        self.relu1 = ReLU()
        self.pool1 = MaxPool(K=2, stride=2)  # 32x32 → 16x16

        # Conv Block 2
        self.conv2 = Convolution(C_in=32, C_out=64, K=3, stride=1, padding=1)
        self.relu2 = ReLU()
        self.pool2 = MaxPool(K=2, stride=2)  # 16x16 → 8x8

        # Conv Block 3
        self.conv3 = Convolution(C_in=64, C_out=128, K=3, stride=1, padding=1)
        self.relu3 = ReLU()
        self.pool3 = MaxPool(K=2, stride=2)  # 8x8 → 4x4

        self.flat = Flatten()
        self.fc1 = Linear(128 * 4 * 4, 256)
        self.relu4 = ReLU()
        self.fc2 = Linear(256, 10)  # CIFAR-10 output

    def forward(self, X):
        X = self.pool1(self.relu1(self.conv1(X)))
        X = self.pool2(self.relu2(self.conv2(X)))
        X = self.pool3(self.relu3(self.conv3(X)))
        X = self.flat(X)
        X = self.relu4(self.fc1(X))
        X = self.fc2(X)
        return X

In [252]:
model = CNN()
optimizer = SGD(model.parameters(), 0.001)

In [253]:
for _ in range(5):
    for X_batch, y_batch in iterate_batches(X_train, y_train, batch_size=64):
        
        # Compute features and loss.
        features = model(X_batch)
        loss = cross_entropy_loss(features, y_batch)

        print(loss.data)

        # Update model parameters.
        optimizer.clear_grad()
        loss.backward()
        optimizer.update_parameters()

3.065007252117097
2.6299331151636527
2.726307767589162
2.6524641533090323
2.82669218808011
2.4197892874425375
2.547660527415288
2.506274942636041
2.3685608983120323
2.4320421252213635
2.393568447632585
2.250594836128891
2.2935146950818632
2.472582236392525
2.3722568770988413
2.2425362768376402
2.3043211228080343
2.3653298346909892
2.4446872080770325
2.2916556053361377
2.3951466452861627
2.2178302929463483
2.299535094947489
2.2322437576905187
2.258606662211438
2.3775859567532
2.3240808447698695
2.263078669984167
2.3132220762176825
2.338663756097699
2.2728887424210877
2.248596752970936
2.3441862320631452
2.2598310086634372
2.2526034470136986
2.2387005788294414
2.2217775343595108
2.2967181599988695
2.2102436624744026
2.1773688964062305
2.279470898296354
2.307198669982678
2.2020440729237367
2.290105807911301
2.2266997465955605
2.183338266186746
2.128764559179724
2.246882906155438
2.2966746760457877
2.2389232835171446
2.308285488111002
2.257297035523948
2.138447206824582
2.271681301604894
2