# Data loading

In [1]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split

cifar10_classes = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

cifar10_class_to_idx = {v: k for k, v in cifar10_classes.items()}

label_transform = lambda label: cifar10_classes[label]

dataset = torchvision.datasets.CIFAR10(
    root='/home/azureuser/gautijha37/vision/data',      # Directory where the data will be stored/loaded from
    train=True,        # Load the training set
    download=True,     # Download the dataset if it's not already present
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
    ]),
    target_transform=label_transform   
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Use random_split to safely create subsets
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [2]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
device

device(type='cuda')

In [None]:
from transformers import SiglipModel, SiglipImageProcessor, SiglipTokenizer, SiglipConfig

config = SiglipConfig()
config.vision_config.image_size=32
config.text_config.max_position_embeddings=4 # max length of tokenized classes

config.text_config.num_hidden_layers=3 # using all 12 sees 4% accuracy increase.
config.vision_config.num_hidden_layers=3

model = SiglipModel(config).to(device)

img_processor = SiglipImageProcessor(do_resize=False, do_rescale=False) # transforms.ToTensor() already scales input img to [-1, 1]
tokenizer = SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224")


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
sum(p.numel() for p in model.parameters()) // 1e6

75.0

# Training

In [5]:
from tqdm import tqdm
def accuracy(model, tokenizer, data_loader):
    correct = 0
    all_text_inputs = list(cifar10_classes.values())
    text_inputs = tokenizer(text=all_text_inputs, padding=True, return_tensors="pt")
    text_inputs['input_ids']=text_inputs['input_ids'].to(device)

    for img_inputs, _, _, correct_labels in tqdm(data_loader):
        img_inputs['pixel_values']=img_inputs['pixel_values'].to(device)
        
        output = model(**text_inputs, **img_inputs)
        predicted_indices = output.logits_per_image.argmax(dim=1)
        predicted_labels = [cifar10_classes[k.item()] for k in predicted_indices]
        
        correct += sum(p == c for p, c in zip(predicted_labels, list(correct_labels)))

    total = len(data_loader) * data_loader.batch_size
    return 100 * correct/total

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb

num_epochs = 50
batch_size = 64
lr = 1e-4

wandb.init(project="siglip-cifar10", config={
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": lr,
    "model": "SigLIP",
    "dataset": "CIFAR10",
    "num_hidden_layers": config.text_config.num_hidden_layers,
    "image_size": config.vision_config.image_size,
    "max_position_embeddings": config.text_config.max_position_embeddings
})

# Dataloader
def collate_fn(batch):
    images, labels = zip(*batch)
    img_inputs = img_processor(images=images, return_tensors="pt")
    text_inputs = tokenizer(text = labels, padding=True, return_tensors="pt")
    
    labels_ids = torch.tensor([cifar10_class_to_idx[label] for label in labels])
    batch_mask = (labels_ids.unsqueeze(1) == labels_ids.unsqueeze(0)).float()
    batch_mask = 2 * batch_mask - 1
    
    return img_inputs, text_inputs, batch_mask, labels

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=collate_fn)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=1e-6
)

# loss calculation
def calculate_loss(model, batch):
    img_inputs, text_inputs, batch_mask, _ = batch
    img_inputs['pixel_values']=img_inputs['pixel_values'].to(device)
    text_inputs['input_ids']=text_inputs['input_ids'].to(device)
    batch_mask=batch_mask.to(device)
    
    # loss calculation
    output = model(**text_inputs, **img_inputs)
    loglik = torch.nn.functional.logsigmoid(batch_mask * output.logits_per_text)
    nll = -torch.sum(loglik, dim=-1)
    loss = nll.mean()
    
    return loss

