In [1]:
import sys
import os
import torch
import numpy as np
import shutil

from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms
from PIL import Image
from PIL import ImageFile

sys.path.insert(0, '..')
from data import crop_upper_part
from model import SqueezeModelSoftmax

import pickle
import pdb

### Setup parameters

In [2]:
# path to a PyTorch state dict
MODEL_PATH = '../models/Standard-No-Test/26-no-test_epoch_10-valLoss_0.02410-valF1_0.99846'  
THRESHOLDS = "../models/thresholds_2std.pkl"
NUM_CLASSES = 26
INPUT_SHAPE = (3, 370, 400) # C x H x W
BATCH_SIZE = 64

# path to a non-annotated dataset where all images are in same folder with name: <integer_id>.jpg
DATASET_PATH = '../../data/robust_ml_challenge_testset'

# path to empty folder (it is not mandatory that it exists) - where annotated dataset will be stored
OUTPUT_ANNOTATED_PATH = '../../data/annotated_ds'

OUTPUT_CSV_PATH = '../output_no_test_focal.csv'

NUM_THREADS = 4 # number of threads to use - should be same as number of virtual CPU cores
USE_GPU = True # use CUDA related stuff

ImageFile.LOAD_TRUNCATED_IMAGES = True

### Class indices

In [3]:
class_labels = {
 'Albertsons': 0,
 'BJs': 1,
 'CVSPharmacy': 2,
 'Costco': 3,
 'FredMeyer': 4,
 'Frys': 5,
 'HEB': 6,
 'HarrisTeeter': 7,
 'HyVee': 8,
 'JewelOsco': 9,
 'KingSoopers': 10,
 'Kroger': 11,
 'Meijer': 12,
 'Other': 13,
 'Publix': 14,
 'Safeway': 15,
 'SamsClub': 16,
 'ShopRite': 17,
 'Smiths': 18,
 'StopShop': 19,
 'Target': 20,
 'Walgreens': 21,
 'Walmart': 22,
 'Wegmans': 23,
 'WholeFoodsMarket': 24,
 'WinCoFoods': 25}

class_dict = {v: k for k, v in class_labels.items()}

### Helper methods

In [4]:
def data_preprocess_transformations(input_shape, crop_perc = 0.5):
    """Preprocess object for transforming image to model input
    Args:
        input_shape: model input shape (channels x height x width)
        crop_perc: percent of how much image would be cropped from

    Returns:
        Composite of transforms objects.
    """
    
    num_channels, height, width = input_shape
    
    return transforms.Compose([
        transforms.Lambda(lambda x: crop_upper_part(np.array(x), crop_perc)),
        transforms.ToPILImage(),
        transforms.Grayscale(num_channels),
        transforms.Resize((height, width)),
        transforms.ToTensor(),
    ])

def list_input_images(images_folder):
    """
    Args:
        images_folder: Folder with input images with name template: <int_id>.jpg

    Returns:
        List of tuples: (image_file_name, image_int_id)
    """
    files = os.listdir(images_folder)
    images = []
    
    for file in files:
        
        name_components = file.split(".")
        extension = name_components[1].lower()
        
        if extension == 'jpg' or extension == 'jpeg':
            image_id = int(name_components[0])
            images.append((file, image_id))
            
    return sorted(images, key=lambda x: x[1])


def open_image(image_path):
    """
    Args:
        image_path: Path to an image.

    Returns:
        PIL Image in RGB format.
    """
    with open(image_path, 'rb') as f:
        image = Image.open(f).convert("RGB")
    return image


def predicted_store(prediction):
    """
    Args:
        prediction: Model probability output.
        
    Returns:
        Most probable store (argmax)
    """
    class_indice = np.argmax(prediction)
    return class_dict[class_indice]


def print_predicted_store_stats(predicted_stores):
    unique, counts = np.unique(predicted_stores, return_counts=True)
    stats_dict = dict(zip(unique, counts))
    for store, count in stats_dict.items():
        print("{:<16} => {}".format(store, count))

In [5]:
class TestDataset(Dataset):
    """Store recipts test dataset."""

    def __init__(self, root_dir, transform):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = list_input_images(root_dir)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.root_dir, self.images[idx][0])
        image = open_image(image_path)
        image = self.transform(image)
        return image

### Load model and make predictions (softmax)

