In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import cv2
import os
import glob

%matplotlib inline

In [2]:
data_dir = 'data/miniimagenet/'
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'train')
valid_dir = os.path.join(data_dir, 'train')

In [69]:
IMG_SIZE = 64
N_CLASSES = len(os.listdir(train_dir))
N_TASK_CLASSES = 10
INNER_EPOCHS = 1
INNER_BATCH_SIZE = 64

## Define Image Loading Functions

In [4]:
img_cache = {}

In [172]:
def preprocess_img(img, size=IMG_SIZE):
    return cv2.resize(img, (size, size))

def load_img(class_dir, img_name, img_path=None):
    if img_path:
        img_name = img_path[img_path.rfind('/')+1:]
    else:
        img_path = os.path.join(class_dir, img_name[:img_name.find('_')], img_name)
    
    # If image is already loaded into cache return it
    if img_name in img_cache:
        return img_cache[img_name]
    else:
        img = cv2.imread(img_path)
        img = preprocess_img(img, size=IMG_SIZE)
        img_cache[img_name] = img
        return img

## Building the DataFrame of Image Names

In [173]:
train_classes = np.asarray(os.listdir(train_dir))

In [174]:
cols = ['class', 'img_name']
df_train = pd.DataFrame(columns=cols)

for cat in train_classes:
    class_list = []
    cat_dir = os.listdir(os.path.join(train_dir, cat))
    for img_name in cat_dir:
        class_list.append([cat, img_name])
    tmp_df = pd.DataFrame(class_list, columns=cols)
    df_train = df_train.append(tmp_df)
    
df_train = df_train.sample(frac=1)
df_train.reset_index(inplace=True)
df_train.drop('index', inplace=True, axis=1)

## Functions for Generating Tasks and Formatting Data

In [188]:
def gen_task(data, n_classes=N_TASK_CLASSES):
    classes = np.random.choice(train_classes, size=n_classes, replace=False)
    task_indices = data['class'].map(lambda x: x in classes)
    task_data = data[task_indices]
    return task_data

def gen_batches(task_data, batch_size=INNER_BATCH_SIZE, data_dir=train_dir):
    task_data = pd.concat([task_data, pd.get_dummies(task_data['class'])], axis=1)
    while True:
        epoch_data = task_data.sample(frac=1)
        for i in range((len(epoch_data) // batch_size) - 1):
            X_names = epoch_data.iloc[i*batch_size:(i+1)*batch_size]['img_name'].values
            X = np.asarray([load_img(data_dir, img_name) for img_name in X_names])
            print(epoch_data.head())
            y = epoch_data.drop(['class', 'img_name'], axis=1).iloc[i*batch_size:(i+1)*batch_size].values
            
            yield X, y

In [189]:
task = gen_task(df_train)

In [190]:
d = gen_batches(task)

In [191]:
next(d) df_train.drop()

          class              img_name  n02108089  n02165456  n02606052  \
5101  n02165456  n02165456_10640.JPEG          0          1          0   
5205  n03047690   n03047690_2214.JPEG          0          0          0   
8361  n02606052   n02606052_5916.JPEG          0          0          1   
4213  n04296562  n04296562_18597.JPEG          0          0          0   
9390  n04296562  n04296562_13790.JPEG          0          0          0   

      n03047690  n03062245  n04251144  n04296562  n04515003  n09246464  \
5101          0          0          0          0          0          0   
5205          1          0          0          0          0          0   
8361          0          0          0          0          0          0   
4213          0          0          0          1          0          0   
9390          0          0          0          1          0          0   

      n13133613  
5101          0  
5205          0  
8361          0  
4213          0  
9390          0  


KeyError: "[('class', 'img_name')] not found in axis"

In [None]:
task[0*20:(0+1)*20]['img_name'].values