Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Oct 28, 2021
1 parent dfae213 commit e61b580
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
24 changes: 12 additions & 12 deletions src/transformers/models/beit/configuration_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ class BeitConfig(PretrainedConfig):
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to use an auxiliary head during training.
loss_weight (:obj:`float`, `optional`, defaults to 0.4):
auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4):
Weight of the cross-entropy loss of the auxiliary head.
channels (:obj:`int`, `optional`, defaults to 256):
auxiliary_channels (:obj:`int`, `optional`, defaults to 256):
Number of channels to use in the auxiliary head.
num_convs (:obj:`int`, `optional`, defaults to 1):
auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1):
Number of convolutional layers to use in the auxiliary head.
concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
Example::
Expand Down Expand Up @@ -134,10 +134,10 @@ def __init__(
out_indices=[3, 5, 7, 11],
pool_scales=[1, 2, 3, 6],
use_auxiliary_head=True,
loss_weight=0.4,
channels=256,
num_convs=1,
concat_input=False,
auxiliary_loss_weight=0.4,
auxiliary_channels=256,
auxiliary_num_convs=1,
auxiliary_concat_input=False,
**kwargs
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
self.pool_scales = pool_scales
# auxiliary head attributes (semantic segmentation)
self.use_auxiliary_head = use_auxiliary_head
self.loss_weight = loss_weight
self.channels = channels
self.num_convs = num_convs
self.concat_input = concat_input
self.auxiliary_loss_weight = auxiliary_loss_weight
self.auxiliary_channels = auxiliary_channels
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
8 changes: 4 additions & 4 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,9 @@ class BeitFCNHead(nn.Module):
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
super().__init__()
self.in_channels = config.hidden_size
self.channels = config.channels
self.num_convs = config.num_convs
self.concat_input = config.concat_input
self.channels = config.auxiliary_channels
self.num_convs = config.auxiliary_num_convs
self.concat_input = config.auxiliary_concat_input
self.in_index = in_index

conv_padding = (kernel_size // 2) * dilation
Expand Down Expand Up @@ -1109,7 +1109,7 @@ def compute_loss(self, logits, auxiliary_logits, labels):
loss_fct = CrossEntropyLoss(ignore_index=255)
main_loss = loss_fct(upsampled_logits, labels)
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
loss = main_loss + self.config.loss_weight * auxiliary_loss
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss

return loss

Expand Down
2 changes: 0 additions & 2 deletions tests/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""

maxDiff = None

all_model_classes = (
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
if is_torch_available()
Expand Down

0 comments on commit e61b580

Please sign in to comment.