```bash
# train on voc
python scripts/dist_clip_voc.py --config your/path/WeCLIP/configs/voc_attn_reg.yaml
```

Three parameters requires to be modified based on your path in [voc_attn_reg.yaml](WeCLIP/configs/voc_attn_reg.yaml), 
```yaml
dataset:
  root_dir: /your/path/VOCdevkit/VOC2012
  name_list_dir: /your/path/WeCLIP/datasets/voc
  ...
clip_init:
  clip_pretrain_path: /your/path/WeCLIP/pretrained/ViT-B-16.pt
```

# [dist_clip_voc.py](WeCLIP/scripts/dist_clip_voc.py)

In [None]:
%%script true

# main steps, logging omitted
def train(cfg):

    train_dataset = voc.VOC12ClsDataset(...)
    
    val_dataset = voc.VOC12SegDataset(...)

    train_loader = DataLoader(train_dataset, ...)

    val_loader = DataLoader(val_dataset, ...)

    WeCLIP_model = WeCLIP(cfg)

    param_groups = WeCLIP_model.get_param_groups()
    WeCLIP_model.cuda()

    mask_size = int(cfg.dataset.crop_size // 16)
    attn_mask = get_mask_by_radius(h=mask_size, w=mask_size, radius=args.radius)

    optimizer = PolyWarmupAdamW(
        params=[
            {
                "params": param_groups[0],
                "lr": cfg.optimizer.learning_rate,
                "weight_decay": cfg.optimizer.weight_decay,
            },
            ...
        ],
        lr = cfg.optimizer.learning_rate,
        weight_decay = cfg.optimizer.weight_decay,
        ...
    )

    train_loader_iter = iter(train_loader)


    for n_iter in range(cfg.train.max_iters):
        
        img_name, inputs, cls_labels, img_box = next(train_loader_iter)

        segs, cam, attn_pred = WeCLIP_model(inputs.cuda(), img_name)

        pseudo_label = cam

        segs = F.interpolate(segs, size=pseudo_label.shape[1:], mode='bilinear', align_corners=False)

        fts_cam = cam.clone()

            
        aff_label = cams_to_affinity_label(fts_cam, mask=attn_mask, ignore_index=cfg.dataset.ignore_index)
        
        attn_loss, pos_count, neg_count = get_aff_loss(attn_pred, aff_label)

        seg_loss = get_seg_loss(segs, pseudo_label.type(torch.long), ignore_index=cfg.dataset.ignore_index)

        loss = 1 * seg_loss + 0.1*attn_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return True


## WeCLIP_model = [WeCLIP(cfg)](WeCLIP/WeCLIP_model/model_attn_aff_voc.py#L60)

In [None]:
%%script True

class WeCLIP(nn.Module):
    def __init__(self, num_classes, clip_model, embedding_dim=256, in_channels=512, dataset_root_path, device='cuda'):
        super().__init__()

        ## CLIP Encoder
        self.encoder, _ = clip.load(clip_model, device=device)
        # Freeze all layers except the last layer
        for name, param in self.encoder.named_parameters():
            if "11" not in name:
                param.requires_grad=False

        ## Decoded Components
        # SegFormerHead: Fuses features from different transformer layers
        self.decoder_fts_fuse = SegFormerHead(in_channel, embedding_dim, num_classes, index=11)
        
        # DecoderTransformer: Processes fused features for segmentation
        self.decoder = DecoderTransformer(width=embedding_dim, layers=3, heads=8, output_dim=num_classes)

        ## Text Features
        # zero-shot classifiers for background and foreground categories
        self.bg_text_features = zeroshot_classifier(BACKGROUND_CATEGORY, ['a clean origami {}.'], self.encoder)
        self.fg_text_features = zeroshot_classifier(new_class_names, ['a clean origami {}.'], self.encoder)

        self.target_layers = [self.encoder.visual.transformer.resblocks[-1].ln_1]

        # GradCAM for visualization and init CAM
        self.grad_cam = GradCAM(model=self.encoder, target_layers, reshape_transform)
        # PAR for refining CAMs
        self.par = PAR(num_iter=20, dilations=[1,2,4,8,12,24]).cuda()

        self.encoder.eval()
        self.cam_bg_thres = 1
        self.require_all_fts = True


    def forward(self, img, img_names='2007_000032', mode='train'):
        
        cam_list = []
        b, c, h, w = img.shape
        self.encoder.eval()                 #NOTE: already in __init__
        self.iter_num += 1

        # ------------------- Initial Feature Extraction -------------------
        fts_all, attn_weight_list = generate_clip_fts(img, self.encoder, require_all_fts=True)
        fts_all_stack = torch.stack(fts_all, dim=0) # (11, hw, b, c) because of 11 transformer layers
        attn_weight_stack = torch.stack(attn_weight_list, dim=0).permute(1, 0, 2, 3)
        
        # ------------------- Feature Processing -------------------
        if self.require_all_fts==True:      #BUG: Why this if condition?
            cam_fts_all = fts_all_stack[-1].unsqueeze(0).permute(2, 1, 0, 3) #(1, hw, 1, c)
        else:
            cam_fts_all = fts_all_stack.permute(2, 1, 0, 3) #(b, hw, 11, c)

        all_img_tokens = fts_all_stack[:, 1:, ...]      # [CLS] not needed for segmentation
        img_tokens_channel = all_img_tokens.size(-1)
        all_img_tokens = all_img_tokens.permute(0, 2, 3, 1) #BUG : why permute?
        all_img_tokens = all_img_tokens.reshape(-1, b, img_tokens_channel, h//16, w //16) #(11, b, c, h, w), 16 is the downsample factor from CLIP
        
        # ------------------- Decoder Pipeline -------------------
        # SegFormerHead for feature fusion
        fts = self.decoder_fts_fuse(all_img_tokens)
        attn_fts = fts.clone()
        _, _, fts_h, fts_w = fts.shape
        
        # DecoderTransformer for generating segmentation predictions
        seg, seg_attn_weight_list = self.decoder(fts)

        # ------------------- Attention Computation -------------------
        # Computes self-attention scores between all spatial positions
        f_b, f_c, f_h, f_w = attn_fts.shape
        attn_fts_flatten = attn_fts.reshape(f_b, f_c, f_h*f_w)  # (b, c, hw)
        attn_pred = attn_fts_flatten.transpose(2, 1).bmm(attn_fts_flatten)  # (b, hw, hw)
        attn_pred = torch.sigmoid(attn_pred)    # interpretable as attention scores

        for i, img_name in enumerate(img_names):
            img_path = os.path.join(self.root_path, str(img_name)+'.png')
            img_i = img[i]  # (3, h, w)
            
            # ------- 1. Extract CAM features -------
            cam_fts = cam_fts_all[i]                        # (hw, 1, c)
            cam_attn = attn_weight_stack[i]                 # (11, hw, h, w)
            seg_attn = attn_pred.unsqueeze(0)[:, i, :, :]   # (1, hw, hw)
            
            # ------- 2. Applies refinement after 15000 iterations -------
            if self.iter_num > 15000 or mode=='val': #15000
                require_seg_trans = True
            else:
                require_seg_trans = False

            cam_refined_list, keys, w, h = perform_single_voc_cam(img_path, img_i, cam_fts, cam_attn, seg_attn,
                                                                   self.bg_text_features, self.fg_text_features,
                                                                   self.grad_cam,
                                                                   mode=mode,
                                                                   require_seg_trans=require_seg_trans)

            cam_dict = generate_cam_label(cam_refined_list, keys, w, h)
            
            cams = cam_dict['refined_cam'].cuda()

            # ------- 3. Process Background scoring -------
            bg_score = torch.pow(1 - torch.max(cams, dim=0, keepdims=True)[0], self.cam_bg_thres).cuda()
            cams = torch.cat([bg_score, cams], dim=0).cuda()    # (num_classes+1, h, w)
            
            # Shifts existing class indices up by 1 to make room for background (0)
            # Adds padding at start for background class
            valid_key = np.pad(cam_dict['keys'] + 1, (1, 0), mode='constant')
            valid_key = torch.from_numpy(valid_key).cuda()
            
            # ------- 4. Refine CAMs using PAR -------
            with torch.no_grad():
                cam_labels = _refine_cams(self.par, img[i], cams, valid_key)
            
            cam_list.append(cam_labels)

        all_cam_labels = torch.stack(cam_list, dim=0)

        return seg, all_cam_labels, attn_pred



### fts_all, attn_weight_list = [generate_clip_fts](WeCLIP/clip/clip_tool.py#31)(img, self.encoder, require_all_fts=True)

In [None]:
%%script True

def generate_clip_fts(image, model, require_all_fts=True):
    model = model.cuda()

    if len(image.shape) == 3:       # (c, h, w)
        image = image.unsqueeze(0)  # (1, c, h, w)
    h, w = image.shape[-2], image.shape[-1]
    image = image.cuda()
    
    # model = self.encoder = clip.load(clip_model, device=device)
    image_features_all, attn_weight_list = model.encode_image(image, h, w, require_all_fts=require_all_fts)
        
    return image_features_all, attn_weight_list

#### image_features_all, attn_weight_list = [model.encode_image](WeCLIP/clip/model.py)(image, h, w, require_all_fts=require_all_fts)

In [None]:
%%script True

class CLIP(nn.Module):
    def __init__(..., vision_layers, ...):
        ...
        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(...)
        else:
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(...)

    def encode_image(self, image, H, W, require_all_fts=False):
        f_x, f_attn = self.visual(image.type(self.dtype), H, W, require_all_fts=require_all_fts)
        # f = self.visual(image.type(self.dtype), H, W, require_all_fts=require_all_fts)
        return f_x, f_attn

##### self.visual = [ModifiedResNet($\cdot$)](WeCLIP/clip/model.py#L114)

In [None]:
%%script True

class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """
    ...

    # NOTE: layers = [Bottleneck(self._inplanes, planes, stride)]

    def forward(self, x, H, W):
        def stem(x):
            x = self.relu1(self.bn1(self.conv1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))
            x = self.relu3(self.bn3(self.conv3(x)))
            x = self.avgpool(x)
            return x

        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)#(1,,2048, 7, 7)
        x_pooled = self.attnpool(x, H, W)

        return x_pooled

In [None]:
%%script True

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        ...
    
        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)
        return out

##### self.visual = [VisionTransformer($ \cdot $)](WeCLIP/clip/model.py#L246)

In [5]:
%%script True

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
        self.patch_size = patch_size

    def forward(self, x: torch.Tensor, H, W, require_all_fts=False):

        self.positional_embedding_new = upsample_pos_emb(self.positional_embedding, (H//16,W//16))
        x = self.conv1(x)                               # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)       # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)                          # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)               
                                                        # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding_new.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x, attn_weight = self.transformer(x, require_all_fts=require_all_fts)


        '''
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x)

        if self.proj is not None:
            x = x @ self.proj
        '''

        return x, attn_weight#cls_attn

In [None]:
%%script True

def upsample_pos_emb(emb, new_size):
    # upsample the pretrained embedding for higher resolution
    # emb size NxD
    first = emb[:1, :]
    emb = emb[1:, :]
    N, D = emb.size(0), emb.size(1)
    size = int(np.sqrt(N))
    assert size * size == N
    #new_size = size * self.upsample
    emb = emb.permute(1, 0)
    emb = emb.view(1, D, size, size).contiguous()
    emb = F.upsample(emb, size=new_size, mode='bilinear',)
    emb = emb.view(D, -1).contiguous()
    emb = emb.permute(1, 0)
    emb = torch.cat([first, emb], 0)
    emb = nn.parameter.Parameter(emb.half())
    return emb


In [None]:
%%script True

class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor, require_all_fts=False):
        attn_weights = []
        x_all = []
        with torch.no_grad():
            layers = self.layers if x.shape[0] == 77 else self.layers-1  # 77 context
            for i in range(layers):
                x, attn_weight = self.resblocks[i](x)
                x_all.append(x)
                attn_weights.append(attn_weight)
        '''
        for i in range(self.layers-1, self.layers):
            x, attn_weight = self.resblocks[i](x)
            attn_weights.append(attn_weight)
            #feature_map_list.append(x)
        '''
        if require_all_fts == True:
            return x_all, attn_weights
        else:
            return x, attn_weights


In [None]:
%% script True

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        # self.attn = nn.MultiheadAttention(d_model, n_head)
        self.attn = myAtt.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)#[0]

    def forward(self, x: torch.Tensor):
        attn_output, attn_weight = self.attention(self.ln_1(x))#(L,N,E)  (N,L,L)
        x = x + attn_output
        x = x + self.mlp(self.ln_2(x))
        return x, attn_weight


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

class MultiheadAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces.
    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
    """

### fts = self.[decoder_fts_fuse(all_img_tokens)](WeCLIP/WeCLIP_model/segformer_head.py#L49)

In [None]:
%%script True

class SegFormerHead(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self, in_channels=128, embedding_dim=256, num_classes=20, index=11, **kwargs):
        super(SegFormerHead, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.indexes = index #6 #11

        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        linear_layers = [MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) for i in range(self.indexes)]
        self.linears_modulelist = nn.ModuleList(linear_layers)

        self.linear_fuse = nn.Conv2d(embedding_dim*self.indexes, embedding_dim, kernel_size=1)
        self.dropout = nn.Dropout2d(0.1)


    def forward(self, x_all):
        x_list = []
        for ind in range(x_all.shape[0]):
            x = x_all[ind,:, :, :, :]
            n, _, h, w = x.shape
            _x = self.linears_modulelist[ind](x.float()).permute(0,2,1).reshape(n, -1, x.shape[2], x.shape[3])
            x_list.append(_x)
        x_list = torch.cat(x_list, dim=1)
        x = self.linear_fuse(x_list)
        x = self.dropout(x)

        return x


In [None]:
%%script True

class MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)
        self.proj_2 = nn.Linear(embed_dim, embed_dim)
        # self.proj_3 = nn.Linear(embed_dim*2, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        x = F.relu(x)
        x = self.proj_2(x)
        return x
    
    
class Conv_Linear(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(input_dim, embed_dim, kernel_size=1)
        self.proj_2 = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        # self.proj_3 = nn.Linear(embed_dim*2, embed_dim)

    def forward(self, x):
        # x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        x = F.relu(x)
        x = self.proj_2(x)
        return x

### seg, seg_attn_weight_list = self.[decoder(fts)](WeCLIP/WeCLIP_model/Decoder/TransDecoder.py#L104)

In [None]:
%%script True

class DecoderTransformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()

        self.transformer = Transformer(width, layers, heads)
        # self.dropout = nn.Dropout2d(0.1)
        self.linear_pred = nn.Conv2d(width, output_dim, kernel_size=1)


    def forward(self, x: torch.Tensor):
        b, c, h, w = x.shape
        x = x.reshape(b, c, h*w)    # NDL
        x = x.permute(2, 0, 1)      # NDL -> LND
        
        x, attn_weights_list = self.transformer(x) # L,N,D
        
        x = x.permute(1, 2, 0)
        x = x.reshape(b, c, h, w)
        logit = self.linear_pred(x)
        

        return logit, attn_weights_list

### cam_refined_list, keys, w, h = [perform_single_voc_cam(...)](WeCLIP/clip/clip_tool.py#L106)

In [None]:
%%script True

def perform_single_voc_cam(img_path, image, image_features, attn_weight_list, seg_attn, bg_text_features,
                       fg_text_features, cam, mode='train', require_seg_trans=False):
    
    ## ---- Inout Processing ----
    bg_text_features = bg_text_features.cuda()
    fg_text_features = fg_text_features.cuda()
    ori_image = Image.open(img_path)
    ori_height, ori_width = np.asarray(ori_image).shape[:2]

    ## ---- Label Processing ----
    label_id_list = np.unique(ori_image)
    label_id_list = (label_id_list - 1).tolist()
    
    # remove ignore labels
    if 255 in label_id_list:
        label_id_list.remove(255)
    if 254 in label_id_list:
        label_id_list.remove(254)

    label_list = []
    for lid in label_id_list:
        label_list.append(new_class_names[int(lid)])
    label_id_list = [int(lid) for lid in label_id_list]

    image = image.unsqueeze(0)
    h, w = image.shape[-2], image.shape[-1]

    highres_cam_to_save = []
    keys = []

    cam_refined_list = []

    ## ---- Feature preparation ----
    bg_features_temp = bg_text_features.cuda()  # [bg_id_for_each_image[im_idx]].to(device_id)
    fg_features_temp = fg_text_features[label_id_list].cuda()
    text_features_temp = torch.cat([fg_features_temp, bg_features_temp], dim=0)
    input_tensor = [image_features, text_features_temp.cuda(), h, w]


    ## ---- CAM Refinement ----
    for idx, label in enumerate(label_list):

        # ---- generates grayscale CAM using CLIP features ----
        label_index = new_class_names.index(label)
        keys.append(label_index)
        targets = [ClipOutputTarget(label_list.index(label))]
        grayscale_cam, logits_per_image, attn_weight_last = cam(input_tensor=input_tensor,
                                                                targets=targets,
                                                                target_size=None)  # (ori_width, ori_height))

        grayscale_cam = grayscale_cam[0, :]

        grayscale_cam_highres = cv2.resize(grayscale_cam, (w, h))
        highres_cam_to_save.append(torch.tensor(grayscale_cam_highres))

        ## ---- Attention-based Refinement ----
        if idx == 0:
            if require_seg_trans == True:

                ## ---- using difference thresholding ----
                attn_weight = torch.cat([attn_weight_list, attn_weight_last], dim=0)
                attn_weight = attn_weight[:, 1:, 1:][-6:] #-8
                #NOTE: Hard-coded attention slice indices (-6 and -8)

                # attn_diff = torch.abs(seg_attn - attn_weight)
                attn_diff = seg_attn - attn_weight
                attn_diff = torch.sum(attn_diff.flatten(1), dim=1)
                diff_th = torch.mean(attn_diff)

                attn_mask = torch.zeros_like(attn_diff)
                attn_mask[attn_diff <= diff_th] = 1

                attn_mask = attn_mask.reshape(-1, 1, 1)
                attn_mask = attn_mask.expand_as(attn_weight)
                attn_weight = torch.sum(attn_mask*attn_weight, dim=0) / (torch.sum(attn_mask, dim=0)+1e-5)

                attn_weight = attn_weight.detach()
                attn_weight = attn_weight * seg_attn.squeeze(0).detach()
            else:
                ## ---- using mean pooling ----
                attn_weight = torch.cat([attn_weight_list, attn_weight_last], dim=0)
                attn_weight = attn_weight[:, 1:, 1:][-8:]
                attn_weight = torch.mean(attn_weight, dim=0)  # (1, hw, hw)
                attn_weight = attn_weight.detach()
            _trans_mat = compute_trans_mat(attn_weight)
        _trans_mat = _trans_mat.float()

        ## ---- compute bounding boxes from CAM ---- 
        
        #NOTE: fixed threshold of 0.4
        box, cnt = scoremap2bbox(scoremap=grayscale_cam, threshold=0.4, multi_contour_eval=True)
        aff_mask = torch.zeros((grayscale_cam.shape[0], grayscale_cam.shape[1])).cuda()
        for i_ in range(cnt):
            x0_, y0_, x1_, y1_ = box[i_]
            aff_mask[y0_:y1_, x0_:x1_] = 1

        ## ---- create affinity mask ----
        aff_mask = aff_mask.view(1, grayscale_cam.shape[0] * grayscale_cam.shape[1])
        trans_mat = _trans_mat*aff_mask


        ## ---- apply transformation matrix for final refinement ----
        cam_to_refine = torch.FloatTensor(grayscale_cam).cuda()
        cam_to_refine = cam_to_refine.view(-1, 1)

        #NOTE: Resolution reduction by factor of 16 in final output
        cam_refined = torch.matmul(trans_mat, cam_to_refine).reshape(h // 16, w // 16)
        cam_refined_list.append(cam_refined)

    if mode == 'train':
        return cam_refined_list, keys, w, h
    else:
        return cam_refined_list, keys, ori_width, ori_height


Potential Improvements: 
- Make thresholds configurable
- Consider batch processing support

#### bg_text_features = [zeroshot_classifier](WeCLIP/WeCLIP_model/model_attn_aff_voc.py#L34)(BACKGROUND_CATEGORY, ['a clean origami {}.'], self.encoder)

fg_text_features = zeroshot_classifier(new_class_names, ['a clean origami {}.'], self.encoder)


In [None]:
%%script True

def zeroshot_classifier(classnames, templates, model):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).cuda() #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) ## normalize to unit length for cosine similarity
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights.t()

##### texts = [clip.tokenize(texts)](WeCLIP/clip/clip.py#L205).cuda()

In [None]:
%%script True

def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
    else:
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

##### class_embeddings = model.[encode_text(texts)](WeCLIP/clip/model.py#L392)

In [None]:
%%script True

def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND (length, batch, dim)
        x, attn_weight = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

x, attn_weight = self.[transformer](WeCLIP/clip/model.py#L218)(x)

x = self.[ln_final](WeCLIP/clip/model.py#L177)(x).type(self.dtype)

In [None]:
%%script True

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

#### grayscale_cam, logits_per_image, attn_weight_last = [cam(input_tensor, targets)](WeCLIP/pytorch_grad_cam/base_cam.py#L62)

In [None]:
%%script True

class GradCAM(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None):
        super(
            GradCAM,
            self).__init__(
            model,
            target_layers,
            use_cuda,
            reshape_transform)

    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        target_category,
                        activations,
                        grads):

        return np.mean(grads, axis=(2, 3))

In [None]:
%%script True

class BaseCAM:
    def __init__(self,
                 model: torch.nn.Module,
                 target_layers: List[torch.nn.Module],
                 use_cuda: bool = False,
                 reshape_transform: Callable = None,
                 compute_input_gradient: bool = False,
                 uses_gradients: bool = True) -> None:
        self.model = model.eval()
        ...
        self.activations_and_grads = ActivationsAndGradients(
            self.model, target_layers, reshape_transform)
    ...
    def forward(self,
                input_tensor: torch.Tensor,
                targets: List[torch.nn.Module],
                target_size,
                eigen_smooth: bool = False) -> np.ndarray:

        ## ---- Input Processing ----
        if self.cuda:
            input_tensor = input_tensor.cuda()
        if self.compute_input_gradient:
            input_tensor = torch.autograd.Variable(input_tensor,
                                                   requires_grad=True)

        W,H = self.get_target_width_height(input_tensor)
        outputs = self.activations_and_grads(input_tensor,H,W)

        ## ---- Target Management ----
        '''
        If no targets are provided, it automatically selects the highest scoring categories as targets. This handles both single inputs and lists of inputs.
        '''
        if targets is None:
            if isinstance(input_tensor, list):
                target_categories = np.argmax(outputs[0].cpu().data.numpy(), axis=-1)
            else:
                target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
            targets = [ClassifierOutputTarget(category) for category in target_categories]

        ## ---- Gradient Computation ----
        if self.uses_gradients:
            self.model.zero_grad()
            if isinstance(input_tensor, list):
                loss = sum([target(output[0]) for target, output in zip(targets, outputs)])
            else:
                loss = sum([target(output) for target, output in zip(targets, outputs)])
            loss.backward(retain_graph=True)

        # In most of the saliency attribution papers, the saliency is
        # computed with a single target layer.
        # Commonly it is the last convolutional layer.
        # Here we support passing a list with multiple target layers.
        # It will compute the saliency image for every image,
        # and then aggregate them (with a default mean aggregation).
        # This gives you more flexibility in case you just want to
        # use all conv layers for example, all Batchnorm layers,
        # or something else.
        cam_per_layer = self.compute_cam_per_layer(input_tensor,
                                                   targets,
                                                   target_size,
                                                   eigen_smooth)
        if isinstance(input_tensor, list):
            return self.aggregate_multi_layers(cam_per_layer), outputs[0], outputs[1]
        else:
            return self.aggregate_multi_layers(cam_per_layer), outputs

    ...
    # cam_per_layer = self.compute_cam_per_layer(...)
    def compute_cam_per_layer(
                self,
                input_tensor: torch.Tensor,
                targets: List[torch.nn.Module],
                target_size,
                eigen_smooth: bool) -> np.ndarray:

            ## ---- Data Preparation ----
            activations_list = [a.cpu().data.numpy()
                                for a in self.activations_and_grads.activations]
            grads_list = [g.cpu().data.numpy()
                        for g in self.activations_and_grads.gradients]

            cam_per_target_layer = []

            ## ---- Loop over the saliency image from every layer ----
            for i in range(len(self.target_layers)):
                target_layer = self.target_layers[i]
                layer_activations = None
                layer_grads = None
                if i < len(activations_list):
                    layer_activations = activations_list[i]
                if i < len(grads_list):
                    layer_grads = grads_list[i]

                cam = self.get_cam_image(input_tensor,
                                        target_layer,
                                        targets,
                                        layer_activations,
                                        layer_grads,
                                        eigen_smooth)
                # apply ReLU, 32-bit precision                        
                cam = np.maximum(cam, 0).astype(np.float32)#float16->32
                # scale to target size
                scaled = scale_cam_image(cam, target_size)
                cam_per_target_layer.append(scaled[:, None, :])

            return cam_per_target_layer

    # cam = self.get_cam_image(...)
    def get_cam_image(self,
                      input_tensor: torch.Tensor,
                      target_layer: torch.nn.Module,
                      targets: List[torch.nn.Module],
                      activations: torch.Tensor,
                      grads: torch.Tensor,
                      eigen_smooth: bool = False) -> np.ndarray:

        weights = self.get_cam_weights(input_tensor,
                                       target_layer,
                                       targets,
                                       activations,
                                       grads)
        weighted_activations = weights[:, :, None, None] * activations
        if eigen_smooth:
            cam = get_2d_projection(weighted_activations)
        else:
            cam = weighted_activations.sum(axis=1)
        return cam

    # weights = self.get_cam_weights(...)
    def get_cam_weights(self,
                        input_tensor: torch.Tensor,
                        target_layers: List[torch.nn.Module],
                        targets: List[torch.nn.Module],
                        activations: torch.Tensor,
                        grads: torch.Tensor) -> np.ndarray:
        
        """ Get a vector of weights for every channel in the target layer.
        Methods that return weights channels,
        will typically need to only implement this function. """

        raise Exception("Not Implemented")      #BUG: WHAT?!


### cam_dict = [generate_cam_label(cam_refined_list, keys, w, h)](WeCLIP/clip/clip_tool.py#L202)

In [None]:
%%script True

def generate_cam_label(cam_refined_list, keys, w, h):
    refined_cam_to_save = []
    refined_cam_all_scales = []
    for cam_refined in cam_refined_list:
        cam_refined = cam_refined.cpu().numpy().astype(np.float32)
        cam_refined_highres = scale_cam_image([cam_refined], (w, h))[0]
        refined_cam_to_save.append(torch.tensor(cam_refined_highres))

    keys = torch.tensor(keys)

    refined_cam_all_scales.append(torch.stack(refined_cam_to_save,dim=0))

    refined_cam_all_scales = refined_cam_all_scales[0]
    
    return {'keys': keys.numpy(), 'refined_cam':refined_cam_all_scales}

## cam_refined_highres = scale_cam_image([cam_refined], (w, h))[0]
# from WeCLIP/pytorch_grad_cam/utils/image.py#L51
def scale_cam_image(cam, target_size=None):
    result = []
    for img in cam:
        img = img - np.min(img)
        img = img / (1e-7 + np.max(img))
        if target_size is not None:
            img = cv2.resize(img, target_size)
        result.append(img)
    result = np.float32(result)

    return result



### cam_labels = [_refine_cams(self.par, img[i], cams, valid_key)](WeCLIP/WeCLIP_model/model_attn_aff_voc.py#L49)

In [None]:
%%script True

def _refine_cams(ref_mod, images, cams, valid_key):
    images = images.unsqueeze(0)
    cams = cams.unsqueeze(0)

    refined_cams = ref_mod(images.float(), cams.float())
    refined_label = refined_cams.argmax(dim=1)
    refined_label = valid_key[refined_label]

    return refined_label.squeeze(0)



#### refined_cams = [ref_mod(images.float(), cams.float())](WeCLIP/WeCLIP_model/PAR.py#L64)

In [None]:
%%script True

class PAR(nn.Module):

    def __init__(self, dilations, num_iter,):
        super().__init__()
        self.dilations = dilations
        self.num_iter = num_iter
        kernel = get_kernel()
        self.register_buffer('kernel', kernel)
        self.pos = self.get_pos()
        self.dim = 2
        self.w1 = 0.3
        self.w2 = 0.01
    
    ...
    
    def forward(self, imgs, masks):

        # masks = F.interpolate(masks, size=imgs.size()[-2:], mode="bilinear", align_corners=True)
        
        ## ensures images and masks are at the same resolution using bilinear interpolation
        imgs = F.interpolate(imgs, size=masks.size()[-2:], mode="bilinear", align_corners=True)

        b, c, h, w = imgs.shape
        ## ---- Neighbor and Position Processing ----
        _imgs = self.get_dilated_neighbors(imgs)
        _pos = self.pos.to(_imgs.device)

        _imgs_rep = imgs.unsqueeze(self.dim).repeat(1,1,_imgs.shape[self.dim],1,1) # (b, c, h*w, h, w), repeat for each neighbor position
        _pos_rep = _pos.repeat(b, 1, 1, h, w) # (b, 8, h*w, h, w), repeat for each image

        ## ---- Affinity Computation ----
        ### --- Appeareance-based ---
        _imgs_abs = torch.abs(_imgs - _imgs_rep)
        _imgs_std = torch.std(_imgs, dim=self.dim, keepdim=True)
        _pos_std = torch.std(_pos_rep, dim=self.dim, keepdim=True)

        aff = -(_imgs_abs / (_imgs_std + 1e-8) / self.w1)**2
        aff = aff.mean(dim=1, keepdim=True)

 
        ### --- Position-based ---
        pos_aff = -(_pos_rep / (_pos_std + 1e-8) / self.w1)**2
        #pos_aff = pos_aff.mean(dim=1, keepdim=True)

        #NOTE: 1e-8 term prevents division by zero but might need adjustment for different scales
        
        aff = F.softmax(aff, dim=2) + self.w2 * F.softmax(pos_aff, dim=2)
        #NOTE: weights w1 and w2 critically affect the balance between appearance and position affinities

        ## ---- Iterative Refinement ----
        for _ in range(self.num_iter):
            _masks = self.get_dilated_neighbors(masks)
            masks = (_masks * aff).sum(2)

        #TODO: Potential Improvement: add early stopping criteria 
        '''
        prev_mask = None
        for _ in range(self.num_iter):
            prev_mask = masks.clone()
            _masks = self.get_dilated_neighbors(masks)
            masks = (_masks * aff).sum(2)
            if torch.abs(masks - prev_mask).max() < 1e-6:
                break
        '''
        return masks

    # _imgs = self.get_dilated_neighbors(imgs)
    def get_dilated_neighbors(self, x):
        """
        Creates a dilated convolution pattern to gather neighboring pixel information.
        Crucial for understanding local image context at different scales.
        """
        b, c, h, w = x.shape
        x_aff = []
        for d in self.dilations:
            _x_pad = F.pad(x, [d]*4, mode='replicate', value=0)
            _x_pad = _x_pad.reshape(b*c, -1, _x_pad.shape[-2], _x_pad.shape[-1])
            _x = F.conv2d(_x_pad, self.kernel, dilation=d).view(b, c, -1, h, w)
            x_aff.append(_x)
 
        return torch.cat(x_aff, dim=2)
    
    # _pos = self.pos.to(_imgs.device)
    # self.pos = self.get_pos()
    def get_pos(self):
        pos_xy = []

        ker = torch.ones(1, 1, 8, 1, 1)
        '''
        initializes a 5D tensor representing a kernel with 8 positions. The dimensions represent:
            Batch size (1)
            Channel (1)
            Positions (8)
            Height (1)
            Width (1)
        '''
        # Special Position Weighting, geometric pattern for relative spatial relationships
        ker[0, 0, 0, 0, 0] = np.sqrt(2)
        ker[0, 0, 2, 0, 0] = np.sqrt(2)
        ker[0, 0, 5, 0, 0] = np.sqrt(2)
        ker[0, 0, 7, 0, 0] = np.sqrt(2)
        
        # scale by different dilations and concatenate along dim=2 (position)
        for d in self.dilations:
            pos_xy.append(ker*d)
        return torch.cat(pos_xy, dim=2)

## attn_mask = [get_mask_by_radius](WeCLIP/scripts/dist_clip_voc.py#L116)(h=mask_size, w=mask_size, radius=args.radius)

In [None]:
%%script True

def get_mask_by_radius(h=20, w=20, radius=8):
    '''
    creates a square binary matrix of size (h*w, h*w) where each row/column represents a pixel position, and 1's indicate connected pixels within the specified radius.
    '''
    # create square binary, each pixel : position
    hw = h * w
    mask  = np.zeros((hw, hw))

    for i in range(hw):
        _h = i // w
        _w = i % w
        
        # neighbors definition
        _h0 = max(0, _h - radius)
        _h1 = min(h, _h + radius+1)
        _w0 = max(0, _w - radius)
        _w1 = min(w, _w + radius+1)

        for i1 in range(_h0, _h1):
            for i2 in range(_w0, _w1):
                _i2 = i1 * w + i2   # Convert 2D coordinates to linear index
                mask[i, _i2] = 1    # forward connection
                mask[_i2, i] = 1    # backward connection   #NOTE: symmetric matrix

    return mask

## aff_label = [cams_to_affinity_label](WeCLIP/utils/camutils.py#L226)(fts_cam, mask=attn_mask, ignore_index=cfg.dataset.ignore_index)


In [None]:
%%script True

def cams_to_affinity_label(cam_label, mask=None, ignore_index=255):
    
    b,h,w = cam_label.shape
    
    ## Downsampling to reduce computational complexity with Neighbor interpolation
    cam_label_resized = F.interpolate(cam_label.unsqueeze(1).type(torch.float32), size=[h//16, w//16], mode="nearest")

    # cam_label_resized = F.interpolate(cam_label.unsqueeze(1).type(torch.float32), size=[h//8, w//8], mode="nearest")

    _cam_label = cam_label_resized.reshape(b, 1, -1) # (b, 1, h*w)
    _cam_label_rep = _cam_label.repeat([1, _cam_label.shape[-1], 1]) # (b, h*w, h*w) #NOTE: repeat to create comparison matrix
    _cam_label_rep_t = _cam_label_rep.permute(0,2,1)
    
    ## Affinity: 1 if same class, 0 otherwise
    aff_label = (_cam_label_rep == _cam_label_rep_t).type(torch.long)
    #aff_label[(_cam_label_rep+_cam_label_rep_t) == 0] = ignore_index

    ## Mask application
    for i in range(b):

        if mask is not None:
            aff_label[i, mask==0] = ignore_index

        aff_label[i, :, _cam_label_rep[i, 0, :]==ignore_index] = ignore_index
        aff_label[i, _cam_label_rep[i, 0, :]==ignore_index, :] = ignore_index

    return aff_label

## attn_loss, pos_count, neg_count = [get_aff_loss](WeCLIP/utils/losses.py#L11)(attn_pred, aff_label)


In [None]:
%%script True

def get_aff_loss(inputs, targets):

    pos_label = (targets == 1).type(torch.int16)    # positive
    pos_count = pos_label.sum() + 1                 # prevent division by zero

    neg_label = (targets == 0).type(torch.int16)    # negative
    neg_count = neg_label.sum() + 1
    #inputs = torch.sigmoid(input=inputs)

    pos_loss = torch.sum(pos_label * (1 - inputs)) / pos_count
    neg_loss = torch.sum(neg_label * (inputs)) / neg_count

    return 0.5 * pos_loss + 0.5 * neg_loss, pos_count, neg_count

## seg_loss = [get_seg_loss](WeCLIP/utils/losses.py#L24)(segs, pseudo_label.type(torch.long), ignore_index=cfg.dataset.ignore_index)

In [None]:
%%script True

def get_seg_loss(pred, label, ignore_index=255):
    
    ## ---- Background ----
    bg_label = label.clone()
    bg_label[label!=0] = ignore_index
    bg_loss = F.cross_entropy(pred, bg_label.type(torch.long), ignore_index=ignore_index)
    
    ## ---- Foreground ----
    fg_label = label.clone()
    fg_label[label==0] = ignore_index
    fg_loss = F.cross_entropy(pred, fg_label.type(torch.long), ignore_index=ignore_index)

    return (bg_loss + fg_loss) * 0.5

## loss = 1 * seg_loss + 0.1*attn_loss