In [1]:
from transformers import CLIPProcessor, CLIPModel
import torch
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")



In [4]:
import sys
sys.path.append('../')
from mit_states_dataset import MIT_states
from torch.utils.data import DataLoader
dataset_args = {
                 'root': '../../data',
                 'transformation': True,
                 'crop_size': 224,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': .4, 
                 'gray_scale_prob': 0.2, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': .5, 
                 'min_scale': 0.6, 
                 'max_scale': 0.95}
train_dataset = MIT_states(dataset_args['root'], dataset_args, download=False, transformation=dataset_args['transformation'])
dataloader = DataLoader(
    train_dataset,
    batch_size=3,
    num_workers=1,
    pin_memory=False,
    drop_last=True,
    persistent_workers=True
)

Transforming in training loop True


In [6]:
train_dataset.adjectives

{0: 'adj',
 1: 'ancient',
 2: 'barren',
 3: 'bent',
 4: 'blunt',
 5: 'bright',
 6: 'broken',
 7: 'browned',
 8: 'brushed',
 9: 'burnt',
 10: 'caramelized',
 11: 'chipped',
 12: 'clean',
 13: 'clear',
 14: 'closed',
 15: 'cloudy',
 16: 'cluttered',
 17: 'coiled',
 18: 'cooked',
 19: 'cored',
 20: 'cracked',
 21: 'creased',
 22: 'crinkled',
 23: 'crumpled',
 24: 'crushed',
 25: 'curved',
 26: 'cut',
 27: 'damp',
 28: 'dark',
 29: 'deflated',
 30: 'dented',
 31: 'diced',
 32: 'dirty',
 33: 'draped',
 34: 'dry',
 35: 'dull',
 36: 'empty',
 37: 'engraved',
 38: 'eroded',
 39: 'fallen',
 40: 'filled',
 41: 'foggy',
 42: 'folded',
 43: 'frayed',
 44: 'fresh',
 45: 'frozen',
 46: 'full',
 47: 'grimy',
 48: 'heavy',
 49: 'huge',
 50: 'inflated',
 51: 'large',
 52: 'lightweight',
 53: 'loose',
 54: 'mashed',
 55: 'melted',
 56: 'modern',
 57: 'moldy',
 58: 'molten',
 59: 'mossy',
 60: 'muddy',
 61: 'murky',
 62: 'narrow',
 63: 'new',
 64: 'old',
 65: 'open',
 66: 'painted',
 67: 'peeled',
 68: '

In [2]:
from transformers.models.clip.modeling_clip import *
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

class CLIPTextTransformer(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        self.embeddings = CLIPTextEmbeddings(config)
        self.encoder = CLIPEncoder(config)
        self.final_layer_norm = nn.LayerNorm(embed_dim)

    @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        Returns:
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is None:
            raise ValueError("You have to specify either input_ids")

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
        print(hidden_states.shape)

        bsz, seq_len = input_shape
        # CLIP's text model uses causal mask, prepare it here.
        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
        causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
            hidden_states.device
        )
        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, hidden_states.dtype)

        print(hidden_states.shape)

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        pooled_output = last_hidden_state[
            torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
        ]

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def _build_causal_attention_mask(self, bsz, seq_len, dtype):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
        mask.fill_(torch.tensor(torch.finfo(dtype).min))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask

In [19]:

class newCLIPTextTransformer(CLIPTextTransformer):
    def __init__(self, old_text_transformer):
        super().__init__(old_text_transformer.config)
        self.embeddings = torch.nn.Embedding(100, 512)
        self.old_text_transformer = old_text_transformer
        self.get_all = torch.LongTensor([i for i in range(100)]).cuda()
        self.first_dim = 10
        self.second_dim = 10

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        Returns:
        """
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is None:
            raise ValueError("You have to specify either input_ids")
        input_ids = torch.ones(self.first_dim, self.second_dim).cuda()
        attention_mask = torch.ones(self.first_dim, self.second_dim).cuda()
        input_shape = (self.first_dim, self.second_dim)
        input_ids = input_ids.view(-1, input_shape[-1])

        

        bsz, seq_len = input_shape
        hidden_states = self.embeddings(self.get_all).view(self.first_dim, self.second_dim, -1)
        
        # CLIP's text model uses causal mask, prepare it here.
        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
        causal_attention_mask = self.old_text_transformer._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
            hidden_states.device
        )
        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, hidden_states.dtype)

        encoder_outputs = self.old_text_transformer.encoder(
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.old_text_transformer.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        pooled_output = last_hidden_state[
            torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
        ]

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
images = torch.randn(2, 3, 224, 224)
text = ['g '*10,] * 10
inputs = processor(text=text, return_tensors="pt", padding=True)
for i in inputs:
    inputs[i] = inputs[i].cuda()
inputs['pixel_values'] = images.cuda()

text_transformer = CLIPTextTransformer(clip.config.text_config)
text_transformer = text_transformer.cuda()
new_transformer = newCLIPTextTransformer(text_transformer)
new_transformer = new_transformer.cuda()
del inputs['pixel_values']
out = new_transformer(**inputs)
