-
Notifications
You must be signed in to change notification settings - Fork 201
/
BEiT3.py
96 lines (87 loc) · 3.19 KB
/
BEiT3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.component.embedding import (
PositionalEmbedding,
TextEmbedding,
VisionEmbedding,
)
from torchscale.component.multiway_network import MutliwayEmbedding
class BEiT3(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
assert args.multiway
assert args.vocab_size > 0
assert not args.share_encoder_input_output_embed
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
self.vision_embed = VisionEmbedding(
args.img_size,
args.patch_size,
args.in_chans,
args.encoder_embed_dim,
contain_mask_token=True,
prepend_cls_token=True,
)
# being consistent with Fairseq, which starts from 2 for position embedding
embed_positions = MutliwayEmbedding(
modules=[
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
],
dim=1,
)
self.encoder = Encoder(
args,
embed_tokens=None,
embed_positions=embed_positions,
output_projection=None,
is_encoder_decoder=False,
)
def forward(
self,
textual_tokens=None,
visual_tokens=None,
text_padding_position=None,
attn_mask=None,
vision_masked_position=None,
incremental_state=None,
positions=None,
):
assert textual_tokens is not None or visual_tokens is not None
if textual_tokens is None:
x = self.vision_embed(visual_tokens, vision_masked_position)
encoder_padding_mask = None
multiway_split_position = -1
elif visual_tokens is None:
x = self.text_embed(textual_tokens)
encoder_padding_mask = text_padding_position
multiway_split_position = 0
else:
x1 = self.vision_embed(visual_tokens, vision_masked_position)
multiway_split_position = x1.size(1)
x2 = self.text_embed(textual_tokens)
x = torch.cat([x1, x2], dim=1)
if text_padding_position is not None:
encoder_padding_mask = torch.cat(
[
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
text_padding_position,
],
dim=1,
)
else:
encoder_padding_mask = None
encoder_out = self.encoder(
src_tokens=None,
encoder_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
token_embeddings=x,
multiway_split_position=multiway_split_position,
incremental_state=incremental_state,
positions=positions,
)
encoder_out["multiway_split_position"] = multiway_split_position
return encoder_out