# WeCLIP

In [4]:
import os
os.chdir(os.path.dirname(os.getcwd())) # to avoid import error with other directories 

In [5]:
from WeCLIP.WeCLIP_model.Decoder.TransDecoder import DecoderTransformer
from WeCLIP.clip.clip import load # works pretty much like original CLIP implementation, but returns multi-level features and attention maps

import torch.nn.functional as F
from torch import nn
import torch

from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

## get feature maps and weights

In [6]:
def encode_image(self, image, H, W, require_all_fts=False):
        f_x, f_attn = self.visual(image.type(self.dtype), H, W, require_all_fts=require_all_fts)
        # f = self.visual(image.type(self.dtype), H, W, require_all_fts=require_all_fts)
        return f_x, f_attn

def upsample_pos_emb(emb, new_size):
    # upsample the pretrained embedding for higher resolution
    # emb size NxD
    first = emb[:1, :]
    emb = emb[1:, :]
    N, D = emb.size(0), emb.size(1)
    size = int(np.sqrt(N))
    assert size * size == N
    #new_size = size * self.upsample
    emb = emb.permute(1, 0)
    emb = emb.view(1, D, size, size).contiguous()
    emb = F.upsample(emb, size=new_size, mode='bilinear',)
    emb = emb.view(D, -1).contiguous()
    emb = emb.permute(1, 0)
    emb = torch.cat([first, emb], 0)
    emb = nn.parameter.Parameter(emb.half())
    return emb

def generate_clip_fts(image, model, require_all_fts=True):
    model = model.cuda()

    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    h, w = image.shape[-2], image.shape[-1]
    image = image.cuda()
    
    image_features_all, attn_weight_list = model.encode_image(image, h, w, require_all_fts=require_all_fts)
        
    return image_features_all, attn_weight_list

In [7]:
vit_b_16_clip_pretrained = 'ViT-B/16'
encoder, preprocess = load(vit_b_16_clip_pretrained, device="cuda")

In [None]:
dog_pic_path = r"C:\01_Learning\01_Data_science\01_University\01_UniTrento\01_Classes\Semester\3\Advanced_CV\Code\dog_pic.jpg" # path to one of your images
dog_pic = preprocess(Image.open(dog_pic_path)).unsqueeze(0).to("cuda")

