/
model.py
207 lines (177 loc) · 9.15 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
'''
Adapted from https://github.com/huggingface/transformers
'''
from transformers import T5Config, T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG
import copy
import os
import warnings
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutput,
Seq2SeqLMOutput,
)
class T5ForMultimodalGenerationMCCoT(T5ForConditionalGeneration):
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config, patch_size, padding_idx, save_dir,vot_num,alpha):
super().__init__(config)
self.model_dim = config.d_model
self.vot_num=vot_num
self.alpha=alpha
self.padding_idx = padding_idx
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.patch_num, self.patch_dim = patch_size
self.image_dense = nn.Linear(self.patch_dim, config.d_model)
self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True)
self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size)
self.sigmoid = nn.Sigmoid()
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
image_ids=None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
decoder_input_ids_base=decoder_input_ids
# decoder_attention_mask_base=decoder_attention_mask
decoder_inputs_embeds_base=decoder_inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=True, # 设置这里 output_attentions
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
image_embedding = self.image_dense(image_ids)
image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
merge = torch.cat([hidden_states, image_att], dim=-1)
gate = self.sigmoid(self.gate_dense(merge))
hidden_states = (1 - gate) * hidden_states + gate * image_att
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
all_logits = []
for _ in range(self.vot_num):
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
all_logits.append(lm_logits)
# voting
stacked_logits = torch.stack(all_logits, dim=0)
mean_logits = torch.mean(stacked_logits, dim=0)
stddev_logits = torch.std(stacked_logits, dim=0)
weights = 1 / (1 + stddev_logits)
weighted_mean_logits = torch.sum(weights * stacked_logits, dim=0) / torch.sum(weights, dim=0)
alpha = self.alpha
lm_logits = alpha * mean_logits + (1 - alpha) * weighted_mean_logits
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)