In [1]:
!pip install transformers
!pip install openai-clip

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
from transformers import AutoTokenizer, BertForMaskedLM
import clip

In [3]:
class MLP(torch.nn.Module):
  def __init__(self, inp_size, hidden_size):
    super().__init__()
    self.inp_size = inp_size
    self.hidden_size = hidden_size
    self.mlp = torch.nn.Sequential(
        torch.nn.Linear(inp_size, hidden_size),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_size, inp_size)
    )

  def forward(self, x):
    return self.mlp(x)

  def backward(self):
    print('hello mlp')

In [110]:
class GumbelSamplerFunction(torch.autograd.Function):
  @staticmethod
  def forward(ctx, logits):
    temperature = 1.0
    gumbels = -torch.log(-torch.log(torch.rand_like(logits)))
    y = (logits + gumbels) / temperature
    y_soft = torch.softmax(y, dim=-1)
    print('y_soft in forward: ', y_soft, 'of shape: ', y_soft.shape)
    # ctx.save_for_backward(y_soft, temperature)
    ctx.y_soft = y_soft
    ctx.tmp = temperature
    token_id = torch.argmax(y, dim=-1)
    print('token_id in gumbel: ', token_id)
    #####lets extract tokens#########



    ################################
    token_id = token_id[:,1]
    print('token_id in gumbel on mask position: ', token_id)
    # tokens = tokenizer.decode(token_id)
    tokens = tokenizer.convert_ids_to_tokens(token_id)
    output = [input.replace("[MASK]", token) for input, token in zip(inputs, tokens)]
    print('gumbel outpt with mask replaced: ', output)
    tokens = tokenizer(output, return_tensors='pt')['input_ids']
    ctx.idxs = tokens
    print('tokens after decoding: ', tokens)
    txt_tokens = clip.tokenize(output).to(device)
    print('gumbel final tokenized: ', txt_tokens)
    embed = clip_model.token_embedding(txt_tokens).type(clip_model.dtype)
    print('gumbel final token embedded: ', embed)
    print('and its shape: ', embed.shape)
    return embed, txt_tokens

  @staticmethod
  def backward(ctx, grad_output, grad_output_):
    print(grad_output.shape)
    print(grad_output_.shape)
    print('grad_output: ', grad_output[0,:6,:])
    print('grad_output: ', grad_output[1,:6,:])
    print('grad_output: ', grad_output[2,:6,:])
    print('grad_output_ sum: ', grad_output_.sum())
    print('hello backpropagation')
    # y_soft, temp, _ = ctx.saved_tensors
    y_soft = ctx.y_soft
    tmp = ctx.tmp
    idxs = ctx.idxs
    print('y_soft in backward: ', y_soft, 'is of shape: ', y_soft.shape)
    print('temperature in backgraound: ', tmp)
    print('softmax output: ', y_soft, 'with shape of: ', y_soft.shape, 'temperature: ', tmp)
    bb = y_soft.reshape(-1, y_soft.shape[-1])[torch.arange(y_soft.shape[0]*y_soft.shape[1]), idxs.reshape(-1)].reshape(idxs.shape[0], idxs.shape[1])
    c = bb.expand(y_soft.shape[-1], y_soft.shape[0], y_soft.shape[1]).transpose(1,0).transpose(2,1)
    d = y_soft*c
    e = torch.zeros(y_soft.shape[0]*y_soft.shape[1], y_soft.shape[-1])
    e[torch.arange(y_soft.shape[0]*y_soft.shape[1]), idxs.reshape(-1)] = bb.reshape(-1)
    e = e.reshape(y_soft.shape[0], y_soft.shape[1], y_soft.shape[-1]) 
    grad_input_ = e - d
    grad_o = grad_output[:,:4,:]
    grad_o = grad_o.reshape(grad_o.shape[0]*grad_o.shape[1], -1).unsqueeze(1)
    grad_input = grad_input_.reshape(grad_input_.shape[0]*grad_input_.shape[1], -1).unsqueeze(-1)
    print(grad_o.shape, grad_input.shape)
    grad = torch.bmm(grad_input.to(device), grad_o.float()).sum(dim=-1).reshape(grad_input_.shape[0], grad_input_.shape[1], grad_input_.shape[-1])
    print('final gradient: ', grad, 'is of shape: ', grad.shape)
    return grad

In [None]:
a = torch.rand((3, 2, 30522))
print(a.shape)
b = a.reshape(-1,30522)
print(b.shape)
b_ = b.unsqueeze(dim=-1)
b__ = b.unsqueeze(dim=-2)
print(b_.shape)
print(b__.shape)
c = torch.bmm(b_, b__)
print(c.shape)

