In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader



In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root = './MNIST_data', train  = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, generator=torch.Generator(device))
# train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True)

In [4]:
from lib.layers import Residual, UnpackGrid, MultiBatchConv2d
from lib.quantumsearch import FitnessFunction, OneToManyNetwork, QuantumSearch
from lib.quantumsearch import TransitionFunction

In [None]:

class ResNetBlock(nn.Module):
    """Basic redisual block."""

    def __init__(
        self,
        num_input_filters: int,
        num_output_filters: int

    ) -> None:
        super().__init__()

        self.conv_block1 = nn.Sequential(
            MultiBatchConv2d(
                in_channels = num_input_filters,
                out_channels = num_input_filters,
                kernel_size = 3,
                stride = 1,
                padding = 1,
                bias = False,
            ),
            # nn.BatchNorm2d(num_features=num_filters),

            nn.ReLU(),
        )

        self.conv_block2 = nn.Sequential(
            MultiBatchConv2d(
                in_channels = num_input_filters,
                out_channels = num_output_filters,
                kernel_size = 3,
                stride = 1,
                padding = 1,
                bias = False,
            ),
            # nn.BatchNorm2d(num_features=num_filters),
        )
        self.conv_block3 = MultiBatchConv2d(
                in_channels = num_input_filters,
                out_channels = num_output_filters,
                kernel_size = 1,
                stride = 1,
                bias = False,
            )
        self.layer_norm1 = None
        self.layer_norm2 = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.conv_block3(x)
        out = self.conv_block1(x)
        _,_,C,H,W = out.shape
        if self.layer_norm1 is None:
            self.layer_norm1 = nn.LayerNorm([None,None,None])
        out = self.layer_norm1(out)
        out = self.conv_block2(out)
        _,_,C,H,W = out.shape
        if self.layer_norm2 is None:
            self.layer_norm2 = nn.LayerNorm([None,None,None])
        out = self.layer_norm2(out)
        out += residual
        out = F.relu(out)
        return out


In [None]:
encoder = nn.Sequential(
    MultiBatchConv2d(1, 32, 3, 1),
    nn.ReLU(),
)
search = QuantumSearch(
    transition = TransitionFunction(OneToManyNetwork(
            nn.Sequential(
                ResNetBlock(num_input_filters=32, num_output_filters = 3*32),
                UnpackGrid(3) # Batch, ...,  3 * H -> Batch, ..., H, 3
            )
        ),
    ),
    fitness=FitnessFunction(
        OneToManyNetwork(
            nn.Sequential(

               ResNetBlock(num_input_filters=32, num_output_filters = 3),
               UnpackGrid(3) # Batch, ...,  3 * H -> Batch, ..., 1, 3
            )
        ),
    ),
    max_depth=1,
    beam_width=3,
    branching_width=3
)
decoder = nn.Sequential(
    nn.AvgPool2d(3),
   nn.Flatten(1),
   nn.Linear(2048, 10)
)
model = nn.Sequential(encoder,
                      search,
                      decoder)
model.to(device)

Sequential(
  (0): Sequential(
    (0): MultiBatchConv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
  (1): QuantumSearch(
    (transition): TransitionFunction(
      (one_to_many): OneToManyNetwork(
        (network): Sequential(
          (0): ResNetBlock(
            (conv_block1): Sequential(
              (0): MultiBatchConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU()
            )
            (conv_block2): Sequential(
              (0): MultiBatchConv2d(32, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
            (conv_block3): MultiBatchConv2d(32, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (1): UnpackGrid()
        )
      )
    )
    (fitness): FitnessFunction(
      (one_to_many): OneToManyNetwork(
        (network): Sequential(
          (0): ResNetBlock(
            (conv_block1): Sequential(
              (0): MultiBatchConv2d(32, 32, k

In [20]:

total_params = sum(p.numel() for p in model.parameters())
print(f" Total number of parameters: {total_params}")

 Total number of parameters: 70922


In [24]:
learning_rate = 1e-3
lambda_l2 = 1e-5
# nn package also has different loss functions.
# we use cross entropy loss for our classification task
criterion = torch.nn.CrossEntropyLoss()

# we use the optim package to apply
# ADAM for our parameter updates
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2


temperature = 3.0
gamma = 0.99

with device:


    # Training
    for t in range(100):

        for batch, targets in train_loader:

            # Feed forward to get the logits
            batch, targets = batch.to(device), targets.to(device)
            y_pred = model(batch)

            # loss
            loss = criterion(y_pred, targets)

            # accuracy
            score, predicted = torch.max(y_pred, 1)
            acc = (targets == predicted).sum().float() / len(targets)

            print("[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f" % (t, loss.item(), acc))
            # display.clear_output(wait=False)

            # zero the gradients before running
            # the backward pass.
            optimizer.zero_grad()

            # Backward pass to compute the gradient
            # of loss w.r.t our learnable params.
            loss.backward()

            # # clip gradient
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

            # Update params
            optimizer.step()

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)


In [None]:
# def hook_fn(module, input, output):
#     print(f"Input shape: {module}, {input[0].shape}")  # input is a tuple; get the shape of the first element
#     print(f"Output shape:{module}, {output.shape}")

In [None]:
# Register the hook on the first layer of conv_block1
# hook_handle = model[2][0].register_forward_hook(hook_fn)

In [None]:
# sample_batch, _ = next(iter(train_loader))  # Get a batch from the dataloader
# sample_batch = sample_batch
# output = model(sample_batch)