# Arcface metric learning.

In [1]:
# Target anomaly label
target = 9
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn
from tqdm import tqdm_notebook as tqdm
import torch.optim.lr_scheduler as lr_scheduler

In [2]:
TRAIN = True

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.13, ), (0.3, ))])

trainset = torchvision.datasets.MNIST(root='./data', 
                                        train=True,
                                        download=True,
                                        transform=transform)

testset = torchvision.datasets.MNIST(root='./data', 
                                        train=False, 
                                        download=True, 
                                        transform=transform)

classes = tuple(np.linspace(0, 9, 10, dtype=np.uint8))

In [3]:
raw_y = np.array([int(y) for y in trainset.targets])
raw_x = np.array([y for y in trainset.data.numpy()])

In [4]:
idx = trainset.train_labels!=target
trainset.targets = trainset.targets[idx]
trainset.data = trainset.data[idx]



In [5]:
idx = testset.train_labels==target
testset.targets = testset.targets[idx]
testset.data = testset.data[idx]

set dataloader

In [6]:
trainloader = torch.utils.data.DataLoader(trainset,
                                            batch_size=100,
                                            shuffle=True,
                                            num_workers=0)
testloader = torch.utils.data.DataLoader(testset, 
                                            batch_size=100,
                                            shuffle=False, 
                                            num_workers=0)

In [7]:
# 正常データをサンプル
testset_norm = torchvision.datasets.MNIST(root='./data', 
                                        train=False, 
                                        download=True, 
                                        transform=transform)
idx = testset_norm.train_labels!=target
testset_norm.targets = testset_norm.targets[idx]
testset_norm.data = testset_norm.data[idx]
testloader1 = torch.utils.data.DataLoader(testset_norm, 
                                            batch_size=100,
                                            shuffle=False, 
                                            num_workers=0)


## define model

In [10]:
# pip install pretrainedmodels
import pretrainedmodels
model_name = "resnet18"
basemodel = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
basemodel = nn.Sequential(*list(basemodel.children())[1:-2])

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\kyosh/.cache\torch\hub\checkpoints\resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [90]:
# tricks from kaggle
# https://www.kaggle.com/tanulsingh077/pytorch-metric-learning-pipeline-only-images
# https://www.kaggle.com/slawekbiel/arcface-explained -- Beautiful kernel explaining ArcFace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [91]:
class mymodel(nn.Module):
    def __init__(self):
        super(mymodel, self).__init__()
        self.features = basemodel
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        if model_name == "resnet34" or model_name == "resnet18":
            num_ch = 512
        else:
            num_ch = 2048
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # 3 channel output for visualization
        self.final = ArcMarginProduct(num_ch, 9,
                                          s=35, m=0.5, easy_margin=False)
        
    def forward(self, x, label):
        # extract features
        x = self.conv1(x)
        x = self.features(x)
        x = self.avgpool(x).squeeze(2).squeeze(2)
        out = self.final(x, label) # pass to arcface layer
        
        return out
    
    def extract(self, x):
        # extract features
        x = self.conv1(x)
        x = self.features(x)
        x = self.avgpool(x).squeeze(2).squeeze(2)
        return x

In [92]:
# define model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = mymodel()
model = model.to(device)

In [93]:
# test forward
model.forward(torch.rand([1,1,32,32],).cuda(), torch.tensor([1]).cuda())

tensor([[  1.0670, -14.7954,  -2.6090,  -0.8609,   0.8688,   1.4414,  -1.0486,
           0.4771,   2.2419]], device='cuda:0', grad_fn=<MulBackward0>)

# define arcface

In [94]:
# crit
criterion = torch.nn.CrossEntropyLoss()

# optimzer nn
optimizer_nn = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)

## train script

In [95]:
def train(epoch):
    model.train()
    print('epochs {}/{} '.format(epoch+1,epochs))
    running_loss = 0.0
    running_acc = 0.0
    acc1 = 0.0
    acc2 = 0.0
    acc3 = 0.0
    t = tqdm(trainloader)
    
    for idx, (inputs,labels1) in enumerate(t):       
        # send to gpu
        inputs = inputs.to(device)
        labels1 = labels1.to(device)
        
        # set opt
        optimizer_nn.zero_grad()
        
        # run model
        outputs = model(inputs.float(), labels1)

        loss = criterion(outputs, labels1)
        loss.backward()
        
        optimizer_nn.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.data, 1)
        accuracy = (labels1.data == predicted).float().mean()
        
        if idx%100==0:
            print("loss:",loss.item())
            print("acc:", accuracy)
        

In [96]:
epochs = 3
if TRAIN:
    for epoch in range(epochs):
        train(epoch)
        torch.save(model.state_dict(), './saved_weights.pth')
else:
    model.load_state_dict(torch.load('./saved_weights.pth'))

