# Zero-shot learning for image classification 

original data and code can be found here https://www.mpi-inf.mpg.de/departments/computer-vision-and-multimodal-computing/research/zero-shot-learning/zero-shot-learning-the-good-the-bad-and-the-ugly/)
[Akata, et al. CVPR2015]
[Xian, et al. CVPR2017]

Download prepared data of Animal with attributes from: 
https://drive.google.com/open?id=1ErU12Q2sHhB2Lb7NCQuan0K3qXP78RJj

In [1]:
import numpy as np 

In [24]:
# load prepared data 
data_dict = np.load('data_dict.npz', encoding='latin1')['data'].item()

tr_theta_x = data_dict['tr_theta_x'] # training image features extracted from deep CNN
tr_labels = data_dict['tr_labels'] # training image labels as indices matching class embeddings and names
val_theta_x = data_dict['val_theta_x']# validation image features extracted from deep CNN
val_labels = data_dict['val_labels'] # validation image labels as indices matching class embeddings and names
test_theta_x = data_dict['test_theta_x'] # test image features extracted from deep CNN
test_labels = data_dict['test_labels'] # test image labels as indices matching class embeddings and names

class_embeddings = data_dict['phi_y'] # class attributes vectors provided by the original dataset AWA
class_names = data_dict['class_name'] # class names in the same order as embeddings 
LR = 0.01

In [40]:
# print training, validation, and test class names
# note that class_embeddings and class_names 

# add your implementation   
print('train_theta_x: ', tr_theta_x)
print('shape_train_theta_x: ', tr_theta_x.shape)
print('shape_train_theta_features_x: ', len(tr_theta_x[0]))
print('tr_labels_x: ', tr_labels)
print('tr_labels_x_size: ', tr_labels.shape)
print('class names: ', class_names)
print('size of class_names: ', len(class_names))
print('class_embeddings: ', len(class_embeddings))
print('class_embeddings_size: ', class_embeddings[0].shape[0])

# print(np.unique(tr_labels))
# print(np.unique(val_labels))
# print(np.unique(test_labels))
tr_labels_names = set([class_names[pic] for pic in tr_labels])
val_labels_names = set([class_names[pic] for pic in val_labels])
test_labels_names = set([class_names[pic] for pic in test_labels])

train_theta_x:  [[3.43117459e-03 8.74107845e-02 2.52593458e-02 ... 1.81578525e-02
  8.71511232e-04 1.10101543e-02]
 [0.00000000e+00 7.31776316e-02 7.84721552e-04 ... 3.33776270e-02
  2.02990322e-04 7.60071157e-03]
 [1.23658850e-02 3.26421652e-02 0.00000000e+00 ... 7.26664598e-02
  9.52347521e-04 1.06897570e-03]
 ...
 [6.51087796e-04 4.29655909e-02 4.01792441e-03 ... 7.32434057e-03
  3.18346479e-05 1.37381313e-02]
 [4.50465981e-03 3.89642309e-02 2.56807193e-04 ... 1.87653561e-03
  0.00000000e+00 8.37735772e-04]
 [6.66480603e-03 3.43691225e-02 1.32691994e-03 ... 3.37716448e-03
  4.30432021e-04 1.70248459e-03]]
