Skip to content

Commit

Permalink
ENH: Labelstudio support for MLHandler template
Browse files Browse the repository at this point in the history
The MLHandler template now has a UI for sentiment analysis with
transformers. A labelstudio interface is added for fine-tuning the model
with live annotations
  • Loading branch information
jaidevd committed Jun 6, 2022
1 parent e2ad8d7 commit c7c7b50
Show file tree
Hide file tree
Showing 6 changed files with 541 additions and 329 deletions.
317 changes: 317 additions & 0 deletions gramex/apps/mlhandler/sklearn.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
{% set base = '.' %}
{% set columns = data.columns.tolist() %}
{% set CLASSIFICTION_MODELS = [
'LogisticRegression',
'BernoulliNB',
'Perceptron',
'PassiveAggressiveClassifier',
'SVC',
'NuSVC',
'LinearSVC',
'KNeighborsClassifier',
'GaussianNB',
'DecisionTreeClassifier',
'RandomForestClassifier',
'MLPClassifier'] %}
{% set REGRESSION_MODELS = [
'LinearRegression',
'PassiveAggressiveRegressor',
'SVR',
'NuSVR',
'LinearSVR',
'KNeighborsRegressor',
'DecisionTreeRegressor',
'RandomForestRegressor',
'MLPRegressor'] %}
{% set tcol = handler.store.load('target_col', False) %}
{% set CLASSIFICTION_METRICS = {
'Accuracy': 'accuracy',
'Balanced Accuracy': 'balanced_accuracy',
'ROC AUC': 'roc_auc',
'F1 Score': 'f1_weighted'
}%}
{% set REGRESSION_METRICS = {
'R2': 'r2',
'Explained Variance': 'explained_variance',
'Max Error': 'max_error',
'Negative Mean Absolute Error': 'neg_mean_absolute_error',
'Negative Mean Squared Error': 'neg_mean_squared_error',
'Negative Root Mean Squared Error': 'neg_root_mean_squared_error'
}%}
<div class="row pb-3 pt-3">
<div class="col-8 border-right border-dark">
<h3 class="text-center">Train the Model</h3>
<form class="form" id="train">
<div class="container">
<div class="row pt-3 pb-3">
<div class="formhandler overflow-auto"></div>
</div>
<div class="row">
<div class="col">
<label for="targetcol">Pick a Target Column:</label>
<select id="targetcol" class="form-control" name="target_col">
{% for col in columns %}
{% set selected = "selected" if col == tcol else "" %}
<option value="{{ col }}" {{ selected }}>{{ col }}</option>
{% end %}
</select>
</div>
<div class="col">
<label for="exclude">Columns to Exclude:</label>
<select id="exclude" class="selectpicker form-control" multiple name="exclude">
{% for col in columns %}
{% set selected = "selected" if col in handler.store.load('exclude', []) else "" %}
<option value="{{ col }}" {{ selected }}>{{ col }}</option>
{% end %}
</select>
</div>
</div>
<div class="row pb-3 pt-3">
<div class="col">
<label for="cats">Categorical Columns:</label>
<select id="cats" class="selectpicker form-control" multiple name="cats">
{% for col in columns %}
{% set selected = "selected" if col in handler.store.load('cats', []) else "" %}
<option value="{{ col }}" {{ selected }}>{{ col }}</option>
{% end %}
</select>
</div>
<div class="col">
<label for="nums">Numerical Columns:</label>
<select id="nums" class="selectpicker form-control" multiple name="nums">
{% for col in columns %}
{% set selected = "selected" if col in handler.store.load('nums', []) else "" %}
<option value="{{ col }}" {{ selected }}>{{ col }}</option>
{% end %}
</select>
</div>
</div>
<div class="row pb-3 pt-3">
<div class="col">
<label for="transform">Transform:</label>
<input class="form-control" id="transform" name="data.transform" type="text"
value="{{ handler.store.load('built_transform', '') }}">
</div>
<div class="col">
<label for="metric">Choose a Metric:</label>
<select id="metric" class="form-control selectpicker" name="metric">
{% if handler.store.load('class') in CLASSIFICTION_MODELS %}
{% for i, (mname, metric) in enumerate(CLASSIFICTION_METRICS.items()) %}
{% set selected = "selected" if metric == "accuracy" else "" %}
<option value="{{ metric }}" {{ selected }}>{{ mname }}</option>
{% if i == 0 %}
<option data-divider="true"></option>
{% end %}
{% end %}
{% else %}
{% for i, (mname, metric) in enumerate(REGRESSION_METRICS.items()) %}
{% set selected = "selected" if metric == "r2" else "" %}
<option value="{{ metric }}" {{ selected }}>{{ mname }}</option>
{% if i == 0 %}
<option data-divider="true"></option>
{% end %}
{% end %}
{% end %}
</select>
</div>
<div class="col">
<label for="modelchoice">Pick a Model:</label>
<select id="modelchoice" class="form-control" name="class">
<optgroup label="Classification">
{% for model in CLASSIFICTION_MODELS %}
{% set selected = "selected" if model == handler.store.load('class') else "" %}
<option value="{{ model }}" class="text-dark" {{ selected }}>{{ model }}</option>
{% end %}
</optgroup>
<optgroup label="Regression">
{% for model in REGRESSION_MODELS %}
{% set selected = "selected" if model == handler.store.load('class') else "" %}
<option value="{{ model }}" {{ selected }}>{{ model }}</option>
{% end %}
</optgroup>
</select>
</div>
</div>
<div class="text-right">
<button class="btn btn-primary" type="submit">Train</button>
</div>
</div>
</form>
<div class="text-center divider">Results</div>
<div class="container" id="resultcnt">
<div class="row">
<div class="col">
<div class="text-center">
<strong>Your model scored</strong>
</div>
<div class="text-center">
<svg width="20%" height="20%" viewBox="0 0 40 40" class="donut">
<circle class="donut-hole" cx="20" cy="20" r="15.91549430918954" fill="#fff"></circle>
<circle class="donut-ring" cx="20" cy="20" r="15.91549430918954" fill="transparent" stroke-width="3.5"></circle>
<circle class="donut-segment" cx="20" cy="20" r="15.91549430918954" fill="transparent" stroke-width="3.5" stroke-dasharray="10 90" stroke-dashoffset="25"></circle>
<g class="donut-text">
<text y="50%" transform="translate(0, 2)">
<tspan x="50%" text-anchor="middle" class="donut-percent">40%</tspan>
</text>
</g>
</svg>
</div>
</div>
<div class="col">
<div class="row">
<form id="testform" enctype="multipart/form-data">
<div class="row pb-3">
<label for="testurl">
Happy with the result? <a id="downlink">Download the model.</a> Or get predictions:
</label>
<input name="file" type="file" id="testurl" class="form-control">
</div>
<div class="text-right">
<button id="testbtn" class="btn btn-primary" type="submit">Predict</button>
</div>
</form>
</div>
<div class="row">
<a class="ml-auto" id="downloadbtn">Download Predictions</a>
</div>
</div>
</div>
</div>
</div>
<div class="col-4">
<h3 class="text-center">Make Predictions</h3>
<div class="row py-2">
<div class="col">
<button type="submit" form="predictform" class="btn btn-primary">Predict</button>
</div>
<div class="col">
<h4 id="predResult"></h4>
</div>
</div>
<form id="predictform" class="overflow-auto">
<template id="predicttabtemplate">
<div class="container">
<% COLS.forEach(function(col) { %>
<div class="form-group row">
<label for="<%= col.name %>" class="col-md-6"><%= col.name %></label>
<input class="form-control col-md-6" type="<%= col.type %>" name="<%= col.name %>" value="<%=row[col.name]%>">
</div>
<% }) %>
</div>
</template>
</form>
</div>
</div>
<script>
/* eslint-env browser */
/* globals $, g1 */
$.fn.selectpicker.Constructor.BootstrapVersion = '4'
var fh_meta = null