# Training loop
for epoch in tqdm(range(num_epochs)):
    # Training
    model.train()
    for batch_id, batch in enumerate(train_loader):
        loss = calculate_loss(model, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        wandb.log({
            "batch_loss": loss.item(),
            "learning_rate": optimizer.param_groups[0]['lr']
        })
    
    scheduler.step()
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            val_loss += calculate_loss(model, batch)
    
    val_loss /= len(val_loader)
    val_accuracy = accuracy(model, tokenizer, val_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], train_loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_accuracy: {val_accuracy:.4f}")
    wandb.log({
        "epoch_train_loss": loss,
        "epoch_val_loss": val_loss,
        "epoch_val_accuracy": val_accuracy
    })
    

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgauti-jha37[0m ([33mgauti-jha37-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 157/157 [00:01<00:00, 104.55it/s]


Epoch [1/50], train_loss: 26.0088, val_loss: 26.1353, val_accuracy: 9.8627


100%|██████████| 157/157 [00:01<00:00, 103.59it/s]
  4%|▍         | 2/50 [00:30<12:13, 15.28s/it]

Epoch [2/50], train_loss: 25.1383, val_loss: 25.2227, val_accuracy: 11.0768


100%|██████████| 157/157 [00:01<00:00, 103.82it/s]
  6%|▌         | 3/50 [00:46<12:02, 15.36s/it]

Epoch [3/50], train_loss: 24.4650, val_loss: 24.5965, val_accuracy: 9.8029


100%|██████████| 157/157 [00:01<00:00, 104.47it/s]
  8%|▊         | 4/50 [01:01<11:39, 15.22s/it]

Epoch [4/50], train_loss: 25.1556, val_loss: 23.8149, val_accuracy: 13.7938


100%|██████████| 157/157 [00:01<00:00, 104.91it/s]
 10%|█         | 5/50 [01:16<11:26, 15.26s/it]

Epoch [5/50], train_loss: 23.3216, val_loss: 23.3786, val_accuracy: 12.2910


100%|██████████| 157/157 [00:01<00:00, 103.94it/s]
 12%|█▏        | 6/50 [01:32<11:15, 15.36s/it]

Epoch [6/50], train_loss: 23.4012, val_loss: 23.2427, val_accuracy: 16.2719


100%|██████████| 157/157 [00:01<00:00, 103.64it/s]
 14%|█▍        | 7/50 [01:47<11:02, 15.41s/it]

Epoch [7/50], train_loss: 23.5171, val_loss: 22.7798, val_accuracy: 16.9686


100%|██████████| 157/157 [00:01<00:00, 102.68it/s]
 16%|█▌        | 8/50 [02:03<10:51, 15.50s/it]

Epoch [8/50], train_loss: 22.2421, val_loss: 22.5570, val_accuracy: 11.1664


100%|██████████| 157/157 [00:01<00:00, 104.54it/s]
 18%|█▊        | 9/50 [02:18<10:33, 15.45s/it]

Epoch [9/50], train_loss: 21.9588, val_loss: 22.6275, val_accuracy: 14.6596


100%|██████████| 157/157 [00:01<00:00, 104.95it/s]
 20%|██        | 10/50 [02:34<10:17, 15.45s/it]

Epoch [10/50], train_loss: 21.9991, val_loss: 22.2400, val_accuracy: 16.0231


100%|██████████| 157/157 [00:01<00:00, 103.53it/s]
 22%|██▏       | 11/50 [02:49<09:56, 15.31s/it]

Epoch [11/50], train_loss: 20.8036, val_loss: 21.6226, val_accuracy: 18.0633


100%|██████████| 157/157 [00:01<00:00, 103.64it/s]
 24%|██▍       | 12/50 [03:04<09:44, 15.39s/it]

Epoch [12/50], train_loss: 22.2839, val_loss: 21.2578, val_accuracy: 17.9240


100%|██████████| 157/157 [00:01<00:00, 104.44it/s]
 26%|██▌       | 13/50 [03:19<09:24, 15.25s/it]

Epoch [13/50], train_loss: 21.4356, val_loss: 20.9978, val_accuracy: 17.9837


100%|██████████| 157/157 [00:01<00:00, 105.30it/s]
 28%|██▊       | 14/50 [03:34<09:10, 15.28s/it]

Epoch [14/50], train_loss: 21.9368, val_loss: 20.7510, val_accuracy: 18.0633


100%|██████████| 157/157 [00:01<00:00, 104.16it/s]
 30%|███       | 15/50 [03:50<08:56, 15.33s/it]

Epoch [15/50], train_loss: 20.5483, val_loss: 20.6277, val_accuracy: 18.5908


100%|██████████| 157/157 [00:01<00:00, 106.05it/s]
 32%|███▏      | 16/50 [04:05<08:42, 15.36s/it]

Epoch [16/50], train_loss: 20.0663, val_loss: 20.5129, val_accuracy: 20.4717


100%|██████████| 157/157 [00:01<00:00, 102.87it/s]
 34%|███▍      | 17/50 [04:21<08:27, 15.39s/it]

Epoch [17/50], train_loss: 20.1625, val_loss: 20.0411, val_accuracy: 24.0147


100%|██████████| 157/157 [00:01<00:00, 104.08it/s]
 36%|███▌      | 18/50 [04:36<08:14, 15.47s/it]

Epoch [18/50], train_loss: 19.2837, val_loss: 19.4839, val_accuracy: 29.7771


100%|██████████| 157/157 [00:01<00:00, 104.50it/s]
 38%|███▊      | 19/50 [04:52<07:59, 15.46s/it]

Epoch [19/50], train_loss: 19.2164, val_loss: 19.2899, val_accuracy: 30.0756


100%|██████████| 157/157 [00:01<00:00, 104.54it/s]
 40%|████      | 20/50 [05:07<07:43, 15.46s/it]

Epoch [20/50], train_loss: 22.1218, val_loss: 19.1033, val_accuracy: 32.3746


100%|██████████| 157/157 [00:01<00:00, 104.75it/s]
 42%|████▏     | 21/50 [05:23<07:29, 15.50s/it]

Epoch [21/50], train_loss: 20.5284, val_loss: 18.8708, val_accuracy: 33.6584


100%|██████████| 157/157 [00:01<00:00, 103.88it/s]
 44%|████▍     | 22/50 [05:38<07:14, 15.50s/it]

Epoch [22/50], train_loss: 18.7551, val_loss: 18.4695, val_accuracy: 34.5840


100%|██████████| 157/157 [00:01<00:00, 104.92it/s]
 46%|████▌     | 23/50 [05:54<06:58, 15.51s/it]

Epoch [23/50], train_loss: 17.9504, val_loss: 18.3148, val_accuracy: 37.1915


100%|██████████| 157/157 [00:01<00:00, 103.71it/s]
 48%|████▊     | 24/50 [06:09<06:43, 15.52s/it]

Epoch [24/50], train_loss: 15.7095, val_loss: 18.2777, val_accuracy: 37.0920


100%|██████████| 157/157 [00:01<00:00, 103.13it/s]
 50%|█████     | 25/50 [06:25<06:26, 15.46s/it]

Epoch [25/50], train_loss: 16.4187, val_loss: 17.8142, val_accuracy: 40.1473


100%|██████████| 157/157 [00:01<00:00, 104.17it/s]
 52%|█████▏    | 26/50 [06:40<06:11, 15.48s/it]

Epoch [26/50], train_loss: 17.1085, val_loss: 17.6667, val_accuracy: 39.8288


100%|██████████| 157/157 [00:01<00:00, 104.98it/s]
 54%|█████▍    | 27/50 [06:56<05:55, 15.45s/it]

Epoch [27/50], train_loss: 15.5486, val_loss: 17.7487, val_accuracy: 39.3611


100%|██████████| 157/157 [00:01<00:00, 102.34it/s]
 56%|█████▌    | 28/50 [07:11<05:39, 15.41s/it]

Epoch [28/50], train_loss: 16.8219, val_loss: 17.4909, val_accuracy: 41.0032


100%|██████████| 157/157 [00:01<00:00, 103.61it/s]
 58%|█████▊    | 29/50 [07:26<05:21, 15.33s/it]

Epoch [29/50], train_loss: 16.6237, val_loss: 17.3074, val_accuracy: 41.7994


100%|██████████| 157/157 [00:01<00:00, 103.99it/s]
 60%|██████    | 30/50 [07:41<05:06, 15.33s/it]

Epoch [30/50], train_loss: 13.2278, val_loss: 17.3827, val_accuracy: 42.0084


100%|██████████| 157/157 [00:01<00:00, 105.09it/s]
 62%|██████▏   | 31/50 [07:57<04:51, 15.35s/it]

Epoch [31/50], train_loss: 13.4633, val_loss: 17.3114, val_accuracy: 41.8690


100%|██████████| 157/157 [00:01<00:00, 104.17it/s]
 64%|██████▍   | 32/50 [08:12<04:36, 15.37s/it]

Epoch [32/50], train_loss: 15.1591, val_loss: 17.2452, val_accuracy: 41.9088


100%|██████████| 157/157 [00:01<00:00, 103.56it/s]
 66%|██████▌   | 33/50 [08:28<04:21, 15.39s/it]

Epoch [33/50], train_loss: 14.1658, val_loss: 17.3836, val_accuracy: 42.4861


100%|██████████| 157/157 [00:01<00:00, 104.88it/s]
 68%|██████▊   | 34/50 [08:43<04:04, 15.27s/it]

Epoch [34/50], train_loss: 14.1160, val_loss: 17.4987, val_accuracy: 41.7994


100%|██████████| 157/157 [00:01<00:00, 104.58it/s]
 70%|███████   | 35/50 [08:58<03:48, 15.26s/it]

Epoch [35/50], train_loss: 14.6254, val_loss: 17.5311, val_accuracy: 42.1576


100%|██████████| 157/157 [00:01<00:00, 103.48it/s]
 72%|███████▏  | 36/50 [09:13<03:32, 15.18s/it]

Epoch [36/50], train_loss: 12.4944, val_loss: 17.5681, val_accuracy: 42.3467


100%|██████████| 157/157 [00:01<00:00, 105.13it/s]
 74%|███████▍  | 37/50 [09:29<03:19, 15.34s/it]

Epoch [37/50], train_loss: 11.0346, val_loss: 17.6050, val_accuracy: 42.0183


100%|██████████| 157/157 [00:01<00:00, 102.78it/s]
 76%|███████▌  | 38/50 [09:44<03:02, 15.25s/it]

Epoch [38/50], train_loss: 10.6454, val_loss: 17.8368, val_accuracy: 42.2671


100%|██████████| 157/157 [00:01<00:00, 103.93it/s]
 78%|███████▊  | 39/50 [09:59<02:48, 15.29s/it]

Epoch [39/50], train_loss: 10.6916, val_loss: 18.0355, val_accuracy: 42.1975


100%|██████████| 157/157 [00:01<00:00, 104.62it/s]
 80%|████████  | 40/50 [10:14<02:31, 15.16s/it]

Epoch [40/50], train_loss: 12.2909, val_loss: 17.9816, val_accuracy: 42.1875


100%|██████████| 157/157 [00:01<00:00, 103.33it/s]
 82%|████████▏ | 41/50 [10:29<02:17, 15.25s/it]

Epoch [41/50], train_loss: 11.0858, val_loss: 17.9350, val_accuracy: 42.1676


100%|██████████| 157/157 [00:01<00:00, 103.46it/s]
 84%|████████▍ | 42/50 [10:45<02:02, 15.30s/it]

Epoch [42/50], train_loss: 11.7832, val_loss: 18.0157, val_accuracy: 42.2074


100%|██████████| 157/157 [00:01<00:00, 104.66it/s]
 86%|████████▌ | 43/50 [11:00<01:47, 15.37s/it]

Epoch [43/50], train_loss: 10.3816, val_loss: 18.2915, val_accuracy: 42.0183


100%|██████████| 157/157 [00:01<00:00, 105.06it/s]
 88%|████████▊ | 44/50 [11:15<01:31, 15.30s/it]

Epoch [44/50], train_loss: 10.0250, val_loss: 18.4042, val_accuracy: 42.0482


100%|██████████| 157/157 [00:01<00:00, 104.28it/s]
 90%|█████████ | 45/50 [11:31<01:16, 15.34s/it]

Epoch [45/50], train_loss: 10.1612, val_loss: 18.3246, val_accuracy: 42.1079


100%|██████████| 157/157 [00:01<00:00, 103.73it/s]
 92%|█████████▏| 46/50 [11:47<01:01, 15.41s/it]

Epoch [46/50], train_loss: 10.7350, val_loss: 18.4070, val_accuracy: 42.1477


100%|██████████| 157/157 [00:01<00:00, 104.54it/s]
 94%|█████████▍| 47/50 [12:02<00:45, 15.32s/it]

Epoch [47/50], train_loss: 10.7656, val_loss: 18.4504, val_accuracy: 41.7795


100%|██████████| 157/157 [00:01<00:00, 102.80it/s]
 96%|█████████▌| 48/50 [12:17<00:30, 15.22s/it]

Epoch [48/50], train_loss: 9.6865, val_loss: 18.4388, val_accuracy: 42.0084


100%|██████████| 157/157 [00:01<00:00, 103.36it/s]
 98%|█████████▊| 49/50 [12:32<00:15, 15.27s/it]

Epoch [49/50], train_loss: 10.5721, val_loss: 18.4961, val_accuracy: 41.8889


100%|██████████| 157/157 [00:01<00:00, 104.87it/s]
100%|██████████| 50/50 [12:47<00:00, 15.36s/it]

Epoch [50/50], train_loss: 11.3788, val_loss: 18.5073, val_accuracy: 41.8690





# Testing

In [8]:
from safetensors.torch import save_file
state_dict = model.state_dict()
save_file(state_dict, "/home/azureuser/gautijha37/vision/siglip.safetensors6")

In [7]:
test_dataset = torchvision.datasets.CIFAR10(
    root='/home/azureuser/gautijha37/vision/data',
    train=False,        # Load the test set
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
    ]),
    target_transform=label_transform
)

test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=16, shuffle=True, collate_fn=collate_fn)
print(accuracy(model, tokenizer, test_loader))

100%|██████████| 157/157 [00:01<00:00, 105.65it/s]

41.590366242038215