In [6]:
model = SqueezeModelSoftmax(num_classes=NUM_CLASSES)
model_state_dict = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_state_dict)
model.eval()
torch.set_num_threads(NUM_THREADS)

if USE_GPU:
    model.cuda(0)

In [48]:
def var(tensor):
    if USE_GPU:
        tensor = tensor.cuda(0)
    return tensor

predictions = []
preprocess_transformations = data_preprocess_transformations(INPUT_SHAPE)
test_set = TestDataset(DATASET_PATH, preprocess_transformations)
loader = DataLoader(test_set,
                    batch_size=BATCH_SIZE,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=NUM_THREADS)

num_batches = len(loader)
for batch_index, test_batch in enumerate(loader):
    
    batch_input_tensors = var(test_batch)

    batch_predictions = model(batch_input_tensors).cpu().data.numpy()
    predictions.extend(batch_predictions)
    
    print('Batch {}/{}'.format(batch_index + 1, num_batches), end="\r", flush=True)

Batch 157/157

### Process predictions

In [95]:
# TODO: Heuristics 

# Load thresholds calculated on the validation set
with open(THRESHOLDS, "rb") as f:
    thr_dict = pickle.load(f)

def threshold_heuristic(predictions, thr_dict):
    """
    Per-class threshold heuristic
    """
    new_preds = []
    
    for act in predictions:
        pred_label = np.argmax(act)
        if class_dict[pred_label] != "Other":
            if np.max(act) < thr_dict[class_dict[pred_label]]:
                # Replace as if the Other is predicted because it doesn't satsify the threshold
                other_vector = np.zeros(NUM_CLASSES, dtype=np.float32)
                other_vector[class_labels["Other"]] = 1
                
                new_preds.append(other_vector)
            else:
                new_preds.append(act)
        else:
            new_preds.append(act)
            
    return new_preds

def identity_heuristic(predictions):
    return predictions

predictions_new = threshold_heuristic(predictions, thr_dict=thr_dict)
predictions_id = identity_heuristic(predictions)

In [96]:
# Lets make predictions

predicted_stores = [predicted_store(prediction) for prediction in predictions_new]
predicted_stores_id = [predicted_store(prediction) for prediction in predictions_id]

### Save and annotate outputs

In [98]:
# Save predicted stores to CSV

with open(OUTPUT_CSV_PATH, "w") as f:
    for i, store in enumerate(predicted_stores):
        f.write(store)
        if i < len(predicted_stores) - 1:
            f.write("\n")

In [97]:
# Print current statistics

print_predicted_store_stats(predicted_stores)
print()
print_predicted_store_stats(predicted_stores_id)

WholeFoodsMarket => 189
Publix           => 175
CVSPharmacy      => 262
Meijer           => 158
Other            => 5070
Walmart          => 191
HEB              => 192
WinCoFoods       => 188
Wegmans          => 191
Kroger           => 141
BJs              => 200
JewelOsco        => 450
Safeway          => 165
Costco           => 191
ShopRite         => 200
Target           => 186
FredMeyer        => 198
Smiths           => 181
SamsClub         => 193
HyVee            => 177
KingSoopers      => 177
Frys             => 187
Albertsons       => 138
HarrisTeeter     => 134
Walgreens        => 165
StopShop         => 301

WholeFoodsMarket => 202
Publix           => 175
CVSPharmacy      => 323
Meijer           => 324
Other            => 3551
Walmart          => 323
HEB              => 208
WinCoFoods       => 235
Wegmans          => 201
Kroger           => 570
BJs              => 200
JewelOsco        => 450
Safeway          => 189
Costco           => 311
ShopRite         => 207
Target       

In [None]:
# Annotate - copy images to new folder - separated by class

if not os.path.exists(OUTPUT_ANNOTATED_PATH):
    os.makedirs(OUTPUT_ANNOTATED_PATH)
    
num_stores = len(predicted_stores)
    
for index, store in enumerate(predicted_stores):
    target_folder_path = os.path.join(OUTPUT_ANNOTATED_PATH, store)
    if not os.path.exists(target_folder_path):
        os.makedirs(target_folder_path)
    
    filename = str(index) + ".jpg"
    src_file_path = os.path.join(DATASET_PATH, filename)

    shutil.copy2(src_file_path, target_folder_path)
    print('Image {}/{}'.format(index + 1, num_stores), end="\r", flush=True)