In [1]:
import tensorflow as tf
import os
import glob
import numpy as np
import csv
import random
from scipy.misc import imread,imresize

import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

from cub_fewshot.image_utils import *

In [119]:
class FewshotBirdsDataGenerator(object):

        def __init__(self, batch_size=100, episode_length=10, episode_width=5, image_dim=(244, 244, 3)):
            self.splits = {
                'train' : '/home/jason/deep-parts-model/src/cub_fewshot/splits/train_img_path_label_size_bbox_parts_split.txt',
                'test'  : '/home/jason/deep-parts-model/src/cub_fewshot/splits/test_img_path_label_size_bbox_parts_split.txt',
                'val'   : '/home/jason/deep-parts-model/src/cub_fewshot/splits/val_img_path_label_size_bbox_parts_split.txt'
            }
            self.batch_size = batch_size
            self.episode_length = episode_length
            self.episode_width = episode_width
            self.image_dim = image_dim
            self.num_classes = 200
            self._cache = {}
            self._load_data()

        def _load_data(self):
            self.train_data = self._data_dict_for_split(self.splits['train'])
            print('finished train')
            self.test_data  = self._data_dict_for_split(self.splits['test'])
            print('finished test')
            self.val_data   = self._data_dict_for_split(self.splits['val'])
            print('finished val')

        def sample_episode_batch(self, data):
            """Generates a random batch for training or validation.

            Structures each element of the batch as an 'episode'.
            Each episode contains episode_length examples and
            episode_width distinct labels.

            Args:
              data: A dictionary mapping label to list of examples.
              episode_length: Number of examples in each episode.
              episode_width: Distinct number of labels in each episode.
              batch_size: Batch size (number of episodes).

            Returns:
              A tuple (x, y) where x is a list of batches of examples
              with size episode_length and y is a list of batches of labels.
            """
            episodes_x = [[] for _ in xrange(self.episode_length)]
            episodes_y = [[] for _ in xrange(self.episode_length)]
            assert len(data) >= self.episode_width
            keys = data.keys()
            for b in xrange(self.batch_size):
                episode_labels = random.sample(keys, self.episode_width)
                remainder = self.episode_length % self.episode_width
                remainders = [0] * (self.episode_width - remainder) + [1] * remainder
                episode_x = [
                  random.sample(data[lab],
                                r + (self.episode_length - remainder) // self.episode_width)
                  for lab, r in zip(episode_labels, remainders)]
                episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
                             for i, xx in enumerate(episode_x)], [])
                random.shuffle(episode)
                # Arrange episode so that each distinct label is seen before moving to
                # 2nd showing
                episode.sort(key=lambda elem: elem[2])
                assert len(episode) == self.episode_length
                for i in xrange(self.episode_length):
                    episodes_x[i].append(episode[i][0])
                    episodes_y[i].append(episode[i][1] + b * self.episode_width)
            episodes_x = self._get_examples_for_image_configs(episodes_x) 
            episodes_y = [np.array(yy).astype('int32') for yy in episodes_y]
            return (episodes_x, episodes_y)

        def _data_dict_for_split(self, split, mode='test'):
            label_to_examples_dict = {}
            with open(split, 'r') as f:
                lines = f.readlines()
            for line in lines:
                # get x, y, bbox, and parts from line
                line = line.strip()
                line = line.split(' ')
                image_path, y, size, bbox, parts = line[0], line[1], line[2:4], line[4:8], line[8:]
                size = [int(s) for s in size]
                y, bbox, parts = int(y), [float(b) for b in bbox], [float(p) for p in parts]
                parts_x, parts_y = parts[0::2], parts[1::2]
                if y not in label_to_examples_dict:
                    label_to_examples_dict[y] = []
                # example is going to be x, p1, p2
                # instead of storing this store args
                label_to_examples_dict[y].append((image_path, size, bbox, parts_x, parts_y, mode))
            return label_to_examples_dict

        def _get_examples_for_image_configs(self, configs):
            '''
            parses the configs of dim: self.batch_size X self.episode_length X 1
            and returns a np.array of image_and_parts of dim: self.batch_size X 3 X self.episode_length X self.image_dim
            '''
            H, W, C = self.image_dim
            examples = [[None] * self.batch_size] * self.episode_length
            for i, config_batch in enumerate(configs):
                for j, c in enumerate(config_batch):
                    image_path, size, bbox, parts_x, parts_y, mode = c
                    if image_path in self._cache:
                        examples[i][j] = self._cache[image_path]
                    else:
                        image_and_parts = self._parser(image_path, size, bbox, parts_x, parts_y, mode)
                        examples[i][j] = image_and_parts
                        self._cache[image_path] = image_and_parts
            return [np.array(xx).astype('uint8') for xx in examples]

        def _parser(self, image_path, size, bbox, parts_x, parts_y, mode='test'):
            # decode the image
            image = imread(image_path) # imread(image_path)
            # get height and width of image to normalize the bounding box and part locations
            height, width = size
            # normalize bbox
            x, y, w, h = bbox
            # extract parts
            breast_x, breast_y = int(parts_x[3]), int(parts_y[3])
            crown_x, crown_y = int(parts_x[4]), int(parts_y[4])
            nape_x, nape_y = int(parts_x[9]), int(parts_y[9])
            tail_x, tail_y = int(parts_x[13]), int(parts_y[13])
            leg_x, leg_y = int(parts_x[7]), int(parts_y[7])
            beak_x, beak_y = int(parts_x[1]), int(parts_y[1])

            new_height, new_width, new_channels = self.image_dim
            # get crop for body
            bxmin, bxmax = min(tail_x, beak_x), max(tail_x, beak_x)
            bymin, bymax = min(leg_y, nape_y, breast_y), max(leg_y, nape_y, breast_y)
            bymin, bymax, bxmin, bxmax = int(bymin), int(bymax), int(bxmin), int(bxmax)
            try:
                body_crop = image[bymin:bymax, bxmin:bxmax, :]
                body_crop = imresize(body_crop, size=(new_height, new_width))
            except:
                h_size = int((0.75 * h) / 2)
                w_size = int((0.75 * w) / 2)
                print('image shape:', image.shape)
                body_crop = image[breast_y-h_size:breast_y+h_size, breast_x-w_size:breast_x+w_size, :]
                body_crop = imresize(body_crop, size=(new_height, new_width))
                print('body crop shape:', body_crop.shape)
                print('bxmin:', bxmin, 'bxmax:', bxmax, 'bymin:', bymin, 'bymax:', bymax)
                print('leg_y', leg_y, 'nape_y', nape_y)
            # get crop for head
            x_len = abs(beak_x - nape_x)
            y_len = abs(crown_x - nape_x)
            bymin, bymax = min(nape_y, crown_y), max(nape_y, crown_y) + y_len
            bxmin, bxmax = max(crown_x - x_len, 0), min(crown_x + x_len, width)
            bymin, bymax, bxmin, bxmax = int(bymin), int(bymax), int(bxmin), int(bxmax)
            
            try:
                head_crop = image[bymin:bymax, bxmin:bxmax, :]
                head_crop = imresize(head_crop, size=(new_height, new_width))
            except:
                h_size = int((0.3 * h) / 2)
                w_size = int((0.3 * w) / 2)
                
                if crown_y == 0.0 or crown_x == 0.0:
                    head_x, head_y = nape_x, nape_y
                else:
                    head_x, head_y = crown_x, crown_y
                plt.imshow(image)
                head_crop = image[head_y-h_size:head_y+h_size, head_x-w_size:head_x+w_size, :]

                print('image shape:', image.shape)
                print('head crop shape:', head_crop.shape)
                print('new height:', new_height, 'new width:', new_width)
                print('bxmin:', bxmin, 'bxmax:', bxmax, 'bymin:', bymin, 'bymax:', bymax)
                print('crown_x:', crown_x, 'x_len:', x_len)

            if mode == 'train':
                # resize the image to 256xS where S is max(largest-image-side, 244)
                # TODO: this seems semi random not sure why STN used this
                clipped_height, clipped_width = max(height, 244), max(width, 244)
                if height > width:
                    image = imresize(image, size=(clipped_height, 256))
                else:
                    image = imresize(image, size=(256, clipped_width))
                image = random_crop(image, new_height)
                image = horizontal_flip(image)
            else:
                image = central_crop(image, central_fraction=0.875)
                image = imresize(image, size=(new_height, new_width, new_channels))
            image_and_parts = np.stack([image, body_crop, head_crop], axis=0)
            return image_and_parts

