# Loading model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from dotenv import load_dotenv
load_dotenv()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from siglip_paligemma.models import SiglipWithPoolingHead
from siglip_paligemma.configs import Config

config = Config()
model = SiglipWithPoolingHead(config.model_config).to(device)

  from .autonotebook import tqdm as notebook_tqdm


model_params=412.442352M, trainable_params=16.391352M


# Loading Dataset

In [None]:
from siglip_paligemma.data import ImageDataLoader

data_loader = ImageDataLoader(config.data_config)
train_loader, val_loader, test_loader = data_loader.get_dataloaders()
len(train_loader), len(val_loader), len(test_loader)

total_train_samples=1024569, total_val_samples/1e6=0.256598




(501, 126, 25)

# Training bottleneck

In [7]:
import time
from tqdm import tqdm
num_batches = 10

start = time.time()
for i, batch in enumerate(tqdm(train_loader, total=num_batches)):
    if i >= num_batches:
        break
    images, labels = batch
    images=images.to(device, non_blocking=True)
    labels=labels.to(device, non_blocking=True)
    
end = time.time()
(end - start)/num_batches

100%|██████████| 10/10 [00:19<00:00,  1.95s/it]


1.9513267755508423

In [6]:
## Identifying Training bottleneck
import time
from tqdm import tqdm
from torch.amp import autocast

data_time = 0
forward_time = 0
backward_time = 0

def time_batch(num_batches=10):
    global data_time, forward_time, backward_time
    model.train()
    start = time.time()
    for i, batch in enumerate(tqdm(train_loader, total=num_batches)):
        if i >= num_batches:
            break
        
        images, labels = batch
        images=images.to(device, non_blocking=True)
        labels=labels.to(device, non_blocking=True)
        torch.cuda.synchronize()
        end = time.time()
        data_time += end - start
        
        with autocast(device_type='cuda'):
            out = model(pixel_values=images)
            
            torch.cuda.synchronize()
            forward_time += time.time() - end

            torch.cuda.synchronize()
            start = time.time()
            loss = torch.nn.functional.cross_entropy(out, labels)
        
        loss.backward()
        torch.cuda.synchronize()
        backward_time += time.time() - start
        
        torch.cuda.synchronize()
        start = time.time()

time_batch()

total_time = (data_time + forward_time + backward_time)/10 * len(train_loader)
data_time, forward_time, backward_time, total_time

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [01:04<00:00,  6.44s/it]


(17.54414463043213, 46.2051842212677, 0.6553661823272705, 3226.6752212047577)

# Training

In [None]:
from siglip_paligemma.train import Trainer
trainer = Trainer(config, device, model, train_loader, val_loader, test_loader)
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgauti-jha37[0m ([33mgauti-jha37-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training: 507it [37:57,  4.49s/it]                         
Validation: 132it [09:38,  4.38s/it]
Training:  72%|███████▏  | 363/501 [27:22<10:21,  4.50s/it]

In [6]:
trainer.save_checkpoint("siglip.safetensors1")

In [5]:
trainer.evaluate()

28it [02:51,  6.12s/it]                        

Evaluation: test_loss=0.7305192708969116, test_acc=82.654





# Paligemma

In [3]:
from siglip_paligemma.models import SiglipWithPoolingHead
from siglip_paligemma.configs import Config

config = Config()
config.model_config.model_name = "google/paligemma-3b-pt-224"
model = SiglipWithPoolingHead(config.model_config).to(device)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  6.82it/s]


model_params=412.442352M, trainable_params=16.391352M


In [None]:
from siglip_paligemma.data import ImageDataLoader

data_loader = ImageDataLoader(config.data_config)
train_loader, val_loader, test_loader = data_loader.get_dataloaders()
len(train_loader), len(val_loader), len(test_loader)

In [None]:
from siglip_paligemma.train import Trainer
trainer1 = Trainer(config, device, model, train_loader, val_loader, test_loader)
trainer1.train()

In [None]:
trainer1.save_checkpoint("siglip.safetensors2")
trainer1.evaluate()

-----

In [7]:
import wandb
from tqdm import tqdm
from torch.amp import autocast, GradScaler

num_epochs=10
lr=1e-4

wandb.init(project="siglip-imagenet-1k", config={
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": lr,
    "num_workers": train_loader.num_workers,
    "model": "SigLIP",
    "dataset": "imagenet-1k",
    "vision_use_head": config.vision_use_head
})

# Optimizer
optimizer = torch.optim.AdamW(pooling_head.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=1e-6
)

scaler = GradScaler()

for epoch in range(num_epochs):
    # Training
    pooling_head.train()
    for batch in tqdm(train_loader, total=train_loader_len):
        images, labels = batch
        labels=labels.to(device, non_blocking=True)
        images=images.to(device, non_blocking=True)
        
        with autocast('cuda'):
            out = model(pixel_values=images)
            out = pooling_head(out.last_hidden_state)
            loss = torch.nn.functional.cross_entropy(out, labels)
        
        # Backpropagation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # Log metrics
        wandb.log({
            "batch_loss": loss.item(),
            "lr": scheduler.get_last_lr()[0]
        })
    
    scheduler.step()
    
    # Validation
    pooling_head.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, total=val_loader_len):
            images, labels = batch
            labels=labels.to(device, non_blocking=True)
            images=images.to(device, non_blocking=True)
            
            with autocast('cuda'):
                out = model(pixel_values=images)
                out = pooling_head(out.last_hidden_state)
                curr_loss = torch.nn.functional.cross_entropy(out, labels)
                
            val_loss += curr_loss.item()
            predicted = torch.argmax(out, dim=1)
            correct += (predicted == labels).sum().item()
            total += len(labels)
            
    val_loss /= val_loader_len
    val_acc = 100 * correct/total
    wandb.log({
        "val_loss": val_loss,
        "val_acc": val_acc
    })

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgauti-jha37[0m ([33mgauti-jha37-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


507it [09:08,  1.08s/it]                         
132it [03:11,  1.45s/it]                         
507it [08:57,  1.06s/it]                         
132it [03:09,  1.44s/it]                         
507it [08:42,  1.03s/it]                         
132it [03:05,  1.41s/it]                         
507it [08:35,  1.02s/it]                         
132it [03:07,  1.42s/it]                         
507it [08:34,  1.02s/it]                         
132it [03:09,  1.43s/it]                         
