Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Pull out settings for model chat onboarding (#4274)
Browse files Browse the repository at this point in the history
* Layer changes

* Documentation
  • Loading branch information
EricMichaelSmith committed Jan 4, 2022
1 parent 22fd071 commit e686a02
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 60 deletions.
4 changes: 3 additions & 1 deletion parlai/crowdsourcing/tasks/model_chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ The following flags can be passed in to specify filepaths for overriding the tex

## Onboarding

In `worlds.py`, modify `ModelChatOnboardWorld.check_onboarding_answers()` to change the worker selection criteria.
Set the `"min_correct"`, `"max_incorrect"`, and `"max_failures_allowed"` fields in the JSON file passed to `mephisto.blueprint.onboard_task_data_path` in order to specify how many onboarding questions workers can pass/fail on while still passing onboarding, as well as how many times they are allowed to re-take the onboarding before being soft-blocked. (See `task_config/onboard_task_data.json` for an example.)

You can further modify the worker selection criteria in `handleOnboardingSubmit` in `frontend/components/onboarding_components.jsx`.

## Human+model image chat

Expand Down
2 changes: 0 additions & 2 deletions parlai/crowdsourcing/tasks/model_chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,3 @@

ONBOARD_FAIL = '[ONBOARD_FAIL]'
ONBOARD_SUCCESS = '[ONBOARD_SUCCESS]'

ONBOARD_CONFIG = {'min_correct': 4, 'max_incorrect': 3}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import React from "react";
import { ErrorBoundary } from './error_boundary.jsx';
import { Checkboxes } from './checkboxes.jsx';
const ONBOARDING_MIN_CORRECT = 4;
const ONBOARDING_MAX_INCORRECT = 3;
const ONBOARDING_MAX_FAILURES_ALLOWED = 1;
const DEFAULT_MIN_CORRECT = 4;
const DEFAULT_MAX_INCORRECT = 3;
const DEFAULT_MAX_FAILURES_ALLOWED = 1;
var onboardingFailuresCount = 0;

var renderOnboardingFail = function () {
Expand All @@ -35,7 +35,6 @@ function arraysEqual(_arr1, _arr2) {
}

var handleOnboardingSubmit = function ({ onboardingData, currentTurnAnnotations, onSubmit }) {
// OVERRIDE: Re-implement this to change onboarding success criteria
console.log('handleOnboardingSubmit');
var countCorrect = 0;
var countIncorrect = 0;
Expand All @@ -60,10 +59,13 @@ var handleOnboardingSubmit = function ({ onboardingData, currentTurnAnnotations,
}
}
console.log('correct: ' + countCorrect + ', incorrect: ' + countIncorrect);
if (countCorrect >= ONBOARDING_MIN_CORRECT && countIncorrect <= ONBOARDING_MAX_INCORRECT) {
const min_correct = onboardingData.hasOwnProperty("min_correct") ? onboardingData.min_correct : DEFAULT_MIN_CORRECT;
const max_incorrect = onboardingData.hasOwnProperty("max_incorrect") ? onboardingData.max_incorrect : DEFAULT_MAX_INCORRECT;
const max_failures_allowed = onboardingData.hasOwnProperty("max_failures_allowed") ? onboardingData.max_failures_allowed : DEFAULT_MAX_FAILURES_ALLOWED;
if (countCorrect >= min_correct && countIncorrect <= max_incorrect) {
onSubmit({ annotations: currentTurnAnnotations, success: true });
} else {
if (onboardingFailuresCount < ONBOARDING_MAX_FAILURES_ALLOWED) {
if (onboardingFailuresCount < max_failures_allowed) {
onboardingFailuresCount += 1;
alert('You did not label the sample conversation well enough. Please try one more time!');
} else {
Expand Down Expand Up @@ -177,7 +179,7 @@ function OnboardingComponent({ onboardingData, annotationBuckets, annotationQues
className="button is-link btn-lg"
onClick={() => handleOnboardingSubmit({
onboardingData,
currentTurnAnnotations,
currentTurnAnnotations,
onSubmit,
})}
>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,8 @@
]
}
]
]
],
"min_correct": 4,
"max_incorrect": 3,
"max_failures_allowed": 1
}
57 changes: 8 additions & 49 deletions parlai/crowdsourcing/tasks/model_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from parlai.crowdsourcing.utils.worlds import CrowdOnboardWorld, CrowdTaskWorld
from parlai.crowdsourcing.tasks.model_chat.bot_agent import TurkLikeAgent
from parlai.crowdsourcing.tasks.model_chat.constants import (
ONBOARD_CONFIG,
ONBOARD_FAIL,
ONBOARD_SUCCESS,
)
Expand Down Expand Up @@ -50,58 +49,14 @@ def __init__(self, opt, agent: "MephistoAgentWrapper"):

self.skip_onboarding = opt['skip_onboarding']

self.min_correct = ONBOARD_CONFIG['min_correct']
self.max_incorrect = ONBOARD_CONFIG['max_incorrect']
self.onboard_task_data = opt['onboard_task_data']
self.status = 'DISCONNECT'
self.onboard_statistics = opt['onboard_statistics']
self.statistics_condition = opt['statistics_condition']
self.max_onboard_time = opt['max_onboard_time']
self.onboarding_qualification = opt['onboarding_qualification']
self.worker_id = get_mturk_id_from_mephisto_wrapper(self.agent)

def has_same_answer(self, ans1, ans2):
if len(ans1) != len(ans2):
return False

ans1_sort = sorted(ans1)
ans2_sort = sorted(ans2)

for x in range(len(ans1_sort)):
if ans1_sort[x] != ans2_sort[x]:
return False
return True

def check_onboarding_answers(self, worker_answers) -> bool:
"""
Calculate how many correct answers the user gave.
`worker_answers` is a list of dicts containing mappings between an annotation
value and whether it was selected for each bucket. We return a boolean as to
whether the worker passed or failed the task.
"""
given_turns = self.onboard_task_data['dialog']
correct_answers = [t[1]['answers'] for t in given_turns]
number_correct = 0
number_incorrect = 0
for worker_answer, correct_answer in zip(worker_answers, correct_answers):
worker_only_selected = [
key for key, selected in worker_answer.items() if selected
]
if self.has_same_answer(worker_only_selected, correct_answer):
number_correct += 1
else:
number_incorrect += 1

print(
f'Worker {self.worker_id} got {number_correct} annotations correct and {number_incorrect} incorrect in onboarding.'
)
if (
number_correct >= self.min_correct
and number_incorrect <= self.max_incorrect
):
return True
return False
self.annotations = None

def parley(self):

Expand Down Expand Up @@ -140,10 +95,11 @@ def _handle_act(self, act):
print(f'{self.__class__.__name__}: {self.worker_id} had no data submitted')
return ONBOARD_FAIL

worker_answers = act['task_data']['annotations']
self.annotations = act['task_data'].get('annotations')
print('Onboarding annotation results: ', self.annotations)

if self.check_onboarding_answers(worker_answers):
print(f'Worker {self.worker_id} successfully passed the onboard task.')
if act['task_data']['success']:
print(f'Worker {self.worker_id} successfully passed the onboarding task.')

# This will end the onboarding and send them directly to the HIT
self.episodeDone = True
Expand All @@ -164,6 +120,9 @@ def shutdown(self):
self.onboard_statistics[self.status] = 0
self.onboard_statistics[self.status] += 1

def get_custom_task_data(self):
return self.annotations


class BaseModelChatWorld(CrowdTaskWorld, ABC):
def __init__(self, opt, agent, bot):
Expand Down

0 comments on commit e686a02

Please sign in to comment.