Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Accept comments

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and Sebastien Ehrhardt committed May 14, 2024
1 parent 22cde7d commit 38642b5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 34 deletions.
9 changes: 2 additions & 7 deletions src/transformers/models/vit_msn/modeling_vit_msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
patch_window_width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
patch_window_height, patch_window_width = (
patch_window_height + 0.1,
patch_window_width + 0.1,
)
patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
Expand Down Expand Up @@ -599,9 +596,7 @@ def forward(
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
pixel_values,
bool_masked_pos=bool_masked_pos,
interpolate_pos_encoding=interpolate_pos_encoding,
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)

encoder_outputs = self.encoder(
Expand Down
34 changes: 7 additions & 27 deletions src/transformers/models/yolos/modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,9 @@ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
patch_pos_embed = patch_pos_embed.view(batch_size, hidden_size, patch_height, patch_width)

height, width = img_size
new_patch_heigth, new_patch_width = (
height // self.config.patch_size,
width // self.config.patch_size,
)
new_patch_heigth, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_patch_heigth, new_patch_width),
mode="bicubic",
align_corners=False,
patch_pos_embed, size=(new_patch_heigth, new_patch_width), mode="bicubic", align_corners=False
)
patch_pos_embed = patch_pos_embed.flatten(2).transpose(1, 2)
scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=1)
Expand All @@ -205,15 +199,9 @@ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
)
patch_pos_embed = patch_pos_embed.view(depth * batch_size, hidden_size, patch_height, patch_width)
height, width = img_size
new_patch_height, new_patch_width = (
height // self.config.patch_size,
width // self.config.patch_size,
)
new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_patch_height, new_patch_width),
mode="bicubic",
align_corners=False,
patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False
)
patch_pos_embed = (
patch_pos_embed.flatten(2)
Expand Down Expand Up @@ -756,16 +744,10 @@ def __init__(self, config: YolosConfig):
# Object detection heads
# We add one for the "no object" class
self.class_labels_classifier = YolosMLPPredictionHead(
input_dim=config.hidden_size,
hidden_dim=config.hidden_size,
output_dim=config.num_labels + 1,
num_layers=3,
input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=config.num_labels + 1, num_layers=3
)
self.bbox_predictor = YolosMLPPredictionHead(
input_dim=config.hidden_size,
hidden_dim=config.hidden_size,
output_dim=4,
num_layers=3,
input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=4, num_layers=3
)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -857,9 +839,7 @@ def forward(
if labels is not None:
# First: create the matcher
matcher = YolosHungarianMatcher(
class_cost=self.config.class_cost,
bbox_cost=self.config.bbox_cost,
giou_cost=self.config.giou_cost,
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
)
# Second: create the criterion
losses = ["labels", "boxes", "cardinality"]
Expand Down

0 comments on commit 38642b5

Please sign in to comment.