# Cosine-similarity classifier [PyTorch]
- feature extractor: ConvNet
- classifier: cosine classifier (S. Gidaris et al., 2018)

<b>Goal</b>: See if it can learn the representations effectively.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn import preprocessing

import torch
import torch.nn as nn

from torch import relu

from torch.utils.data import DataLoader, Dataset

import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [2]:
# GPU settings
device = torch.device(0)
device

device(type='cuda', index=0)

### Few-shot Learning Settings
K-shot N-ways

In [3]:
K = 10
N = 1
left_class = 7

### Get a dataset

In [4]:
data_transform = transforms.Compose([transforms.ToTensor()])

# import the `MNIST datasets`
mnist_train = dsets.MNIST(root='data',
                          train=True,
                          transform=data_transform,
                          download=True)

mnist_test = dsets.MNIST(root='data',
                          train=False,
                          transform=data_transform,
                          download=True)

# build the `DataLoader`
train_data_loader = DataLoader(mnist_train, batch_size=2**10)
test_data_loader = DataLoader(mnist_test, batch_size=mnist_test.data.shape[0])

In [5]:
# Label Encoder
label_encoder  = preprocessing.LabelEncoder()

targets = list(range(0, 10, 1))
targets.pop(left_class)
targets = np.array(targets).reshape(-1, 1)

label_encoder.fit(targets);

  return f(*args, **kwargs)


In [6]:
for x, y in train_data_loader:
    print(x); print(y); print(x.max()); print(x.shape)
    break

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.

### Build a model

