Skip to content

Next speaker prediction for groups utilizing graph representations and role encodings.

Notifications You must be signed in to change notification settings

mitmedialab/graph-role-nsp

Repository files navigation

Role-Aware Graph-Based Next Speaker Prediction

This is the official repository for Role-Aware Graph-Based Next Speaker Prediction in Multi-party Human-Robot Interaction

Overview

drawing

This repo has information on the training code.

For the purposes of this repository, we assume that the dataset is downloaded to ../SpeedDating/

This repo is divided into the following sections:

Clone

Clone only the master branch,

git clone https://github.com/mitmedialab/graph-role-nsp.git

Set up Environment

  • Create an anaconda or a virtual enviroment and activate it
pip install -r requirements.txt

Training

To train a model from scratch, run the following script. Currently, 3 and 4 person settings are supported:

Next Speaker Prediction

#Ours
python train.py --task next_speaker --model_name Graph --group_num 3 --time_step 1 --role 1 --epochs 250 --init_seed 0 --cv_seed 0 

#Ours w/o Role Encoder 
python train.py --task next_speaker --model_name Graph --group_num 3 --time_step 1 --role 0 --epochs 250 --init_seed 0 --cv_seed 0 

#XGBoost
python train_xgboost.py --task next_speaker --group_num 3

Next Speaker Identification

#Ours
python train.py --task identify_speaker --model_name Graph --group_num 3 --time_step 0 --role 1 --epochs 200 --init_seed 0 --cv_seed 0

#Ours w/o Role Encoder 
python train.py --task identify_speaker --model_name Graph --group_num 3 --time_step 0 --role 0 --epochs 200 --init_seed 0 --cv_seed 0 

#XGBoost
python train_xgboost.py --task identify_speaker --group_num 3

About

Next speaker prediction for groups utilizing graph representations and role encodings.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published