In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import cv2
import os
import glob

%matplotlib inline

In [2]:
data_dir = 'data/miniimagenet/'
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'train')
valid_dir = os.path.join(data_dir, 'train')

In [69]:
IMG_SIZE = 64
N_CLASSES = len(os.listdir(train_dir))
N_TASK_CLASSES = 10
INNER_EPOCHS = 1
INNER_BATCH_SIZE = 64

## Define Image Loading Functions

In [4]:
img_cache = {}

In [172]:
def preprocess_img(img, size=IMG_SIZE):
    return cv2.resize(img, (size, size))

def load_img(class_dir, img_name, img_path=None):
    if img_path:
        img_name = img_path[img_path.rfind('/')+1:]
    else:
        img_path = os.path.join(class_dir, img_name[:img_name.find('_')], img_name)
    
    # If image is already loaded into cache return it
    if img_name in img_cache:
        return img_cache[img_name]
    else:
        img = cv2.imread(img_path)
        img = preprocess_img(img, size=IMG_SIZE)
        img_cache[img_name] = img
        return img

## Building the DataFrame of Image Names

In [173]:
train_classes = np.asarray(os.listdir(train_dir))

In [174]:
cols = ['class', 'img_name']
df_train = pd.DataFrame(columns=cols)

for cat in train_classes:
    class_list = []
    cat_dir = os.listdir(os.path.join(train_dir, cat))
    for img_name in cat_dir:
        class_list.append([cat, img_name])
    tmp_df = pd.DataFrame(class_list, columns=cols)
    df_train = df_train.append(tmp_df)
    
df_train = df_train.sample(frac=1)
df_train.reset_index(inplace=True)
df_train.drop('index', inplace=True, axis=1)

## Functions for Generating Tasks and Formatting Data

In [199]:
def gen_task(data, n_classes=N_TASK_CLASSES):
    classes = np.random.choice(train_classes, size=n_classes, replace=False)
    task_indices = data['class'].map(lambda x: x in classes)
    task_data = data[task_indices]
    return task_data

def gen_batches(task_data, batch_size=INNER_BATCH_SIZE, data_dir=train_dir):
    task_data = pd.concat([task_data, pd.get_dummies(task_data['class'])], axis=1)
    while True:
        epoch_data = task_data.sample(frac=1)
        for i in range((len(epoch_data) // batch_size) - 1):
            X_names = epoch_data.iloc[i*batch_size:(i+1)*batch_size]['img_name'].values
            X = np.asarray([load_img(data_dir, img_name) for img_name in X_names])
            y = epoch_data.drop(['class', 'img_name'], axis=1).iloc[i*batch_size:(i+1)*batch_size].values
            
            yield X, y

In [200]:
task = gen_task(df_train)

In [201]:
d = gen_batches(task)

In [290]:
next(d)
len(img_cache)

6304

In [291]:
img_cache.keys()

dict_keys(['n01532829_100.JPEG', 'n02120079_9914.JPEG', 'n02113712_3316.JPEG', 'n02113712_4125.JPEG', 'n02120079_24135.JPEG', 'n02120079_782.JPEG', 'n04596742_26395.JPEG', 'n01558993_10224.JPEG', 'n01558993_10809.JPEG', 'n02120079_712.JPEG', 'n03838899_819.JPEG', 'n02795169_29207.JPEG', 'n04596742_28357.JPEG', 'n03207743_23815.JPEG', 'n02606052_2457.JPEG', 'n02795169_31985.JPEG', 'n04596742_6862.JPEG', 'n01558993_2691.JPEG', 'n03838899_35292.JPEG', 'n01558993_1853.JPEG', 'n02120079_5424.JPEG', 'n03854065_54949.JPEG', 'n04296562_52409.JPEG', 'n02795169_4002.JPEG', 'n02606052_5256.JPEG', 'n02795169_21785.JPEG', 'n04296562_13670.JPEG', 'n02606052_1030.JPEG', 'n04296562_10549.JPEG', 'n03838899_38334.JPEG', 'n03207743_20679.JPEG', 'n03854065_10639.JPEG', 'n03207743_2806.JPEG', 'n02795169_57364.JPEG', 'n03838899_32243.JPEG', 'n03207743_28762.JPEG', 'n02113712_2648.JPEG', 'n03838899_4161.JPEG', 'n02795169_13070.JPEG', 'n04296562_47733.JPEG', 'n01558993_14151.JPEG', 'n03854065_7994.JPEG', 'n02

In [None]:
task[0*20:(0+1)*20]['img_name'].values

2137