# Data loading

In [1]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split

cifar10_classes = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

cifar10_class_to_idx = {v: k for k, v in cifar10_classes.items()}

label_transform = lambda label: cifar10_classes[label]

dataset = torchvision.datasets.CIFAR10(
    root='/home/azureuser/gautijha37/vision/data',      # Directory where the data will be stored/loaded from
    train=True,        # Load the training set
    download=True,     # Download the dataset if it's not already present
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
    ]),
    target_transform=label_transform   
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Use random_split to safely create subsets
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [2]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from transformers import SiglipModel, SiglipImageProcessor, SiglipTokenizer, SiglipConfig, SiglipVisionConfig, SiglipTextConfig, SiglipVisionModel, SiglipTextModel

config = SiglipConfig()
config.vision_config.image_size=32
config.text_config.max_position_embeddings=4 # max length of tokenized classes

model = SiglipModel(config).to(device)

img_processor = SiglipImageProcessor(do_resize=False, do_rescale=False) # transforms.ToTensor() already scales input img to [-1, 1]
tokenizer = SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224")


  from .autonotebook import tqdm as notebook_tqdm


# Training

In [12]:
from tqdm import tqdm
from torch.utils.data import DataLoader
import time
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb

num_epochs = 20
# base_lr = 1e-3
# ref_batch_size = 4096
batch_size = 64
# lr = base_lr * torch.sqrt(batch_size/ref_batch_size)
lr = 1e-4

wandb.init(project="siglip-cifar10", config={
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": lr,
    "model": "SigLIP",
    "dataset": "CIFAR10"
})

def collate_fn(batch):
    images, labels = zip(*batch)
    img_inputs = img_processor(images=images, return_tensors="pt")
    text_inputs = tokenizer(text = labels, padding=True, return_tensors="pt")
    
    labels_ids = torch.tensor([cifar10_class_to_idx[label] for label in labels])
    batch_mask = (labels_ids.unsqueeze(1) == labels_ids.unsqueeze(0)).float()
    batch_mask = 2 * batch_mask - 1
    
    return img_inputs, text_inputs, batch_mask

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=collate_fn)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

for epoch in tqdm(range(num_epochs)):
    model.train()
    for batch_id, (img_inputs, text_inputs, batch_mask) in enumerate(train_loader):
        t0 = time.time()
        
        img_inputs['pixel_values']=img_inputs['pixel_values'].to(device)
        text_inputs['input_ids']=text_inputs['input_ids'].to(device)
        batch_mask=batch_mask.to(device)
        
        # loss calculation
        output = model(**text_inputs, **img_inputs)
        loglik = torch.nn.functional.logsigmoid(batch_mask * output.logits_per_text)
        nll = -torch.sum(loglik, dim=-1)
        loss = nll.mean()
        
        optimizer.zero_grad()
        loss.backward()
        
        wandb.log({
            "batch_loss": loss.item(),
            "epoch": epoch,
            "batch": batch_id,
            "step": epoch * len(train_loader) + batch_id
        })
        if (batch_id + 1) % 125 == 0:
            print(f"batched loss: {epoch:}, {batch_id + 1:}, {loss:}")
    scheduler.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}")
        

  0%|          | 0/20 [00:00<?, ?it/s]

batched loss: 0, 125, 8.251632690429688
batched loss: 0, 250, 10.437765121459961
batched loss: 0, 375, 7.887005805969238
batched loss: 0, 500, 8.607136726379395


  5%|▌         | 1/20 [00:32<10:18, 32.55s/it]

batched loss: 0, 625, 11.275980949401855
Epoch [1/20], Loss: 11.2760
batched loss: 1, 125, 12.086872100830078
batched loss: 1, 250, 10.13748550415039
batched loss: 1, 375, 10.029108047485352
batched loss: 1, 500, 9.858497619628906


 10%|█         | 2/20 [01:04<09:39, 32.21s/it]

batched loss: 1, 625, 8.93753433227539
Epoch [2/20], Loss: 8.9375
batched loss: 2, 125, 7.573966026306152
batched loss: 2, 250, 7.16526985168457
batched loss: 2, 375, 8.515593528747559
batched loss: 2, 500, 8.504768371582031


 15%|█▌        | 3/20 [01:37<09:16, 32.76s/it]

batched loss: 2, 625, 8.086873054504395
Epoch [3/20], Loss: 8.0869
batched loss: 3, 125, 7.386260986328125
batched loss: 3, 250, 8.679308891296387
batched loss: 3, 375, 5.694660186767578
batched loss: 3, 500, 10.428244590759277


 20%|██        | 4/20 [02:10<08:43, 32.72s/it]