In [16]:
class Model(nn.Module):
    def __init__(self, in_size=28, embedding_feature_size=2, n_classes=10):
        super().__init__()
        
        # Data properties
        in_channels = 1
        
        # Define layers
        # 1
        self.conv1 = nn.Conv2d(in_channels, 12, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(12)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 2
        self.conv2 = nn.Conv2d(12, 24, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(24)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 3
        self.conv3 = nn.Conv2d(24, 48, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(48)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 4
        self.last_embedding_layer = nn.Linear(48*3*3, embedding_feature_size)
        
        # Each `w` is a representation vector of each class.
        self.Wstar_layer = nn.Linear(embedding_feature_size, n_classes, bias=False)  # 'bias=False' allows the weights to be the pure representation vectors (not affected by the bias)
        
    def forward(self, x):
        #============================
        # Feature Extractor
        h1 = self.maxpool1(relu(self.bn1(self.conv1(x))))
        h2 = self.maxpool2(relu(self.bn2(self.conv2(h1))))
        h3 = self.maxpool3(relu(self.bn3(self.conv3(h2))))
        
        batch_size = x.shape[0]
        h3 = h3.view(batch_size, -1)  # 1
        
        self.z = self.last_embedding_layer(h3)
        norm_z = self.z / torch.norm(self.z, p=2, dim=1, keepdim=True)#.detach() # 2)
        
        # 1) flattne layer
        # 2) representation vector of `x`; Note `.detach()`: ; We don't want the l2-norm function to be involved with the gradient descent update.
        
        #============================
        # Cosine-similarity Classifier
        Wstar = self.Wstar_layer.weight.T#.detach()  # Note `.detach()`
        norm_Wstar = Wstar / torch.norm(Wstar, p=2, dim=0, keepdim=True)
        
        cosine_similarities = torch.mm(norm_z, norm_Wstar)
        # Note that `CrossEntropyLoss()` = `LogSoftmax` + `NLLLoss`
        return cosine_similarities
    

In [17]:
model = Model(embedding_feature_size=64, n_classes=9).to(device)

### Compile

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

### Train

In [19]:
# settings
n_epochs = 6

In [20]:
def leave_out_a_class(x, y, left_class):
    "leave out some class for the few-shot learning"
    indices = (y != left_class)
    return x[indices, :, :, :], y[indices]

In [21]:
train_hist = {"epochs": [], "loss_per_epoch": [], "loss": [], "test_acc": []}

for epoch in range(n_epochs):
    model.train()
    
    loss_per_epoch = 0
    iters = 0
    for x, y in train_data_loader:
        x, y = leave_out_a_class(x, y, left_class)
        y = torch.tensor(label_encoder.transform(y))
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        yhat = model(x)
        
        loss = criterion(yhat, y)
        loss.backward()
        
        optimizer.step()
        
        # data storage
        loss_per_epoch += loss
        iters += 1
        train_hist["loss"].append(loss)
        #print(round(loss.item(), 2), end=" ")
            
    train_hist["epochs"].append(epoch)
    train_hist["loss_per_epoch"].append(loss_per_epoch)
    
    # validation
    with torch.no_grad():
        model.eval()
        test_acc = 0
        for x_test, y_test in test_data_loader:
            x_test, y_test = leave_out_a_class(x_test, y_test, left_class)
            y_test = torch.tensor(label_encoder.transform(y_test))
            x_test, y_test = x_test.to(device), y_test.to(device)
            yhat = torch.argmax(model(x_test.to(device)), axis=1)
            test_acc += np.mean((yhat.to("cpu") == y_test.to("cpu")).numpy())
        train_hist["test_acc"].append(test_acc)
    
    print("epoch: {}, loss: {:0.3f}, test_acc: {:0.3f}".format(epoch, loss_per_epoch/iters, test_acc))

epoch: 0, loss: 1.572, test_acc: 0.964
epoch: 1, loss: 1.382, test_acc: 0.978
epoch: 2, loss: 1.357, test_acc: 0.983
epoch: 3, loss: 1.345, test_acc: 0.986
epoch: 4, loss: 1.336, test_acc: 0.988
epoch: 5, loss: 1.330, test_acc: 0.987


### Make `support set` and `query set`

In [22]:
indices = (mnist_test.targets == left_class)

class SupportSet(Dataset):
    def __init__(self,):
        super().__init__()
        
        support_set_X = []
        support_set_Y = []
        for c in range(0, 10):
            indices = (mnist_test.targets == c)
            support_set_X.append( mnist_test.data[indices, :, :][:N, :, :].numpy() / 255. )
            support_set_Y.append( mnist_test.targets[indices][:N].numpy() )

        self.support_set_X = torch.from_numpy(np.array(support_set_X)).view(N*10, 1, 28, 28)
        self.support_set_Y = torch.from_numpy(np.array(support_set_Y).flatten())

        self.len = self.support_set_X.shape[0]
        
    def __getitem__(self, idx):
        return self.support_set_X[idx], self.support_set_Y[idx]
    
    def __len__(self):
        return self.len

support_set = SupportSet()
support_set_data_loader = DataLoader(support_set, batch_size=2**10)

## Prototype Approach

### Obtain $w_{left-class}$
Note $w_i \in W^*$

In [25]:
with torch.no_grad():
    model.eval()
    avg_z = 0.
    count = 0
    for x, y in support_set:
        if y.item() == left_class:
            enc_yhat = model(x.view(1, 1, 28, 28).to(device, dtype=torch.float))
            avg_z += model.z
            count += 1
    avg_z /= count

In [26]:
avg_z.shape

torch.Size([1, 64])

### Re-establish $W^*$ 
- `model.Wstar_layer.weight`

In [27]:
model.Wstar_layer.weight.shape  # [in, out]

torch.Size([9, 64])

In [28]:
font_Wstar = model.Wstar_layer.weight[:left_class, :]
back_Wstar = model.Wstar_layer.weight[left_class:, :]

print(font_Wstar.shape); print(back_Wstar.shape)

torch.Size([7, 64])
torch.Size([2, 64])


In [29]:
res_Wstar = torch.cat((font_Wstar, avg_z, back_Wstar), 0).to(device)
print(res_Wstar.shape)

torch.Size([10, 64])


In [30]:
# assign the `res_Wstar` to the existing model
model.Wstar_layer.weight = nn.Parameter(res_Wstar)

### Evaluate the learned representations

`left_class` only

In [31]:
with torch.no_grad():
    model.eval()
    
    indices = (mnist_test.targets == left_class)
    
    test_acc = 0
    count = 0
    for x_test, y_test in test_data_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test, y_test = x_test[indices,:,:], y_test[indices]
        
        yhat = torch.argmax(model(x_test.to(device)), axis=1)
        
        test_acc += np.mean((yhat.to("cpu") == y_test.to("cpu")).numpy())
        count += 1
        
    test_acc /= count

In [33]:
test_acc

0.5223735408560312

over all classes

In [34]:
with torch.no_grad():
    model.eval()
    
    test_accs = []
    for c in range(10):
        indices = (mnist_test.targets == c)
    
        test_acc = 0
        count = 0
        for x_test, y_test in test_data_loader:
            x_test, y_test = x_test.to(device), y_test.to(device)
            x_test, y_test = x_test[indices,:,:], y_test[indices]

            yhat = torch.argmax(model(x_test.to(device)), axis=1)

            test_acc += np.mean((yhat.to("cpu") == y_test.to("cpu")).numpy())
            count += 1

        test_acc /= count
        test_accs.append(test_acc)
        print(f"class: {c} | acc: {round(test_acc, 3)}")

print("\n overall acc: {:0.3f}".format(np.mean(test_accs)))

class: 0 | acc: 0.991
class: 1 | acc: 0.992
class: 2 | acc: 0.99
class: 3 | acc: 0.99
class: 4 | acc: 0.976
class: 5 | acc: 0.988
class: 6 | acc: 0.983
class: 7 | acc: 0.522
class: 8 | acc: 0.99
class: 9 | acc: 0.979

 overall acc: 0.940


<b>Conclusion</b>: Given that there are 10 classes, the random-chance of getting `left_class` right is 10%. This model is not even trained but just uses the <i>prototype</i>. Hence, The result if is not so bad compared to the <i>linear classifier</i>.

## Fine-tuning Approach

In [24]:
# free all the layers except `Wstar_layer`
for name, param in model.named_parameters():
    if not name == 'Wstar_layer.weight':
        param.requires_grad = False
    print(name, f"| grad:{param.requires_grad}")

conv1.weight | grad:False
conv1.bias | grad:False
bn1.weight | grad:False
bn1.bias | grad:False
conv2.weight | grad:False
conv2.bias | grad:False
bn2.weight | grad:False
bn2.bias | grad:False
conv3.weight | grad:False
conv3.bias | grad:False
bn3.weight | grad:False
bn3.bias | grad:False
last_embedding_layer.weight | grad:False
last_embedding_layer.bias | grad:False
Wstar_layer.weight | grad:True


In [25]:
optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad], 
                              lr=5e-4)

### Fine-Tuning

In [26]:
train_hist = {"epochs": [], "loss_per_epoch": [], "loss": [], "test_acc": []}

In [27]:
n_epochs = 200

In [28]:
for epoch in range(n_epochs):
    model.train()
    
    loss_per_epoch = 0
    iters = 0
    for x, y in support_set_data_loader:
        x, y = x.to(device, dtype=torch.float), y.to(device)
        
        optimizer.zero_grad()
        
        yhat = model(x)
        
        loss = criterion(yhat, y)
        loss.backward()
        
        optimizer.step()
        
        # data storage
        loss_per_epoch += loss
        iters += 1
        train_hist["loss"].append(loss)
        #print(round(loss.item(), 2), end=" ")
            
    train_hist["epochs"].append(epoch)
    train_hist["loss_per_epoch"].append(loss_per_epoch)
    
    # validation
    if (epoch % 10) == 0:
        with torch.no_grad():
            model.eval()
            
            test_acc = 0
            for x_test, y_test in test_data_loader:
                x_test, y_test = x_test.to(device, dtype=torch.float), y_test.to(device)
                yhat = torch.argmax(model(x_test.to(device)), axis=1)
                test_acc += np.mean((yhat.to("cpu") == y_test.to("cpu")).numpy())
            train_hist["test_acc"].append(test_acc)

        print("epoch: {}, loss: {:0.3f}, test_acc: {:0.3f}".format(epoch, loss_per_epoch/iters, test_acc))

epoch: 0, loss: 1.451, test_acc: 0.940
epoch: 10, loss: 1.440, test_acc: 0.940
epoch: 20, loss: 1.432, test_acc: 0.941
epoch: 30, loss: 1.426, test_acc: 0.941
epoch: 40, loss: 1.421, test_acc: 0.942
epoch: 50, loss: 1.417, test_acc: 0.942
epoch: 60, loss: 1.414, test_acc: 0.942
epoch: 70, loss: 1.412, test_acc: 0.942
epoch: 80, loss: 1.410, test_acc: 0.942
epoch: 90, loss: 1.409, test_acc: 0.942
epoch: 100, loss: 1.408, test_acc: 0.942
epoch: 110, loss: 1.407, test_acc: 0.942
epoch: 120, loss: 1.406, test_acc: 0.942
epoch: 130, loss: 1.405, test_acc: 0.941
epoch: 140, loss: 1.405, test_acc: 0.941
epoch: 150, loss: 1.404, test_acc: 0.940
epoch: 160, loss: 1.404, test_acc: 0.939
epoch: 170, loss: 1.404, test_acc: 0.939
epoch: 180, loss: 1.403, test_acc: 0.939
epoch: 190, loss: 1.403, test_acc: 0.938


### Evaluate the learned representations

`left_class` only

In [29]:
with torch.no_grad():
    model.eval()
    
    indices = (mnist_test.targets == left_class)
    
    test_acc = 0
    count = 0
    for x_test, y_test in test_data_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test, y_test = x_test[indices,:,:], y_test[indices]
        
        yhat = torch.argmax(model(x_test.to(device)), axis=1)
        
        test_acc += np.mean((yhat.to("cpu") == y_test.to("cpu")).numpy())
        count += 1
        
    test_acc /= count

In [30]:
test_acc

0.5953307392996109

over all classes

In [31]:
with torch.no_grad():
    model.eval()
    
    test_accs = []
    for c in range(10):
        indices = (mnist_test.targets == c)
    
        test_acc = 0
        count = 0
        for x_test, y_test in test_data_loader:
            x_test, y_test = x_test.to(device), y_test.to(device)
            x_test, y_test = x_test[indices,:,:], y_test[indices]

            yhat = torch.argmax(model(x_test.to(device)), axis=1)

            test_acc += np.mean((yhat.to("cpu") == y_test.to("cpu")).numpy())
            count += 1

        test_acc /= count
        test_accs.append(test_acc)
        print(f"class: {c} | acc: {round(test_acc, 3)}")

print("\n overall acc: {:0.3f}".format(np.mean(test_accs)))

class: 0 | acc: 0.994
class: 1 | acc: 0.879
class: 2 | acc: 0.978
class: 3 | acc: 0.767
class: 4 | acc: 0.95
class: 5 | acc: 0.929
class: 6 | acc: 0.952
class: 7 | acc: 0.595
class: 8 | acc: 0.939
class: 9 | acc: 0.964

 overall acc: 0.895


<b>Conclusion</b>: The performance degraded than the prior approach.

---