In [None]:
from types import SimpleNamespace
from models import build_clip, TwoEncoderVLM
from peft import LoraModel, LoraConfig

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

vision_model, image_processor, text_model, tokenizer = build_clip(SimpleNamespace(
    clip_model_name="B32",      # one of: base32, base, large, huge, giga, meta-large, meta-huge
    mixed_precision="fp16",      # or "fp32"
    cache_dir=".cache"
))

model = TwoEncoderVLM(
    vision_model=vision_model,
    text_model=text_model,
    logit_scale=0.01,
    trainable_temp=True,
    proj_dim=512,
    tokenizer=tokenizer,
    image_processor=image_processor
)

print("Base model's number of parameters: ", count_trainable_parameters(model))

config = LoraConfig(
    task_type="FEATURE_EXTRACTION",
    r=8,
    lora_alpha=32,
    lora_dropout=0.01,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "text_projection", "visual_projection", "position_embedding", "token_embedding", "patch_embedding"],
)

model_lora = LoraModel(model, config, "lora_adapter")
print("LoRA model's number of parameters: ", count_trainable_parameters(model_lora))

  from .autonotebook import tqdm as notebook_tqdm


namespace(clip_model_name='B32', mixed_precision='fp16', cache_dir='.cache')
Building CLIP model: openai/clip-vit-base-patch32
Base model's number of parameters:  151277313
LoRA model's number of parameters:  2671608


In [2]:
display(model)

TwoEncoderVLM(
  (vision): CLIPVisionModelWithProjection(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): lora.Conv2d(
          (base_layer): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
          (lora_dropout): ModuleDict(
            (lora_adapter): Dropout(p=0.01, inplace=False)
          )
          (lora_A): ModuleDict(
            (lora_adapter): Conv2d(3, 8, kernel_size=(32, 32), stride=(32, 32), bias=False)
          )
          (lora_B): ModuleDict(
            (lora_adapter): Conv2d(8, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
          (lora_magnitude_vector): ModuleDict()
        )
        (position_embedding): lora.Embedding(
          (base_layer): Embedding(50, 768)
          (lora_dropout): ModuleDict(
            (lora_adapter): Dropout(p=0.01, inplace=False)
          )


In [3]:
import sys
if 'datasets.mscoco' in sys.modules:
    del sys.modules['datasets.mscoco']
if 'datasets' in sys.modules:
    del sys.modules['datasets']

from datasets.mscoco import MSCOCOCaptions

from datasets.mscoco import MSCOCOCaptions

train_dataset = MSCOCOCaptions(
    root="data/mscoco/images/train2017",
    annotations_file="data/mscoco/annotations/captions_train2017.json",
    image_transform=image_processor,
    caption_transform=tokenizer,
)

eval_dataset = MSCOCOCaptions(
    root="data/mscoco/images/val2017",
    annotations_file="data/mscoco/annotations/captions_val2017.json",
    image_transform=image_processor,
    caption_transform=tokenizer,
    resize_dataset=True,
)


In [4]:
import torch

def loss_fn(outputs, inputs, num_items_in_batch, temperature=0.07, **kwargs):
    vision_embeds = outputs['vision_embeds']
    text_embeds = outputs['text_embeds']
    batch_size = vision_embeds.size(0)
    logits = (vision_embeds @ text_embeds.t()) / temperature
    labels = torch.arange(batch_size).to(vision_embeds.device)
    loss_i2t = torch.nn.functional.cross_entropy(logits, labels)
    loss_t2i = torch.nn.functional.cross_entropy(logits.t(), labels)
    loss = (loss_i2t + loss_t2i) / 2
    return loss

def loss_fn_debug(*args, **kwargs):
    print("Debug: loss_fn_debug called with args:", len(args), "and kwargs:", kwargs)
    return loss_fn(*args, **kwargs)

In [5]:
from transformers import get_constant_schedule


if 'train' in sys.modules:
    del sys.modules['train']
import train

if 'losses' in sys.modules:
    del sys.modules["losses"]
import losses

loss_fn = losses.build_loss_fn("ma_bi_sw")


train.train(SimpleNamespace(
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=train_dataset.collate_fn,
    model=model,
    loss_fn=loss_fn,
    batch_size=2,
    num_epochs=1,
    lr=1e-4,
    warmup_ratio=0.1,
    output_dir="checkpoints",
    tqdm=True,
    logging_steps=2,
    save_steps=2,
    save_strategy="steps",
    debug=True,
    max_steps=4,
    scheduler=get_constant_schedule,
))

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /nfs/home/magnanini/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mmarcomag416[0m ([33mmarco-magnanini[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  super().__init__(loader)


Step,Training Loss,Validation Loss,Modality Gap,Alignment,Xsc-sr
0,No log,0.692032,0.94873,1.460938,1.28125


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/nfs/home/magnanini/ma_cir_paper/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 358, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/magnanini/ma_cir_paper/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/nfs/home/magnanini/ma_cir_paper/datasets/macir.py", line 116, in __getitem__
    image = self.preprocess(im, return_tensors='pt')['pixel_values'][0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not callable
