# Notebook details

In [None]:
def setup_notebook(fix_python_path=True, reduce_margins=True, plot_inline=True):
    if reduce_margins:
        # Reduce side margins of the notebook
        from IPython.core.display import display, HTML
        display(HTML("<style>.container { width:100% !important; }</style>"))

    if fix_python_path:
        # add egosocial to the python path
        import os, sys
        sys.path.extend([os.path.dirname(os.path.abspath('.'))])

    if plot_inline:
        # Plots inside cells
        %matplotlib inline
    
    global __file__
    __file__ = 'Notebook'

setup_notebook()

# Imports and Constants Definition

In [None]:
# !/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import json
import logging
import os

import pandas as pd

import egosocial
from egosocial import config
from egosocial.core.types import FaceClustering
from egosocial.utils.filesystem import check_directory
from egosocial.utils.filesystem import list_segments
from egosocial.utils.logging import setup_logging

DOMAINS = ['Attachent', 'Reciprocity', 'Mating', 'Heirarchical Power', 'Coalitional Group']
RELATIONS = [
    ['father-child', 'mother-child', 'grandpa-grandchild', 'grandma-grandchild'],
    ['friends', 'siblings', 'classmates'],
    ['lovers/spouses'],
    ['presenter-audience', 'teacher-student', 'trainer-trainee', 'leader-subordinate', 'customer-staff'],
    ['band members', 'dance team members', 'sport team members', 'colleages'],
]

def relation_to_domain(rel_label):
    for dom_idx, grouped_relations in enumerate(RELATIONS):
        for relation in grouped_relations:
            if rel_label == relation:
                return DOMAINS[dom_idx]
    
    for domain in DOMAINS:
        if domain in rel_label:
            return domain
    
    return rel_label

def is_valid_relation(rel_label):
    for grouped_relations in RELATIONS:
        for relation in grouped_relations:
            if rel_label == relation:
                return True
    
    return False

In [None]:
def main(*fake_args):
    entry_msg = 'Merge multiple data sources and labels for egosocial photo-streams.'
    parser = argparse.ArgumentParser(description=entry_msg)

    parser.add_argument('--base_images_dir', required=True,
                        help='Directory containing social segments in several splits. Original images.')    
    parser.add_argument('--base_subimages_dir', required=True,
                        help='Directory containing social segments in several splits. Body/Face images.')    
    
    parser.add_argument('--base_groups_dir', required=True,
                        help='Directory containing groups information.')
    parser.add_argument('--groups_file_name', default='grouped_faces.json',
                        help='File name containing groups information.')

    parser.add_argument('--base_labels_dir', required=True,
                        help='Directory containing labels information.')
    parser.add_argument('--labels_file_name', default='labels.json',
                        help='File name containing labels information.')

    parser.add_argument('--splits', default='train,test,extended',
                        help="""Split folders containing segment directories.
                                Example: given a heirarchy base_images/{train, test}/1/, 
                                call with --splits train,test""")    
    parser.add_argument('--camera_user_attributes_path', required=True,
                        help='Path to file containing the attibutes from camera user.')

    parser.add_argument('--dataset_path', required=True,
                        help='Path to file containing the input data and labels information merged.')

    
    if not os.path.isdir(egosocial.config.TMP_DIR):
        os.mkdir(egosocial.config.TMP_DIR)

    setup_logging(egosocial.config.LOGGING_CONFIG,
                  log_dir=egosocial.config.LOGS_DIR)
    
    # TODO: implement correctly
    args = parser.parse_args(*fake_args)
    
    return args

# Main class