const get_score_color = function(s) {
let color = '#ff0000'
if (s > 50) { color = '#f7b100'}
if (s > 90) { color = '#00f700' }
return color
}
const get_score = function() {
let url = g1.url.parse(window.location)
$.ajax({
url: url + '?_action=score&_metric=' + encodeURIComponent($('#metric').val()),
method: 'POST',
success: function(resp) {
let score = Number.parseFloat(JSON.parse(resp).score * 100).toPrecision(4)
score = Number.parseFloat(score)
$('.donut-segment').attr('stroke-dasharray', `${score} ${100 - score}`)
$('tspan').html(`${score}%`)
$('.donut-segment').attr('stroke', get_score_color(score))
$('#resultcnt').show()
let inputcols = fh_meta.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name))
$('#predicttabtemplate').template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: fh_meta.formdata[0]})
}
})
}
const post_train = function (target_col) {
$('#resultcnt').hide()
let url = g1.url.parse(window.location)
url.hash = ''
$.ajax({
url: url + '?_action=retrain&target_col=' + encodeURIComponent(target_col) + '&_metric=' + encodeURIComponent($('#metric').val()),
method: 'POST',
success: get_score
})
}
$(document).ready(function() {
$('#downloadbtn').hide()
$('#resultcnt').hide()
let url = g1.url.parse(window.location)
url.search = '_cache'
$('.formhandler').attr('data-src', url.toString())
$('.formhandler').on('load', function(obj) {
fh_meta = obj
let inputcols = obj.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name))
$('#predicttabtemplate').template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: obj.formdata[0]})
}).formhandler({
pageSize: 5,
export: false
})
url.search = '_cache&_opts'
$.getJSON(url.toString()).done(function (e) {
// TODO: Jaidev, please check if & why the next line is required. Anand
let opts = e // eslint-disable-line no-unused-vars
})

url.search = '_model&_download'
$('#downlink').attr('href', url.toString())

// Select
$('.selectpicker').selectpicker({
actionsBox: true
})

$('#train').submit(function(e) {
e.preventDefault()
let fd = new FormData(this)
let trainUrl = g1.url.parse(window.location + '?_model')
trainUrl.update({class: fd.get('class')})
trainUrl.update({exclude: fd.getAll('exclude')})
trainUrl.update({cats: fd.getAll('cats')})
$.ajax({
url: trainUrl.toString(),
method: 'PUT',
success: function() {
post_train(fd.get('target_col'))
}
})
})
$('#predictform').submit(function(e) {
e.preventDefault()
url = g1.url.parse(window.location)
url.search = $(this).serialize()
$.getJSON(url.toString()).done(function(pred) {
let tcol = $('#targetcol').val()
$('#predResult').text(`${tcol}: ${pred[0][tcol]}`)
})
})
$('#testform').submit(function(e) {
e.preventDefault()
let fd = new FormData(this)
let testUrl = g1.url.parse(window.location)
testUrl.hash = ''
testUrl.search = '_action=predict'
$.ajax({
url: testUrl.toString(),
method: 'POST',
data: fd,
processData: false,
contentType: false,
success: function(resp) {
$('#downloadbtn').show()
$('#downloadbtn').attr('href', 'data:text/json;charset=utf-8,' + encodeURIComponent(resp))
$('#downloadbtn').attr('download', 'predictions.json')
$('#downloadbtn').attr('_target', '_blank')
}
})
})
})
</script>

0 comments on commit c7c7b50

Please sign in to comment.