b, c, h, w = dog_pic.shape
new_size = (h//16,w//16)

In [9]:
positional_embedding_new = upsample_pos_emb(encoder.visual.positional_embedding, new_size)
positional_embedding_new.shape



torch.Size([197, 768])

In [10]:
x = encoder.visual.conv1(dog_pic.type(encoder.dtype)) # patchify
x.shape

torch.Size([1, 768, 14, 14])

In [11]:
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
x.shape # flatten the patches and permute to (N, Patches, Embedding_dim) 

torch.Size([1, 196, 768])

In [12]:
# equivalent to (append cls token):
# x = torch.cat([encoder.visual.class_embedding.expand(x.shape[0], 1, -1), x], dim=1).dtype(x.dtype), but don't know if lose info this way
x = torch.cat([encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) 
x.shape

torch.Size([1, 197, 768])

In [13]:
x = x + positional_embedding_new # in the original implementation encoder.visual.positional_embedding_new
x = encoder.visual.ln_pre(x) # layer norm
x.shape

torch.Size([1, 197, 768])

In [14]:
x = x.permute(1,0,2) # needed to pass to encoder.visual.transformer
x.shape

torch.Size([197, 1, 768])

encoder.visual.transformer.resblocks.forward
```
attn_output, attn_weight = self.attention(self.ln_1(x))#(L,N,E)  (N,L,L)
        x = x + attn_output
        x = x + self.mlp(self.ln_2(x)) # linear 768 -> 3072, QuickGELU(), 3072 -> 768
        return x, attn_weight
```
self.attention() (since VisualTransformer self.attn_mask is None):
```
def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 
        return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
```
attn_weights: set of probabilities that indicate the "importance" or "relevance" of each input token to every other token in the sequence (row-wise sum to 1) \
attn_output: feature map from a transformer block

In [15]:
# forward encoder.visual.transformer
# x_all, attn_weights = encoder.visual(dog_pic.type(encoder.dtype), dog_pic.shape[-2], dog_pic.shape[-1], require_all_fts=False)

attn_weights = []
x_all = []
layers = encoder.visual.transformer.layers if x.shape[0] == 77 else encoder.visual.transformer.layers-1
for i in range(layers):
    x, attn_weight = encoder.visual.transformer.resblocks[i](x) # 
    x_all.append(x)
    attn_weights.append(attn_weight)

In [16]:
# all the previous steps are equivalent to:
image_features_all, attn_weight_list = generate_clip_fts(dog_pic, encoder, require_all_fts=True)

In [17]:
# length of list = num of residual attention blocks - 1 (not understood why), but still one for each block
# multihead attention + layer norm + MLP + layer norm
print(attn_weight_list[0].shape, image_features_all[0].shape, len(image_features_all), len(attn_weight_list))

torch.Size([1, 197, 197]) torch.Size([197, 1, 768]) 11 11


In [18]:
fts_all_stack = torch.stack(image_features_all,dim=0)
fts_all_stack.shape

torch.Size([11, 197, 1, 768])

## decoder

In [19]:
all_img_tokens = fts_all_stack[:, 1:, ...] # remove the class token
all_img_tokens.shape

torch.Size([11, 196, 1, 768])

In [20]:
img_tokens_channel = all_img_tokens.size(-1)
img_tokens_channel # get embedding dimension

768

In [21]:
all_img_tokens = all_img_tokens.permute(0, 2, 3, 1)
all_img_tokens.shape

torch.Size([11, 1, 768, 196])

In [22]:
all_img_tokens = all_img_tokens.reshape(-1, b, img_tokens_channel, h // 16, w // 16)
all_img_tokens.shape # get back patches

torch.Size([11, 1, 768, 14, 14])

### fusion step of image features
- number of trainable parameters of ``SegFormerHead``
    - ``linear``: 2.89 M
    - ``Conv2d``: 0.72 M
    - ``total`` : 3.61 M

In [23]:
class MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)
        self.proj_2 = nn.Linear(embed_dim, embed_dim)
        # self.proj_3 = nn.Linear(embed_dim*2, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        x = F.relu(x)
        x = self.proj_2(x)
        return x

In [24]:
class SegFormerHead(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self, in_channels=128, embedding_dim=256, num_classes=20, index=11, **kwargs):
        super(SegFormerHead, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.indexes = index #6 #11

        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        linear_layers = [MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) for i in range(self.indexes)]
        self.linears_modulelist = nn.ModuleList(linear_layers)

        self.linear_fuse = nn.Conv2d(embedding_dim*self.indexes, embedding_dim, kernel_size=1)
        self.dropout = nn.Dropout2d(0.1)


    def forward(self, x_all):
        x_list = []
        for ind in range(x_all.shape[0]):
            x = x_all[ind,:, :, :, :]
            n, _, h, w = x.shape
            _x = self.linears_modulelist[ind](x.float()).permute(0,2,1).reshape(n, -1, x.shape[2], x.shape[3])
            x_list.append(_x)
        x_list = torch.cat(x_list, dim=1)
        x = self.linear_fuse(x_list)
        x = self.dropout(x)

        return x

In [25]:
all_img_tokens_emb_dim = all_img_tokens.shape[2]
num_feature_maps = all_img_tokens.shape[0]
output_embedding_dim = 256 # will also be used to define the width of DecoderTransformer

In [26]:
linear_layers = [MLP(input_dim=all_img_tokens_emb_dim, embed_dim=output_embedding_dim) for i in range(num_feature_maps)]
linears_modulelist = nn.ModuleList(linear_layers).to("cuda")
linears_modulelist

ModuleList(
  (0-10): 11 x MLP(
    (proj): Linear(in_features=768, out_features=256, bias=True)
    (proj_2): Linear(in_features=256, out_features=256, bias=True)
  )
)

In [27]:
# merge the output of the MLPs
linear_fuse = nn.Conv2d(output_embedding_dim*num_feature_maps, output_embedding_dim, kernel_size=1).to("cuda")
linear_fuse

Conv2d(2816, 256, kernel_size=(1, 1), stride=(1, 1))

In [28]:
all_img_tokens[0,...].shape

torch.Size([1, 768, 14, 14])

In [29]:
# performed internally by MLP forward
all_img_tokens[0,...].float().flatten(2).transpose(1,2).shape

torch.Size([1, 196, 768])

In [30]:
# This operation is performed independently for each of the 196 vectors in the sequence
linears_modulelist[0](all_img_tokens[0,...].float()).shape

torch.Size([1, 196, 256])

In [31]:
x_list = []
for ind in range(all_img_tokens.shape[0]):
    x = all_img_tokens[ind,:, :, :, :]
    n, _, h, w = x.shape
    _x = linears_modulelist[ind](x.float()).permute(0,2,1).reshape(n, -1, x.shape[2], x.shape[3])
    x_list.append(_x)
x_list = torch.cat(x_list, dim=1)
x_list.shape

torch.Size([1, 2816, 14, 14])

In [32]:
fused_x = linear_fuse(x_list)
fused_x.shape

torch.Size([1, 256, 14, 14])

previous steps equivalent to the following:

In [33]:
# equivalentl to:
all_img_tokens_emb_dim = all_img_tokens.shape[2]
num_feature_maps = all_img_tokens.shape[0]
num_classes = 20 
decorder_fts_fuse = SegFormerHead(in_channels=[all_img_tokens_emb_dim, all_img_tokens_emb_dim, all_img_tokens_emb_dim, all_img_tokens_emb_dim],
                                  embedding_dim=output_embedding_dim, # output embedding dimension 
                                  num_classes=num_classes, # doesn't have any influence on the output here
                                  index=num_feature_maps).to("cuda")
# fuse the features from the encoder
fts = decorder_fts_fuse(all_img_tokens)
fts.shape

torch.Size([1, 256, 14, 14])

### decode segmentation

Takes compressed feature maps, flattens spatial dimensions ``(b, output_embedding_dim, h*w)`` , feeds the data to 3 transformer blocks, reshapes the output ``(b, output_embedding_dim, h, w)`` and pass it to a ``Conv2d -> [b, num_classes, h, w]`` 
- heads: Number of parallel attention heads. each head will have dimension embed_dim // num_heads -> \
``256 // 8 = 32 == decoder_transform.transformer.resblocks[0].attn.head_dim``
- ``decoder_transform.transformer``: transformer layers just like the ones in CLIP
- ``decoder_transform.linear_pred``: convolution maps ``embedding -> num_classes``
- total number of trainable parameters: 2.37 M

In [34]:
decoder_transform = DecoderTransformer(width = output_embedding_dim, # 256
                                       layers = 3, # as suggested in the paper, >3 overfits
                                       heads = 8, 
                                       output_dim = num_classes # same fo SegFormerHead
                                       ).to("cuda")

In [35]:
# seg_attn_weight_list not used in the repo 
seg, seg_attn_weight_list = decoder_transform(fts)

In [36]:
# this is the output of the decoder that will be compared to the pseudo label after using bilinear interpolation (upsampling) to align the spatial dimensions
# the upsampling is performed in the training loop e.g. see WeCLIP/scripts/dist_clip_coco.py line 251
seg.shape 

torch.Size([1, 20, 14, 14])

#### visualize seg pathces (not in the repo)

In [37]:
seg_plots = (seg[0].cpu().detach().numpy() * 255).astype('uint8')
seg_plots.shape

(20, 14, 14)

In [None]:
# Image dimensions
rows, cols = 5, 5
img_height, img_width = seg_plots.shape[1], seg_plots.shape[2]

# Create a blank canvas for the grid
grid_image = Image.new('L', (cols * img_width, rows * img_height))  # 'L' mode for grayscale

# Place each 14x14 image on the grid
for idx, img_array in enumerate(seg_plots):
    if idx >= rows * cols:  # Stop if we've filled the grid
        break
    img = Image.fromarray((img_array * 255).astype('uint8'))  # Scale to 0-255
    row, col = divmod(idx, cols)
    grid_image.paste(img, (col * img_width, row * img_height))

# Display the grid
# grid_image.show()

### get affinity map

In [39]:
attn_fts = fts.clone()
attn_fts.shape

torch.Size([1, 256, 14, 14])

In [40]:
f_b, f_c, f_h, f_w = attn_fts.shape

In [41]:
attn_fts_flatten = attn_fts.reshape(f_b, f_c, f_h*f_w) 
attn_fts_flatten.shape

torch.Size([1, 256, 196])

In [42]:
# batch matrix-matrix product of attn_fts_flatten and its transpose
attn_pred = attn_fts_flatten.transpose(2, 1).bmm(attn_fts_flatten) 
attn_pred.shape

torch.Size([1, 196, 196])

In [43]:
# will be used with attn_weight_list in RFM
attn_pred = torch.sigmoid(attn_pred) 

## CAM computation

In [44]:
attn_weight_stack = torch.stack(attn_weight_list, dim=0).permute(1, 0, 2, 3)
attn_weight_stack.shape

torch.Size([1, 11, 197, 197])

In [45]:
require_all_fts = True

if require_all_fts == True:
    # take always only the last layer for cam
    cam_fts_all = fts_all_stack[-1].unsqueeze(0).permute(2, 1, 0, 3)  # (1, hw, 1, c)
else:
    cam_fts_all = fts_all_stack.permute(2, 1, 0, 3)
cam_fts_all.shape

torch.Size([1, 197, 1, 768])

In [46]:
cam_fts_all[0].shape

torch.Size([197, 1, 768])

In [47]:
attn_weight_stack[0].shape

torch.Size([11, 197, 197])

In [None]:
seg_attn = attn_pred.unsqueeze(0)[:, 0, :, :]  # seg_attn for the first image
seg_attn.shape

torch.Size([1, 196, 196])