torch.Size([3, 2, 30522])
torch.Size([6, 30522])
torch.Size([6, 30522, 1])
torch.Size([6, 1, 30522])


In [72]:
class GumbelSampler(torch.nn.Module):
  def __init__(self, tokenizer, temperature):
    super().__init__()
    self.tokenizer = tokenizer
    self.temperature = temperature

  def forward(self, logits, inputs):
    return GumbelSamplerFunction.apply(logits)

In [7]:
class ModelWrapper(torch.nn.Module):
  def __init__(self, model, tokenizer, hidden_size, num_tokens, temperature):
    super().__init__()
    self.model = model
    self.tokenizer = tokenizer
    # self.clip_model = clip_model
    self.hidden_size = hidden_size
    self.num_tokens = num_tokens
    self.temperature = temperature
    self.mlp = MLP(self.model.config.hidden_size, hidden_size)
    self.gumbel = GumbelSampler(self.tokenizer, self.temperature)
    # self.mlps = [MLP(self.model.config.hidden_size, hidden_size) for _ in range(self.num_tokens)]
    # self.gumbels = [GumbelSampler(self.tokenizer, self.temperature) for _ in range(self.num_tokens)]

  def forward(self, inputs):
    tokens = self.tokenizer(inputs, return_tensors='pt')
    print('tokens: ',tokens)
    hidden_states = self.model.bert(**tokens)[0]
    print('hidden_states: ', hidden_states.shape)
    logits = self.mlp(hidden_states)
    print('logits: ',logits.shape)
    logits = self.model.cls(logits)
    print('logits: ',logits.shape)
    clip_embed, txt_tokens = self.gumbel(logits, inputs)
    print('clip_embed: ', clip_embed, clip_embed.shape)
    return clip_embed, txt_tokens
    # for i in range(self.num_tokens):
    #   inputs = ["[MASK] " + input for input in inputs]
    #   print('masked inputs ', inputs)
    #   tokens = self.tokenizer(inputs, return_tensors='pt')
    #   print('tokens: ',tokens)
    #   hidden_states = self.model.bert(**tokens)[0]
    #   print('hidden_states: ', hidden_states.shape)
    #   logits = self.mlps[i](hidden_states)
    #   print('logits: ',logits.shape)
    #   logits = self.model.cls(logits)
    #   print('logits: ', logits.shape)
    #   inputs = self.gumbels[i](logits, inputs)

    # return inputs

  def backward(self):
    print('hello wrapper')

In [8]:
class WholeModel(torch.nn.Module):
  def __init__(self, wrapper, clip_model):
    super().__init__()
    self.wrapper = wrapper
    self.clip_model = clip_model
    self.clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

  def forward(self, inputs):
    # clip_model.encode_text(feed)
    # self.output = self.wrapper(inputs)
    # self.txt_tokens = clip.tokenize(self.output).to(device)
    # self.embed = self.clip_model.token_embedding(self.txt_tokens).type(self.clip_model.dtype)
    # self.embed.retain_grad()
    self.embed, feed = self.wrapper(inputs)
    print('faile here: ', self.embed[0,:,:])

    x = self.embed + self.clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = clip_model.ln_final(x).type(clip_model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), feed.argmax(dim=-1)] @ clip_model.text_projection
    self.txt_encodings = x / x.norm(dim=-1, keepdim=True)
    # print('inputs: ', inputs)
    # self.output = self.wrapper(inputs)
    # print('outputs: ', self.output)
    # self.tmp = torch.tensor([self.clip_tokenizer.encode(input, add_special_tokens=True, padding='max_length') for input in self.output]).float()
    # self.tmp.requires_grad_(True)
    # self.tmp.retain_grad()
    # self.txt_encodings_ = self.clip_model.encode_text(self.tmp) 
    # self.txt_tokens = clip.tokenize(self.output).to(device)
    # self.txt_tokens.requires_grad = True
    # self.txt_encodings_ = self.clip_model.encode_text(self.txt_tokens)
    # self.txt_encodings_.requires_grad_(True)
    # self.txt_encodings_.retain_grad()
    # self.txt_encodings = self.txt_encodings_ / self.txt_encodings_.norm(dim=-1, keepdim=True)
    # img_features = model.encode_image(imgs.to(device))
    self.img_encodings = torch.randn((30458, 512))
    # self.txt_encodings.requires_grad_(True)
    # self.txt_encodings.retain_grad()
    # self.img_encodings.requires_grad_(True)
    self.similarity = self.txt_encodings @ self.img_encodings.T
    self.similarity.requires_grad_(True)
    self.similarity.retain_grad()
    loss = self.similarity.sum()
    return loss

  def backward(self, grad_output):
    print('hello whole')

