In [0]:
mount_path = '/mnt/vision-test'
normal_data = f'{mount_path}/gold_dataset_normal'
abnormal_data = f'{mount_path}/gold_dataset_abnormal'

df_augmented = (
    spark.read.parquet(normal_data)
    .union(spark.read.parquet(abnormal_data)
           )
)

# display(df_augmented)

dataset_name = "training_dataset_augemented"
df_augmented.write.mode("overwrite").format("parquet").save(f"{mount_path}/{dataset_name}")
df_augmented.write.mode("overwrite").format("delta").saveAsTable(dataset_name)

In [0]:
from datasets import Dataset

df = spark.read.table(dataset_name)
dataset = Dataset.from_spark(df)

num_labels = len(dataset.unique('label'))
print(f"{num_labels} labels found in the dataset")
display(dataset)
display(dataset.features)

2 labels found in the dataset


Dataset({
    features: ['name', 'path', 'label', 'image'],
    num_rows: 960
})

{'name': Value(dtype='string', id=None),
 'path': Value(dtype='string', id=None),
 'label': Value(dtype='int32', id=None),
 'image': Value(dtype='binary', id=None)}

In [0]:
import io
import time
import pandas as pd
from PIL import Image
import torch
import numpy as np
from sklearn.metrics import accuracy_score
import mlflow.pytorch

# Add device selection logic
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')
    

device = get_device()
print(f"Using device: {device}")

    


Using device: cpu


In [0]:
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification
)


class WrappedViT(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, x):
        outputs = self.model(x)
        return outputs.logits # only return logits (a single tensor)


# Specify a pre-trained model
model_checkpoint = "google/vit-base-patch16-224"
image_processor = AutoImageProcessor.from_pretrained(
    model_checkpoint,
    use_fast = True,
    do_resize= True,
    size = 224
    )

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    ignore_mismatched_sizes=True,
)


# Wrap the model
wrapped_model = WrappedViT(model)


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [0]:
model.classifier = torch.nn.Linear(model.config.hidden_size, num_labels)
torch.nn.init.xavier_uniform_(model.classifier.weight)
model.classifier.bias.data.fill_(0)

model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [0]:
# Define a preprocessing function to handle binary inmage data

def preprocess(example):
    # Convert binary image data to PIL Image with RGB channel
    image = Image.open(io.BytesIO(example['image'])).convert("RGB")

    # Process the image using the image processor
    processed_image = image_processor(images=image, return_tensors="pt")

    # [1,3,224,224] -> [3,224,224]
    example['pixel_values'] = processed_image['pixel_values'].squeeze()
    example['pixel_values'] = example['pixel_values'].to(device)
    return example

# Apply the preprocessing function to the dataset
dataset = dataset.map(preprocess)

# Set the format of dataset to PyTorch Tensors
dataset.set_format(type='torch', columns=['pixel_values', 'label'])

# Split the dataset into training and validation sets
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset['train']
eval_dataset = dataset['test']



Map:   0%|          | 0/960 [00:00<?, ? examples/s]

In [0]:
train_ds_abnormal = len(train_dataset['label'][train_dataset['label'] == 0])
train_ds_normal = len(train_dataset['label'][train_dataset['label'] == 1])
eval_ds_abnormal = len(eval_dataset['label'][eval_dataset['label'] == 0])
eval_ds_normal = len(eval_dataset['label'][eval_dataset['label'] == 1])

print(f"Training dataset: {train_ds_abnormal} abnormal, {train_ds_normal} normal")
print(f"Eval dataset: {eval_ds_abnormal} abnormal, {eval_ds_normal} normal")


Training dataset: 132 abnormal, 636 normal
Eval dataset: 36 abnormal, 156 normal


In [0]:
from transformers import TrainingArguments

import os# Set environment vairables to avoid the warning
os.environ['OMP_NUM_THREADS'] = '16'
os.environ['MKL_NUM_THREADS'] ='16'

model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    output_dir=f"/tmp/huggingface/{model_name}-finetuned-dog",
    remove_unused_columns=False,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    gradient_accumulation_steps=1,
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    ddp_find_unused_parameters=False
)

In [0]:
import evaluate

accuracy = evaluate.load("f1")

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

    

Downloading builder script: 0.00B [00:00, ?B/s]

In [0]:
from transformers import Trainer, EarlyStoppingCallback
from mlflow.models.signature import infer_signature

# Start an MLflow run
run_name = f"vit-classification-wrapped-{time.strftime('%Y-%m-%d-%H-%M-%S')}"
with mlflow.start_run(run_name=run_name):
    early_stop = EarlyStoppingCallback(early_stopping_patience=5)
    
    # Initialize
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        callbacks=[early_stop]
    )

    # Train the model
    train_result = trainer.train()


    # Log training metrics
    mlflow.log_metrics(train_result.metrics)

    # Evaluate and log metrics
    eval_metrics = trainer.evaluate()
    mlflow.log_metrics(eval_metrics)


    # Get a sample input and prepare it for signature
    sample_input = next(iter(eval_dataset))
    input_tensor = sample_input['pixel_values'].unsqueeze(0) #[-1,3,224,224]

    # Get model prediction for signature
    with torch.no_grad():
        model.eval()
        sample_output = model(input_tensor)
    
    # Convert to numpy arrays for MLflow
    input_array = input_tensor.cpu().numpy()
    output_array = sample_output.logits.cpu().numpy()

    # Create signature
    signature = infer_signature(input_array, output_array)

    # Log requirements
    reqs = mlflow.transformers.get_default_pip_requirements(model)

    # Log the model with MLflow
    mlflow.pytorch.log_model(
        pytorch_model = wrapped_model,
        artifact_path = 'model',
        signature = signature,
        pip_requirements = reqs
        )

    
    # Log the input dataset for lineage tracking from table to model
    src_dataset = mlflow.data.load_delta(
        table_name = dataset_name
    )
    mlflow.log_input(src_dataset, context='Training-Input')

[2025-09-15 17:33:06,816] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cpu (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


Epoch,Training Loss,Validation Loss,F1
1,0.4071,0.395547,0.925816


Uploading artifacts:   0%|          | 0/10 [00:00<?, ?it/s]