batched loss: 3, 625, 7.595582008361816
Epoch [4/20], Loss: 7.5956
batched loss: 4, 125, 5.7448930740356445
batched loss: 4, 250, 7.888884544372559
batched loss: 4, 375, 8.85267448425293
batched loss: 4, 500, 7.029024124145508


 25%|██▌       | 5/20 [02:43<08:10, 32.69s/it]

batched loss: 4, 625, 6.560457706451416
Epoch [5/20], Loss: 6.5605
batched loss: 5, 125, 6.026599884033203
batched loss: 5, 250, 6.020657062530518
batched loss: 5, 375, 7.212881565093994
batched loss: 5, 500, 6.445387840270996


 30%|███       | 6/20 [03:14<07:30, 32.17s/it]

batched loss: 5, 625, 5.751923561096191
Epoch [6/20], Loss: 5.7519
batched loss: 6, 125, 8.649019241333008
batched loss: 6, 250, 6.059032440185547
batched loss: 6, 375, 4.5904541015625
batched loss: 6, 500, 6.245373725891113


 35%|███▌      | 7/20 [03:45<06:55, 31.95s/it]

batched loss: 6, 625, 7.074521064758301
Epoch [7/20], Loss: 7.0745
batched loss: 7, 125, 3.4181060791015625
batched loss: 7, 250, 3.8116278648376465
batched loss: 7, 375, 4.8702311515808105
batched loss: 7, 500, 4.888028144836426


 40%|████      | 8/20 [04:17<06:22, 31.90s/it]

batched loss: 7, 625, 8.323680877685547
Epoch [8/20], Loss: 8.3237
batched loss: 8, 125, 5.261958122253418
batched loss: 8, 250, 3.921198844909668
batched loss: 8, 375, 3.4627089500427246
batched loss: 8, 500, 5.434945106506348


 45%|████▌     | 9/20 [04:49<05:51, 31.98s/it]

batched loss: 8, 625, 4.468222618103027
Epoch [9/20], Loss: 4.4682
batched loss: 9, 125, 4.3728556632995605
batched loss: 9, 250, 4.545963287353516
batched loss: 9, 375, 4.304239273071289
batched loss: 9, 500, 7.176427841186523


 50%|█████     | 10/20 [05:23<05:24, 32.42s/it]

batched loss: 9, 625, 3.0454506874084473
Epoch [10/20], Loss: 3.0455
batched loss: 10, 125, 2.044349193572998
batched loss: 10, 250, 5.162559509277344
batched loss: 10, 375, 4.393728733062744
batched loss: 10, 500, 2.4485228061676025


 55%|█████▌    | 11/20 [05:55<04:51, 32.37s/it]

batched loss: 10, 625, 2.394949436187744
Epoch [11/20], Loss: 2.3949
batched loss: 11, 125, 2.3177614212036133
batched loss: 11, 250, 1.491883635520935
batched loss: 11, 375, 4.769649505615234
batched loss: 11, 500, 2.6373038291931152


 60%|██████    | 12/20 [06:27<04:18, 32.28s/it]

batched loss: 11, 625, 2.0363616943359375
Epoch [12/20], Loss: 2.0364
batched loss: 12, 125, 1.79353666305542
batched loss: 12, 250, 1.914278268814087
batched loss: 12, 375, 0.9808317422866821
batched loss: 12, 500, 1.2365622520446777


 65%|██████▌   | 13/20 [07:00<03:46, 32.39s/it]

batched loss: 12, 625, 0.8773010969161987
Epoch [13/20], Loss: 0.8773
batched loss: 13, 125, 2.913853645324707
batched loss: 13, 250, 1.869170904159546
batched loss: 13, 375, 1.7536046504974365
batched loss: 13, 500, 0.8908488750457764


 70%|███████   | 14/20 [07:31<03:13, 32.20s/it]

batched loss: 13, 625, 2.1419222354888916
Epoch [14/20], Loss: 2.1419
batched loss: 14, 125, 0.8355830311775208
batched loss: 14, 250, 0.6680774688720703
batched loss: 14, 375, 2.665609121322632
batched loss: 14, 500, 1.632032871246338


 75%|███████▌  | 15/20 [08:05<02:42, 32.47s/it]

batched loss: 14, 625, 1.8114547729492188
Epoch [15/20], Loss: 1.8115
batched loss: 15, 125, 1.2758708000183105
batched loss: 15, 250, 0.9957938194274902
batched loss: 15, 375, 1.9029319286346436
batched loss: 15, 500, 0.5994880795478821


 80%|████████  | 16/20 [08:37<02:10, 32.60s/it]

