In [1]:
from hyptorch.geoopt import PoincareBall   
from hyptorch.lorentz.manifold import CustomLorentz 
from model.modules.utils import ManifoldMapper
import torch


poincare = PoincareBall(c=1.0, learnable=True)
lorentz = CustomLorentz(k=1.0, learnable=True)

lmapper = ManifoldMapper(lorentz, clip_r=1.25, curv=lorentz.k)
pmapper = ManifoldMapper(lorentz, clip_r=1.25, curv=poincare.c)
vis_embed = torch.rand(100, 12, 256)
text_embed = torch.rand(100, 256)
print(vis_embed.T.shape)


vis_embed = lmapper(vis_embed)
text_embed = lmapper(text_embed)
lorentz.assert_check_point_on_manifold(vis_embed)
print('text_embed', text_embed.shape)
print('vision_embed', vis_embed.shape)

sim_q2t,_ = lorentz.dist_batch(
    vis_embed.unsqueeze(1), 
    text_embed.unsqueeze(-1),
).squeeze().max(1)


print(sim_q2t)
sim_t2q,_ = lorentz.dist_batch(
    text_embed.unsqueeze(1).unsqueeze(1),
    vis_embed.permute(0, 2, 1),
).squeeze().max(1)
print(sim_t2q)


  from .autonotebook import tqdm as notebook_tqdm


torch.Size([256, 12, 100])
text_embed torch.Size([100, 257])
vision_embed torch.Size([100, 12, 257])
tensor([[1.1281, 1.1504, 1.1722,  ..., 1.1488, 1.1289, 1.1030],
        [1.1409, 1.0988, 1.1540,  ..., 1.1626, 1.1323, 1.1301],
        [1.1425, 1.1422, 1.1160,  ..., 1.1393, 1.1579, 1.0931],
        ...,
        [1.1313, 1.1173, 1.1613,  ..., 1.1147, 1.1159, 1.1187],
        [1.1250, 1.1361, 1.1387,  ..., 1.1157, 1.1181, 1.1222],
        [1.1356, 1.1506, 1.1449,  ..., 1.1274, 1.1318, 1.1059]],
       grad_fn=<MaxBackward0>)
tensor([[1.1281, 1.1409, 1.1425,  ..., 1.1313, 1.1250, 1.1356],
        [1.1504, 1.0988, 1.1422,  ..., 1.1173, 1.1361, 1.1506],
        [1.1722, 1.1540, 1.1160,  ..., 1.1613, 1.1387, 1.1449],
        ...,
        [1.1488, 1.1626, 1.1393,  ..., 1.1147, 1.1157, 1.1274],
        [1.1289, 1.1323, 1.1579,  ..., 1.1159, 1.1181, 1.1318],
        [1.1030, 1.1301, 1.0931,  ..., 1.1187, 1.1222, 1.1059]],
       grad_fn=<MaxBackward0>)


  print(vis_embed.T.shape)


In [2]:
import torch.nn.functional as F
def itc_loss(self, image_embeds, text_embeds, image_worlds, text_worlds, sim_i2t_targets, sim_t2i_targets):
    sim_i2t = self.dist_func(image_embeds, text_worlds) 
    sim_t2i = self.dist_func(text_embeds, image_worlds)

    loss_i2t = -torch.sum(
        F.log_softmax(sim_i2t / self.logit_scale, dim=1) * sim_i2t_targets, dim=-1
    ).mean()
    loss_t2i = -torch.sum(
        F.log_softmax(sim_t2i / self.logit_scale, dim=1) * sim_t2i_targets, dim=-1
    ).mean()      

    loss_itc = self.config.weight_i2t * loss_i2t + (1-self.config.weight_i2t) * loss_t2i
    return loss_itc



In [3]:
from lavis.models import load_model_and_preprocess
  
model, vis_processors, txt_processors = load_model_and_preprocess("blip2", "coco", is_eval=False)

KeyboardInterrupt: 

In [None]:
model.visual_encoder

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-38): 39 x Block(
      (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1408, out_features=4224, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1408, out_features=1408, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1408, out_features=6144, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=6144, out_features=1408, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
)

In [None]:
model.query_tokens

Parameter containing:
tensor([[[-0.0143,  0.0018,  0.0118,  ..., -0.0255, -0.0029,  0.0191],
         [-0.0155,  0.0102,  0.0205,  ..., -0.0046,  0.0095, -0.0043],
         [-0.0010,  0.0106,  0.0050,  ..., -0.0086,  0.0077,  0.0237],
         ...,
         [ 0.0026,  0.0240,  0.0212,  ...,  0.0037, -0.0483,  0.0041],
         [ 0.0014,  0.0005, -0.0147,  ..., -0.0160, -0.0442,  0.0008],
         [ 0.0140,  0.0054,  0.0262,  ..., -0.0280,  0.0035,  0.0058]]],
       requires_grad=True)

In [None]:
model.visual_encoder.get_intermediate_layers(pixel_values)

Blip2Qformer(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-38): 39 x Block(
        (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1408, out_features=4224, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1408, out_features=1408, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
