-
Notifications
You must be signed in to change notification settings - Fork 161
/
label_types.py
120 lines (92 loc) · 4.26 KB
/
label_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Interface and implementations of label types for a reward model."""
from abc import ABC, abstractmethod
from typing import Optional, Dict
import tensorflow as tf
from lm_human_preferences.utils.core import Schema, pearson_r
class LabelType(ABC):
@abstractmethod
def label_schemas(self) -> Dict[str, Schema]:
"""Schema for the human annotations."""
@abstractmethod
def target_scales(self, labels: Dict[str, tf.Tensor]) -> Optional[tf.Tensor]:
"""Extracts scalars out of labels whose scale corresponds to the reward model's output.
May be none if the labels have no such information."""
@abstractmethod
def loss(self, reward_model, labels: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""
:param labels: the questions with their labels
:returns: a dict of stats, including 'loss' for the actual loss
"""
@abstractmethod
def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
"""Schema for the questions associated with this LabelType."""
class PickBest(LabelType):
"""Pick best response amongst N."""
def __init__(self, num_responses):
self.num_responses = num_responses
def label_schemas(self):
return dict(best=Schema(tf.int32, ()))
def target_scales(self, labels):
return None
def loss(self, reward_model, labels):
logits = tf.stack([reward_model(labels['query'], labels[f'sample{i}'])
for i in range(self.num_responses)], axis=1)
error = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels['best'], logits=logits))
return dict(loss=error, error=error)
def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
return dict(
query=Schema(tf.int32, (query_length,)),
**{f"sample{i}": Schema(tf.int32, (response_length,)) for i in range(self.num_responses)}
)
class ScalarRating(LabelType):
"""Rate a single number with a scalar score."""
def __init__(self):
pass
def label_schemas(self):
return dict(
score=Schema(tf.float32, ()))
def target_scales(self, labels):
return labels['score']
def loss(self, reward_model, labels):
predicted = reward_model(labels['query'], labels['sample'])
labels = labels['score']
error = tf.reduce_mean((labels - predicted) ** 2, axis=0)
label_mean, label_var = tf.nn.moments(labels, axes=[0])
corr = pearson_r(labels, predicted)
return dict(loss=error, error=error,
label_mean=label_mean, label_var=label_var, corr=corr)
def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
return dict(
query=Schema(tf.int32, (query_length,)),
sample=Schema(tf.int32, (response_length,)),
)
class ScalarComparison(LabelType):
"""Give a scalar indicating difference between two responses."""
def label_schemas(self):
return dict(difference=Schema(tf.float32, ()))
def target_scales(self, labels):
# Divide by two to get something with the same variance as the trained reward model output
return labels['difference']/2
def loss(self, reward_model, labels):
outputs0 = reward_model(labels['query'], labels['sample0'])
outputs1 = reward_model(labels['query'], labels['sample1'])
differences = labels['difference']
predicted_differences = outputs1 - outputs0
error = tf.reduce_mean((differences - predicted_differences)**2, axis=0)
return dict(loss=error, error=error)
def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
return dict(
query=Schema(tf.int32, (query_length,)),
sample0=Schema(tf.int32, (response_length,)),
sample1=Schema(tf.int32, (response_length,)),
)
def get(label_type: str) -> LabelType:
if label_type == 'scalar_rating':
return ScalarRating()
if label_type == 'scalar_compare':
return ScalarComparison()
if label_type.startswith('best_of_'):
n = int(label_type[len('best_of_'):])
return PickBest(n)
raise ValueError(f"Unexpected label type {label_type}")