epochs 1/3 


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


  0%|          | 0/541 [00:00<?, ?it/s]

loss: 19.669275283813477
acc: tensor(0., device='cuda:0')
loss: 0.41272032260894775
acc: tensor(0.9500, device='cuda:0')
loss: 0.06240810453891754
acc: tensor(0.9900, device='cuda:0')
loss: 0.561215341091156
acc: tensor(0.9700, device='cuda:0')
loss: 0.0029340938199311495
acc: tensor(1., device='cuda:0')
loss: 0.013818695209920406
acc: tensor(1., device='cuda:0')
epochs 2/3 


  0%|          | 0/541 [00:00<?, ?it/s]

loss: 0.0024203194770962
acc: tensor(1., device='cuda:0')
loss: 0.0016447328962385654
acc: tensor(1., device='cuda:0')
loss: 0.001489203074015677
acc: tensor(1., device='cuda:0')
loss: 0.0019049153197556734
acc: tensor(1., device='cuda:0')
loss: 0.0016667357413098216
acc: tensor(1., device='cuda:0')
loss: 0.004393452778458595
acc: tensor(1., device='cuda:0')
epochs 3/3 


  0%|          | 0/541 [00:00<?, ?it/s]

loss: 0.0013250508345663548
acc: tensor(1., device='cuda:0')
loss: 0.24397392570972443
acc: tensor(0.9900, device='cuda:0')
loss: 0.0010622605914250016
acc: tensor(1., device='cuda:0')
loss: 0.0009237535996362567
acc: tensor(1., device='cuda:0')
loss: 0.0011531260097399354
acc: tensor(1., device='cuda:0')
loss: 0.0008271292317658663
acc: tensor(1., device='cuda:0')


# Anomaly detection

In [97]:
from scipy import spatial
model.eval()
def cosin_metric(x1, x2):
    return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))

In [98]:
with torch.no_grad():
    train = trainset[0][0]
    out1 = model.extract(train.unsqueeze(0).to(device)).cpu().numpy()
    test = testset_norm[0][0]
    out2 = model.extract(test.unsqueeze(0).to(device)).cpu().numpy()
    
print("same number:", cosin_metric(out1, out2.T))

same number: [[0.22485602]]


In [99]:
with torch.no_grad():
    train = trainset[0][0]
    out1 = model.extract(train.unsqueeze(0).to(device)).cpu().numpy()
    test = testset[0][0]
    out2 = model.extract(test.unsqueeze(0).to(device)).cpu().numpy()
    
print("abnormal number:", cosin_metric(out1, out2.T))

abnormal number: [[0.10142373]]


In [100]:
torch.cuda.empty_cache()

## inference and evaluate distance.

In [101]:
t = tqdm(trainloader)  
for idx, (inputs,labels1) in enumerate(t):       
    # send to gpu
    inputs = inputs.to(device)
    labels1 = torch.zeros(inputs.size()[0]).to(device).long() #.unsqueeze(1)
    # run model
    with torch.no_grad():
        out = model.extract(inputs.float())
    if idx  == 0:
        outs = out
    else:
        outs = torch.cat((outs, out))

normals = outs.cpu().numpy()

t = tqdm(testloader1)
norm_labels = []
for idx, (inputs,labels1) in enumerate(t):       
    # send to gpu
    inputs = inputs.to(device)
    norm_labels.extend(labels1.numpy())
    labels1 = torch.zeros(inputs.size()[0]).to(device).long() #.unsqueeze(1)
    # run model
    with torch.no_grad():
        out = model.extract(inputs.float())
    if idx  == 0:
        outs = out
    else:
        outs = torch.cat((outs, out))

normals1 = outs.cpu().numpy()

t = tqdm(testloader)  
for idx, (inputs,labels1) in enumerate(t):       
    # send to gpu
    inputs = inputs.to(device)
    labels1 = torch.zeros(inputs.size()[0]).to(device).long() #.unsqueeze(1)
    # run model
    with torch.no_grad():
        out = model.extract(inputs.float())
    if idx  == 0:
        outs = out
    else:
        outs = torch.cat((outs, out))
        
abnormals = outs.cpu().numpy()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


  0%|          | 0/541 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  app.launch_new_instance()


  0%|          | 0/90 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/11 [00:00<?, ?it/s]

# Calculate cosin distance

In [102]:
from sklearn.preprocessing import normalize
# compute cosine distance with my function
with torch.no_grad():
    test = testset[2][0]
    out2 = model.extract(test.unsqueeze(0).to(device)).cpu().numpy()

# normalize features and convert to torch
out2_torch = torch.from_numpy(normalize(out2)).cuda()
normals_torch = torch.from_numpy(normalize(normals)).cuda()

# compute cosine distance
distances = 1 - torch.matmul(out2_torch, normals_torch.T).cpu().numpy().T
print("different number:", np.mean(distances))