In [None]:
class MergeInfoHelper:
    
    def __init__(
        self,
        base_images_dir=None,
        base_subimages_dir=None,
        base_groups_dir=None,
        groups_file_name=None,
        base_labels_dir=None,
        labels_file_name=None,
        camera_user_attributes_path=None,
        splits=None,
        dataset_path=None,
    ):
        self._base_images_dir = base_images_dir
        self._base_subimages_dir = base_subimages_dir

        self._base_groups_dir = base_groups_dir
        self._groups_file_name = groups_file_name

        self._base_labels_dir = base_labels_dir
        self._labels_file_name = labels_file_name
        
        self._camera_user_attributes_path = camera_user_attributes_path
        self._splits = splits                
        self._dataset_path = dataset_path

        # set up logging
        self._log = logging.getLogger(self.__class__.__name__)
        
    def _load_face_groups(self, split, segment_id):
        segm_groups_dir = os.path.join(self._base_groups_dir, split, segment_id)
        # check groups directory
        check_directory(segm_groups_dir, 'Groups Segment')

        groups_path = os.path.join(segm_groups_dir, self._groups_file_name)
        with open(groups_path) as json_file:
            groups_asjson = json.load(json_file)            
        
        clusters = FaceClustering.from_json(groups_asjson)
        groups_mapping = {}
        for group in clusters.groups:
            key = split, int(segment_id), int(group[0].group_id)
            groups_mapping[key] = group
        
        return groups_mapping
    
    def _load_labels(self, split, segment_id):
        segm_labels_dir = os.path.join(self._base_labels_dir, split, segment_id)
        check_directory(segm_labels_dir, 'Labels')

        labels_file = os.path.join(segm_labels_dir, self._labels_file_name)
        if not os.path.exists(labels_file):
            return []

        with open(labels_file, 'r') as json_file:
            labels_per_segment = json.load(json_file)
            
        labels_data = []
        for labels_group_id in sorted(labels_per_segment, key=lambda x : int(x['group_id'])):
            label_list = labels_group_id['labels']
            # filter for valid relation labels
            # exactly one valid label
            if len(label_list) == 1 and is_valid_relation(label_list[0]):
                rel_label, dom_label = label_list[0], relation_to_domain(label_list[0])                
                group_id = int(labels_group_id['group_id'])
                
                entry = (split, int(segment_id), int(group_id)), rel_label, dom_label
                labels_data.append(entry)
        
        return labels_data
    
    def _load_camera_user_attributes(self):        
        camera_user_df = pd.read_csv(self._camera_user_attributes_path)

        self._camera_users_data = {}
        for index, row in camera_user_df.iterrows():
            
            key = row['Split'], str(row['Segment'])
            camera_user_info = dict(
                camera_user_gender=row['Gender'],
                camera_user_age=row['Age'],
                camera_user_name=row['User'],
            )
            
            self._camera_users_data[key] = camera_user_info
        
    def _get_face_info(self, split, segment_id, iface):        
        image_name, ext = os.path.splitext(iface.image_name)
        src_image_name, face_id = image_name.rsplit('_', 1)
        src_image_name = src_image_name + ext
        
        face_info = dict(
            face_id = int(face_id),
            global_image_path = os.path.join(self._base_images_dir, split, segment_id, src_image_name),
            face_image_path = os.path.join(self._base_subimages_dir, split, 'face', segment_id, iface.image_name),
            body_image_path = os.path.join(self._base_subimages_dir, split, 'body', segment_id, iface.image_name),
        )
        
        return face_info
    
    def _process_segment_in_split(self, split, segment_id):        
        groups_data = self._load_face_groups(split, segment_id)
        labels_data = self._load_labels(split, segment_id)
        
        merged_data = []
        for group_key, relation_label, domain_label in labels_data:
            group = groups_data.get(group_key, [])
            
            group_merge_data = []
            for iface in sorted(group, key=lambda iface: iface.image_name):
                face_info = self._get_face_info(split, segment_id, iface)
                k_split, k_segment_id, k_group_id = group_key
                
                tagged_info = dict(
                    split=k_split, 
                    segment_id=k_segment_id, 
                    group_id=k_group_id,
                    relation_label=relation_label, 
                    domain_label=domain_label,
                )
                tagged_info.update(face_info)
                
                camera_user_key = split, segment_id                
                camera_user_info = self._camera_users_data[camera_user_key]
                tagged_info.update(camera_user_info)
                
                group_merge_data.append(tagged_info)
            
            merged_data.append(group_merge_data)
        
        return merged_data
    
    def process_all(self):
        all_merged_data = []
        
        self._load_camera_user_attributes()
        
        for split in self._splits:
            labels_dir = os.path.join(self._base_labels_dir, split)            
            segments = sorted(list_segments(labels_dir), key=int)
            
            for segment_id in segments:
                all_merged_data.extend(self._process_segment_in_split(split, segment_id))
                
        self._store_merge_data(all_merged_data)
        
    def _store_merge_data(self, merged_data):      
        output_path = self._dataset_path
        # TODO: add docstring
        self._log.debug('Saving merge data to %s' % output_path)

        with open(output_path, 'w') as json_file:
            json.dump(merged_data, json_file, indent=4)    

In [None]:
BASE_DIR = '/media/emasa/OS/Users/Emanuel/Downloads/NO_SYNC/Social Segments'
args = [
    "--base_images_dir", os.path.join(BASE_DIR, 'images'),
    "--base_groups_dir", os.path.join(BASE_DIR, 'clustering_output_sync'),
    "--base_subimages_dir", os.path.join(BASE_DIR, 'output_images_6x3'),
    "--base_labels_dir", os.path.join(BASE_DIR, 'labels'),
    "--camera_user_attributes_path", os.path.join(BASE_DIR, 'camera-user-attributes-all-splits.csv'),
    "--splits", "train,test,extended",
    "--dataset_path", os.path.join(BASE_DIR, 'merged_dataset.json'),
]

conf = main(args)

data_merger = MergeInfoHelper(
    base_images_dir=conf.base_images_dir,
    base_groups_dir=conf.base_groups_dir,
    groups_file_name=conf.groups_file_name,
    base_subimages_dir=conf.base_subimages_dir,    
    base_labels_dir=conf.base_labels_dir,
    labels_file_name=conf.labels_file_name,
    camera_user_attributes_path=conf.camera_user_attributes_path,
    dataset_path=conf.dataset_path,   
)

data_merger._splits = conf.splits.strip().split(',')

In [None]:
data_merger.process_all()