In [1]:
!pip install git+https://github.com/esgantivar/qmc.git@torch -q

  Building wheel for qmc (setup.py) ... [?25l[?25hdone


In [2]:
!pip install torch



In [3]:
import torch
import torch.nn.functional as F
from qmc.torch.mps.torchmps import MPS
from qmc.torch.layers import QMeasureDensityEig
from torchvision import transforms, datasets

In [55]:
class MPSLayer(torch.nn.Module):
  def __init__(self, input_dim, output_dim, bond_dim):
    super(MPSLayer, self).__init__()
    self.mps = MPS(
        input_dim=input_dim,
        output_dim=output_dim,
        bond_dim=bond_dim,
        adaptive_mode=False,
        periodic_bc=False
    )
  
  def forward(self, inputs):
    #return torch.sqrt(F.softmax(self.mps(inputs), dim=1))
    return self.mps(inputs)

In [72]:
class DMKDClassifierMPS(torch.nn.Module):
  def __init__(
      self,
      input_dim, 
      bond_dim, 
      n_output, 
      num_eig, 
      num_classes):
    super(DMKDClassifierMPS, self).__init__()
    self.mps = MPSLayer(
        input_dim=input_dim,
        output_dim=n_output,
        bond_dim=bond_dim
    )
    self.num_classes = num_classes
    self.qmd = []
    for _ in range(num_classes):
      self.qmd.append(QMeasureDensityEig(n_output, num_eig))
    
  def forward(self, inputs):
    psi_x = self.mps(inputs)
    probs = []
    for i in range(self.num_classes):
        probs.append(self.qmd[i](psi_x))
    posteriors = torch.stack(probs, dim=-1)
    posteriors = (posteriors / torch.unsqueeze(torch.sum(posteriors, dim=-1), dim=-1))
    return posteriors

In [73]:
num_train = 2000
num_test = 1000
batch_size = 100
num_epochs = 20

In [74]:
model = DMKDClassifierMPS(
    input_dim=28 ** 2,
    n_output=50,
    bond_dim=16,
    num_eig=5,
    num_classes=10
)

In [80]:
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [81]:
transform = transforms.ToTensor()
train_set = datasets.MNIST("./mnist", download=True, transform=transform)
test_set = datasets.MNIST("./mnist", download=True, transform=transform, train=False)

In [82]:
loss_list = []
accuracy_list = []

In [83]:
# Put MNIST data into dataloaders
samplers = {
    "train": torch.utils.data.SubsetRandomSampler(range(num_train)),
    "test": torch.utils.data.SubsetRandomSampler(range(num_test)),
}

loaders = {
    name: torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=samplers[name], drop_last=True
    )
    for (name, dataset) in [("train", train_set), ("test", test_set)]
}

num_batches = {
    name: total_num // batch_size
    for (name, total_num) in [("train", num_train), ("test", num_test)]
}

In [84]:
for epoch_num in range(1, num_epochs + 1):
    running_loss = 0.0
    running_acc = 0.0

    for inputs, labels in loaders["train"]:
        inputs, labels = inputs.view([batch_size, 28 ** 2]), labels.data

        # Call our MPS to get logit scores and predictions
        scores = model(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)
        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_loss += loss
            running_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_list.append(running_loss / num_batches['train'])
    accuracy_list.append(running_acc / num_batches['train'])
    print(f"### Epoch {epoch_num} ###")
    print(f"Average loss:           {running_loss / num_batches['train']:.4f}")
    print(f"Average train accuracy: {running_acc / num_batches['train']:.4f}")

    # Evaluate accuracy of MPS classifier on the test set
    with torch.no_grad():
        running_acc = 0.0

        for inputs, labels in loaders["test"]:
            inputs, labels = inputs.view([batch_size, 28 ** 2]), labels.data

            # Call our MPS to get logit scores and predictions
            scores = model(inputs)
            _, preds = torch.max(scores, 1)
            running_acc += torch.sum(preds == labels).item() / batch_size

### Epoch 1 ###
Average loss:           nan
Average train accuracy: 0.0935
### Epoch 2 ###
Average loss:           nan
Average train accuracy: 0.0955
### Epoch 3 ###
Average loss:           nan
Average train accuracy: 0.0955


KeyboardInterrupt: ignored