batched loss: 15, 625, 0.7992912530899048
Epoch [16/20], Loss: 0.7993
batched loss: 16, 125, 0.597846508026123
batched loss: 16, 250, 0.5750250220298767
batched loss: 16, 375, 1.4463105201721191
batched loss: 16, 500, 0.5460751056671143


 85%|████████▌ | 17/20 [09:10<01:37, 32.58s/it]

batched loss: 16, 625, 3.1886606216430664
Epoch [17/20], Loss: 3.1887
batched loss: 17, 125, 0.5369805097579956
batched loss: 17, 250, 0.5398163199424744
batched loss: 17, 375, 0.5343838930130005
batched loss: 17, 500, 2.0929360389709473


 90%|█████████ | 18/20 [09:42<01:04, 32.29s/it]

batched loss: 17, 625, 1.461958408355713
Epoch [18/20], Loss: 1.4620
batched loss: 18, 125, 0.5359780192375183
batched loss: 18, 250, 0.5427296757698059
batched loss: 18, 375, 0.5220982432365417
batched loss: 18, 500, 0.5435538291931152


 95%|█████████▌| 19/20 [10:14<00:32, 32.33s/it]

batched loss: 18, 625, 0.5519874095916748
Epoch [19/20], Loss: 0.5520
batched loss: 19, 125, 0.5661091804504395
batched loss: 19, 250, 2.934039354324341
batched loss: 19, 375, 1.2906641960144043
batched loss: 19, 500, 0.5394741296768188


100%|██████████| 20/20 [10:46<00:00, 32.31s/it]

batched loss: 19, 625, 0.5377059578895569
Epoch [20/20], Loss: 0.5377





# Validation

In [13]:
from safetensors.torch import save_file
state_dict = model.state_dict()
save_file(state_dict, "/home/azureuser/gautijha37/vision/siglip.safetensors5")

In [14]:
model.eval()

SiglipModel(
  (text_model): SiglipTextTransformer(
    (embeddings): SiglipTextEmbeddings(
      (token_embedding): Embedding(32000, 768)
      (position_embedding): Embedding(4, 768)
    )
    (encoder): SiglipEncoder(
      (layers): ModuleList(
        (0-11): 12 x SiglipEncoderLayer(
          (self_attn): SiglipSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): SiglipMLP(
            (activation_fn): PytorchGELUTanh()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,

In [15]:
from torch.utils.data import DataLoader
from tqdm import tqdm

val_loader = DataLoader(val_dataset, batch_size=len(cifar10_classes), num_workers=16, shuffle=True)
correct = 0

all_text_inputs = list(cifar10_classes.values())
text_inputs = tokenizer(text=all_text_inputs, padding=True, return_tensors="pt")
text_inputs['input_ids']=text_inputs['input_ids'].to(device)

for images, correct_labels in tqdm(val_loader):
    img_inputs = img_processor(images=images, return_tensors="pt")
    img_inputs['pixel_values']=img_inputs['pixel_values'].to(device)
    
    output = model(**text_inputs, **img_inputs)
    predicted_indices = output.logits_per_image.argmax(dim=1)
    predicted_labels = [cifar10_classes[k.item()] for k in predicted_indices]
    
    correct += sum(p == c for p, c in zip(predicted_labels, list(correct_labels)))

total = len(val_loader) * val_loader.batch_size
100 * correct/total

100%|██████████| 1000/1000 [00:15<00:00, 66.49it/s]


45.33

# Testing

In [16]:
test_dataset = torchvision.datasets.CIFAR10(
    root='/home/azureuser/gautijha37/vision/data',
    train=False,        # Load the test set
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
    ]),
    target_transform=label_transform
)

# Testing on test dataset
test_loader = DataLoader(test_dataset, batch_size=len(cifar10_classes), num_workers=16, shuffle=True)
correct = 0

all_text_inputs = list(cifar10_classes.values())
text_inputs = tokenizer(text=all_text_inputs, padding=True, return_tensors="pt")
text_inputs['input_ids']=text_inputs['input_ids'].to(device)

for images, correct_labels in tqdm(test_loader):
    img_inputs = img_processor(images=images, return_tensors="pt")
    img_inputs['pixel_values']=img_inputs['pixel_values'].to(device)
    
    output = model(**text_inputs, **img_inputs)
    predicted_indices = output.logits_per_image.argmax(dim=1)
    predicted_labels = [cifar10_classes[k.item()] for k in predicted_indices]
    
    correct += sum(p == c for p, c in zip(predicted_labels, list(correct_labels)))

total = len(test_loader) * test_loader.batch_size
test_accuracy = 100 * correct/total
print(f"Test Accuracy: {test_accuracy:.2f}%")

100%|██████████| 1000/1000 [00:15<00:00, 65.54it/s]

Test Accuracy: 46.87%



