In [None]:
from transformers import TFViTForImageClassification, ViTFeatureExtractor
from utils import *
from tensorflow.keras.callbacks import EarlyStopping

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

In [None]:
IMG_SIZE = (512,384)
early_stopping = EarlyStopping(monitor='accuracy', patience=20)

In [None]:
def model_vit_pretrained(size: tuple[int, int], fine_tune: bool = False, model_name: str = 'google/vit-small-patch16-224'):
    """
    Creates an image classification model using a pre-trained Vision Transformer (ViT).

    Args:
        size (tuple[int, int]): Target input image size (height, width).
                                Note: ViT models often expect specific input sizes (e.g., 224x224).
                                The feature extractor will handle resizing. The default model
                                'google/vit-small-patch16-224' expects 224x224 inputs.
        fine_tune (bool): If True, unfreezes the weights of the pre-trained
                          ViT base model for fine-tuning. Defaults to False.
        model_name (str): The name of the pre-trained ViT model to load from
                          Hugging Face Hub. Defaults to 'google/vit-small-patch16-224'.
                          Other options include 'google/vit-base-patch16-224-in21k', etc.

    Returns:
        tf.keras.Model: A Keras model ready for compilation and training.
        ViTFeatureExtractor: The feature extractor for preprocessing images.
    """
    num_classes = 6 # Number of trash classes

    # Load the pre-trained ViT model for image classification.
    # We specify the number of labels and explicitly ignore the mismatch
    # in the classification head dimensions, as we intend to train it.
    print(f"Loading pre-trained model: {model_name}")
    model = TFViTForImageClassification.from_pretrained(
        model_name,
        num_labels=num_classes,
        ignore_mismatched_sizes=True # Allows loading a new classification head
    )

    # Load the feature extractor associated with the model.
    # This handles resizing, normalization, etc.
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

    # --- Freezing/Unfreezing Layers ---
    # Access the base ViT model within the classification model
    base_model = model.vit

    if not fine_tune:
        # Freeze all layers in the base ViT model
        print("Freezing base ViT model weights.")
        base_model.trainable = False
    else:
        # Unfreeze all layers in the base ViT model for fine-tuning
        print("Unfreezing base ViT model weights for fine-tuning.")
        base_model.trainable = True
        # Note: Unlike the ResNet example, fine-tuning transformers often involves
        # unfreezing the entire base or specific blocks rather than just the last N layers.
        # Unfreezing the whole base is a common strategy.

    # The classification head added by TFViTForImageClassification is automatically trainable.

    # --- Model Definition (Implicit) ---
    # The `TFViTForImageClassification` class already combines the base ViT
    # and the classification head into a single Keras model.

    print(f"Model '{model_name}' loaded.")
    # The feature extractor knows the required input size
    print(f"Input size expected by feature extractor: {feature_extractor.size}")
    print(f"Number of output classes: {num_classes}")
    print(f"Base model trainable: {base_model.trainable}")

    # Note: The model is returned uncompiled. Compilation should happen before training.
    return model, feature_extractor