# TF Cascade

In [1]:
import torchvision
import torch
import pandas as pd
import numpy as np
from torchvision import transforms
from PIL import Image
import os

In [3]:
with open('imagenet_classes.txt') as f:
    classes = [line.strip() for line in f.readlines()]

x = np.array([])
directory = os.fsencode('imagenet/')

for root, dirs, files in os.walk("imagenet/"):
    for filename in files:
        x = np.append(x, filename)
df = pd.DataFrame(data=x, columns=["images"])
df['images'][0]

'ILSVRC2012_val_00036091.JPEG'

In [10]:
def resnet_model(img):
    """
    ResNet101 for image classification on ResNet
    """
    transform = transforms.Compose([
    transforms.Resize(256),                    
    transforms.CenterCrop(224),                
    transforms.ToTensor(),                     
    transforms.Normalize(                      
    mean=[0.485, 0.456, 0.406],                
    std=[0.229, 0.224, 0.225]                  
    )])
    
    resnet = torchvision.models.resnet101(pretrained=True)
    resnet.eval()
    img_2 = Image.open('imagenet/'+img[0])
    img_t = transform(img_2)
    batch_t = torch.unsqueeze(img_t, 0)
    out = resnet(batch_t)
    _, indices = torch.sort(out, descending=True)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    p_2 = percentage.detach().numpy()
    return indices.detach().numpy()[0], p_2, p_2[indices[0][0]]


In [28]:
def inceptionv3_model(img):
    transform = transforms.Compose([
    transforms.Resize(256),                    
    transforms.CenterCrop(224),                
    transforms.ToTensor(),                     
    transforms.Normalize(                      
    mean=[0.485, 0.456, 0.406],                
    std=[0.229, 0.224, 0.225]                  
    )])
    
    resnet = torchvision.models.inception_v3(pretrained=True)
    resnet.eval()
    img_2 = Image.open('imagenet/'+img[0])
    img_t = transform(img_2)
    batch_t = torch.unsqueeze(img_t, 0)
    out = resnet(batch_t)
    _, indices = torch.sort(out, descending=True)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    p_2 = percentage.detach().numpy()
    return indices.detach().numpy()[0], p_2, p_2[indices[0][0]]

In [38]:
def cascade_predict(row):
    """
    cascade predict based on resnet/alexnet results
    """
    r_index = row[1]
    r_perc = row[2]
    r_max_prob = row[3]
    i_index = row[4]
    i_perc = row[5]
    i_max_prob= row[6]
#     print(np.isnan(i_max_prob))
    
    if np.isnan(i_max_prob):
        # didn't go to inception because resnet prediction was confident enough
        return r_index, r_perc, classes[r_index[0]]
    else:
        #choose the distribution with the higher max_prob
        if r_max_prob > i_max_prob:
            return r_index, r_perc, classes[r_index[0]]
        else:
            return i_index, i_perc, classes[i_index[0]]

    

In [30]:
#Calling the functions

df_s = df.head(7)
resnet_preds = df_s.apply(resnet_model, axis=1, result_type="expand").rename(columns={0: "resnet_indices", 1: "resnet_percentage", 2: "resnet_max_prob"}) 

# I used pandas notation here, but the value inside the join would be a WHERE sql query
# Might want to explore query optimization here? (Between where/join/apply)
inception_preds = df_s.join(resnet_preds[resnet_preds['resnet_max_prob'] < 85], how='right') \
                      .apply(inceptionv3_model, axis=1, result_type="expand").rename(columns={0: "inception_indices", 1: "inception_percentage", 2:"inception_max_prob"}) 

all_preds = df_s.join([resnet_preds, inception_preds])
all_preds 



Unnamed: 0,images,resnet_indices,resnet_percentage,resnet_max_prob,inception_indices,inception_percentage,inception_max_prob
0,ILSVRC2012_val_00036091.JPEG,"[511, 436, 581, 817, 627, 468, 479, 717, 751, ...","[2.2426984e-06, 1.2366584e-07, 1.2042776e-07, ...",93.52153,,,
1,ILSVRC2012_val_00018439.JPEG,"[74, 72, 815, 73, 77, 75, 70, 78, 119, 998, 99...","[1.6091902e-07, 3.3655735e-07, 1.6260454e-07, ...",62.949924,"[74, 815, 72, 73, 77, 989, 119, 70, 656, 492, ...","[3.2356027e-05, 2.1293085e-05, 4.3236323e-05, ...",72.439697
2,ILSVRC2012_val_00033769.JPEG,"[316, 125, 33, 113, 32, 988, 37, 34, 120, 47, ...","[2.4145097e-06, 5.8845744e-05, 6.1371406e-07, ...",44.614311,"[33, 34, 125, 32, 119, 37, 54, 36, 39, 316, 30...","[3.2593138e-07, 1.7814665e-06, 7.6569376e-07, ...",99.132767
3,ILSVRC2012_val_00031754.JPEG,"[332, 153, 154, 283, 258, 204, 259, 257, 338, ...","[5.034303e-06, 4.521411e-05, 2.657164e-05, 9.4...",63.601433,"[332, 338, 153, 154, 330, 585, 259, 257, 680, ...","[7.876797e-09, 4.725366e-08, 1.5149253e-08, 5....",99.998688
4,ILSVRC2012_val_00030846.JPEG,"[928, 924, 960, 927, 936, 415, 935, 923, 933, ...","[1.6053225e-08, 7.6835704e-08, 3.3813205e-08, ...",99.712997,,,
5,ILSVRC2012_val_00006427.JPEG,"[346, 730, 345, 690, 385, 101, 348, 342, 341, ...","[0.00058187195, 1.34180655e-05, 1.587513e-05, ...",95.050987,,,
6,ILSVRC2012_val_00025347.JPEG,"[10, 13, 15, 12, 14, 11, 19, 92, 17, 20, 86, 8...","[9.453798e-09, 3.4297395e-07, 1.7989052e-08, 1...",99.647789,,,


