In [1]:
from transformers import TFViTForImageClassification, ViTFeatureExtractor
from utils import *

  from .autonotebook import tqdm as notebook_tqdm
2025-05-01 02:00:57.986236: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-01 02:01:00.016475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746061260.768337   86068 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746061260.961954   86068 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746061262.684940   86068 computation_placer.cc:177] computation placer already r

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
def model_vit_pretrained(input_size: tuple[int, int], num_classes: int, fine_tune: bool = False, model_name: str = 'google/vit-small-patch16-224'):
    """
    Creates an image classification model using a pre-trained Vision Transformer (ViT).

    Args:
        input_size (tuple[int, int]): Target input image size (height, width) expected by the ViT model.
                                      The feature extractor will handle resizing images to this size.
        num_classes (int): Number of output classes for the classification head.
        fine_tune (bool): If True, unfreezes the weights of the pre-trained
                          ViT base model for fine-tuning. Defaults to False.
        model_name (str): The name of the pre-trained ViT model to load from
                          Hugging Face Hub. Defaults to 'google/vit-small-patch16-224'.

    Returns:
        tf.keras.Model: A Keras model ready for compilation and training.
        ViTFeatureExtractor: The feature extractor for preprocessing images.
    """
    # Load the pre-trained ViT model for image classification.
    # We specify the number of labels and explicitly ignore the mismatch
    # in the classification head dimensions, as we intend to train it.
    print(f"Loading pre-trained model: {model_name}")
    model = TFViTForImageClassification.from_pretrained(
        model_name,
        num_labels=num_classes,
        ignore_mismatched_sizes=True # Allows loading a new classification head
    )

    # Load the feature extractor associated with the model.
    # This handles resizing, normalization, etc.
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

    # --- Freezing/Unfreezing Layers ---
    base_model = model.vit # Access the base ViT model

    if not fine_tune:
        print("Freezing base ViT model weights.")
        base_model.trainable = False
    else:
        print("Unfreezing base ViT model weights for fine-tuning.")
        base_model.trainable = True

    # The classification head added by TFViTForImageClassification is automatically trainable.

    print(f"Model '{model_name}' loaded.")
    # Use the size defined by the feature extractor
    print(f"Input size expected by feature extractor: {feature_extractor.size}")
    print(f"Number of output classes: {num_classes}")
    print(f"Base model trainable: {base_model.trainable}")

    # Note: The model is returned uncompiled. Compilation should happen before training.
    return model, feature_extractor

In [4]:
VIT_MODEL_NAME = 'google/vit-base-patch16-224' # Smaller ViT model
NUM_CLASSES = 6 # cardboard, glass, metal, paper, plastic, trash
BATCH_SIZE = 32 # Adjust based on GPU memory
EPOCHS = 100 # Number of training epochs (adjust as needed)
LEARNING_RATE = 3e-5 # Common learning rate for ViT fine-tuning/transfer
FINE_TUNE_BASE = False # Start with frozen base model (transfer learning)

IMG_HEIGHT = 512
IMG_WIDTH = 384
INPUT_SIZE = (IMG_HEIGHT, IMG_WIDTH)

print(f"Using ViT Model: {VIT_MODEL_NAME}")
print(f"Expected Input Size: {INPUT_SIZE}")

Using ViT Model: google/vit-base-patch16-224
Expected Input Size: (512, 384)


In [5]:
model_vit , feature_extractor = model_vit_pretrained(INPUT_SIZE, NUM_CLASSES, fine_tune=FINE_TUNE_BASE, model_name=VIT_MODEL_NAME)

model_vit1 = LModel("vit1")
model_vit1.set_model(model_vit, feature_extractor)
model_vit1.summary()

Loading pre-trained model: google/vit-base-patch16-224


I0000 00:00:1746061294.516719   86068 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13499 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4080 SUPER, pci bus id: 0000:01:00.0, compute capability: 8.9
2025-05-01 02:01:36.641708: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: INVALID_ARGUMENT: Input to reshape is a tensor with 768000 values, but the requested shape has 4608
2025-05-01 02:01:36.642305: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: INVALID_ARGUMENT: Input to reshape is a tensor with 1000 values, but the requested shape has 6
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFViTForImageClassification: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing TFViTForImageClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing 

Freezing base ViT model weights.
Model 'google/vit-base-patch16-224' loaded.
Input size expected by feature extractor: {'height': 224, 'width': 224}
Number of output classes: 6
Base model trainable: False
Model 'vit1' set with 85803270 parameters.
Feature extractor requires input size: {'height': 224, 'width': 224}
Model: "tf_vi_t_for_image_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vit (TFViTMainLayer)        multiple                  85798656  
                                                                 
 classifier (Dense)          multiple                  4614      
                                                                 
Total params: 85803270 (327.31 MB)
Trainable params: 4614 (18.02 KB)
Non-trainable params: 85798656 (327.30 MB)
_________________________________________________________________




In [6]:
model_vit1.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=['accuracy']
)

Compiling model...
Model compiled.


In [7]:
hist_vit1 = model_vit1.fit(
    train_path="train",
    test_path="test",
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    model_save_dir="models/vit1",
    cache=True,
    continue_training=False
)

Found 2274 images belonging to 6 classes.
Found 253 images belonging to 6 classes.
Class Indices: {'cardboard': 0, 'glass': 1, 'metal': 2, 'paper': 3, 'plastic': 4, 'trash': 5}
Using cached model and history without further training.


In [8]:
if hist_vit1: # Only evaluate if training was successful or loaded
    print("\n--- Starting Evaluation Phase ---")
    model_vit1.evaluate(model_save_dir="models/vit1")
else:
    print("Skipping evaluation as training did not produce a history.")

Skipping evaluation as training did not produce a history.
