In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from datasets import Dataset, ClassLabel
from transformers import AutoModelForImageClassification, AutoFeatureExtractor, TrainingArguments, Trainer, ViTForImageClassification
from huggingface_hub import hf_hub_download

In [6]:
# Define a transform to preprocess the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images
    transforms.ToTensor(),          # Convert images to PyTorch tensors
])

num_classes = 5

# Load the dataset
dataset_torch = datasets.ImageFolder(root='objects', transform=transform)
class_names = dataset_torch.classes

dataset = Dataset.from_dict({
    "pixel_values": [img.numpy() for img, _ in dataset_torch],  # Convert tensors to numpy
    "label": [label for _, label in dataset_torch],
})
dataset = dataset.cast_column(
    "label", 
    ClassLabel(names=class_names)  # Map the class names from torchvision dataset
)

split_datasets = dataset.train_test_split(test_size=0.2)
train_dataset = split_datasets["train"]
test_dataset = split_datasets["test"]

Casting the dataset:   0%|          | 0/176 [00:00<?, ? examples/s]

In [3]:
repo_id = "facebook/sparsh-mae-small"
filename = "mae_vitsmall.safetensors"

file_path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir="pretrained_models/facebook/sparsh-mae-small")
print(f"Downloaded file path: {file_path}")

Downloaded file path: pretrained_models/facebook/sparsh-mae-small/mae_vitsmall.safetensors


In [4]:
model = ViTForImageClassification.from_pretrained("pretrained_models/facebook/sparsh-mae-small", num_labels=num_classes, use_safetensors=True)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at pretrained_models/facebook/sparsh-mae-small and are newly initialized: ['classifier.bias', 'classifier.weight', 'embeddings.cls_token', 'embeddings.patch_embeddings.projection.bias', 'embeddings.patch_embeddings.projection.weight', 'embeddings.position_embeddings', 'encoder.layer.0.attention.attention.key.bias', 'encoder.layer.0.attention.attention.key.weight', 'encoder.layer.0.attention.attention.query.bias', 'encoder.layer.0.attention.attention.query.weight', 'encoder.layer.0.attention.attention.value.bias', 'encoder.layer.0.attention.attention.value.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.layernorm_after.bias', 'encoder.layer.0.layernorm_after.weight', 'encoder.layer.0.layernorm_before.bias', 'encoder.layer.0.layernorm_before

In [7]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=5,
    per_device_eval_batch_size=5,
    num_train_epochs=5,
    logging_dir="./logs",
    logging_steps=10
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

trainer.train()



Epoch,Training Loss,Validation Loss
1,1.9837,1.666364
2,1.6528,1.624359
3,1.6428,1.703058
4,1.6406,1.577521
5,1.6362,1.607208


TrainOutput(global_step=140, training_loss=1.7958145482199532, metrics={'train_runtime': 798.7864, 'train_samples_per_second': 0.876, 'train_steps_per_second': 0.175, 'total_flos': 5.42458512562176e+16, 'train_loss': 1.7958145482199532, 'epoch': 5.0})

In [8]:
results = trainer.evaluate()
print(results)

{'eval_loss': 1.6072077751159668, 'eval_runtime': 15.5824, 'eval_samples_per_second': 2.31, 'eval_steps_per_second': 0.513, 'epoch': 5.0}


In [None]:
# Get predictions on the test set
predictions = trainer.predict(test_dataset)

# Get the predicted class labels (argmax over the logits)
preds = predictions.predictions.argmax(-1)

labels = predictions.label_ids

In [15]:
print(dataset_torch.classes)

['BNC', 'hdmi_cable', 'hdmi_port', 'usb_cable', 'usb_port']
