# Analyze Caption

In [1]:
from typing import List
from torch.utils.data import Dataset
import os.path as osp
import logging
import torch
from utils import read_image
from utils.simple_tokenizer import SimpleTokenizer
from prettytable import PrettyTable
import random
import regex as re
import copy
import torchvision.transforms as T
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt to /home/k64t/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/k64t/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [5]:
tokenizer = SimpleTokenizer()
def tokenize(caption: str, tokenizer, text_length=77, truncate=True) -> torch.LongTensor:
    sot_token = tokenizer.encoder["<|startoftext|>"]
    eot_token = tokenizer.encoder["<|endoftext|>"]
    tokens = [sot_token] + tokenizer.encode(caption) + [eot_token]

    result = torch.zeros(text_length, dtype=torch.long)
    if len(tokens) > text_length:
        if truncate:
            tokens = tokens[:text_length]
            tokens[-1] = eot_token
        else:
            raise RuntimeError(
                f"Input {caption} is too long for context length {text_length}"
            )
    result[:len(tokens)] = torch.tensor(tokens)
    return result
def build_random_masked_tokens_and_labels(tokens):
    """
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
    :param tokens: list of int, tokenized sentence.
    :return: (list of int, list of int), masked tokens and related labels for MLM prediction
    """
    mask = tokenizer.encoder["<|mask|>"]
    token_range = list(range(1, len(tokenizer.encoder)-3)) # 1 ~ 49405
    
    labels = []
    for i, token in enumerate(tokens):
        if 0 < token < 49405:
            prob = random.random()
            # mask token with 15% probability
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token
                if prob < 0.8:
                    tokens[i] = mask

                # 10% randomly change token to random token
                elif prob < 0.9:
                    tokens[i] = random.choice(token_range)

                # -> rest 10% randomly keep current token

                # append current token to output (we will predict these later)
                labels.append(token)
            else:
                # no masking token (will be ignored by loss function later)
                labels.append(0)
        else:
            labels.append(0)
    
    if all(l == 0 for l in labels):
        # at least mask 1
        labels[1] = tokens[1]
        tokens[1] = mask

    return torch.tensor(tokens), torch.tensor(labels)

def build_random_masked_tokens_and_labels_2(tokens):
    """
    Masking some (nouns, adjective, verb) tokens for Language Model task with probabilities as in the original BERT paper.
    :param tokens: list of int, tokenized sentence.
    :return: (list of int, list of int), masked tokens and related labels for MLM prediction
    """
    mask = tokenizer.encoder["<|mask|>"]
    token_range = list(range(1, len(tokenizer.encoder)-3)) # 1 ~ 49405
    selected_categories= ['DT', "NN", "NNS", "JJ", "CD", "PRP"]

    def __is_core_words_(token):
        try:
            if not token in token_range: return False
            word = tokenizer.decode([int(token)])
            tokens = nltk.word_tokenize(word)
            post_tag = nltk.pos_tag(tokens)[0][1]
            return post_tag in selected_categories
        except:
            print("mlm2 got error in nltk --> ",word, " \t pos_tag = ", nltk.pos_tag([word]))
            return True     


    labels = []
    for i, token in enumerate(tokens):
        label_token = int(token)
        if 0 < token < 49405 and __is_core_words_(token):
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15
                

                # 80% randomly change token to mask token
                if prob < 0.8:
                    tokens[i] = mask

                # 10% randomly change token to random token
                elif prob < 0.9:
                    tokens[i] = random.choice(token_range)

                # -> rest 10% randomly keep current token

                # append current token to output (we will predict these later)
                labels.append(label_token)
            else:
                # no masking token (will be ignored by loss function later)
                labels.append(0)
        else:
            labels.append(0)
    
    if all(l == 0 for l in labels):
        # at least mask 1
        labels[1] = tokens[1]
        tokens[1] = mask

    return torch.tensor(tokens), torch.tensor(labels)



In [6]:
caption = "The man is wearing brown framed glasses. He has short dark hair. He is wearing a gray t shirt with a red pocket. He has a backpack with one yellow strap and one blue strap. He has on khaki pants and black shoes with yellow trim. A person wearing brown eyeglasses holds a silver-and-black object in the left hand while wearing a gray t-shirt with a yellow backpack strap over one shoulder and a blue strap over the other shoulder. The person is wearing brown pants ending above the ankle with black and lime-green running shoes."
tokens_1 = tokenize(caption, tokenizer)
tokens_1

