In [None]:
from __future__ import print_function
from collections import Counter

import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchvision.models import resnet18
import time

from google.colab import drive
drive.mount('/content/drive')

METHODS = ("fc-relu", "dot", "cosine", "neg-dist", "exp-neg-dist", "inv-dist", "neglog-dist")

class SimpleTwoLayerNet(nn.Module):
    def __init__(self, method, input_shape, n_classes, n_protos, eps=1e-3, power=10):
        super(SimpleTwoLayerNet, self).__init__()
        assert method in METHODS
        self.method = method
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.n_protos = n_protos
        self.eps = eps
        self.power = power
        self.keys = nn.Parameter(torch.zeros([n_protos] + list(input_shape)))
        self.values = nn.Parameter(torch.zeros(n_protos, n_classes))

    def forward(self, inputs):
        x = torch.flatten(inputs, start_dim=1)
        keys = torch.flatten(self.keys, start_dim=1)
        values = self.values

        if self.method == "neglog-dist":
            dist = self.eps + torch.cdist(x, keys).pow(self.power)
            attn = 1 / dist
            attn = attn / attn.sum(dim=1, keepdim=True)
            logits = attn @ values
        elif self.method == "inv-dist":
            dist = self.eps + torch.cdist(x, keys).pow(self.power)
            attn = 1 / dist
            logits = attn.softmax(axis=1) @ values
        elif self.method == "exp-neg-dist":
            dist = torch.cdist(x, keys).pow(self.power)
            attn = torch.exp(-0.5 * dist)
            logits = attn.softmax(axis=1) @ values
        elif self.method == "neg-dist":
            dist = torch.cdist(x, keys).pow(self.power)
            attn = -1 * dist
            logits = attn.softmax(axis=1) @ values
        elif self.method == "dot":
            scaling = np.sqrt(1 / self.n_protos)
            attn = scaling * x @ keys.T
            logits = attn.softmax(axis=1) @ values
        elif self.method == "fc-relu":
            logits = (x @ keys.T).relu() @ values
        return logits

    def augment(self, inputs, c, idx=None, verbose=False):
        x = torch.flatten(inputs, start_dim=1)
        keys = torch.flatten(self.keys, start_dim=1)
        values = self.values
        assert x.shape[0] == 1
        assert c < self.n_classes
        assert self.method == "neglog-dist"

        dist = self.eps + torch.cdist(x, keys).pow(self.power)
        attn = 1 / dist
        attn_norm = attn.sum(dim=1, keepdim=True)
        attn = attn / attn_norm
        logits = attn @ values
        eta = (1 + self.eps * attn_norm.sum()) * (logits.max() - logits[0,c])
        newval = eta * F.one_hot(torch.tensor([c]), self.n_classes)
        if eta > 0:
            if idx is None:
                if verbose:
                    print(f"Fixing with new key-value:{eta.item():.4f}")
                with torch.no_grad():
                    self.keys = torch.nn.Parameter(torch.concat([self.keys, x], axis=0))
                    self.values = torch.nn.Parameter(torch.concat([self.values, newval], axis=0))
            else:
                if verbose:
                    print(f"Replacing {idx} key-value:{eta.item():.4f}")
                with torch.no_grad():
                    self.keys.data[idx, :] = x.detach().clone()
                    self.values.data[idx, :] = newval


def create_nnnet(X, y, method, n_protos, eps, power):
    input_shape = X[0, :].shape
    #len(train_df.label.unique())
    n_classes = int(1 + torch.max(y))

    model = SimpleTwoLayerNet(
        method=method, input_shape=input_shape, n_classes=n_classes,
        n_protos=n_protos, eps=eps, power=power)

    mean = torch.mean(X, axis=0)
    std = torch.std(X, axis=0)
    mins = torch.min(X, axis=0)[0]
    maxs = torch.max(X, axis=0)[0]
    protos_size = [n_protos] + list(mean.shape)
    keys_np = np.random.normal(
        mean, 0.1*std, size=protos_size).astype(np.float32)
    #keys_np = np.random.uniform(
    #    mins, maxs, size=protos_size).astype(np.float32)
    #    nn.init.zeros_(self.values)
    values_np = np.zeros((n_protos, n_classes)).astype(np.float32)
    with torch.no_grad():
        if method == "fc-relu":
            nn.init.kaiming_uniform_(model.keys, nonlinearity="relu")
            nn.init.kaiming_uniform_(model.values)
        elif method == "dot":
            nn.init.xavier_uniform_(model.keys)
            nn.init.xavier_uniform_(model.values)
            nn.init.zeros_(model.keys)
            nn.init.kaiming_uniform_(model.keys)
            nn.init.kaiming_uniform_(model.values)
        else:
            model.keys.data = torch.tensor(keys_np)
            model.values.data = torch.tensor(values_np)
    return model

In [None]:
seed=12345
batch_size=4
test_batch_size=1000
lr=0.001
epochs=50
dry_run=False
log_interval=3000
save_model=False
use_cuda = False
use_mps = False
method = "neglog-dist"

torch.manual_seed(seed)

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

train_kwargs = {'batch_size': batch_size, 'shuffle': True}
test_kwargs = {'batch_size': test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
dataset1 = datasets.MNIST(
    '../data', train=True, download=True, transform=train_transform)
dataset2 = datasets.MNIST(
    '../data', train=False, transform=test_transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
train_X = torch.concat([datum[0] for datum in train_loader], axis=0)
train_y = torch.concat([datum[1] for datum in train_loader], axis=0)


In [None]:
def train(model, device, train_loader, optimizer, epoch, log_interval, dry_run):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{:04d}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if dry_run:
                break

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss

def print_digit_key_counts(model):
    cnt = Counter({_: 0 for _ in range(10)})
    for i in range(0, model.keys.shape[0]):
        amax = torch.argmax(model.values[i,:]).item();
        cnt[amax] += 1
    print(cnt)


model = create_nnnet(X=train_X, y=train_y, n_protos=20, method=method, eps=1e-3, power=2)
optimizer = optim.Adam(model.parameters(), lr=lr, amsgrad=True)
criterion = torch.nn.CrossEntropyLoss()

cos_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval, dry_run)
    print_digit_key_counts(model)
    test_loss = test(model, device, test_loader)
    cos_scheduler.step()

In [None]:

#torch.save(model.state_dict(), '/content/drive/My Drive/inverse-distance-mnist-plot-20protos.pth')

In [None]:
cnt = Counter()
key_to_digit = []
for i in range(0, model.keys.shape[0]):
    amax = torch.argmax(model.values[i,:]).item();
    key_to_digit.append((i, amax))
    cnt[amax] += 1
print(cnt)

fig, axes = plt.subplots(nrows=2, ncols=10, figsize=(7,7/5.), dpi=150);
sorted_keys = sorted(key_to_digit, key=lambda kd: kd[1])
for plotix, (k, d) in enumerate(sorted_keys):
    axes[plotix // 10, plotix % 10].imshow(model.keys[k,0, :, :].detach().numpy())
    #axes[plotix // 10, plotix % 10].set_title(f"{d}")
    axes[plotix // 10, plotix % 10].set_xticks([])
    axes[plotix // 10, plotix % 10].set_yticks([])
#plt.tight_layout();
plt.savefig(f'/content/drive/My Drive/inverse-distance-mnist-{method}-plot-20protos.pdf')