<a href="https://colab.research.google.com/github/ProfEddie/HypCLIP/blob/perceiver/lab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install -U salesforce-lavis

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6
Collecting salesforce-lavis
  Downloading salesforce_lavis-1.0.2-py3-none-any.whl (1.8 MB

In [None]:
import torch
import torch.nn as nn
import numpy as np
import math


def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """

    x_shape = x.shape
    # print(x_shape)
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)
    # print('x', x.shape)
    # print(x)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
    # print('v', v.shape)
    # print(v)
    Vc = torch.fft.fft(v, dim=1)
    # print('vc', Vc.shape)
    # print(Vc)

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    # print('k', k.shape)
    # print(k)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc.real * W_r - Vc.imag * W_i
    # print('V', V.shape)
    # print(V)

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)
    # print('V final', V.shape)
    # print(V)

    return V


def idct(X, norm=None):
    """
    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
    Our definition of idct is that idct(dct(x)) == x
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the inverse DCT-II of the signal over the last dimension
    """

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2
    if norm == 'ortho':
        X_v[:, 0] *= np.sqrt(N) * 2
        X_v[:, 1:] *= np.sqrt(N / 2) * 2

    k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V_t_r = X_v
    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
    V = torch.view_as_complex(V)

    v = torch.fft.ifft(V, dim=1).real
    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)



def dc_transform(x, r=0.8):
    # cufft doesn't accept fp16
    # dct along T dimension
    print('original', x.shape)
    x_dct = dct(x.transpose(0,2), norm='ortho').transpose(0,2)
    print()
    T, B, C = x_dct.size()
    print('dct', x_dct.shape)

    # feel free to play with any method here
    x_dct = x_dct[:math.ceil(T* r), :, :]

    return idct(x_dct.transpose(0,2), norm='ortho').transpose(0,2)




# plot_hidden_states(vision_output.hidden_states)

# vision_output.hidden_states[-1][:,0,:].shape

In [None]:
from PIL import Image
import requests

from transformers import AutoProcessor, AutoModel

# model_ckt =  "openai/clip-vit-large-patch14"
# model_ckt =  "openai/clip-vit-base-patch16"
# model_ckt =  "openai/clip-vit-base-patch32"


import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from transformers import (
    BlipPreTrainedModel,
    BlipConfig,
    BlipVisionModel,
    BlipTextModel,
)
from transformers.models.blip.modeling_blip import BlipImageTextMatchingModelOutput
import torch
import matplotlib.pyplot as plt


class DCTBlipForImageTextRetrieval(BlipPreTrainedModel):
    config_class = BlipConfig

    def __init__(self, config: BlipConfig):
        super().__init__(config)

        self.vision_model = BlipVisionModel(config.vision_config)

        self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)

        self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)

        self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)

        self.itm_head = nn.Linear(config.text_config.hidden_size, 2)

        self.decoder_pad_token_id = (
            config.text_config.pad_token_id
            if not hasattr(config, "decoder_pad_token_id")
            else config.decoder_pad_token_id
        )
        self.decoder_start_token_id = (
            config.text_config.bos_token_id
            if not hasattr(config, "decoder_start_token_id")
            else config.decoder_start_token_id
        )

        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        use_itm_head: Optional[bool] = True,
        attention_mask: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        image_embeds = vision_outputs[0]
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)

        if use_itm_head:
            question_embeds = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=return_dict,
            )
            question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state

            output = self.itm_head(question_embeds[:, 0, :])
        else:
            question_embeds = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=return_dict,
            )
            question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state

            image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
            text_feat = F.normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)

            output = image_feat @ text_feat.t()

        if not return_dict:
            outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,)
            return tuple(output for output in outputs if output is not None)

        return BlipImageTextMatchingModelOutput(
            itm_score=output,
            last_hidden_state=vision_outputs.last_hidden_state,
            hidden_states=vision_outputs.hidden_states,
            attentions=vision_outputs.attentions,
            question_embeds=question_embeds,
        )

    def get_vision_features(self,pixel_values):
        state = self.vision_model.embeddings(pixel_values)
        # state = self.vision_model.pre_layrnorm(state)
        hidden_states = []
        dct_signals = []
        hidden_states.append(state)

        for layer in self.vision_model.encoder.layers:
            state = layer(state, None, None)[0]
            dct_signals.append(dct(state[:,1:,:].permute(2,0,1)).transpose(0,2))

            # cls = state[:, 0, :].unsqueeze(1)
            # state = dc_transform(state[:,1:,:].permute(1,0,2), r=0.9).permute(1,0,2)
            # state = torch.cat([cls, state], dim=1)
            # state = dc_transform(state.permute(1,0,2)).permute(1,0,2)
            hidden_states.append(state)

        last_hidden_state = self.vision_model.post_layernorm(state)

        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.vision_model.post_layernorm(pooled_output)
        return last_hidden_state, pooled_output, hidden_states, dct_signals

    def get_text_features(self, input_ids, attention_mask):
        question_embeds = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        question_embeds = question_embeds[0]
        return  question_embeds







In [None]:
from transformers import BlipModel
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_ckt =  "Salesforce/blip-itm-base-flickr"

processor = AutoProcessor.from_pretrained(model_ckt)
model = BlipModel.from_pretrained(model_ckt).to(device)
dct_model = DCTBlipForImageTextRetrieval.from_pretrained(model_ckt).to(device)

vision_model = model.vision_model
text_model = model.text_model

# outputs = model(**inputs, output_hidden_states=True)

Some weights of the model checkpoint at Salesforce/blip-itm-base-flickr were not used when initializing BlipModel: ['text_encoder.encoder.layer.0.crossattention.self.value.bias', 'text_encoder.encoder.layer.0.output.dense.bias', 'text_encoder.encoder.layer.5.attention.self.key.bias', 'text_encoder.encoder.layer.1.crossattention.self.value.bias', 'text_encoder.encoder.layer.11.attention.self.value.weight', 'text_proj.bias', 'text_encoder.encoder.layer.2.crossattention.self.query.weight', 'text_encoder.encoder.layer.6.output.LayerNorm.weight', 'text_encoder.encoder.layer.11.crossattention.self.query.weight', 'text_encoder.encoder.layer.11.output.dense.weight', 'text_encoder.encoder.layer.10.attention.output.dense.weight', 'text_encoder.encoder.layer.8.crossattention.output.LayerNorm.bias', 'text_encoder.encoder.layer.1.crossattention.self.key.bias', 'text_encoder.encoder.layer.3.intermediate.dense.bias', 'text_encoder.encoder.layer.3.attention.self.value.weight', 'text_encoder.encoder.la

In [None]:
from datasets import load_dataset
flickr = load_dataset("nlphuji/flickr30k")['test']
print(flickr[:2]['caption'])

[['Two young guys with shaggy hair look at their hands while hanging out in the yard.', 'Two young, White males are outside near many bushes.', 'Two men in green shirts are standing in a yard.', 'A man in a blue shirt standing in a garden.', 'Two friends enjoy time spent together.'], ['Several men in hard hats are operating a giant pulley system.', 'Workers look down from up above on a piece of equipment.', 'Two men working on a machine wearing hard hats.', 'Four men on top of a tall structure.', 'Three men on a large rig.']]




In [None]:
inputs = processor(text=[
    "Two young guys with shaggy hair look at their hands while hanging out in the yard.",
    "Two men in green shirts are standing in a yard",
    "Two young, White males are outside near many bushes.",
    "Workers look down from up above on a piece of equipment.",
    "Four men on top of a tall structure"
],images=flickr[:2]['image'], return_tensors="pt", padding=True)

In [None]:
vision_output = vision_model(inputs['pixel_values'].to(device), output_hidden_states=True)
text_output = text_model(input_ids=inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device), output_hidden_states=True)

vis_embed = F.normalize(model.visual_projection(vision_output[1]), dim=-1)

text_embed = F.normalize(model.text_projection(text_output[1]), dim=-1)
vis_embed @ text_embed.T

tensor([[0.0059, 0.0259, 0.0263, 0.0231, 0.0124],
        [0.0146, 0.0241, 0.0317, 0.0280, 0.0317]], grad_fn=<MmBackward0>)

In [None]:

def plot_hidden_states(hidden_states):
  for hidden_state in hidden_states:
    hidden_state=hidden_state.permute(1,0,2)
    x_dct = dct(hidden_state.transpose(0,2), norm='ortho').transpose(0,2)
    numpy_array = (torch.abs(x_dct.permute(1,0,2).mean(0).mean(1))**2).cpu().detach().numpy()
    plt.figure(figsize=(10, 2))

    # plt.imshow(numpy_array, cmap='viridis')  # You can choose a different colormap
    plt.plot(numpy_array)# You can choose a different colormap
    # plt.colorbar()
    plt.show()



In [None]:
vision_output.hidden_states[0].shape

torch.Size([2, 577, 768])

In [None]:
plot_hidden_states(vision_output.hidden_states)

In [None]:
from lavis import BlipRetrieval

class DCTLAVISBlip(nn.Module):
    config_class = BlipConfig

    def __init__(self, model:BlipRetrieval):
        super().__init__()

        self.vision_modejl = model.visual_encoder

        self.text_encoder = model.text_encoder

        self.vision_proj = model.vision_proj

        self.text_proj = model.text_proj
        self.r_list = nn.ParameterList([
            1.0, 1.0, 1.0, 1.0, 1.0, 0.8,
            0.8, 0.9, 0.9, 0.9, 1.0, 1.0,
        ])


    def forward(
        self,
        input_ids: torch.LongTensor=None,
        pixel_values: torch.FloatTensor=None,
        attention_mask: Optional[torch.LongTensor] = None,
        apply_fourier: Optional[torch.LongTensor] = True,

    ):
        if input_ids is not None:
            return self.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
        else:
            return self.get_vision_features(pixel_values=pixel_values, apply_fourier=apply_fourier)


    def get_vision_features(self, pixel_values, apply_fourier=True):
        B = pixel_values.shape[0]
        hidden_states = []
        x = self.vision_model.patch_embed(pixel_values)

        cls_tokens = self.vision_model.cls_token.expand(
            B, -1, -1
        )
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.vision_model.pos_embed[:, : x.size(1), :]
        x = self.vision_model.pos_drop(x)

        for i, blk in enumerate(self.vision_model.blocks):
            x = blk(x)
            cls = x[:, 0, :].unsqueeze(1)
            state = dc_transform(x[:,1:,:].permute(1,0,2), r=(self.r_list[i] if (self.training or apply_fourier) else 1.0)).permute(1,0,2)
            x = torch.cat([cls, state], dim=1)
            hidden_states.append(x)
        x = self.vision_model.norm(x)

        vision_embed = self.vision_proj(x[:,0,:])
        return x, vision_embed

    def get_text_features(self, input_ids, attention_mask):
        with torch.no_grad():
            class Text(object):
                pass
            text = Text()
            text.input_ids=input_ids
            text.attention_mask=attention_mask
            question_embeds = self.text_encoder.forward_text(text)
            last_hidden_state = question_embeds[0]
            text_embed = self.text_proj(last_hidden_state[:,0,:])

            return  last_hidden_state, text_embed


In [None]:
from lavis.models import load_model_and_preprocess
model, vis_processors, txt_processors = load_model_and_preprocess("blip_retrieval", "flickr", is_eval=False)
model = DCTLAVISBlip(model)