In [40]:
# Calling the cascading function
cascade_df = all_preds.join(all_preds.apply(cascade_predict, axis=1, result_type="expand") \
                    .rename(columns={0: "cascade_indices", 1: "cascade_percentage", 2:"cascade_prediction"}))
cascade_df


Unnamed: 0,images,resnet_indices,resnet_percentage,resnet_max_prob,inception_indices,inception_percentage,inception_max_prob,cascade_indices,cascade_percentage,cascade_prediction
0,ILSVRC2012_val_00036091.JPEG,"[511, 436, 581, 817, 627, 468, 479, 717, 751, ...","[2.2426984e-06, 1.2366584e-07, 1.2042776e-07, ...",93.52153,,,,"[511, 436, 581, 817, 627, 468, 479, 717, 751, ...","[2.2426984e-06, 1.2366584e-07, 1.2042776e-07, ...",convertible
1,ILSVRC2012_val_00018439.JPEG,"[74, 72, 815, 73, 77, 75, 70, 78, 119, 998, 99...","[1.6091902e-07, 3.3655735e-07, 1.6260454e-07, ...",62.949924,"[74, 815, 72, 73, 77, 989, 119, 70, 656, 492, ...","[3.2356027e-05, 2.1293085e-05, 4.3236323e-05, ...",72.439697,"[74, 815, 72, 73, 77, 989, 119, 70, 656, 492, ...","[3.2356027e-05, 2.1293085e-05, 4.3236323e-05, ...",garden_spider
2,ILSVRC2012_val_00033769.JPEG,"[316, 125, 33, 113, 32, 988, 37, 34, 120, 47, ...","[2.4145097e-06, 5.8845744e-05, 6.1371406e-07, ...",44.614311,"[33, 34, 125, 32, 119, 37, 54, 36, 39, 316, 30...","[3.2593138e-07, 1.7814665e-06, 7.6569376e-07, ...",99.132767,"[33, 34, 125, 32, 119, 37, 54, 36, 39, 316, 30...","[3.2593138e-07, 1.7814665e-06, 7.6569376e-07, ...",loggerhead
3,ILSVRC2012_val_00031754.JPEG,"[332, 153, 154, 283, 258, 204, 259, 257, 338, ...","[5.034303e-06, 4.521411e-05, 2.657164e-05, 9.4...",63.601433,"[332, 338, 153, 154, 330, 585, 259, 257, 680, ...","[7.876797e-09, 4.725366e-08, 1.5149253e-08, 5....",99.998688,"[332, 338, 153, 154, 330, 585, 259, 257, 680, ...","[7.876797e-09, 4.725366e-08, 1.5149253e-08, 5....",Angora
4,ILSVRC2012_val_00030846.JPEG,"[928, 924, 960, 927, 936, 415, 935, 923, 933, ...","[1.6053225e-08, 7.6835704e-08, 3.3813205e-08, ...",99.712997,,,,"[928, 924, 960, 927, 936, 415, 935, 923, 933, ...","[1.6053225e-08, 7.6835704e-08, 3.3813205e-08, ...",ice_cream
5,ILSVRC2012_val_00006427.JPEG,"[346, 730, 345, 690, 385, 101, 348, 342, 341, ...","[0.00058187195, 1.34180655e-05, 1.587513e-05, ...",95.050987,,,,"[346, 730, 345, 690, 385, 101, 348, 342, 341, ...","[0.00058187195, 1.34180655e-05, 1.587513e-05, ...",water_buffalo
6,ILSVRC2012_val_00025347.JPEG,"[10, 13, 15, 12, 14, 11, 19, 92, 17, 20, 86, 8...","[9.453798e-09, 3.4297395e-07, 1.7989052e-08, 1...",99.647789,,,,"[10, 13, 15, 12, 14, 11, 19, 92, 17, 20, 86, 8...","[9.453798e-09, 3.4297395e-07, 1.7989052e-08, 1...",brambling
