In [5]:
import numpy as np

#%load_ext autoreload
#%autoreload 1
import augmentation


def prep_test_split(all_persons, all_personlabels):
    """
    Function that return some variables that help to split the data.
    
    :return return a person array and a array with how many samples each person has made for each label    
    """
    persons = []
    for i in all_persons:
        if i not in persons:
            persons.append(i)
    num_persons = []
    for i in persons:
        num_persons.append(np.count_nonzero(all_persons == i))
        
        
    label_list_all_persons = []
    for person in persons:
        label_list = []
        for sample in all_personlabels:
            if sample[1] == person:
                label_list.append(sample[0])
        tot_label_list = []
        for i in range(18):
            tot_label_list.append(0)
        for label in label_list:
            tot_label_list[int(label)] += 1
        label_list_all_persons.append(tot_label_list)
        
    return persons, label_list_all_persons

def get_tot_label_list(all_persons, all_personlabels):
    """
    Get the total_label_list that is being used when upsampling your data
    
    :return return the tot_label_list that is being used when you want to upsample your data
    
    """
    persons, label_list_all_persons = prep_test_split(all_persons, all_personlabels)
    tot_label_list = []
    for i in range(18):
        tot_label_list.append(0)
    for labellist in label_list_all_persons:
        for i in range(18):
            tot_label_list[i] += labellist[i]
    return tot_label_list

def upsample_help(all_samples, all_labels, all_persons, all_personlabels):
    """
    HELP function for the total upsampling function. This function checks whether or not upsampling is needed for
    samples with a low amount of frames
    Samples the labels with the lowest amount up to the number of labels from the most common type
    
    :param all_samples: samplearray
    :param all_labels:
    :return an array with the total new samples, new labels and a new tot_label_list
    """
    tot_label_list = get_tot_label_list()
    upsampled_samples = [sample for sample in all_samples]
    upsampled_labels = [label for label in all_labels]
    #print(len(all_samples), len(all_labels))
    #print(len(upsampled_samples), len(upsampled_labels))
    max_labels = max(tot_label_list)
    cur_label_list = [tot_label_list[e] for e in range(18)] #make a list that counts the amount of labels
    
    for i in range(len(all_samples)):
        if (cur_label_list[int(all_labels[i])] < max_labels): #if there are not enough samples yet for that label
            cur_label_list[int(all_labels[i])] += 1 #count the label
            
            sample = all_samples[i]
            if (i%2 == 0):
                extra_sample_shift = augmentation.move_left_hand(sample, -11, -10)
            else:
                extra_sample_shift = augmentation.move_left_hand(sample, 17, 15)   
            upsampled_samples.append(extra_sample_shift)
            upsampled_labels.append(all_labels[i])
    for i in range(len(all_samples)):
        if (cur_label_list[int(all_labels[i])] < max_labels): #if there are not enough samples yet for that label
            cur_label_list[int(all_labels[i])] += 1 #count the label
            
            sample = all_samples[i]
            if (i%2 == 0):
                extra_sample_shift = augmentation.move_right_hand(sample, 18, 13)
            else:
                extra_sample_shift = augmentation.move_right_hand(sample, -16, -18)      
            upsampled_samples.append(extra_sample_shift)
            upsampled_labels.append(all_labels[i])
    return (upsampled_samples, upsampled_labels, cur_label_list)

def upsample(all_samples, all_labels, all_persons, all_personlabels):
    """
    This function upsampled the lowest labels and when all the labels are even it does augmentation on the new sampleset
    """
    upsampled_samples, upsampled_labels, upsampled_label_list = upsample_help(all_samples, all_labels)#upsampling the lowest ones
    #print(len(upsampled_samples), len(all_samples), len(all_labels))
    for i in range(len(upsampled_samples)): #add the scaling to every sample too
        upsampled_samples.append(augmentation.scale_person(sample, 1.56))
        upsampled_labels.append(upsampled_labels[i])
        upsampled_label_list[int(upsampled_labels[i])] += 1
    return (upsampled_samples, upsampled_labels, upsampled_label_list)