In [9]:
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model, processor = clip.load('ViT-B/32', device=device)

100%|████████████████████████████████████████| 338M/338M [00:03<00:00, 101MiB/s]


In [111]:
wrapper = ModelWrapper(model, tokenizer, 256, 1, 1.0)
whole_model = WholeModel(wrapper, clip_model)

In [116]:
global inputs 
inputs = ["[MASK] person", "[MASK] bicycle", "[MASK] car"]
loss = whole_model(inputs)
print(loss)

tokens:  {'input_ids': tensor([[  101,   103,  2711,   102],
        [  101,   103, 10165,   102],
        [  101,   103,  2482,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])}
hidden_states:  torch.Size([3, 4, 768])
logits:  torch.Size([3, 4, 768])
logits:  torch.Size([3, 4, 30522])
y_soft in forward:  tensor([[[2.7031e-07, 5.2699e-07, 1.7985e-07,  ..., 8.1199e-08,
          7.3751e-08, 4.8735e-05],
         [7.0732e-07, 2.7526e-08, 1.9080e-08,  ..., 3.9186e-08,
          3.7138e-08, 1.7693e-05],
         [3.6295e-08, 1.0578e-07, 1.8867e-08,  ..., 2.0783e-08,
          2.2870e-08, 1.4402e-06],
         [6.1302e-07, 1.2641e-07, 6.9573e-08,  ..., 3.2119e-07,
          3.8192e-07, 3.3424e-07]],

        [[4.7906e-08, 1.7311e-08, 3.3159e-08,  ..., 1.7142e-06,
          3.5751e-08, 4.2431e-06],
         [3.0684e-08, 6.1585e-09, 4.4270e-08,  ..., 1.5841e-08,


In [113]:
loss.backward()

torch.Size([3, 77, 512])
torch.Size([3, 77])
grad_output:  tensor([[-42.4949, -44.1479,  47.0883,  ...,   8.6873,  11.8495, -20.6549],
        [-51.8301,   6.8017,  16.2583,  ...,   0.3932, -36.0103,  14.3313],
        [ 85.5589, -84.1684,  19.0699,  ..., -29.8966,  50.9443, -82.3116],
        [ 82.7132, -33.1042, 114.5548,  ...,  90.5907,  33.3180, 136.3680],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]])
grad_output:  tensor([[ -98.4914,   21.9524,   25.6109,  ...,  -21.3790,  102.2291,
           72.0058],
        [  50.8416, -148.2797,  121.4769,  ...,    7.4888,  -88.9292,
           68.4372],
        [ -82.5315,  100.9339,  -62.0025,  ...,   47.8599,  192.2266,
           10.2641],
        [  23.3634,   53.7243,  146.3239,  ...,  -30.6457,  227.4390,
           35.7940],
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000],
        [   0.0

In [117]:
# print(whole_model.similarity.grad.shape)
# print(whole_model.img_encodings.grad.shape)
# print(whole_model.txt_encodings.grad.shape)
# print(whole_model.txt_encodings_.grad.shape)
# print(torch.autograd.grad(loss, whole_model.similarity, retain_graph=True))
# print(torch.autograd.grad(loss, whole_model.txt_encodings, retain_graph=True))
# print(torch.autograd.grad(loss, whole_model.img_encodings, retain_graph=True))
print(torch.autograd.grad(loss, whole_model.embed, retain_graph=True))
# print(whole_model.output.grad.shape)
# print(whole_model.tmp.grad)

(tensor([[[ -41.5390,   75.1322,   53.0962,  ...,   83.7070,  -23.6668,
           -26.9198],
         [   2.8780, -114.4467,   79.9060,  ...,   77.3000,  106.5010,
            -7.5556],
         [  51.0725, -104.9996,  104.5557,  ...,  -46.5451,   40.8801,
             1.9088],
         ...,
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000]],

        [[   4.7402,   -5.2702,  -18.8308,  ...,   28.6637,   65.2339,
            70.0304],
         [ -60.8085,  -35.9729,   -8.6126,  ...,   88.6485,  239.0194,
           207.3170],
         [ 132.9128, -243.3581,   25.1346,  ...,  -25.7362,  307.1031,
            37.7087],
         ...,
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [   0.0000,    0.0000,    0.00