In [1]:
%pip install transformers datasets torch torchvision

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
import requests 

# imagenet labels can be found here
labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
imagenet_labels = requests.get(labels_url).json()

# domain specific labels of furniture images
furniture_labels = [
                    "barber chair",
                    "bookcase",
                    "china cabinet",
                    "chiffonier",
                    "chest",
                    "cradle",
                    "desk",
                    "dining table",
                    "filing cabinet", 
                    "folding chair",
                    "four-poster bed",
                    "infant bed",
                    "medicine chest",
                    "rocking chair",
                    "sofa",
                    "wardrobe"
                    ]

print(furniture_labels)

['barber chair', 'bookcase', 'china cabinet', 'chiffonier', 'chest', 'cradle', 'desk', 'dining table', 'filing cabinet', 'folding chair', 'four-poster bed', 'infant bed', 'medicine chest', 'rocking chair', 'sofa', 'wardrobe']


In [3]:
from huggingface_hub import notebook_login
from datasets import load_dataset
from transformers import ViTForImageClassification, AutoImageProcessor
from torch.utils.data import DataLoader
from PIL import Image
import torch

# login to huggingface
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [9]:
train_ds = load_dataset("imagenet-1k", split="train", streaming=True)
val_ds = load_dataset("imagenet-1k", split="validation", streaming=True)

def filter_furniture(example):
    return imagenet_labels[example['label']] in furniture_labels

furniture_train_ds = train_ds.filter(filter_furniture)
furniture_val_ds = val_ds.filter(filter_furniture)

In [10]:
for sample in furniture_train_ds.take(5):
    print(sample['label'], imagenet_labels[sample['label']])
    img = sample['image']
    img.show()

765 rocking chair
559 folding chair
559 folding chair
765 rocking chair
492 chest


In [6]:
# Load pretrained ViT model
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [11]:
# test ViT on sample image

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()

model.config.id2label[predicted_class_idx]

'Egyptian cat'

In [12]:
# Get initial accuracy on furniture data

num_samples = 0
num_correct = 0
for sample in furniture_train_ds:
    inputs = processor(sample['image'], return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = logits.argmax(-1).item()
    if predicted_class_id == sample['label']:
        num_correct += 1
    num_samples += 1

    
    print(imagenet_labels[predicted_class_id], imagenet_labels[sample['label']])
    img = sample['image']
    img.show()

    if num_samples == 10:
        break
    

accuracy = num_correct / num_samples

accuracy

rocking chair rocking chair
folding chair folding chair
folding chair folding chair
rocking chair rocking chair
chest chest
desktop computer desk
chiffonier chiffonier
cradle cradle
window shade dining table
medicine chest medicine chest


0.8

In [8]:


# Create preprocessing function for batched data
def preprocess_images(examples):
    # Process images in batch
    inputs = processor(examples['image'], return_tensors="pt")
    inputs['labels'] = examples['label']
    return inputs

# Create dataloaders with preprocessing
furniture_train_dataloader = DataLoader(
    furniture_train_ds.map(preprocess_images, batched=True),
    batch_size=16
)


furniture_val_dataloader = DataLoader(
    furniture_val_ds.map(preprocess_images, batched=True), 
    batch_size=16
)

In [None]:
# code for finetuning

# Set up training parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3

# # Training loop
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch in furniture_train_dataloader:
        # Move batch to device
        input_ids = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(furniture_train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")