-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: Labelstudio support for MLHandler template
The MLHandler template now has a UI for sentiment analysis with transformers. A labelstudio interface is added for fine-tuning the model with live annotations
- Loading branch information
Showing
6 changed files
with
541 additions
and
329 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,317 @@ | ||
{% set base = '.' %} | ||
{% set columns = data.columns.tolist() %} | ||
{% set CLASSIFICTION_MODELS = [ | ||
'LogisticRegression', | ||
'BernoulliNB', | ||
'Perceptron', | ||
'PassiveAggressiveClassifier', | ||
'SVC', | ||
'NuSVC', | ||
'LinearSVC', | ||
'KNeighborsClassifier', | ||
'GaussianNB', | ||
'DecisionTreeClassifier', | ||
'RandomForestClassifier', | ||
'MLPClassifier'] %} | ||
{% set REGRESSION_MODELS = [ | ||
'LinearRegression', | ||
'PassiveAggressiveRegressor', | ||
'SVR', | ||
'NuSVR', | ||
'LinearSVR', | ||
'KNeighborsRegressor', | ||
'DecisionTreeRegressor', | ||
'RandomForestRegressor', | ||
'MLPRegressor'] %} | ||
{% set tcol = handler.store.load('target_col', False) %} | ||
{% set CLASSIFICTION_METRICS = { | ||
'Accuracy': 'accuracy', | ||
'Balanced Accuracy': 'balanced_accuracy', | ||
'ROC AUC': 'roc_auc', | ||
'F1 Score': 'f1_weighted' | ||
}%} | ||
{% set REGRESSION_METRICS = { | ||
'R2': 'r2', | ||
'Explained Variance': 'explained_variance', | ||
'Max Error': 'max_error', | ||
'Negative Mean Absolute Error': 'neg_mean_absolute_error', | ||
'Negative Mean Squared Error': 'neg_mean_squared_error', | ||
'Negative Root Mean Squared Error': 'neg_root_mean_squared_error' | ||
}%} | ||
<div class="row pb-3 pt-3"> | ||
<div class="col-8 border-right border-dark"> | ||
<h3 class="text-center">Train the Model</h3> | ||
<form class="form" id="train"> | ||
<div class="container"> | ||
<div class="row pt-3 pb-3"> | ||
<div class="formhandler overflow-auto"></div> | ||
</div> | ||
<div class="row"> | ||
<div class="col"> | ||
<label for="targetcol">Pick a Target Column:</label> | ||
<select id="targetcol" class="form-control" name="target_col"> | ||
{% for col in columns %} | ||
{% set selected = "selected" if col == tcol else "" %} | ||
<option value="{{ col }}" {{ selected }}>{{ col }}</option> | ||
{% end %} | ||
</select> | ||
</div> | ||
<div class="col"> | ||
<label for="exclude">Columns to Exclude:</label> | ||
<select id="exclude" class="selectpicker form-control" multiple name="exclude"> | ||
{% for col in columns %} | ||
{% set selected = "selected" if col in handler.store.load('exclude', []) else "" %} | ||
<option value="{{ col }}" {{ selected }}>{{ col }}</option> | ||
{% end %} | ||
</select> | ||
</div> | ||
</div> | ||
<div class="row pb-3 pt-3"> | ||
<div class="col"> | ||
<label for="cats">Categorical Columns:</label> | ||
<select id="cats" class="selectpicker form-control" multiple name="cats"> | ||
{% for col in columns %} | ||
{% set selected = "selected" if col in handler.store.load('cats', []) else "" %} | ||
<option value="{{ col }}" {{ selected }}>{{ col }}</option> | ||
{% end %} | ||
</select> | ||
</div> | ||
<div class="col"> | ||
<label for="nums">Numerical Columns:</label> | ||
<select id="nums" class="selectpicker form-control" multiple name="nums"> | ||
{% for col in columns %} | ||
{% set selected = "selected" if col in handler.store.load('nums', []) else "" %} | ||
<option value="{{ col }}" {{ selected }}>{{ col }}</option> | ||
{% end %} | ||
</select> | ||
</div> | ||
</div> | ||
<div class="row pb-3 pt-3"> | ||
<div class="col"> | ||
<label for="transform">Transform:</label> | ||
<input class="form-control" id="transform" name="data.transform" type="text" | ||
value="{{ handler.store.load('built_transform', '') }}"> | ||
</div> | ||
<div class="col"> | ||
<label for="metric">Choose a Metric:</label> | ||
<select id="metric" class="form-control selectpicker" name="metric"> | ||
{% if handler.store.load('class') in CLASSIFICTION_MODELS %} | ||
{% for i, (mname, metric) in enumerate(CLASSIFICTION_METRICS.items()) %} | ||
{% set selected = "selected" if metric == "accuracy" else "" %} | ||
<option value="{{ metric }}" {{ selected }}>{{ mname }}</option> | ||
{% if i == 0 %} | ||
<option data-divider="true"></option> | ||
{% end %} | ||
{% end %} | ||
{% else %} | ||
{% for i, (mname, metric) in enumerate(REGRESSION_METRICS.items()) %} | ||
{% set selected = "selected" if metric == "r2" else "" %} | ||
<option value="{{ metric }}" {{ selected }}>{{ mname }}</option> | ||
{% if i == 0 %} | ||
<option data-divider="true"></option> | ||
{% end %} | ||
{% end %} | ||
{% end %} | ||
</select> | ||
</div> | ||
<div class="col"> | ||
<label for="modelchoice">Pick a Model:</label> | ||
<select id="modelchoice" class="form-control" name="class"> | ||
<optgroup label="Classification"> | ||
{% for model in CLASSIFICTION_MODELS %} | ||
{% set selected = "selected" if model == handler.store.load('class') else "" %} | ||
<option value="{{ model }}" class="text-dark" {{ selected }}>{{ model }}</option> | ||
{% end %} | ||
</optgroup> | ||
<optgroup label="Regression"> | ||
{% for model in REGRESSION_MODELS %} | ||
{% set selected = "selected" if model == handler.store.load('class') else "" %} | ||
<option value="{{ model }}" {{ selected }}>{{ model }}</option> | ||
{% end %} | ||
</optgroup> | ||
</select> | ||
</div> | ||
</div> | ||
<div class="text-right"> | ||
<button class="btn btn-primary" type="submit">Train</button> | ||
</div> | ||
</div> | ||
</form> | ||
<div class="text-center divider">Results</div> | ||
<div class="container" id="resultcnt"> | ||
<div class="row"> | ||
<div class="col"> | ||
<div class="text-center"> | ||
<strong>Your model scored</strong> | ||
</div> | ||
<div class="text-center"> | ||
<svg width="20%" height="20%" viewBox="0 0 40 40" class="donut"> | ||
<circle class="donut-hole" cx="20" cy="20" r="15.91549430918954" fill="#fff"></circle> | ||
<circle class="donut-ring" cx="20" cy="20" r="15.91549430918954" fill="transparent" stroke-width="3.5"></circle> | ||
<circle class="donut-segment" cx="20" cy="20" r="15.91549430918954" fill="transparent" stroke-width="3.5" stroke-dasharray="10 90" stroke-dashoffset="25"></circle> | ||
<g class="donut-text"> | ||
<text y="50%" transform="translate(0, 2)"> | ||
<tspan x="50%" text-anchor="middle" class="donut-percent">40%</tspan> | ||
</text> | ||
</g> | ||
</svg> | ||
</div> | ||
</div> | ||
<div class="col"> | ||
<div class="row"> | ||
<form id="testform" enctype="multipart/form-data"> | ||
<div class="row pb-3"> | ||
<label for="testurl"> | ||
Happy with the result? <a id="downlink">Download the model.</a> Or get predictions: | ||
</label> | ||
<input name="file" type="file" id="testurl" class="form-control"> | ||
</div> | ||
<div class="text-right"> | ||
<button id="testbtn" class="btn btn-primary" type="submit">Predict</button> | ||
</div> | ||
</form> | ||
</div> | ||
<div class="row"> | ||
<a class="ml-auto" id="downloadbtn">Download Predictions</a> | ||
</div> | ||
</div> | ||
</div> | ||
</div> | ||
</div> | ||
<div class="col-4"> | ||
<h3 class="text-center">Make Predictions</h3> | ||
<div class="row py-2"> | ||
<div class="col"> | ||
<button type="submit" form="predictform" class="btn btn-primary">Predict</button> | ||
</div> | ||
<div class="col"> | ||
<h4 id="predResult"></h4> | ||
</div> | ||
</div> | ||
<form id="predictform" class="overflow-auto"> | ||
<template id="predicttabtemplate"> | ||
<div class="container"> | ||
<% COLS.forEach(function(col) { %> | ||
<div class="form-group row"> | ||
<label for="<%= col.name %>" class="col-md-6"><%= col.name %></label> | ||
<input class="form-control col-md-6" type="<%= col.type %>" name="<%= col.name %>" value="<%=row[col.name]%>"> | ||
</div> | ||
<% }) %> | ||
</div> | ||
</template> | ||
</form> | ||
</div> | ||
</div> | ||
<script> | ||
/* eslint-env browser */ | ||
/* globals $, g1 */ | ||
$.fn.selectpicker.Constructor.BootstrapVersion = '4' | ||
var fh_meta = null | ||
|
||
const get_score_color = function(s) { | ||
let color = '#ff0000' | ||
if (s > 50) { color = '#f7b100'} | ||
if (s > 90) { color = '#00f700' } | ||
return color | ||
} | ||
const get_score = function() { | ||
let url = g1.url.parse(window.location) | ||
$.ajax({ | ||
url: url + '?_action=score&_metric=' + encodeURIComponent($('#metric').val()), | ||
method: 'POST', | ||
success: function(resp) { | ||
let score = Number.parseFloat(JSON.parse(resp).score * 100).toPrecision(4) | ||
score = Number.parseFloat(score) | ||
$('.donut-segment').attr('stroke-dasharray', `${score} ${100 - score}`) | ||
$('tspan').html(`${score}%`) | ||
$('.donut-segment').attr('stroke', get_score_color(score)) | ||
$('#resultcnt').show() | ||
let inputcols = fh_meta.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name)) | ||
$('#predicttabtemplate').template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: fh_meta.formdata[0]}) | ||
} | ||
}) | ||
} | ||
const post_train = function (target_col) { | ||
$('#resultcnt').hide() | ||
let url = g1.url.parse(window.location) | ||
url.hash = '' | ||
$.ajax({ | ||
url: url + '?_action=retrain&target_col=' + encodeURIComponent(target_col) + '&_metric=' + encodeURIComponent($('#metric').val()), | ||
method: 'POST', | ||
success: get_score | ||
}) | ||
} | ||
$(document).ready(function() { | ||
$('#downloadbtn').hide() | ||
$('#resultcnt').hide() | ||
let url = g1.url.parse(window.location) | ||
url.search = '_cache' | ||
$('.formhandler').attr('data-src', url.toString()) | ||
$('.formhandler').on('load', function(obj) { | ||
fh_meta = obj | ||
let inputcols = obj.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name)) | ||
$('#predicttabtemplate').template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: obj.formdata[0]}) | ||
}).formhandler({ | ||
pageSize: 5, | ||
export: false | ||
}) | ||
url.search = '_cache&_opts' | ||
$.getJSON(url.toString()).done(function (e) { | ||
// TODO: Jaidev, please check if & why the next line is required. Anand | ||
let opts = e // eslint-disable-line no-unused-vars | ||
}) | ||
|
||
url.search = '_model&_download' | ||
$('#downlink').attr('href', url.toString()) | ||
|
||
// Select | ||
$('.selectpicker').selectpicker({ | ||
actionsBox: true | ||
}) | ||
|
||
$('#train').submit(function(e) { | ||
e.preventDefault() | ||
let fd = new FormData(this) | ||
let trainUrl = g1.url.parse(window.location + '?_model') | ||
trainUrl.update({class: fd.get('class')}) | ||
trainUrl.update({exclude: fd.getAll('exclude')}) | ||
trainUrl.update({cats: fd.getAll('cats')}) | ||
$.ajax({ | ||
url: trainUrl.toString(), | ||
method: 'PUT', | ||
success: function() { | ||
post_train(fd.get('target_col')) | ||
} | ||
}) | ||
}) | ||
$('#predictform').submit(function(e) { | ||
e.preventDefault() | ||
url = g1.url.parse(window.location) | ||
url.search = $(this).serialize() | ||
$.getJSON(url.toString()).done(function(pred) { | ||
let tcol = $('#targetcol').val() | ||
$('#predResult').text(`${tcol}: ${pred[0][tcol]}`) | ||
}) | ||
}) | ||
$('#testform').submit(function(e) { | ||
e.preventDefault() | ||
let fd = new FormData(this) | ||
let testUrl = g1.url.parse(window.location) | ||
testUrl.hash = '' | ||
testUrl.search = '_action=predict' | ||
$.ajax({ | ||
url: testUrl.toString(), | ||
method: 'POST', | ||
data: fd, | ||
processData: false, | ||
contentType: false, | ||
success: function(resp) { | ||
$('#downloadbtn').show() | ||
$('#downloadbtn').attr('href', 'data:text/json;charset=utf-8,' + encodeURIComponent(resp)) | ||
$('#downloadbtn').attr('download', 'predictions.json') | ||
$('#downloadbtn').attr('_target', '_blank') | ||
} | ||
}) | ||
}) | ||
}) | ||
</script> |
Oops, something went wrong.