tensor([49406,   518,   786,   533,  3309,  2866, 13135,  6116,   269,   797,
          791,  3005,  3144,  2225,   269,   797,   533,  3309,   320,  7048,
          339,  2523,   593,   320,   736,  8504,   269,   797,   791,   320,
        14894,   593,   637,  4481, 13946,   537,   637,  1746, 13946,   269,
          797,   791,   525, 41646,  5003,   537,  1449,  4079,   593,  4481,
        14182,   269,   320,  2533,  3309,  2866,  5034,  6116,  7286,   320,
         3467,   268,   537,   268,  1449, 14115,   530,   518,  1823,  2463,
         1519,  3309,   320,  7048,   339,   268, 49407])

In [28]:
build_random_masked_tokens_and_labels(tokens_1)

  return torch.tensor(tokens), torch.tensor(labels)


(tensor([49406,   518,   786,   533,  3309,  2866, 49405, 49405,   269,   797,
           791,  3005,  3144,  2225,   269,   797,   533,  3309,   320,  7048,
         49405, 49405,   593,   320,   736,  8504,   269,   797, 49405,   320,
         14894,   593,   637,  4481, 13946,   537,   637,  1746, 13946,   269,
           797,   791,   525, 41646,  5003,   537, 24749,  4079,   593,  4481,
         14182,   269,   320,  2533,  3309,  2866, 49405,  6116,  7286,   320,
         21162,   268,   537,   268,  1449, 14115,   530,   518,  1823,  2463,
          1519,  3309,   320,  7048,   339,   268, 49407]),
 tensor([    0,     0,     0,     0,     0,     0, 49405, 49405,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         49405, 49405,     0,     0,     0,     0,     0,     0, 49405,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 24749,     0,     0,

In [36]:
mlm_ids, mlm_label = build_random_masked_tokens_and_labels_2(tokens_1.clone().detach())
mlm_ids, mlm_label

  return torch.tensor(tokens), torch.tensor(labels)


(tensor([49406,   518, 49405,   533,  3309,  2866, 49405, 49405,   269,   797,
           791,  3005,  3144,  2225,   269,   797,   533,  3309,   320, 12867,
         49405, 49405,   593,   320, 49405,  8504,   269,   797, 49405,   320,
         45485,   593,   637,  4481, 13946,   537,   637,  1746, 13946,   269,
           797,   791,   525, 41646,  5003,   537, 24749,  4079,   593, 49405,
         14182,   269,   320,  2533,  3309,  2866, 49405,  6116,  7286,   320,
         21162,   268,   537,   268,  1449, 14115,   530,   518,  1823,  2463,
          1519,  3309,   320,  7048,   339,   268, 49407]),
 tensor([    0,     0,   786,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,  7048,
             0,     0,     0,     0,   736,     0,     0,     0,     0,     0,
         14894,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,

In [14]:
tokenizer = SimpleTokenizer()
vocal_size = 49408
labels = tokens_1.clone()
probability_matrix = torch.full(labels.shape, 0.25)
def mask(input_ids, vocab_size, targets=None, masked_indices=None, probability_matrix=None):
    if masked_indices is None:
        masked_indices = torch.bernoulli(probability_matrix).bool()
    masked_indices[input_ids > 49405] = False
    print("masked indices:")
    print(masked_indices)
    if targets is not None:
        targets[~masked_indices] = -100  # We only compute loss on masked tokens
    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
    input_ids[indices_replaced] = -1 #tokenizer.encoder["<|mask|>"]
    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(input_ids.device)
    input_ids[indices_random] = random_words[indices_random]
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    if targets is not None:
        return input_ids, targets
    else:
        return input_ids
mask(tokens_1, vocab_size=vocal_size, targets=labels, probability_matrix=probability_matrix)

masked indices:
tensor([False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False,  True,
         True, False, False, False, False, False, False,  True, False, False,
         True, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False,  True, False])


(tensor([49406, 32678,   786,   533,    -1,    -1, 13135,    -1,   269,    -1,
            -1,    -1,    -1,  2225, 49405, 49405,    -1,    -1,    -1,    -1,
           339,  2523,    -1, 24598,    -1,    -1,    -1, 20968,   791,   320,
         49405,    -1,    -1,    -1, 13946, 46302,    -1,  1746,    -1,    -1,
            -1,    -1, 22716,    -1,  5003, 49405, 49405,    -1,   593, 35150,
            -1,   269,    -1,    -1,  3309, 26858,  5034,  6116,    -1,    -1,
          3467,   268,    -1, 49405,    -1, 14115,   530,    -1,    -1,  2463,
            -1,  3309,    -1,  7048,   339,   268, 49407]),
 tensor([ -100,  -100,  -100,   533,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,   533,  -100,  -100,    -1,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,   593,  -100,  -100,  -100,  -100,  -100,  -100,  -100,    -1,
         49405,  -100,  -100,  -100,  -100,  -100,  -100, 49405,  -100,

# Triplet

In [13]:
from __future__ import absolute_import

import torch
from torch import nn
import torch.nn.functional as F


def euclidean_dist(x, y):
	m, n = x.size(0), y.size(0)
	xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
	yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
	dist = xx + yy
	dist.addmm_(1, -2, x, y.t())
	dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
	return dist

def cosine_dist(x, y):
	bs1, bs2 = x.size(0), y.size(0)
	frac_up = torch.matmul(x, y.transpose(0, 1))
	frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
	            (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
	cosine = frac_up / frac_down
	return 1-cosine

def _batch_hard(mat_distance, mat_similarity, indice=False, topK=0):
	#topK should be < than  the number instances /id in a batch
	sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True)
	hard_p = sorted_mat_distance[:, :topK]
	hard_p_indice = positive_indices[:, :topK]
	sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False)
	hard_n = sorted_mat_distance[:, :topK]
	hard_n_indice = negative_indices[:, :topK]
	if(indice):
		return hard_p, hard_n, hard_p_indice, hard_n_indice
	return hard_p, hard_n



class TopKTripletLoss(nn.Module):

	def __init__(self, margin=0, normalize_feature=False, skip_mean=False, topK=1):
		super(TopKTripletLoss, self).__init__()
		self.margin = margin
		self.normalize_feature = normalize_feature
		self.skip_mean = skip_mean
		self.topk = topK
	def forward(self, emb1, emb2, label):
		if self.normalize_feature:
			# equal to cosine similarity
			emb1 = F.normalize(emb1)
			emb2 = F.normalize(emb2)

		mat_dist = euclidean_dist(emb1, emb2)
		assert mat_dist.size(0) == mat_dist.size(1)
		N = mat_dist.size(0)
		mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float()

		dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True, topK=self.topk)

		assert an_idx.size(0)==ap_idx.size(0)
		dist_group_ap = torch.sum((emb1 - torch.mean(emb2[ap_idx], dim=1)) ** 2, dim=1).sqrt()
		dist_group_an = torch.sum((emb1 - torch.mean(emb2[an_idx], dim=1)) ** 2, dim=1).sqrt()

		# triple_dist = torch.stack((dist_ap, dist_an), dim=1)
		triple_dist = torch.stack((dist_group_ap, dist_group_an), dim=1)
		print(triple_dist.shape)
		triple_dist = F.log_softmax(triple_dist, dim=1)

		loss = (- self.margin * triple_dist[:,0] - (1 - self.margin) * triple_dist[:,1])
		if self.skip_mean:
			return loss
		else:
			return loss.mean()

In [81]:
em = torch.rand(32, 2)
label = torch.randint(low=0, high=3, size=(32,))
print(label)
loss = TopKTripletLoss(margin=0, topK=4)
loss(em, em, label)

tensor([2, 0, 2, 0, 1, 1, 0, 2, 2, 0, 2, 1, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0,
        2, 1, 2, 0, 2, 2, 0, 0])
torch.Size([32, 2])


tensor(0.9422)

In [2]:
#from https://github.com/TinyZeaMays/CircleLoss/blob/master/circle_loss.py 

from typing import Tuple

import torch
from torch import nn, Tensor


def convert_label_to_similarity(normed_feature: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
    similarity_matrix = normed_feature @ normed_feature.transpose(1, 0)
    label_matrix = label.unsqueeze(1) == label.unsqueeze(0)

    positive_matrix = label_matrix.triu(diagonal=1)  #set diagnoal from 1 to 0, exclude itself
    negative_matrix = label_matrix.logical_not().triu(diagonal=1) #set diagnoal from 0 to 1, exclude itself

    similarity_matrix = similarity_matrix.view(-1)
    positive_matrix = positive_matrix.view(-1)
    negative_matrix = negative_matrix.view(-1)
    return similarity_matrix[positive_matrix], similarity_matrix[negative_matrix]


class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp: Tensor, sn: Tensor) -> Tensor:
        ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = - ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

        return loss




In [9]:
feat = nn.functional.normalize(torch.rand(16, 2, requires_grad=True))
label = torch.randint(high=3, size=(16,))

inp_sp, inp_sn = convert_label_to_similarity(feat, label)

criterion = CircleLoss(m=0.25, gamma=256)
circle_loss = criterion(inp_sp, inp_sn)

print(circle_loss)

tensor(423.2520, grad_fn=<SoftplusBackward0>)


In [7]:
inp_sp.shape

torch.Size([37])

In [10]:
similarity_matrix = feat @ feat.transpose(1, 0)
print(similarity_matrix)
label_matrix = label.unsqueeze(1) == label.unsqueeze(0)
print(label_matrix)


tensor([[1.0000, 0.9697, 0.6768, 0.9319, 0.9713, 0.7001, 0.8127, 0.9014, 0.8106,
         0.9989, 0.9152, 0.9129, 0.9979, 0.9953, 0.9996, 0.8397],
        [0.9697, 1.0000, 0.8361, 0.8151, 0.8838, 0.8533, 0.6457, 0.9799, 0.6430,
         0.9800, 0.7891, 0.9850, 0.9518, 0.9415, 0.9621, 0.9469],
        [0.6768, 0.8361, 1.0000, 0.3638, 0.4824, 0.9995, 0.1210, 0.9288, 0.1176,
         0.7104, 0.3228, 0.9184, 0.6275, 0.6024, 0.6550, 0.9681],
        [0.9319, 0.8151, 0.3638, 1.0000, 0.9914, 0.3935, 0.9687, 0.6830, 0.9678,
         0.9140, 0.9990, 0.7027, 0.9535, 0.9627, 0.9422, 0.5856],
        [0.9713, 0.8838, 0.4824, 0.9914, 1.0000, 0.5103, 0.9279, 0.7726, 0.9266,
         0.9592, 0.9848, 0.7897, 0.9847, 0.9898, 0.9779, 0.6865],
        [0.7001, 0.8533, 0.9995, 0.3935, 0.5103, 1.0000, 0.1528, 0.9402, 0.1494,
         0.7326, 0.3530, 0.9306, 0.6522, 0.6277, 0.6789, 0.9756],
        [0.8127, 0.6457, 0.1210, 0.9687, 0.9279, 0.1528, 1.0000, 0.4802, 1.0000,
         0.7846, 0.9786, 0.5040, 0.84

In [11]:
positive_matrix = label_matrix.triu(diagonal=1)
print(positive_matrix.shape)


torch.Size([16, 16])


In [12]:
positive_matrix

tensor([[False, False, False,  True, False, False, False, False, False, False,
          True,  True,  True,  True, False, False],
        [False, False, False, False,  True, False, False, False,  True,  True,
         False, False, False, False,  True, False],
        [False, False, False, False, False,  True,  True,  True, False, False,
         False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False,
          True,  True,  True,  True, False, False],
        [False, False, False, False, False, False, False, False,  True,  True,
         False, False, False, False,  True, False],
        [False, False, False, False, False, False,  True,  True, False, False,
         False, False, False, False, False,  True],
        [False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False,
    

In [None]:
negative_matrix = label_matrix.logical_not().triu(diagonal=1)

In [None]:

similarity_matrix = similarity_matrix.view(-1)
positive_matrix = positive_matrix.view(-1)
negative_matrix = negative_matrix.view(-1)

# MIM

In [1]:
import torch

# Assuming label_array is your 1x4 array
label_array = torch.tensor([1, 2, 1, 3])  # example label array

# Create a 4x4 matrix where each row is filled with the corresponding label
# label_matrix = torch.zeros((4, 4))

# Fill the matrix
# label_matrix[torch.arange(4), label_array] = 1
label_matrix = (torch.arange(4) == label_array[:, None]).float()
print(label_matrix)


tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.]])


In [3]:
label_array.view(-1, 1)

tensor([[1],
        [2],
        [1],
        [3]])

In [7]:
def compute_mim(pred, target, patch_mask):
    """
    pred: [N, L, p*p*3]  L = patches # masks generated by hogloss
    mask: [N, (W/block size) * (H / block size)], 0 is keep, 1 is remove, 
    """
    #compute loss:
    B, N, C = pred.shape
    H = W = int(N**0.5)
    target = target.permute(0, 2, 3, 1) #BxPxPx(Cxbins)
    print("\n\t pred shape :", pred.shape)
    print("\n\t target shape :", target.shape)
    print("\n\t mask shape :", patch_mask.shape)
    
    mask_size = patch_mask.shape[-1]
    if mask_size > W:

        target_size, target_channel = target.shape[2], target.shape[3]
        target = target.flatten(1, 2) #Bx(h*w)x(C*bin)
        patch_mask = torch.repeat_interleave(patch_mask, target_size//mask_size, dim=2)
        patch_mask = torch.repeat_interleave(patch_mask, target_size//mask_size, dim=3)
        pred = pred.reshape(B, H, W, -1, target_size//H, target_size//W).permute(0, 1, 4, 2, 5, 3).reshape(B, target_size**2, target_channel)
    else:
        
        unfold_size = target.shape[-1] // W
        if unfold_size > 0: 
            target = (
                target.unfold(1, unfold_size, unfold_size)
                .unfold(2, unfold_size, unfold_size)
                .flatten(1, 2).flatten(2)
            )
        else: target = target.flatten(1, 2).flatten(2)
    pred.
    print("after change")
    print("\npred shape :", pred.shape)
    print("\ntarget shape: ", target.shape)
    print("\nmask shape : ", patch_mask.shape)
    mim_loss = (pred[patch_mask] - target[patch_mask]) ** 2
    mim_loss = mim_loss.mean()
    return mim_loss

In [None]:
import torch
x = torch.rand((8, 16*12, 108)) #
target = torch.rand(8, 27, 16*2, 12*2)
mask = torch.rand(8, 16*12*4)
compute_mim(x, target, mask)

In [13]:
x.flatten(1).reshape(target.shape).shape

torch.Size([8, 27, 32, 24])

In [17]:
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# MAE: https://github.com/facebookresearch/mae
# UM-MAE: https://github.com/implus/UM-MAE
# --------------------------------------------------------

from functools import partial

import numpy as np
import timm.models
import torch
import torch.nn as nn
from einops import rearrange

from timm.models.vision_transformer import PatchEmbed, Block, DropPath, Mlp

from util.pos_embed import get_2d_sincos_pos_embed


class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
                 asymmetric_decoder=False, mask_ratio=0.75, vis_mask_ratio=0.,
                 learning_loss=True):
        super().__init__()

        self.vis_mask_ratio = vis_mask_ratio
        if vis_mask_ratio > 0:
            self.vis_mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.learning_loss = learning_loss

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
                                      requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics (projector, predictor, and loss predictor)
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
                                              requires_grad=False)  # fixed sin-cos embedding

        # reconstructor (e.g., projector at feature-level)
        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        # self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True)  # decoder to patch
        # self.decoder_pred = nn.Linear(decoder_embed_dim, embed_dim, bias=True)          # dino/ema
        self.decoder_pred = nn.Linear(decoder_embed_dim, decoder_embed_dim, bias=True)  # clip

        # loss predictor
        if self.learning_loss:
            self.decoder_blocks_losspred = nn.ModuleList([
                Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None,
                      norm_layer=norm_layer)
                for i in range(decoder_depth)])
            self.decoder_norm_losspred = norm_layer(decoder_embed_dim)
            self.decoder_pred_losspred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True)
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
                                            cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
                                                    int(self.patch_embed.num_patches ** .5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        if hasattr(self, 'vis_mask_token'):
            torch.nn.init.normal_(self.vis_mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
        # x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def forward_encoder(self, x, mask):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        N, _, D = x.shape
        x = x[~mask].reshape(N, -1, D)

        if self.vis_mask_ratio > 0:
            vis_mask_token = self.vis_mask_token + self.pos_embed[:, 1:, :]
            vis_mask_token = vis_mask_token.expand(N, -1, -1)
            vis_mask_token = vis_mask_token[~mask].reshape(N, -1, D)
            L = x.size(1)
            noise = torch.rand(N, L, device=x.device)
            ids_restore = torch.argsort(noise, dim=1)

            len_keep = int(L * (1 - self.vis_mask_ratio))
            vis_mask = torch.ones([N, L], device=x.device)
            vis_mask[:, :len_keep] = 0
            vis_mask = torch.gather(vis_mask, dim=1, index=ids_restore).unsqueeze(-1)

            x = x * (1. - vis_mask) + vis_mask_token * vis_mask

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x

    def forward_decoder(self, x, mask):
        # embed tokens
        x = self.decoder_embed(x)
        x_vis = x[:, 1:, :]
        N, _, D = x_vis.shape

        # append mask tokens to sequence
        expand_pos_embed = self.decoder_pos_embed[:, 1:, :].expand(N, -1, -1)
        pos_vis = expand_pos_embed[~mask].reshape(N, -1, D)
        pos_mask = expand_pos_embed[mask].reshape(N, -1, D)

        x_ = torch.cat([x_vis + pos_vis, self.mask_token + pos_mask], dim=1)

        # add cls_token + decoder_pos_embed
        x = torch.cat([x[:, :1, :] + self.decoder_pos_embed[:, :1, :], x_], dim=1)
        loss_pred = x.clone()

        # apply reconstructor
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)
        x = x[:, 1:, :]

        if self.learning_loss:
            # apply loss predictor
            for blk in self.decoder_blocks_losspred:
                loss_pred = blk(loss_pred)
            loss_pred = self.decoder_norm_losspred(loss_pred)
            loss_pred = self.decoder_pred_losspred(loss_pred)
            loss_pred = loss_pred[:, 1:, :]  # (N, L, 1)

            return x, pos_mask.shape[1], loss_pred.mean(dim=-1)

        return x, pos_mask.shape[1]

    def forward_loss(self, pred, target, mask):
        """
        pred: [N, mask, D]
        target: [N, L, D]
        mask: [N, L], 0 is keep, 1 is remove,
        """
        N, _, D = target.shape
        target = target[mask].reshape(N, -1, D)

        pred = torch.nn.functional.normalize(pred, p=2, dim=-1)
        target = torch.nn.functional.normalize(target, p=2, dim=-1)
        loss = ((pred - target) ** 2).sum(dim=-1)

        return {'mean': loss.mean(), 'matrix': loss}

    def forward(self, imgs, mask):
        latent = self.forward_encoder(imgs, mask)  # returned mask may change

        if self.learning_loss:
            pred, mask_num, loss_pred = self.forward_decoder(latent, mask)  # [N, L, p*p*3]
        else:
            pred, mask_num = self.forward_decoder(latent, mask)
        # loss = self.forward_loss(imgs, pred[:, -mask_num:], mask)
        # return loss, pred, mask
        out = {
            'pix_pred': pred,
            'mask': mask,
            'mask_num': mask_num,
            'features': latent,
        }

        if self.learning_loss:
            out['loss_pred'] = loss_pred

        return out

    def generate_mask(self, loss_pred, mask_ratio=0.75, images=None,  guide=True, epoch=0, total_epoch=200):
        loss_pred = loss_pred.squeeze()
        N, L = loss_pred.shape
        len_keep = int(L * (1 - mask_ratio))

        ids_shuffle_loss = torch.argsort(loss_pred, dim=1)  # (N, L)

        # keep `keep_ratio` loss and `1 - keep_ratio` random
        keep_ratio = 0.2
        ids_shuffle = torch.zeros_like(ids_shuffle_loss, device=loss_pred.device).int()

        if guide:
            keep_ratio = float((epoch + 1) / total_epoch) * 0.5

        ## top 0 -> 0.5
        if int((L - len_keep) * keep_ratio) <= 0:
            # random
            noise = torch.randn(N, L, device=loss_pred.device)
            ids_shuffle = torch.argsort(noise, dim=1)
        else:
            for i in range(N):
                ## mask top `keep_ratio` loss and `1 - keep_ratio` random
                len_loss = int((L - len_keep) * keep_ratio)
                ids_shuffle[i, -len_loss:] = ids_shuffle_loss[i, -len_loss:]

                temp = torch.arange(L, device=loss_pred.device)
                deleted = np.delete(temp.cpu().numpy(), ids_shuffle[i, -len_loss:].cpu().numpy())
                np.random.shuffle(deleted)
                ids_shuffle[i, :(L - len_loss)] = torch.LongTensor(deleted).to(loss_pred.device)

        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # generate mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=loss_pred.device)
        mask[:, :len_keep] = 0
        # unshuffle to get final mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return mask

    def forward_learning_loss(self, loss_pred, mask, loss_target, relative=False):
        """
        loss_pred: [N, L, 1]
        mask: [N, L], 0 is keep, 1 is remove,
        loss_target: [N, L]
        """
        # N, L = loss_target.shape
        # loss_pred = loss_pred[mask].reshape(N, L)
        assert self.learning_loss

        if relative:
            # binary classification for LxL
            labels_positive = loss_target.unsqueeze(1) > loss_target.unsqueeze(2)
            labels_negative = loss_target.unsqueeze(1) < loss_target.unsqueeze(2)
            labels_valid = labels_positive + labels_negative

            loss_matrix = loss_pred.unsqueeze(1) - loss_pred.unsqueeze(2)
            loss = - labels_positive.int() * torch.log(torch.sigmoid(loss_matrix) + 1e-6) \
                   - labels_negative.int() * torch.log(1 - torch.sigmoid(loss_matrix) + 1e-6)

            return loss.sum() / labels_valid.sum()

        else:
            # normalize by each image
            mean = loss_target.mean(dim=1, keepdim=True)
            var = loss_target.var(dim=1, keepdim=True)
            loss_target = (loss_target - mean) / (var + 1.e-6) ** .5  # [N, L, 1]

            loss = (loss_pred - loss_target) ** 2
            loss = loss.mean()
            return loss


def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

(torch.Size([8, 768, 27]), torch.Size([8, 768]), torch.Size([8, 27, 32, 24]))

In [None]:
import numpy as np
def generate_mask( loss_pred, mask_ratio=0.75, images=None,  guide=True, epoch=0, total_epoch=200):
    loss_pred = loss_pred.squeeze()
    N, L = loss_pred.shape
    len_keep = int(L * (1 - mask_ratio))

    ids_shuffle_loss = torch.argsort(loss_pred, dim=1)  # (N, L)
    print("ids_shuffle_loss")
    print(ids_shuffle_loss)
    # keep `keep_ratio` loss and `1 - keep_ratio` random
    keep_ratio = 0.2
    ids_shuffle = torch.zeros_like(ids_shuffle_loss, device=loss_pred.device).int()
    print("ids_shuffle")
    print(ids_shuffle)

    if guide:
        keep_ratio = float((epoch + 1) / total_epoch) * 0.5

    ## top 0 -> 0.5
    if False and int((L - len_keep) * keep_ratio) <= 0:
        print("random")
        # random
        noise = torch.randn(N, L, device=loss_pred.device)
        ids_shuffle = torch.argsort(noise, dim=1)
    else:
        for i in range(N):
            ## mask top `keep_ratio` loss and `1 - keep_ratio` random
            len_loss = int((L - len_keep) * keep_ratio)
            print("len losss ==", len_loss, "keep ratio ==", keep_ratio )
            ids_shuffle[i, -len_loss:] = ids_shuffle_loss[i, -len_loss:]

            temp = torch.arange(L, device=loss_pred.device)
            deleted = np.delete(temp.cpu().numpy(), ids_shuffle[i, -len_loss:].cpu().numpy())
            print("delete Tensor")
            print(deleted)
            np.random.shuffle(deleted)
            ids_shuffle[i, :(L - len_loss)] = torch.LongTensor(deleted).to(loss_pred.device)

    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # generate mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=loss_pred.device)
    mask[:, :len_keep] = 0
    # unshuffle to get final mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return mask

pred = torch.rand((1, 4 , 8))
print(pred.shape, pred)
generate_mask(pred, epoch=50)

In [45]:
import torch
loss = torch.rand(2, 16)
loss[0, 10:]=0
loss[1, 8:]=0
print(loss)

tensor([[0.9162, 0.4749, 0.3362, 0.6069, 0.3912, 0.7985, 0.7731, 0.3706, 0.4915,
         0.9046, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4569, 0.0728, 0.8442, 0.7757, 0.5863, 0.6298, 0.5279, 0.8872, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])


In [46]:
L = torch.sum(loss > 0, axis=1)
L

tensor([10,  8])

In [47]:
P = L
mlm_prob=0.6
mmm_mask=0.5
num_masked_paths = torch.round(P * mlm_prob + 0.5).clamp_max(P).long()
num_hard_masked_paths = torch.round(num_masked_paths * mmm_mask +0.5).clamp_max(num_masked_paths).long()
num_masked_paths, num_hard_masked_paths

(tensor([6, 5]), tensor([4, 3]))

In [48]:
idx_tensor = torch.argsort(loss, dim=1, descending=True)
idx_tensor

tensor([[ 0,  9,  5,  6,  3,  8,  1,  4,  7,  2, 10, 11, 12, 13, 14, 15],
        [ 7,  2,  3,  5,  4,  6,  0,  1,  8,  9, 10, 11, 12, 13, 14, 15]])

In [42]:
loss[idx_tensor[:, num_hard_masked_paths:]] * torch.rand(loss[idx_tensor[:, num_hard_masked_paths:]].shape)


TypeError: only integer tensors of a single element can be converted to an index

In [26]:
idx_tensor[:, num_hard_masked_paths:] = idx_tensor[:, num_hard_masked_paths:][:, torch.randperm(P-num_hard_masked_paths)]
idx_tensor

tensor([[6, 3, 4, 0, 1, 2, 5, 7],
        [0, 3, 4, 1, 7, 6, 2, 5]])

In [6]:
import torch

# Your tensors
index_tensor = torch.tensor([[1,2,3,0],[3,0,2,1]])
value_tensor = torch.tensor([[0.1, 0.2, 0.3, 0.4],[0.1, 0.2, 0.3, 0.4]])
P = torch.tensor([2, 3])

# Create a mask tensor
mask = torch.arange(index_tensor.shape[1]).expand(*index_tensor.shape) < P.unsqueeze(1)
print(mask)
# Apply the mask to the index tensor
masked_index = mask * index_tensor + (mask * 1 - 1)
print(masked_index)
# Create a tensor of zeros with the same shape as value_tensor
result = torch.zeros_like(value_tensor)

# Use scatter to place the values from value_tensor into result at the positions specified by masked_index
result.scatter_(1, masked_index, value_tensor)

print(result)


tensor([[ True,  True, False, False],
        [ True,  True,  True, False]])
tensor([[ 1,  2, -1, -1],
        [ 3,  0,  2, -1]])


RuntimeError: index -1 is out of bounds for dimension 1 with size 4

In [8]:
import torch

# Your tensors
index_tensor = torch.tensor([[1,2,3,0],[3,0,2,1]])
value_tensor = torch.tensor([[0.1, 0.2, 0.3, 0.4],[0.1, 0.2, 0.3, 0.4]])
P = torch.tensor([2, 3])

# Create a mask tensor
mask = torch.arange(index_tensor.shape[1]).expand(*index_tensor.shape) < P.unsqueeze(1)

# Apply the mask to the index tensor
masked_index = mask * index_tensor
print(masked_index)
# Create a tensor of zeros with the same shape as value_tensor
result = torch.zeros_like(value_tensor)

# Use scatter to place the values from value_tensor into result at the positions specified by masked_index
result.scatter_(1, masked_index, value_tensor)

# Replace 0 with -1 in the result tensor
result[result == 0] = -1

print(result)


tensor([[1, 2, 0, 0],
        [3, 0, 2, 0]])
tensor([[ 0.4000,  0.1000,  0.2000, -1.0000],
        [ 0.4000, -1.0000,  0.3000,  0.1000]])