In [120]:

data_generator = FewshotBirdsDataGenerator()
xs, ys = data_generator.sample_episode_batch(data_generator.train_data)

# see what the parts and images look like
#f, ax = plt.subplots(data_generator.batch_size, 3, figsize=(3*3, data_generator.batch_size*3))
'''
for i in range(data_generator.batch_size):
    ax[i, 0].imshow(xs[0][i, 0])
    ax[i, 0].set_title('original')
    ax[i, 1].imshow(xs[0][i, 1])
    ax[i, 1].set_title('breast')
    ax[i, 2].imshow(xs[0][i, 2])
    ax[i, 2].set_title('head')
plt.tight_layout()
plt.show()
'''


finished train
finished test
finished val
('image shape:', (500, 333, 3))


ValueError: tile cannot extend outside image

In [None]:
        def _parser(self, image_path, size, bbox, parts_x, parts_y, mode='test'):
            # decode the image
            image_file = tf.read_file(image_path) # imread(image_path)
            image = tf.image.decode_jpeg(image_file, channels=self.image_dim[-1])
            # get height and width of image to normalize the bounding box and part locations
            height, width = size
            # normalize bbox
            x, y, w, h = bbox
            # normalize parts
            parts_x = [max(px / width, 0) for px in parts_x]
            parts_y = [max(py / height, 0) for py in parts_y]
            # extract parts
            breast_x, breast_y = parts_x[3], parts_y[3]
            crown_x, crown_y = parts_x[4], parts_y[4]
            nape_x, nape_y = parts_x[9], parts_y[9]
            tail_x, tail_y = parts_x[13], parts_y[13]
            leg_x, leg_y = parts_x[7], parts_y[7]
            beak_x, beak_y = parts_x[1], parts_y[1]

            new_height, new_width, new_channels = self.image_dim
            # get crop for body
            bxmin, bxmax = tf.minimum(tail_x, beak_x), tf.maximum(tail_x, beak_x)
            bymin, bymax = tf.minimum(leg_y, nape_y), tf.maximum(leg_y, nape_y)
            boxes = tf.expand_dims(tf.stack([bymin, bxmin, bymax, bxmax], axis=0), 0)
            box_ind = tf.constant([0])
            # imresize, 
            body_crop = tf.image.crop_and_resize(tf.expand_dims(image, 0), boxes, box_ind, [new_height, new_width], method='bilinear', extrapolation_value=0, name=None)
            body_crop = tf.squeeze(body_crop, [0])

            # get crop for head
            x_len = tf.abs(beak_x - nape_x)
            y_len = tf.abs(crown_x - nape_x)
            bymin, bymax = tf.minimum(nape_y, crown_y), tf.maximum(nape_y, crown_y) + y_len
            bxmin, bxmax = crown_x - x_len, crown_x + x_len
            boxes = tf.expand_dims(tf.stack([bymin, bxmin, bymax, bxmax], axis=0), 0)
            head_crop = tf.image.crop_and_resize(tf.expand_dims(image, 0), boxes, box_ind, [new_height, new_width], method='bilinear', extrapolation_value=0, name=None)
            head_crop = tf.squeeze(head_crop, [0])

            if mode == 'train':
                # resize the image to 256xS where S is max(largest-image-side, 244)
                image = tf.expand_dims(image, 0)
                clipped_height, clipped_width = max(height, 244), max(width, 244)
                if height > width:
                    image = tf.image.resize_bilinear(image, [clipped_height, 256], align_corners=False)
                else:
                    image = tf.image.resize_bilinear(image, [256, clipped_width], align_corners=False)
                # TODO: get rid of this
                #image = tf.image.resize_bilinear(image, [new_height, new_width], align_corners=False)
                # TODO: ^
                image = tf.squeeze(image, [0])
                # preprocess with random crops and horizontal flipping
                image = tf.random_crop(image, size=[new_height, new_width, new_channels])
                image = tf.image.random_flip_left_right(image)
            else:
                image = tf.image.central_crop(image, central_fraction=0.875)
                image = tf.expand_dims(image, 0)
                image = tf.image.resize_bilinear(image, [new_height, new_width])
                image = tf.squeeze(image, [0])
                image = tf.cast(image, tf.uint8)
                session = tf.Session()
                plt.imshow(session.run(image))
            image_and_parts = [image, body_crop, head_crop]
            return image_and_parts