In [None]:
import numpy as np
import os
import sys
import cv2
import shutil

In [None]:
class Image_util:
    def __init__(self, data_dir, biz_label_file_name, photo_biz_file_name):
        self.batch_index = 0
        image_paths = [os.path.join(data_dir,i) for i in os.listdir(data_dir) if i.endswith('.jpg') and not i.startswith("._")]
        self.images = []
        self.labels = []
        one_hot = self.read_csv_one_hot(biz_label_file_name)
        photo_biz = self.photo_to_biz_id(photo_biz_file_name)
        
        for path in image_paths[:10]:
            img = cv2.imread(path)
            if img == None:
                continue
            photo_id = os.path.basename(path).split(".")[0]
            self.labels.append(one_hot[photo_biz[photo_id]])
            img = cv2.resize(img,(299,299),interpolation = cv2.INTER_AREA)
            self.images.append(img)
        self.labels = np.asarray(self.labels)
        self.images = np.asarray(self.images)
        print(self.labels.shape)
        
    def next_batch(self, batch_size):
        if batch_size + self.batch_index < self.images.shape[0]:
            imgs = self.images[self.batch_index:batch_size + self.batch_index,:,:,:]
            labels = self.labels[self.batch_index:batch_size + self.batch_index,:,:,:]
            return imgs, labels
        else:
            end_len = self.images.shape[0]-self.batch_index
            start_len = batch_size - (self.images.shape[0] - end_len)
            imgs = np.concatenate((self.images[-end_len:,:,:,:],self.images[0:start_len,:,:,:]))
            labels = np.concatenate((self.labels[-end_len:,:,:,:],self.labels[0:start_len,:,:,:]))
            return imgs, labels
        
    def read_csv_one_hot(self, file_name):
        with open(file_name,"r") as f:
            lines = f.readlines()[1:]
        biz_id_to_label = {}
        for line in lines:
            try:
                biz_id_to_label[line.split(",")[0]] = np.zeros(9)
                for label in line.split(",")[1].rstrip().split(' '):
                    biz_id_to_label[line.split(",")[0]][int(label)]=1
            except:
                if not line.split(",")[1].rstrip():
                    continue
        return biz_id_to_label
    
    def photo_to_biz_id(self, file_name):
        with open(file_name,"r") as f:
            lines = f.readlines()[1:]
        photo_to_biz = {}
        for line in lines:
            photo_to_biz[line.split(",")[0]] = line.split(",")[1].rstrip() 
        return photo_to_biz

In [None]:
img_util = image_util('/home/rendaxuan/Documents/workspace/4032/train_photos', '/home/rendaxuan/Documents/workspace/4032/train.csv', '/home/rendaxuan/Documents/workspace/4032/train_photo_to_biz_ids.csv')

In [None]:
photo_id_dict = img_util.photo_to_biz_id('/home/rendaxuan/Documents/workspace/4032/train_photo_to_biz_ids.csv')

In [None]:
biz_photo_dict = img_util.read_csv_one_hot('/home/rendaxuan/Documents/workspace/4032/train.csv')

In [None]:
print(str(biz_photo_dict[photo_id_dict['160233']]))

In [None]:
images = os.listdir('./train_photos')
print(len(images))

In [None]:
images = [i for i in images if not i.startswith('._')]

In [None]:
print(len(images))

In [None]:
images_labels = {}
for i in images:
    image_name = i.split('.')[0]
    id_image = str(biz_photo_dict[photo_id_dict[image_name]])
    if id_image in images_labels:
        images_labels[id_image].append(image_name)
        continue
    else:
        images_labels[id_image] = []
        images_labels[id_image].append(image_name)

In [None]:
len(images_labels)

In [None]:
for key in images_labels:
    os.mkdir(os.path.join('./class_images',key))
    for img in images_labels[key]:
        shutil.copy('./train_photos/'+img+".jpg",os.path.join('./class_images',key))