This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
teachers.py
2677 lines (2275 loc) · 96.5 KB
/
teachers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This module provides a set of teachers that deal with dialog.
``FixedDialogTeacher(Teacher)``
Base class for teachers in tasks that have fixed dialog - i.e., dialog
that is not dynamically generated but rather is pulled from set examples.
However, the class can be extended to all tasks involved fixed data.
Implements much of the basic functionality of these teachers, including
``observe()``, ``act()``, ``next_example()``
``DialogTeacher(FixedDialogTeacher)``
Base teacher class for doing dialog specifically with fixed chat logs.
``ParlAIDialogTeacher(DialogTeacher)``
Teacher class that provides access to data in the ParlAI Dialog format.
See the class description for more details.
``ConversationTeacher(DialogTeacher)``
Teacher class that provides access to data in the Conversations format.
See the class description for more details.
``FbDeprecatedDialogTeacher(DialogTeacher)``
Teacher class that provides access to data in the Facebook Dialog format.
See the class description for more details. **This class is deprecated**.
This module also includes ``DataLoader``, a threadpool data loader for
``FixedDialogTeacher``, and ``DialogData``/``StreamDialogData``, data
structures for accessing textual dialog data and utilized by ``DialogTeacher``
"""
from parlai.core.params import ParlaiParser
from parlai.core.agents import Agent, create_agent_from_shared
from parlai.core.image_featurizers import ImageLoader
from parlai.core.loader import load_teacher_module
from parlai.core.loader import register_teacher # noqa: F401
from parlai.core.message import Message
from parlai.core.metrics import TeacherMetrics, aggregate_named_reports
from parlai.core.opt import Opt
from parlai.utils.conversations import Conversations
from parlai.utils.data import DatatypeHelper
from parlai.utils.misc import AttrDict, str_to_msg, warn_once, SimpleCounter
from parlai.utils.distributed import get_rank, num_workers, is_distributed
import parlai.utils.torch as torch_utils
import parlai.utils.logging as logging
from parlai.utils.io import PathManager
from parlai.core.mutators import Mutator
from abc import ABC, abstractmethod
import argparse
from collections import defaultdict
import concurrent.futures
import copy
import json
import os
import queue
import random
import yaml
from threading import Thread
import torch
from typing import List, Tuple, Optional, TypeVar, Any
ERROR_MESSAGE_NO_DATAFILE = (
"{class_name} is expected to set self.opt['datafile'] inside `__init__` "
"before calling `super().__init__`. This will passed to setup_data, "
"indicating what data to load. If you don't know what to use, set "
"`opt['datafile'] = parlai.utils.data.DatatypeHelper.fold(opt['datatype'])` "
"to receive the fold name in setup_data."
)
ChunkOutput = TypeVar('ChunkOutput')
class DataLoader(Thread):
"""
A worker thread that provides a threadpool for data loading.
A teacher may submit a request to the loader, which will return the
appropriate data.
To submit a request, a teacher should call ``request_load``.
"""
def __init__(self, opt):
Thread.__init__(self, daemon=True)
self.num_workers = opt.get('num_load_threads', 1)
self.request_queue = queue.Queue()
self.last_future = None
def request_load(self, receive_fn, load_fn, args):
"""
Queue a request for loading.
:param receive_fn:
a receive function (for receiving the data)
:param load_fn:
a load function (for loading the data)
:param args:
arguments for the load function. args can be either a dictionary of
arguments for a function, or a list of positional arguments
"""
self.request_queue.put((receive_fn, load_fn, args))
def run(self):
"""
Run the execution loop.
"""
executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix=self.name
)
with executor:
while True:
receive_fn, load_fn, args = self.request_queue.get()
if receive_fn is StopIteration:
return
try:
if type(args) == dict:
future = executor.submit(load_fn, **args)
else:
future = executor.submit(load_fn, *args)
self.last_future = future
receive_fn(future)
except RuntimeError:
return
class _ErrorThrowingDataLoader(object):
"""
A fake DataLoader which throws an exception when a work order is placed.
Since threads cannot be mixed with spawn_method='fork', we need to disallow users
from combining --num-workers with teachers that utilize threads. This placeholder
object is only useful for ensuring the user sees a loud error message when they
accidentally use a thread.
"""
def __init__(self, opt):
pass
def request_load(self, receive_fn, load_fn, args):
raise RuntimeError(
'One of your teachers uses a DataLoader or a thread. You may only '
'combine this with --num-workers 0.'
)
def start(self):
pass
class Teacher(Agent):
"""
Basic Teacher agent that keeps track of how many times it's received messages.
Teachers provide the ``report()`` method to get back metrics.
"""
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
parser.add_argument(
'--mutators',
'-mut',
default=None,
help='Apply one or more mutators to the data.',
)
mutators = Mutator.load_mutator_types(partial_opt.get('mutators'))
for m in mutators:
m.add_cmdline_args(parser, partial_opt)
return parser
def __init__(self, opt: Opt, shared=None):
if not hasattr(self, 'opt'):
self.opt = copy.deepcopy(opt)
if not hasattr(self, 'id'):
self.id = opt.get('task', 'teacher')
if not hasattr(self, 'metrics'):
self.metrics = TeacherMetrics(
metrics_list=opt.get('metrics', 'default'),
shared=shared['metrics'] if shared is not None else None,
)
self.epochDone = False
# return state/action dict based upon passed state
def act(self):
"""
Act upon the previous observation.
"""
if self.observation is not None and 'text' in self.observation:
t = Message({'text': 'Hello agent!'})
return t
def epoch_done(self):
"""
Return whether the epoch is done.
"""
return self.epochDone
# Default unknown length
def num_examples(self):
"""
Return the number of examples (e.g. individual utterances) in the dataset.
Default implementation returns `None`, indicating an unknown number.
"""
return None
def num_episodes(self):
"""
Return the number of episodes (e.g. conversations) in the dataset.
Default implementation returns `None`, indicating an unknown number.
"""
return None
def report(self):
"""
Return metrics showing total examples and accuracy if available.
"""
return self.metrics.report()
def reset(self):
"""
Reset the teacher.
"""
super().reset()
self.reset_metrics()
self.epochDone = False
def reset_metrics(self):
"""
Reset metrics.
"""
self.metrics.clear()
def share(self):
"""
In addition to default Agent shared parameters, share metrics.
"""
shared = super().share()
shared['metrics'] = self.metrics.share()
return shared
def __iter__(self):
"""
Iterate through the examples of the teacher.
"""
clone = self.clone()
while True:
message = clone.act()
if not isinstance(message, Message):
# backwards compatibility with older agents
message = Message(message)
if message.is_padding():
break
yield message
class FixedDialogTeacher(Teacher):
"""
A teacher agent for all teachers involved in tasks with fixed data.
This class provides the following functionality for its subclasses:
- Resets a teacher
- Provides an observe method
- Computes and retrieves the next episode index for a teacher
- Provides a threadpool option for loading data (especially useful for
large data, e.g. images)
In order to take advantage of the first few features, all a subclass has to
implement is three functions: ``num_episodes``, ``num_examples``, and
``get`` (which returns a specific example from a specific episode).
To utilize the DataLoader for threadpool loading, a teacher should
implement the ``submit_load_request`` function to send a load request
to the DataLoader by calling ``self.data_loader.request_load`` with the
appropriate arguments (``receive_fn, load_fn, args``). The DataLoader then
returns the data to the teacher's ``data_queue``, which the teacher can
poll in its ``act`` method.
The following is an example of the DataLoader usage in the VQA-V1 teacher.
1. In the teacher's ``init`` function, the teacher calls its
``submit_load_request`` function to preload an image.
2. The ``submit_load_request`` function gets the next ``episode_idx``,
and computes the image path for the load request.
3. At the end of ``submit_load_request``, the teacher calls
``self.data_loader.request_load`` with three args:
- ``self.receive_data`` - the function that the DataLoader calls to
return the the loaded object
- ``self.image_loader.load`` - the function used to load the image
from the image path
- ``[img_path]`` - a list of arguments for the load function, which
in this case is the path of the image.
4. In the teacher's ``act`` function, the teacher loads the data from
its data queue.
5. At the end of the ``act`` function, the teacher calls
``submit_load_request`` to preload an image for the next example.
To see this in action, take a look at this teacher in ``tasks.vqa_v1.agents``.
"""
def __init__(self, opt, shared=None):
super().__init__(opt, shared)
if not hasattr(self, 'datatype'):
self.datatype = opt['datatype']
if not hasattr(self, 'random'):
self.random = self.datatype == 'train'
if not hasattr(self, 'training'):
self.training = DatatypeHelper.is_training(self.datatype)
if not hasattr(self, 'cycle'):
self.cycle = DatatypeHelper.should_cycle(self.datatype)
if not hasattr(self, 'datafile'):
self.datafile = opt.get('datafile')
# set up support for multithreaded data loading
self.data_queue = queue.Queue()
if shared:
self.index = shared['index']
if 'data_loader' in shared:
self.data_loader = shared['data_loader']
if 'threadindex' in shared:
self.threadindex = shared['threadindex']
if 'examples' in shared:
self.examples = shared['examples']
else:
self.index = AttrDict(value=-1)
if not hasattr(self, 'data_loader'):
if opt.get('background_index') is None:
self.data_loader = DataLoader(opt)
else:
self.data_loader = _ErrorThrowingDataLoader(opt)
self.data_loader.start()
# set up batching
self.bsz = opt.get('batchsize', 1)
if shared:
self.mutators = shared.get('mutators', [])
else:
mutator_types = Mutator.load_mutator_types(self.opt.get('mutators'))
self.mutators = [mutator(self.opt) for mutator in mutator_types]
self._episode_done = True
def reset(self):
"""
Reset the dialog to the start of the epoch, and reset all metrics.
"""
super().reset()
self.metrics.clear()
self.lastY = None
self.last_act = None
self._episode_done = True
self.epochDone = False
self.data_queue = queue.Queue()
self.episode_idx = -1
self.index.value = -1
def submit_load_request(self):
"""
Submit a load request.
An agent should implement this method to submit requests to the data
loader. At the end of this method, the agent should call
``self.data_loader.request_load()`` with the appropriate args.
By default, this method does nothing.
"""
# TODO: mark as abstract
pass
def receive_data(self, future: concurrent.futures.Future):
"""
Receive data from the data loader.
:param future: result from the load request.
"""
data = future.result()
self.data_queue.put(data)
def share(self):
"""
Share the data and dataloader.
"""
shared = super().share()
if hasattr(self, 'examples'):
shared['examples'] = self.examples
if hasattr(self, 'data_loader'):
shared['data_loader'] = self.data_loader
if hasattr(self, 'mutators'):
shared['mutators'] = self.mutators
shared['index'] = self.index
return shared
def next_episode_idx(self, num_eps=None, loop=None):
"""
Return the next episode index.
:param num_eps:
default None uses ``num_episodes`` value.
:param loop:
default None loops during training but not evaluation.
"""
if num_eps is None:
num_eps = self.num_episodes()
if loop is None:
loop = self.training
if self.random:
new_idx = random.randrange(num_eps)
else:
self.index.value += 1
if loop:
self.index.value %= num_eps
new_idx = self.index.value
return new_idx
def next_example(self):
"""
Return the next example.
If there are multiple examples in the same episode, returns the next one in that
episode. If that episode is over, gets a new episode index and returns the first
example of that episode.
"""
if self._episode_done:
self.episode_idx = self.next_episode_idx()
self.entry_idx = 0
if self.episode_idx >= self.num_episodes():
return Message.padding_example(), True
# buffer the full conversation ahead of time for mutators
episode_buffer = []
buffer_entry_idx = 0
while True:
entry = self.get(self.episode_idx, buffer_entry_idx)
if not isinstance(entry, Message):
assert isinstance(entry, dict)
typ = type(self)
warn_once(
f"{typ.__module__}.{typ.__name__}' is outputting dicts "
"instead of messages. If this is a teacher that is part of "
"ParlAI, please file an issue on GitHub. If it is your own "
"teacher, please return a Message object instead."
)
entry = Message(entry)
episode_buffer.append(entry)
if entry.get('episode_done'):
break
buffer_entry_idx += 1
# apply mutators
if self.mutators:
episode_buffer = [m.copy() for m in episode_buffer]
for mutator in self.mutators:
episode_buffer = mutator(episode_buffer)
self.episode_buffer = list(episode_buffer)
if not self.episode_buffer:
# if we got back an empty episode after mutating, skip it
return self.next_example()
else:
self.entry_idx += 1
if self.episode_idx >= self.num_episodes():
return Message.padding_example(), True
# buffer the entire conversation so we can apply mutators
ex = self.episode_buffer[self.entry_idx]
self._episode_done = self.entry_idx == len(self.episode_buffer) - 1
if (
not self.cycle
and self._episode_done
and self.episode_idx + self.opt.get("batchsize", 1) >= self.num_episodes()
):
epoch_done = True
else:
epoch_done = False
return ex, epoch_done
def num_episodes(self) -> int:
"""
Get the number of episodes in this dataset.
"""
raise RuntimeError('"num_episodes" must be overridden by children.')
def num_examples(self) -> int:
"""
Get the total number of examples in this dataset.
"""
raise RuntimeError('"num_examples" must be overridden by children.')
def get(self, episode_idx, entry_idx=0):
"""
Get the specified episode and the specified entry in that episode.
Children must override this method in order to inherit the
`next_example` method.
:param episode_idx:
which episode to return examples from
:param entry_idx:
which example to return from the episode. Many datasets have only
single-entry episodes, so this defaults to zero.
"""
# TODO: mark as abstract, get rid of runtime error.
raise RuntimeError('"Get" method must be overridden by children.')
def observe(self, observation):
"""
Process observation for metrics.
"""
self.metrics.clear_recent()
if hasattr(self, 'lastY') and self.lastY is not None:
self.metrics.evaluate_response(observation, self.lastY)
self.custom_evaluation(self.last_act, self.lastY, observation)
self.lastY = None
recent_metrics = self.metrics.report_recent()
if recent_metrics:
# for display purposes (display_model), take all accumulated
# metrics back into the original observation. This is an abuse of
# Messages being pointers
if 'metrics' in observation:
# override agent-level metrics if present
observation.pop('metrics')
observation['metrics'] = recent_metrics
return observation
def custom_evaluation(
self,
teacher_action: Message,
labels: Optional[Tuple[str]],
model_response: Message,
) -> None:
"""
A method designated for hooking custom evaluations into teachers.
Generally, a user will want to use `self.metrics.add` to record any
specialized metrics that only make sense for this one dataset.
:param teacher_action:
The message last sent from this teacher.
:param labels:
The previous correct labels, if there were any.
:param model_response:
The raw response from the model. Generally you want to rely on the
text field, but others may be necessary in specific situations.
"""
pass
def act(self):
"""
Send new dialog message.
"""
orig_action = self.get_orig_action()
processed_action = self.process_action(orig_action)
return processed_action
def get_orig_action(self) -> Message:
"""
Get the unprocessed action and reset if needed.
This function will return the raw action from `self.next_example()`, before the
`self.last_act` and `self.lastY` attributes have been defined based on this
action for metrics or custom evaluations. This is so that wrapper teachers can
modify the raw action first, such as to change the contents of its 'text' and
'label' fields, without the action becoming out of sync with `self.last_act` and
`self.lastY`.
"""
if not hasattr(self, 'epochDone'):
# reset if haven't yet
self.reset()
# get next example, action is episode_done dict if already out of exs
action, self.epochDone = self.next_example()
if not isinstance(action, Message):
# TODO: all teachers should eventually create messages
# while setting up the data, so this won't be necessary
action = Message(action)
return action
def process_action(self, action: Message) -> Message:
"""
Remember the raw action and prepare its fields for passing out of the teacher.
"""
action.force_set('id', self.getID())
# remember correct answer if available
self.last_act = action
self.lastY = action.get('labels', action.get('eval_labels', None))
if not DatatypeHelper.is_training(self.datatype) and 'labels' in action:
# move labels to eval field so not used for training
# but this way the model can use the labels for perplexity or loss
action = action.copy()
labels = action.pop('labels')
if not self.opt.get('hide_labels', False):
action['eval_labels'] = labels
return action
class DialogTeacher(FixedDialogTeacher):
"""
A base teacher class for doing dialog with fixed chat logs.
This class provides a set a basic functionality:
- uses data class to store and query text data
- generates action tables to send to the student agent from the data
In order to subclass this class, you must implement ``setup_data()`` in
your class, which reads your data file as an iterator.
"""
def __init__(self, opt, shared=None):
# Check for setup_data
if not hasattr(self, 'setup_data'):
raise RuntimeError(
'Must implement setup_data or subclass a class '
'which implements it (e.g. FbDeprecatedDialogTeacher) '
'in order to use this class.'
)
super().__init__(opt, shared)
self.datatype = opt['datatype']
self.training = DatatypeHelper.is_training(self.datatype)
self.cycle = DatatypeHelper.should_cycle(self.datatype)
self.stream = 'stream' in self.datatype
# first initialize any shared objects
data_class = StreamDialogData if self.stream else DialogData
kwargs = (
# never cycle if "ordered" is in the datatype. this is used by
# build_dict to enumerate through the data exactly once while still
# marking examples as training examples.
{'cycle': self.cycle}
if self.stream
else {}
)
if shared and shared.get('data'):
self.data = data_class(opt, shared=shared['data'], **kwargs)
else:
if 'datafile' not in self.opt:
raise KeyError(
ERROR_MESSAGE_NO_DATAFILE.format(class_name=self.__class__.__name__)
)
self.data = data_class(
opt,
data_loader=self.setup_data,
cands=self.label_candidates(),
**kwargs,
)
self.reset()
@abstractmethod
def setup_data(self, datafile: str):
"""
The core method which the user should override.
Yields the data, one message at a time, as well as markers indicating
new episodes.
:param str datafile:
If the initializer set a 'datafile' field within the initialization,
this will be provided here. Otherwise, datafile will be the fold:
either "train", "valid", or "test".
:return:
Yields pairs (message, new_episode) containing a Message object
and whether the message marks the beginning of a totally new
episode.
"""
pass
def reset(self):
"""
Reset the dialog to the start of the epoch, reset all metrics.
"""
super().reset()
if self.stream:
self.data.reset()
self.epochDone = False
def share(self):
"""
Share the data.
"""
shared = super().share()
if hasattr(self, 'data'):
shared['data'] = self.data.share()
return shared
def label_candidates(self):
"""
Provide consistent label candidates for all examples.
Default implementation returns ``None`` always, but this may be overridden to
provide candidates in all areas. See ``FbDialogueTeacher``.
"""
# TODO DEPRECATIONDAY: FbDialogueTeacher is being deprecated, should we
# remove this?
# TODO: mark as optionally abstract?
return None
def num_episodes(self) -> int:
"""
Return the number of episodes in the data.
"""
try:
return self.data.num_episodes()
except AttributeError:
return super().num_episodes()
def num_examples(self) -> int:
"""
Return the number of examples in the data.
"""
if hasattr(self, '_num_examples_cache'):
return self._num_examples_cache
try:
self._num_examples_cache: int = self.data.num_examples()
except AttributeError:
self._num_examples_cache = super().num_examples()
return self._num_examples_cache
def get(self, episode_idx, entry_idx=0):
"""
Get a specific example.
"""
return self.data.get(episode_idx, entry_idx)[0]
def next_example(self):
"""
Get the next example.
"""
if self.stream:
# unfortunately we need to also do the mutator buffering here.
# it's difficult to structure it so it's not
if hasattr(self, 'episode_buffer') and self.episode_buffer:
action = self.episode_buffer.pop(0)
epoch_done = (not self.episode_buffer) and self._saw_epoch_done
return action, epoch_done
episode_buffer = []
while True:
action, epoch_done = self.data.get()
episode_buffer.append(action)
if action['episode_done']:
self._saw_epoch_done = epoch_done
break
# perform any mutations there are
if self.mutators:
episode_buffer = [m.copy() for m in episode_buffer]
for mutator in self.mutators:
episode_buffer = mutator(episode_buffer)
# make sure mutations are fully realized (not generators)
self.episode_buffer = list(episode_buffer)
# The recursive call has dual purpose:
# - if we get back an empty episode after mutating, skip it gracefully
# - pull the first item the episode w/ epoch_done logic, but DRY
return self.next_example()
else:
action, epoch_done = super().next_example()
return action, epoch_done
class DialogData(object):
"""
Provides a data structure for accessing textual dialog data.
This can be used whenever the dialog data is a fixed log of chats
(i.e not a simulator setting). The logs can include dialog text and possibly
supervised labels, candidate labels and rewards.
All these are stored in this internal data format which is used by the
``DialogTeacher`` class.
:param opt:
options to initialize the class
:param data_loader:
an iterable with each call returning a tuple in the form
``((x, y, r, c, i), new_episode?)`` where the ``x`` and ``new_episode``
fields are mandatory and other fields may be omitted or ``None``.
:param cands:
can be set to provide a list of candidate labels for every example in
this dataset, which the agent can choose from (the correct answer
should be in this set).
:param random:
tells the data class whether or not to visit episodes sequentially or
randomly when returning examples to the caller.
The contents of the ``((x, y, r, c, i), new_episode?)`` tuples returned by
the data loader is the following:
- ``x`` (str) is a query and possibly context
- ``y`` (iter) is an iterable of label(s) for that query
- ``r`` (str) is the str reward for getting that query correct
- ``c`` (iter) is an iterable of label candidates that the student can choose from
- ``i`` (str) is a str path to an image on disk, which will be loaded by the
data class at request-time. should always point to the raw image file.
- ``new_episode?`` (bool) is a boolean value specifying whether that example
is the start of a new episode. If you don't use episodes set this
to ``True`` every time.
"""
def __init__(self, opt, data_loader=None, cands=None, shared=None, **kwargs):
# in case we need to shard the dataset
self.rank = get_rank()
self.num_workers = num_workers()
self.is_distributed_and_is_eval = is_distributed() and any(
x in opt['datatype'] for x in ('valid', 'test', 'train:evalmode')
)
# self.data is a list of episodes
# each episode is a tuple of entries
# each entry is a tuple of values for the action/observation table
if shared:
self.image_loader = shared.get('image_loader', None)
self.data = shared.get('data', [])
self.cands = shared.get('cands', None)
else:
self.image_loader = ImageLoader(opt)
self.data = []
if 'datafile' not in opt:
raise KeyError(
ERROR_MESSAGE_NO_DATAFILE.format(class_name=self.__class__.__name__)
)
self._load(data_loader, opt['datafile'])
self.cands = None if cands is None else set(c for c in cands)
self.addedCands = []
self.copied_cands = False
def share(self):
"""
Share the data.
"""
shared = {
'data': self.data,
'cands': self.cands,
'image_loader': self.image_loader,
}
return shared
def _read_episode(self, data_loader):
"""
Read one episode at a time from the provided iterable over entries.
:param data_loader:
an iterable which returns tuples in the format described in the
class docstring.
"""
episode = []
for entry, new in data_loader:
if new and len(episode) > 0:
yield episode
episode = []
episode.append(entry)
if len(episode) > 0:
yield episode
def _load(self, data_loader, datafile):
"""
Load up data from an iterable over tuples described in the class docs.
:param iter data_loader:
an iterator which returns tuples in the format described in the
class docstring.
:param str datafile:
"""
for i, episode in enumerate(self._read_episode(data_loader(datafile))):
if not self.is_distributed_and_is_eval or i % self.num_workers == self.rank:
self.data.append(episode)
def num_episodes(self):
"""
Return number of episodes in the dataset.
"""
return len(self.data)
def num_examples(self):
"""
Return total number of entries available.
Each episode has at least one entry, but might have many more.
"""
if hasattr(self, '_num_examples_cache'):
return self._num_examples_cache
self._num_examples_cache = sum(len(episode) for episode in self.data)
return self._num_examples_cache
def get(self, episode_idx, entry_idx=0):
"""
Get the specified episode and the specified entry in that episode.
:param episode_idx:
which episode to return examples from
:param entry_idx:
which example to return from the episode. Many datasets have only
single-entry episodes, so this defaults to zero.
"""
if episode_idx >= len(self.data):
return Message.padding_example(), True
next_episode_idx_for_rank = episode_idx + 1
# first look up data
episode = self.data[episode_idx]
entry = episode[entry_idx]
episode_done = entry_idx == len(episode) - 1
end_of_data = episode_done and next_episode_idx_for_rank >= len(self.data)
# now pack it in a action-observation dictionary
table = self.build_table(entry)
# last entry in this episode
table['episode_done'] = episode_done
return table, end_of_data
def build_table(self, entry):
"""
Packs an entry into an action-observation dictionary.
:param entry: a tuple in the form described in the class docstring.
"""
if isinstance(entry, (dict, Message)):
# user is already provided things
if 'eval_labels' in entry or 'eval_label' in entry:
raise KeyError(
'Labels are converted to eval_labels automatically. Please do not '
'set them in setup_data.'
)
if 'episode_done' in entry:
raise KeyError(
"episode_done is set automatically for you. Please don't set it "
"in setup_data."
)
if 'label' in entry:
# for convenience, rename to the labels convention automatically
label = entry.pop('label')
assert isinstance(label, str)
entry['labels'] = (label,)
if 'labels' in entry and isinstance(entry['labels'], str):
entry['labels'] = (entry['labels'],)
table = entry.copy()
elif isinstance(entry, (Tuple, List)):
table = {}
if entry[0] is not None:
table['text'] = entry[0]
if len(entry) > 1 and entry[1] is not None:
l = entry[1]
if isinstance(l, str):
l = (l,)
table['labels'] = l
if len(entry) > 2 and entry[2] is not None:
table['reward'] = entry[2]
if len(entry) > 3 and entry[3] is not None:
table['label_candidates'] = entry[3]
if len(entry) > 4 and entry[4] is not None:
img = self.image_loader.load(entry[4])
if img is not None:
table['image'] = img
else:
raise TypeError(
f"items out of setup_data should be dict, Message, list, or tuple. "
f"Got {type(entry)})"
)
if table.get('labels', None) is not None and self.cands is not None:
if self.addedCands:
# remove elements in addedCands
self.cands.difference_update(self.addedCands)
self.addedCands.clear()
for label in table['labels']:
if label not in self.cands:
# add labels, queue them for removal next time
if not self.copied_cands: