In [0]:
mount_path = '/mnt/vision-test'
normal_data = f'{mount_path}/gold_dataset_normal'
abnormal_data = f'{mount_path}/gold_dataset_abnormal'

df_augmented = (
    spark.read.parquet(normal_data)
    .union(spark.read.parquet(abnormal_data)
           )
)

# display(df_augmented)

dataset_name = "training_dataset_augemented"
df_augmented.write.mode("overwrite").format("parquet").save(f"{mount_path}/{dataset_name}")
df_augmented.write.mode("overwrite").format("delta").saveAsTable(dataset_name)

In [0]:
from datasets import Dataset

df = spark.read.table(dataset_name)
dataset = Dataset.from_spark(df)

num_labels = len(dataset.unique('label'))
print(f"{num_labels} labels found in the dataset")
display(dataset)
display(dataset.features)

In [0]:
import io
import time
import pandas as pd
from PIL import Image
import torch
import numpy as np
from sklearn.metrics import accuracy_score
import mlflow.pytorch

# Add device selection logic
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')
    

device = get_device()
print(f"Using device: {device}")

    


In [0]:
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification
)

# Specify a pre-trained model
model_checkpoint = "google/vit-base-patch16-224"
image_processor = AutoImageProcessor.from_pretrained(
    model_checkpoint,
    use_fast = True,
    do_resize= True,
    size = 224
    )

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    ignore_mismatched_sizes=True,
)


In [0]:
model.classifier = torch.nn.Linear(model.config.hidden_size, num_labels)
torch.nn.init.xavier_uniform_(model.classifier.weight)

model.to(device)

In [0]:
# Define a preprocessing function to handle binary inmage data

def preprocess(example):
    # Convert binary image data to PIL Image with RGB channel
    image = Image.open(io.BytesIO(example['image'])).convert("RGB")

    # Process the image using the image processor
    processed_image = image_processor(images=image, return_tensors="pt")

    # [1,3,224,224] -> [3,224,224]
    example['pixel_values'] = processed_image['pixel_values'].squeeze()
    example['pixel_values'] = example['pixel_values'].to(device)
    return example

# Apply the preprocessing function to the dataset
dataset = dataset.map(preprocess)

# Set the format of dataset to PyTorch Tensors
dataset.set_format(type='torch', columns=['pixel_values', 'label'])

# Split the dataset into training and validation sets
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset['train']
eval_dataset = dataset['test']



In [0]:
train_ds_abnormal = len(train_dataset['label'][train_dataset['label'] == 0])
train_ds_normal = len(train_dataset['label'][train_dataset['label'] == 1])
eval_ds_abnormal = len(eval_dataset['label'][eval_dataset['label'] == 0])
eval_ds_normal = len(eval_dataset['label'][eval_dataset['label'] == 1])

print(f"Training dataset: {train_ds_abnormal} abnormal, {train_ds_normal} normal")
print(f"Eval dataset: {eval_ds_abnormal} abnormal, {eval_ds_normal} normal")


In [0]:
from transformers import TrainingArguments

import os# Set environment vairables to avoid the warning
os.environ['OMP_NUM_THREADS'] = '16'
os.environ['MKL_NUM_THREADS'] ='16'

model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    output_dir=f"/tmp/huggingface/{model_name}-finetuned-dog",
    remove_unused_columns=False,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    gradient_accumulation_steps=1,
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    ddp_find_unused_parameters=False
)

In [0]:
import evaluate

accuracy = evaluate.load("f1")

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

    

In [0]:
from transformers import Trainer, EarlyStoppingCallback
from mlflow.models.signature import infer_signature

# Start an MLflow run
run_name = f"vit-classification-{time.strftime('%Y-%m-%d-%H-%M-%S')}"
with mlflow.start_run(run_name=run_name):
    early_stop = EarlyStoppingCallback(early_stopping_patience=5)
    
    # Initialize
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        callbacks=[early_stop]
    )

    # Train the model
    train_result = trainer.train()


    # Log training metrics
    mlflow.log_metrics(train_result.metrics)

    # Evaluate and log metrics
    eval_metrics = trainer.evaluate()
    mlflow.log_metrics(eval_metrics)


    # Get a sample input and prepare it for signature
    sample_input = next(iter(eval_dataset))
    input_tensor = sample_input['pixel_values'].unsqueeze(0) #[-1,3,224,224]

    # Get model prediction for signature
    with torch.no_grad():
        model.eval()
        sample_output = model(input_tensor)
    
    # Convert to numpy arrays for MLflow
    input_array = input_tensor.cpu().numpy()
    output_array = sample_output.logits.cpu().numpy()

    # Create signature
    signature = infer_signature(input_array, output_array)

    # Log requirements
    reqs = mlflow.transformers.get_default_pip_requirements(model)

    # Log the model with MLflow
    mlflow.pytorch.log_model(
        pytorch_model = model,
        artifact_path = 'model',
        signature = signature,
        pip_requirements = reqs
        )

    
    # Log the input dataset for lineage tracking from table to model
    src_dataset = mlflow.data.load_delta(
        table_name = dataset_name
    )
    mlflow.log_input(src_dataset, context='Training-Input')