shape_train_theta_x:  (20218, 2048)
shape_train_theta_features_x:  2048
tr_labels_x:  [ 0  0  0 ... 37 37 37]
tr_labels_x_size:  (20218,)
class names:  ['antelope', 'grizzly+bear', 'killer+whale', 'beaver', 'dalmatian', 'persian+cat', 'horse', 'german+shepherd', 'blue+whale', 'siamese+cat', 'skunk', 'mole', 'tiger', 'hippopotamus', 'leopard', 'moose', 'spider+monkey', 'humpback+wha

In [19]:
print(tr_labels_names)
print(val_labels_names)
print(test_labels_names)

{'humpback+whale', 'buffalo', 'pig', 'wolf', 'zebra', 'otter', 'grizzly+bear', 'mouse', 'skunk', 'antelope', 'cow', 'siamese+cat', 'persian+cat', 'tiger', 'polar+bear', 'fox', 'killer+whale', 'german+shepherd', 'squirrel', 'lion', 'collie', 'chihuahua', 'hippopotamus', 'weasel', 'spider+monkey', 'elephant', 'rhinoceros'}
{'moose', 'leopard', 'ox', 'dalmatian', 'deer', 'rabbit', 'gorilla', 'chimpanzee', 'beaver', 'hamster', 'raccoon', 'mole', 'giant+panda'}
{'bobcat', 'bat', 'seal', 'horse', 'rat', 'blue+whale', 'walrus', 'giraffe', 'sheep', 'dolphin'}


In [52]:
np.unique(tr_labels)

array([ 0,  1,  2,  5,  7,  9, 10, 12, 13, 16, 17, 18, 21, 26, 27, 31, 32,
       34, 35, 36, 37, 41, 42, 43, 44, 45, 48])

In [89]:
def train(X,Y ,LR,EPOCHS,W=None):
    
    # add your implementation
    
    ####### use SGD to minimize SJE loss ######### 
    # Initialize W (DxE)
    # foreach epoch 
        # Shuffle training samples
        # foreach sample (xi, ytrue)
            # 1. scoretrue =  𝛳(xi) * WT * 𝜙(ytrue)
            # 2. lossmax= -1,  ymax = -1
            # foreach training label ytr  
                # 1. score =  𝛳(xi) * WT *𝜙(ytr)
                # 2. loss = Δ(ytr,ytrue) + scoretrue - score
                # 3. if loss > lossmax --> update lossmax and ymax
            # 4. if ymax ≠ ytrue --> W = W - lr * 𝛳(xi) [𝜙(ytrue) - 𝜙(ymax)]dsa
    d = X.shape[1]
    e = class_embeddings[0].shape[0]
    W = np.zeros((d, e))  # E = 85 features in the embedding #D = features of each class
    for e in range(EPOCHS):
        print("Epoch: ", e)
        s = np.arange(X.shape[0])
        s = np.random.shuffle(s)
        X = X[s]
        Y = Y[s]
        print(s)
        break
        for i in range(X.shape[0]):
            ytrue = Y[i]
            scoretrue = np.dot(np.dot(X[i],W),class_embeddings[ytrue])
            lossmax, ymax = -1, -1
            for m in list(np.unique(Y)):
                ytrain = Y[m]
                score = np.dot(np.dot(X[i] ,W), class_embeddings[ytrain])
                if ytrue == ytrain:
                    delta = 0
                else:
                    delta = 1
                loss = delta + score - scoretrue
                if loss > lossmax:
                    ymax = Y[m]
                    lossmax = loss
                    
            if ymax != ytrue:
                print(f'Updating W on Epoch {e}')
                W = W - np.dot(LR, np.outer(X[i],(class_embeddings[ytrue]-class_embeddings[ymax])))
            
    
    return W

In [70]:
def predict(x,test_class_indices,W):
    
    # add your implementation
    
    #  max_score = -1,  ymax = -1
    # foreach label in test_class_indices for i in test_class_indices: 
        # score =  𝛳(xi) * WT *𝜙(label) 
        # if score > max_score --> update max_score and ymax
    max_score, ymax = -1, -1
    for i in test_class_indices:
        score = np.dot(np.dot(x,W),class_embeddings[i])
        if score > max_score:
            max_score = score
            ymax = i
        
    return ymax

In [82]:
def evaluate(X,Y,test_class_indices,W):
    
    # add your implementation
    
    # correct_pred = 0
    # foreach sample (xi, ytrue)
        # pred_label = predict(xi,test_class_indices,W)
        # if pred_label == ytrue:
            # correct_pred += 1
    # acc = correct_pred / size of test set
    
    correct_pred = 0
    
    for i in range(X.shape[0]):
        ytrue = Y[i]
        pred_label = predict(X[i], test_class_indices, W)
        print(pred_label)
        if pred_label == ytrue:
            correct_pred += 1
            
        acc = correct_pred / Y.shape[0]
    
    return acc

In [90]:
EPOCHS = 10
W = train(tr_theta_x,  tr_labels,LR,EPOCHS,W=None)
W.shape


Epoch:  0
None


(2048, 85)

array([[-0.71406191, -1.17858475, -0.22636969, ...,  0.94671955,
         0.25391736, -1.43513811],
       [-2.85071329, -2.44628989, -0.26824691, ...,  3.04151596,
         0.98638237, -2.55605003],
       [-1.42676955, -0.81164087, -0.12849155, ...,  1.51673991,
         0.26178055, -1.03676848],
       ...,
       [-2.62547196, -3.11864153, -0.2147274 , ...,  2.39319996,
         0.95572225, -5.21240409],
       [-1.07272488, -1.35156881, -0.34268158, ...,  0.61722995,
         0.37175318, -2.59731165],
       [-0.84854171, -0.83077719, -0.07685617, ...,  0.99224313,
         0.26995132, -1.06883337]])

In [84]:
test_class_indices = np.unique(test_labels)
evaluate(test_theta_x,test_labels,test_class_indices,W)

40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
4

40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
4

40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
40
4

0.0796158220649564