## test with same numbers
with torch.no_grad():
    test = testset_norm[2][0]
    out2 = model.extract(test.unsqueeze(0).to(device)).cpu().numpy()

# normalize features and convert to torch
out2_torch = torch.from_numpy(normalize(out2)).cuda()
normals_torch = torch.from_numpy(normalize(normals)).cuda()

# compute cosine distance
distances = 1 - torch.matmul(out2_torch, normals_torch.T).cpu().numpy().T
print("same number:", np.mean(distances))

different number: 0.7652784
same number: 0.68906516


## Test anomaly detection

In [103]:
# normalize features and convert to torch
out2_torch = torch.from_numpy(normalize(normals1)).cuda()
normals_torch = torch.from_numpy(normalize(normals)).cuda()

# compute cosine distance
distances = 1 - torch.matmul(out2_torch, normals_torch.T).cpu().numpy().T
print("normal numbers:", np.mean(distances))

normal numbers: 0.7222915


In [104]:
np.mean(distances, axis=0)

array([0.6960141 , 0.7638719 , 0.6890651 , ..., 0.71422666, 0.7258415 ,
       0.7198604 ], dtype=float32)

In [105]:
# normalize features and convert to torch
out2_torch = torch.from_numpy(normalize(abnormals)).cuda()
normals_torch = torch.from_numpy(normalize(normals)).cuda()

# compute cosine distance
distances = 1 - torch.matmul(normals_torch, out2_torch.T).cpu().numpy().T
print("abnormal numbers:", np.mean(distances))

abnormal numbers: 0.766923


In [106]:
np.mean(distances, axis=0)

array([0.8503306 , 0.86097825, 0.5911175 , ..., 0.42061535, 0.83533645,
       0.42637256], dtype=float32)

# 推論結果を3D可視化

In [None]:
# https://github.com/egcode/pytorch-losses/blob/master/mnist-visualize-arcface6_fc7-loss.ipynb
f3d = []
lbls = []
for i in range(len(testset_norm)):
    image_tensor, label_tensor = testloader1.dataset[i]
    label_tensor = torch.tensor(testloader1.dataset[i][1])
    image_tensor = image_tensor.reshape(1,1,28,28)

    image_tensor, label_tensor = image_tensor.to(device), label_tensor.to(device)

    features3d  = model(image_tensor)
    logits = metric_fc(features3d, torch.unsqueeze(label_tensor, dim=-1))
    _, prediction = torch.max(logits.data, 1)

    f3d.append(features3d[0].cpu().detach().numpy())
    
    prediction = prediction.cpu().detach().numpy()[0]
    lbls.append(prediction)

feat3d = np.array(f3d)
print("3d features shape" + str(feat3d.shape))

lbls = np.array(lbls)
print("labels shape" + str(lbls.shape))

In [None]:
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(16,9))
ax = plt.axes(projection='3d')

for i in range(10):
    # Data for three-dimensional scattered points
    xdata = feat3d[lbls==i,2].flatten()
    ydata = feat3d[lbls==i,0].flatten()
    zdata = feat3d[lbls==i,1].flatten()
    ax.scatter3D(xdata, ydata, zdata);
ax.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],loc='center left', bbox_to_anchor=(1, 0.5))

plt.show()

# Visualize by TSNE

## 普通の分類結果

In [None]:
from sklearn.manifold import TSNE
# 特徴量抽出
feats = normals1
# ラベル情報追加

# TSNEをfit
tSNE_metrics = TSNE(n_components=2, random_state=0).fit_transform(feats)
plt.scatter(tSNE_metrics[:, 0], tSNE_metrics[:, 1], c=norm_labels)
plt.colorbar()
plt.savefig('tsne_classification.png')
plt.show()

## 異常ラベルとの比較

In [None]:
predict_y = np.concatenate([normals1, abnormals])
diflabel = np.ones_like(different)*9
test_y = np.concatenate([norm_labels, diflabel])

tSNE_metrics = TSNE(n_components=2, random_state=0).fit_transform(predict_y)
plt.scatter(tSNE_metrics[:, 0], tSNE_metrics[:, 1], c=test_y)
plt.colorbar()
plt.savefig('tsne_abnormals.png')
plt.show()

## plot roc curve

In [None]:
predict_y = np.concatenate([same, different])
samelabel = np.ones_like(same)
diflabel = np.zeros_like(different)
test_y = np.concatenate([samelabel, diflabel])

In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np

# FPR, TPR(, しきい値) を算出
fpr, tpr, thresholds = metrics.roc_curve(test_y, predict_y)

In [None]:
# ついでにAUCも
auc = metrics.auc(fpr, tpr)

# ROC曲線をプロット
plt.plot(fpr, tpr, label='ROC curve (area = %.2f)'%auc)
plt.legend()
plt.title('ROC curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.grid(True)