In [18]:
from hyptorch.geoopt import PoincareBall   
from hyptorch.lorentz.manifold import CustomLorentz 
from model.modules.utils import ManifoldMapper
import torch


poincare = PoincareBall(c=1.0, learnable=True)
lorentz = CustomLorentz(k=1.0, learnable=True)

lmapper = ManifoldMapper(lorentz, clip_r=1.25, curv=lorentz.k)
pmapper = ManifoldMapper(lorentz, clip_r=1.25, curv=poincare.c)
vis_embed = torch.rand(100, 12, 256)
text_embed = torch.rand(100, 256)
print(vis_embed.T.shape)


vis_embed = lmapper(vis_embed)
text_embed = lmapper(text_embed)
lorentz.assert_check_point_on_manifold(vis_embed)
print('text_embed', text_embed.shape)
print('vision_embed', vis_embed.shape)

sim_q2t,_ = lorentz.dist_batch(
    vis_embed.unsqueeze(1), 
    text_embed.unsqueeze(-1),
).squeeze().max(1)


print(sim_q2t)
sim_t2q,_ = lorentz.dist_batch(
    text_embed.unsqueeze(1).unsqueeze(1),
    vis_embed.permute(0, 2, 1),
).squeeze().max(1)
print(sim_t2q)


torch.Size([256, 12, 100])
text_embed torch.Size([100, 257])
vision_embed torch.Size([100, 12, 257])
tensor([[1.1687, 1.1688, 1.1689,  ..., 1.1458, 1.1415, 1.1293],
        [1.1008, 1.1727, 1.1338,  ..., 1.1007, 1.1767, 1.1207],
        [1.1291, 1.1587, 1.1833,  ..., 1.1119, 1.1322, 1.1551],
        ...,
        [1.1232, 1.1703, 1.1322,  ..., 1.1180, 1.1537, 1.1263],
        [1.1470, 1.1715, 1.1758,  ..., 1.1159, 1.1572, 1.1315],
        [1.1334, 1.1306, 1.1603,  ..., 1.1238, 1.1147, 1.1224]],
       grad_fn=<MaxBackward0>)
tensor([[1.1687, 1.1008, 1.1291,  ..., 1.1232, 1.1470, 1.1334],
        [1.1688, 1.1727, 1.1587,  ..., 1.1703, 1.1715, 1.1306],
        [1.1689, 1.1338, 1.1833,  ..., 1.1322, 1.1758, 1.1603],
        ...,
        [1.1458, 1.1007, 1.1119,  ..., 1.1180, 1.1159, 1.1238],
        [1.1415, 1.1767, 1.1322,  ..., 1.1537, 1.1572, 1.1147],
        [1.1293, 1.1207, 1.1551,  ..., 1.1263, 1.1315, 1.1224]],
       grad_fn=<MaxBackward0>)


In [35]:
import torch.nn.functional as F
def itc_loss(self, image_embeds, text_embeds, image_worlds, text_worlds, sim_i2t_targets, sim_t2i_targets):
    sim_i2t = self.dist_func(image_embeds, text_worlds) 
    sim_t2i = self.dist_func(text_embeds, image_worlds)

    loss_i2t = -torch.sum(
        F.log_softmax(sim_i2t / self.logit_scale, dim=1) * sim_i2t_targets, dim=-1
    ).mean()
    loss_t2i = -torch.sum(
        F.log_softmax(sim_t2i / self.logit_scale, dim=1) * sim_t2i_targets, dim=-1
    ).mean()      

    loss_itc = self.config.weight_i2t * loss_i2t + (1-self.config.weight_i2t) * loss_t2i
    return loss_itc



In [7]:
from lavis.models import load_model_and_preprocess
  
model, vis_processors, txt_processors = load_model_and_preprocess("blip2", "coco", is_eval=False)

Position interpolate from 16x16 to 26x26


In [9]:
model.visual_encoder

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-38): 39 x Block(
      (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1408, out_features=4224, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1408, out_features=1408, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1408, out_features=6144, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=6144, out_features=1408, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
)

In [12]:
model.query_tokens

Parameter containing:
tensor([[[-0.0143,  0.0018,  0.0118,  ..., -0.0255, -0.0029,  0.0191],
         [-0.0155,  0.0102,  0.0205,  ..., -0.0046,  0.0095, -0.0043],
         [-0.0010,  0.0106,  0.0050,  ..., -0.0086,  0.0077,  0.0237],
         ...,
         [ 0.0026,  0.0240,  0.0212,  ...,  0.0037, -0.0483,  0.0041],
         [ 0.0014,  0.0005, -0.0147,  ..., -0.0160, -0.0442,  0.0008],
         [ 0.0140,  0.0054,  0.0262,  ..., -0.0280,  0.0035,  0.0058]]],
       requires_grad=True)

In [13]:
model.visual_encoder.get_intermediate_layers(pixel_values)

Blip2Qformer(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-38): 39 x Block(
        (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1408, out_features=4224, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1408, out_features=1408, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
