Permalink
Browse files

add temperature in kd doc classification (#199)

Summary:
Pull Request resolved: #199

reference implementation from https://arxiv.org/pdf/1503.02531.pdf
I'm actually not sure if we need to re-train the teacher using the same temperature. Quoting "In the simplest form of distillation, knowledge is transferred to the distilled model by training it on a transfer set and using a soft target distribution for each case in the transfer set that is produced by using the cumbersome model with a high temperature in its softmax. The same high temperature is used when training the distilled model, but after it has been trained it uses a temperature of 1."

Reviewed By: gardenia22

Differential Revision: D13605113

fbshipit-source-id: 04dfb2c857db4d41f1b039df4ecfcb9399926683
  • Loading branch information...
Haoran Li authored and facebook-github-bot committed Jan 11, 2019
1 parent ec4b851 commit 5d0e6af0092b7cb02cbc873e9733e47729b1e982
@@ -33,6 +33,7 @@ class ModelInput:

class Target:
DOC_LABEL = "doc_label"
TARGET_LOGITS_FIELD = "target_logit"
TARGET_PROB_FIELD = "target_prob"
TARGET_LABEL_FIELD = "target_label"

@@ -35,6 +35,7 @@ class RawData:
TEXT = "text"
DICT_FEAT = "dict_feat"
TARGET_PROBS = "target_probs"
TARGET_LOGITS = "target_logits"
TARGET_LABELS = "target_labels"


@@ -97,6 +98,7 @@ def from_config(
extra_fields: Dict[str, Field] = {ExtraField.RAW_TEXT: RawField()}
if target_config.target_prob:
target_fields[Target.TARGET_PROB_FIELD] = RawField()
target_fields[Target.TARGET_LOGITS_FIELD] = RawField()

if target_config.target_prob:
extra_fields[Target.TARGET_LABEL_FIELD] = RawField()
@@ -138,6 +140,9 @@ def preprocess_row(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
res[Target.TARGET_LABEL_FIELD] = parse_json_array(
row_data[RawData.TARGET_LABELS]
)
res[Target.TARGET_LOGITS_FIELD] = parse_json_array(
row_data[RawData.TARGET_LOGITS]
)

return res

@@ -149,7 +154,7 @@ def init_target_metadata(
):
# build vocabs for label fields
for name, label in self.labels.items():
if name in [Target.TARGET_PROB_FIELD]:
if name in [Target.TARGET_PROB_FIELD, Target.TARGET_LOGITS_FIELD]:
continue
# Need test data to make sure we cover all of the labels in it
# It is particularly important when BIO is enabled as a B-[Label] can
@@ -167,7 +172,7 @@ def init_target_metadata(
self.metadata.target = [
field.get_meta()
for name, field in self.labels.items()
if name not in [Target.TARGET_PROB_FIELD]
if name not in [Target.TARGET_PROB_FIELD, Target.TARGET_LOGITS_FIELD]
]
if len(self.metadata.target) == 1:
self.metadata.target = self.metadata.target[0]
@@ -202,7 +207,7 @@ def _target_from_batch(self, batch):
targets = []
for name in self.labels:
target = getattr(batch, name)
if name in [Target.TARGET_PROB_FIELD]:
if name in [Target.TARGET_PROB_FIELD, Target.TARGET_LOGITS_FIELD]:
target = self._align_target_label(target, label_list, batch_label_list)
targets.append(target)
if len(targets) == 1:
@@ -19,7 +19,13 @@ def setUp(self):
file_name = tests_module.test_file("knowledge_distillation_test_tiny.tsv")
label_config_dict = {"target_prob": True}
data_handler_dict = {
"columns_to_read": ["text", "target_probs", "target_labels", "doc_label"]
"columns_to_read": [
"text",
"target_probs",
"target_logits",
"target_labels",
"doc_label",
]
}
self.data_handler = KDDocClassificationDataHandler.from_config(
KDDocClassificationDataHandler.Config(**data_handler_dict),
@@ -37,21 +43,22 @@ def test_create_from_config(self):
expected_columns = [
RawData.TEXT,
RawData.TARGET_PROBS,
RawData.TARGET_LOGITS,
RawData.TARGET_LABELS,
RawData.DOC_LABEL,
]
# check that the list of columns is as expected
self.assertTrue(self.data_handler.raw_columns == expected_columns)

def test_read_from_file(self):
# Check if the data has 10 rows and 4 columns
# Check if the data has 10 rows and 5 columns
self.assertEqual(len(self.data), 10)
self.assertEqual(len(self.data[0]), 4)
self.assertEqual(len(self.data[0]), 5)

self.assertEqual(self.data[0][RawData.TEXT], "Who R U ?")
self.assertEqual(
self.data[0][RawData.TARGET_PROBS],
"[-0.001166735659353435, -6.743621826171875]",
"[-0.005602254066616297, -5.430975914001465]",
)
self.assertEqual(
self.data[0][RawData.TARGET_LABELS], '["cu:other", "cu:ask_Location"]'
@@ -64,7 +71,7 @@ def test_tokenization(self):
self.assertListEqual(data[0][ModelInput.WORD_FEAT], ["who", "r", "u", "?"])
self.assertListEqual(
data[0][Target.TARGET_PROB_FIELD],
[-0.001166735659353435, -6.743621826171875],
[-0.005602254066616297, -5.430975914001465],
)

def test_align_target_label(self):
@@ -228,20 +228,26 @@ def _prepare_labels_weights(logits, targets, weights=None):


class KLDivergenceBCELoss(Loss):
class Config(ConfigBase):
temperature: float = 1.0

def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs):
self.ignore_index = ignore_index
self.weight = weight
self.t = config.temperature

def __call__(self, logits, targets, reduce=True):
"""
Computes Kullback-Leibler divergence loss for multiclass classification
probability distribution computed by BinaryCrossEntropyLoss loss
"""
hard_targets, soft_targets = targets
hard_targets, _, soft_targets_logits = targets
# we clamp the probability between (1e-20, 1 - 1e-20) to avoid log(0) problem
# in the calculation of KLDivergence
soft_targets = FloatTensor(soft_targets).exp().clamp(1e-20, 1 - 1e-20)
probs = F.sigmoid(logits).clamp(1e-20, 1 - 1e-20)
soft_targets = F.sigmoid(FloatTensor(soft_targets_logits) / self.t).clamp(
1e-20, 1 - 1e-20
)
probs = F.sigmoid(logits / self.t).clamp(1e-20, 1 - 1e-20)
probs_neg = probs.neg().add(1).clamp(1e-20, 1 - 1e-20)
soft_targets_neg = soft_targets.neg().add(1).clamp(1e-20, 1 - 1e-20)
if self.weight is not None:
@@ -264,18 +270,24 @@ def __call__(self, logits, targets, reduce=True):


class KLDivergenceCELoss(Loss):
class Config(ConfigBase):
temperature: float = 1.0

def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs):
self.ignore_index = ignore_index
self.weight = weight
self.t = config.temperature

def __call__(self, logits, targets, reduce=True):
"""
Computes Kullback-Leibler divergence loss for multiclass classification
probability distribution computed by CrossEntropyLoss loss
"""
hard_targets, soft_targets = targets
soft_targets = FloatTensor(soft_targets).exp().clamp(1e-20, 1 - 1e-20)
log_probs = F.log_softmax(logits, 1)
hard_targets, _, soft_targets_logits = targets
soft_targets = F.softmax(FloatTensor(soft_targets_logits) / self.t).clamp(
1e-20, 1 - 1e-20
)
log_probs = F.log_softmax(logits / self.t, 1)
if self.weight is not None:
loss = F.kl_div(log_probs, soft_targets, reduction="none") * self.weight
if reduce:
@@ -288,16 +300,24 @@ def __call__(self, logits, targets, reduce=True):


class SoftHardBCELoss(Loss):
""" Reference implementation from Distilling the knowledge in a Neural Network:
https://arxiv.org/pdf/1503.02531.pdf
"""

class Config(ConfigBase):
temperature: float = 1.0

def __init__(self, config, ignore_index=-100, weight=None, *args, **kwargs):
self.ignore_index = ignore_index
self.weight = weight
self.config = config
self.t = config.temperature

def __call__(self, logits, targets, reduce=True):
"""
Computes soft and hard loss for knowledge distillation
"""
hard_targets, prob_targets = targets
hard_targets, _, _ = targets

# hard targets
one_hot_targets = (
@@ -320,4 +340,4 @@ def __call__(self, logits, targets, reduce=True):
hard_loss = F.binary_cross_entropy_with_logits(
logits, one_hot_targets, reduction="mean" if reduce else "none"
)
return prob_loss(logits, targets, reduce=reduce) + hard_loss
return self.t * self.t * prob_loss(logits, targets, reduce=reduce) + hard_loss
@@ -1,10 +1,10 @@
Who R U ? [-0.001166735659353435, -6.743621826171875] ["cu:other", "cu:ask_Location"] cu:other
You in the gym [-1.052109956741333, -0.4391290545463562] ["cu:other", "cu:ask_Location"] cu:ask_Location
look at me [-0.0476444847881794, -3.0614256858825684] ["cu:other", "cu:ask_Location"] cu:other
Move fast and ship love [-0.0006301801186054945, -7.4362711906433105] ["cu:other", "cu:ask_Location"] cu:other
At your house? [-1.0430370569229126, -0.4422380030155182] ["cu:other", "cu:ask_Location"] cu:ask_Location
Lmao [-2.9802276912960224e-06, -12.688275337219238] ["cu:other", "cu:ask_Location"] cu:other
Owwww Sweet [-0.00019667598826345056, -8.894574165344238] ["cu:other", "cu:ask_Location"] cu:other
You home yet lol [-1.4669092893600464, -0.27201464772224426] ["cu:other", "cu:ask_Location"] cu:ask_Location
Are you home? [-3.8916919231414795, -0.02050662972033024] ["cu:other", "cu:ask_Location"] cu:ask_Location
Where you at? [-1.9239834547042847, -0.1610211282968521] ["cu:other", "cu:ask_Location"] cu:ask_Location
Who R U ? [-0.005602254066616297, -5.430975914001465] [5.181787490844727, -5.4265875816345215] ["cu:other", "cu:ask_Location"] cu:other
You in the gym [-1.8341524600982666, -0.1862216591835022] [-1.6600980758666992, 1.5862623453140259] ["cu:other", "cu:ask_Location"] cu:ask_Location
look at me [-0.0028048718813806772, -5.696012496948242] [5.874996662139893, -5.692647457122803] ["cu:other", "cu:ask_Location"] cu:other
Move fast and ship love [-0.0013048476539552212, -6.862490653991699] [6.641023635864258, -6.861443996429443] ["cu:other", "cu:ask_Location"] cu:other
At your house? [-3.073305130004883, -0.038981884717941284] [-3.025932550430298, 3.225104808807373] ["cu:other", "cu:ask_Location"] cu:ask_Location
Lmao [-2.312633478140924e-05, -10.719563484191895] [10.675718307495117, -10.719541549682617] ["cu:other", "cu:ask_Location"] cu:other
Owwww Sweet [-0.038656335324048996, -3.3673951625823975] [3.2336537837982178, -3.3323073387145996] ["cu:other", "cu:ask_Location"] cu:other
You home yet lol [-2.0499677658081055, -0.12795871496200562] [-1.912153959274292, 1.9913861751556396] ["cu:other", "cu:ask_Location"] cu:ask_Location
Are you home? [-5.4780426025390625, -0.003714567981660366] [-5.473856449127197, 5.593633651733398] ["cu:other", "cu:ask_Location"] cu:ask_Location
Where you at? [-5.395582675933838, -0.004174210596829653] [-5.391035556793213, 5.4767537117004395] ["cu:other", "cu:ask_Location"] cu:ask_Location

0 comments on commit 5d0e6af

Please sign in to comment.