In [8]:
"""
Run tsne with different numbers of dimensions, svm and export result
"""
import os
import pickle
from django.conf import settings
from collections import Counter

import numpy as np
from sklearn.manifold import MDS
from sklearn.metrics import euclidean_distances

from root.models import *
from koe.models import *
from koe.model_utils import get_or_error
from koe.ts_utils import bytes_to_ndarray, get_rawdata_from_binary

import plotly.plotly as py
import plotly.graph_objs as go

import plotly
# plotly.tools.set_credentials_file(username='As25ry0z', api_key='wv0DAw06X9OBjVInCUNC')
# plotly.tools.set_credentials_file(username='Vq7J11Fs', api_key='Gai9GC6MzaCD6rs2Wlv0')
plotly.tools.set_credentials_file(username='6X5uQKc1', api_key='3ygdcpCk0SuRuO1wFlGq')


def get_labels_by_sids(sids, label_level, annotator, min_occur):
    sid2lbl = {
        x: y.lower() for x, y in ExtraAttrValue.objects
        .filter(attr__name=label_level, owner_id__in=sids, user=annotator)
        .values_list('owner_id', 'value')
    }

    occurs = Counter(sid2lbl.values())

    segment_to_labels = {}
    for segid, label in sid2lbl.items():
        if occurs[label] >= min_occur:
            segment_to_labels[segid] = label

    labels = []
    no_label_ids = []
    for id in sids:
        label = segment_to_labels.get(id, None)
        if label is None:
            no_label_ids.append(id)
            labels.append('__NONE__')
        else:
            labels.append(label)

    return np.array(labels), np.array(no_label_ids, dtype=np.int32)


def exclude_no_labels(sids, tids, labels, no_label_ids):
    no_label_inds = np.searchsorted(sids, no_label_ids)

    sids_mask = np.full((len(sids),), True, dtype=np.bool)
    sids_mask[no_label_inds] = False

    return sids[sids_mask], tids[sids_mask], labels[sids_mask]


def handle(database_name, population_name, type, normalised, *args, **kwargs):
    database = get_or_error(Database, dict(name__iexact=database_name))

    normalised_str = 'normed' if normalised else 'raw'
    file_name =  '{}_{}_{}_{}.pkl'.format(database_name, population_name, type, normalised_str)
    file_path = os.path.join(settings.BASE_DIR, file_name)
    if not os.path.isfile(file_path):
        raise Exception('File not found: {}'.format(file_path))

    with open(file_path, 'rb') as f:
        saved = pickle.load(f)
    coordinate = saved['coordinate']
    stress = saved['stress']
    sids = saved['sids']
    tids = saved['tids']
    return coordinate, stress, sids, tids


def get_plot_graph(database_name, population, annotator_name, label_level, type, normalised):
    annotator = get_or_error(User, dict(username__iexact=annotator_name))
    normalised_str = 'normed' if normalised else 'raw'
    fig_name = '{}_{}_{}_{}_{}_{}'.format(database_name, population, annotator_name, label_level, type, 
                                          normalised_str)
        
    coordinate, stress, sids, tids = handle(database_name, population, type, normalised)
    
    labels, no_label_ids = get_labels_by_sids(sids, label_level, annotator, 1)

    unique_labels = np.unique(labels)
    traces = []
    
    traces = []
    for label in unique_labels:
        ind = np.where(labels==label)
        x = coordinate[ind, 0].ravel();
        y = coordinate[ind, 1].ravel();
        z = coordinate[ind, 2].ravel();
        ids = sids[ind];

        trace1 = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            name=label.strip(),
            mode='markers',
            marker=dict(
                size=5,
                line=dict(
                    width=0.5
                ),
                opacity=1
            ),
            text=ids
        )
        traces.append(trace1)
    layout = go.Layout(
        title=fig_name,
        margin=dict(
            l=0,
            r=0,
            b=0,
            t=0
        )
    )
    fig = go.Figure(data=traces, layout=layout)

    return fig, fig_name


In [11]:
database_name = 'Bellbirds'
annotator_name = 'wesley'
label_levels = ['label', 'label_family']
populations = ['LBI']
normaliseds = [True]

for population in populations:
    for label_level in label_levels:
        for normalised in normaliseds:
            fig, fig_name = get_plot_graph(database_name, population, annotator_name, label_level, 'mdspca', normalised)
            plot = py.iplot(fig, filename=fig_name)
            print('{}: {}'.format(fig_name, plot.resource))

Bellbirds_LBI_wesley_label_mdspca_normed: https://plot.ly/~6X5uQKc1/16
Bellbirds_LBI_wesley_label_family_mdspca_normed: https://plot.ly/~6X5uQKc1/18
