In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import math

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")


In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root = './MNIST_data', train  = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, generator=torch.Generator(device))


In [None]:
from lib.layers import Residual, UnpackGrid, MultiBatchConv2d
from lib.quantumsearch import FitnessFunction, OneToManyNetwork, QuantumSearch
from lib.quantumsearch import TransitionFunction
encoder = nn.Sequential(
    MultiBatchConv2d(1, 32, 3, 1),
    nn.ReLU(),
)
search = QuantumSearch(
    transition=TransitionFunction(
        OneToManyNetwork(
            nn.Sequential(
                MultiBatchConv2d(32, 32, 3),
                nn.ReLU(),
                MultiBatchConv2d(32, 3*32,3),   # batch, 3*H, row_p, clo_p
                UnpackGrid(3) # Batch, ...,  3 * H -> Batch, ..., H, 3
            )
        ),
    ),
    fitness=FitnessFunction(
        OneToManyNetwork(
            nn.Sequential(
                MultiBatchConv2d(32, 32, 3, 1,1),
                nn.ReLU(),
                MultiBatchConv2d(32, 3, 3, 1,1),
                UnpackGrid(3) # Batch, ...,  3 * H -> Batch, ..., 1, 3
            )
        ),
    ),
    max_depth=5,
    beam_width=3,
    branching_width=3
)

decoder = nn.Sequential(
   nn.Flatten(1),
   nn.Linear(1152, 10)
)

model = nn.Sequential(encoder,
    search,
    decoder)
model.to(device)

In [5]:
learning_rate = 1e-3
lambda_l2 = 1e-5

In [None]:

# nn package also has different loss functions.
# we use cross entropy loss for our classification task
criterion = torch.nn.CrossEntropyLoss()

# we use the optim package to apply
# ADAM for our parameter updates
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2


temperature = 3.0
gamma = 0.99

with device:


    # Training
    for t in range(100):

        for batch, targets in train_loader:

            # Feed forward to get the logits
            batch, targets = batch.to(device), targets.to(device)
            y_pred = model(batch)

            # loss
            loss = criterion(y_pred, targets)

            # accuracy
            score, predicted = torch.max(y_pred, 1)
            acc = (targets == predicted).sum().float() / len(targets)

            print("[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f" % (t, loss.item(), acc))
            # display.clear_output(wait=False)

            # zero the gradients before running
            # the backward pass.
            optimizer.zero_grad()

            # Backward pass to compute the gradient
            # of loss w.r.t our learnable params.
            loss.backward()

            # # clip gradient
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

            # Update params
            optimizer.step()