In [1]:
import torchvision
import torch
import tensorflow as tf
from torch import nn
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import os
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader, random_split
import wandb 
from transformers import TrainingArguments, Trainer

2023-12-20 15:50:12.311742: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-20 15:50:12.312248: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-20 15:50:12.387039: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-20 15:50:12.583875: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [35]:
# Login to wandb
wandb.login()

True

In [36]:
# Initialize wandb for logging
wandb.init(project="caltech-101")

os.environ["WANDB_PROJECT"]="caltech-101"
os.environ["WANDB_LOG_MODEL"]="true"
os.environ["WANDB_WATCH"]="false"


In [37]:
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [38]:
# Initialize the CLIP model
model_id = "openai/clip-vit-base-patch32"

processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)
model.to(device)
model.eval() # Freeze the CLIP model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [39]:
# Preprocess the data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(3),
    transforms.ToTensor(),
])

In [40]:
# Load in the Caltech 101 dataset
dataset = torchvision.datasets.Caltech101(root="caltech-101/", download=False, transform=transform)

# Split the data into training and validation data
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [41]:
# Define the lightweight neural network
class SimpleClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_rate):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        # x = self.dropout(x)
        x = self.fc2(x)
        return x

In [42]:
# Train the classifier and log results with wandb
hidden_size = 256
dropout_rate = 0.2
num_classes = 101
classifier = SimpleClassifier(512, hidden_size, num_classes, dropout_rate) 
classifier.to(device)

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

num_epochs = 10

for epoch in range(num_epochs):
    classifier.train()
    for batch in train_dataloader:
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)

        text = labels.tolist()
        text = [str(x) for x in text]
        inputs = processor(text=text, images=images, return_tensors="pt", padding=True)
        inputs.to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            embeddings = outputs['text_embeds']

        # Forward pass
        outputs = classifier(embeddings)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation step
    classifier.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for batch in val_dataloader:
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)

            text = labels.tolist()
            text = [str(x) for x in text]
            inputs = processor(text=text, images=images, return_tensors="pt", padding=True)
            inputs.to(device)

            outputs = model(**inputs)
            embeddings = outputs['text_embeds']

            outputs = classifier(embeddings)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    # Log results to wandb
    wandb.log({
        "Epoch": epoch + 1, 
        "Validation Accuracy": accuracy, 
        "Validation Loss": loss.item()
    })

    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {accuracy:.4f}, Loss: {loss.item():.4f}")


Epoch 1/10, Validation Accuracy: 0.3744, Loss: 2.7053
Epoch 2/10, Validation Accuracy: 0.5069, Loss: 1.7720
Epoch 3/10, Validation Accuracy: 0.6970, Loss: 1.1822
Epoch 4/10, Validation Accuracy: 0.7725, Loss: 1.1906
Epoch 5/10, Validation Accuracy: 0.8174, Loss: 0.5864
Epoch 6/10, Validation Accuracy: 0.9026, Loss: 0.6122
Epoch 7/10, Validation Accuracy: 0.9130, Loss: 0.4463
Epoch 8/10, Validation Accuracy: 0.9291, Loss: 0.3485
Epoch 9/10, Validation Accuracy: 0.9724, Loss: 0.3494
Epoch 10/10, Validation Accuracy: 0.9804, Loss: 0.2589


In [43]:
# Finish the wandb run
wandb.finish()



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Epoch,▁▂▃▃▄▅▆▆▇█
Validation Accuracy,▁▃▅▆▆▇▇▇██
Validation Loss,█▅▄▄▂▂▂▁▁▁

0,1
Epoch,10.0
Validation Accuracy,0.98041
Validation Loss,0.25887
