## Setup and Imports

In [None]:
import os
import shutil
import numpy as np
import torch
from datasets import Dataset
from matplotlib import pyplot as plt
from transformers.image_utils import load_image
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
)
from pathlib import Path
from functools import partial
import evaluate
from PIL import Image
import json
import time
from transformers.trainer_utils import EvalPrediction
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
SEED = 27

## Dataset

In [None]:
# workdir = "data/fashion_mnist"

# # Define transformations (convert images to tensors and normalize if needed)
# transform = transforms.ToTensor()

# # Download the FashionMNIST dataset
# fashion_mnist = datasets.FashionMNIST(
#     root=".",  # Temporary directory to store the raw dataset
#     train=True,  # Download the training set
#     download=True,  # Download the dataset if not already present
#     transform=transform,
# )

# label_map = {0: "top", 1: "trouser"}

# # Save each image as a separate file
# for idx, (image, label) in enumerate(fashion_mnist):
#     if label > 1:
#         continue
#     # Convert the tensor image to a PIL image
#     pil_image = transforms.ToPILImage()(image)

#     # Create a subdirectory for each label (optional)
#     label_dir = os.path.join(workdir, str(label_map[label]))
#     os.makedirs(label_dir, exist_ok=True)

#     # Save the image to the corresponding label directory
#     image_path = os.path.join(label_dir, f"image_{idx}.png")
#     pil_image.save(image_path)


# print(f"All images saved to: {workdir}")

# shutil.rmtree("FashionMNIST")

In [None]:
# Create torch dataset, and then huggingfa
images_path = Path("./data/fashion_mnist/all")
classes = [fold.stem for fold in images_path.glob("*")]
print(classes)

In [None]:
label2id = {"top": 0, "trouser": 1}

# To transformers dataset
all_images = []
labels = []

with open("./data/fashion_mnist/fashion_mnist.jsonl", "w") as f:
    id_ = 0
    for class_ in classes:
        class_path = images_path / class_
        for img_path in class_path.glob("*.png"):
            new_path = shutil.copy2(img_path.as_posix(), images_path.parent / "images")
            all_images.append(new_path)
            labels.append(label2id[class_])

            f.write(json.dumps({"id": id_, "filename": img_path.name, "label": label2id[class_]}) + "\n")
            id_ += 1

# Create a Hugging Face Dataset
dataset = Dataset.from_dict({"image_path": all_images, "label": labels})

In [None]:
# Step 1: Perform train-test split
tv_test_split = dataset.train_test_split(test_size=0.25, seed=SEED)

# Add validation split
train_val_split = tv_test_split["train"].train_test_split(test_size=tv_test_split["test"].num_rows, seed=SEED)

train_set = train_val_split["train"]
validation_set = train_val_split["test"]
test_set = tv_test_split["test"]

splits_dict = {
    "validation": [Path(img["image_path"]).name for img in validation_set],
    "test": [Path(img["image_path"]).name for img in test_set],
}

In [None]:
# Write split to file
with open("./data/fashion_mnist/splits.json", "w") as f:
    json.dump(splits_dict, f)

In [None]:
print(f"Number of training examples: {len(train_set)}")
print(f"Number of validation examples: {len(validation_set)}")
print(f"Number of test examples: {len(test_set)}")

In [None]:
labels = classes
num_labels = len(labels)
print(f"Number of labels: {num_labels}")

# Build the dictionaries for easier query
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[i] = label

In [None]:
label2id

In [None]:
id2label

### Visualize the dataset

In [None]:
def display_image_grid(images, labels, rows=2, cols=5, figsize=(12, 6), target_size=(128, 128), after_aug=False):
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        if after_aug:
            image = images[i]
            image = image.permute(1, 2, 0)
            image = image.clip(min=0.0, max=1.0)
        else:
            image = images[i].resize((128, 128))
        ax.imshow(image)
        ax.set_title(labels[i])
        ax.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
samples = train_set.shuffle().select(range(10))
sample_images = [Image.open(img_path) for img_path in samples["image_path"]]
sample_labels = [id2label[label] for label in samples["label"]]
display_image_grid(sample_images, sample_labels)

## Dataset Transforms for Traning

In [None]:
# Load a pre-trained model from Hugging Face Hub
model_path = "./models/resnet18.a1_in1k"

image_processor = AutoImageProcessor.from_pretrained(model_path)


# checkpoint = "timm/resnet18.a1_in1k"
# model = AutoModelForImageClassification.from_pretrained(checkpoint)
# model.save_pretrained(model_path)


In [None]:
train_transforms = image_processor.train_transforms
val_transforms = image_processor.val_transforms


def apply_transforms(examples, train_aug=False):
    if train_aug:
        examples["pixel_values"] = [train_transforms(Image.open(img).convert("RGB")) for img in examples["image_path"]]
    else:
        examples["pixel_values"] = [val_transforms(Image.open(img).convert("RGB")) for img in examples["image_path"]]

    del examples["image_path"]
    return examples

In [None]:
train_ds = train_set.with_transform(partial(apply_transforms, train_aug=True))
test_ds = test_set.with_transform(apply_transforms)
val_ds = validation_set.with_transform(apply_transforms)

In [None]:
samples = train_ds.shuffle().select(range(10))
sample_images = [s["pixel_values"] for s in samples]
sample_labels = [id2label[s["label"]] for s in samples]

# After augmentation
display_image_grid(sample_images, sample_labels, after_aug=True)

## Model

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_path,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

## Training

In [None]:
def compute_metrics(pred: EvalPrediction):
    # Extract predictions and labels
    predictions = np.argmax(pred.predictions, axis=1)  # Get the predicted class
    labels = pred.label_ids

    # Compute metrics (e.g., accuracy, precision, recall, F1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="weighted")

    # Return a dictionary of metrics
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

In [None]:
learning_rate = 1e-5
batch_size = 128
num_epochs = 5
output_dir = "models/checkpoints"

In [None]:
data_collator = DefaultDataCollator()

training_args = TrainingArguments(
    output_dir=output_dir,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    warmup_ratio=0.2,
    logging_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
# Save the trained model and tokenizer
output_dir = "./models/trained-model"
trainer.model.save_pretrained(output_dir)  # Save the model


## Inference on Test Data

In [None]:
image = load_image("data/fashion_mnist/images/image_1.png")
inputs = image_processor(image, return_tensors="pt")

In [None]:
with torch.no_grad():
    logits = model(**inputs).logits
    labels = logits.argmax(-1).item()

In [None]:
plt.imshow(image)
plt.axis("off")
plt.title(f"Prediction: {id2label[labels]}")
plt.show()

## Test optimised inference

In [None]:
output_dir

In [None]:
def measure_image_classification_inference_time_separate(
    image_paths: list[Path],
    model_name: str,  # Replace with your model
    batch_size=1,
    device=torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu",
):
    image_processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModelForImageClassification.from_pretrained(model_name).to(device)
    model.eval()  # Set the model to evaluation mode

    total_inference_time = 0.0
    num_images = len(image_paths)

    with torch.no_grad():  # Disable gradient calculation for inference
        for i in range(0, num_images, batch_size):
            batch_paths = image_paths[i : i + batch_size]
            images = [Image.open(path).convert("RGB") for path in batch_paths]

            # Preprocess images
            inputs = image_processor(images=images, return_tensors="pt").to(device)

            start_time = time.time()
            _ = model(**inputs)  # Perform inference
            end_time = time.time()

            total_inference_time += end_time - start_time

    average_inference_time = total_inference_time / num_images
    return average_inference_time

In [None]:
N_IMAGES = 1000
image_paths = [img for i, img in enumerate(Path("data/fashion_mnist/images").glob("*")) if i < N_IMAGES]

print(len(image_paths))

time_b1 = measure_image_classification_inference_time_separate(image_paths, output_dir, batch_size=1)
time_b2 = measure_image_classification_inference_time_separate(image_paths, output_dir, batch_size=2)

print(f"Time per image with batch=1: {time_b1:.2f} seconds")
print(f"Time per image with batch=2: {time_b2:.2f} seconds")

## Batche inputs to output

In [None]:
image1 = load_image(image_paths[0].as_posix())
image2 = load_image(image_paths[1].as_posix())
image1

In [None]:
image_processor = AutoImageProcessor.from_pretrained(output_dir)
# Preprocess both images
inputs1 = image_processor(image1, return_tensors="pt")
inputs2 = image_processor(image2, return_tensors="pt")

In [None]:
inputs2.pixel_values.shape

In [None]:
# Concatenate the pixel values along the batch dimension
batched_inputs = {"pixel_values": torch.cat((inputs1.pixel_values, inputs2.pixel_values), dim=0)}
batched_inputs["pixel_values"].shape

In [None]:
model = AutoModelForImageClassification.from_pretrained(output_dir).to(torch.accelerator.current_accelerator())
model.eval()  # Set the model to evaluation mode


# Perform inference
with torch.no_grad():
    logits = model(**batched_inputs).logits
    predicted_labels = torch.argmax(logits, dim=-1).tolist()  # Get list of predictions

In [None]:
logits

In [None]:
logits.shape

In [None]:
torch.argmax(logits, dim=-1)

In [None]:
predicted_labels

In [None]:
# predicted_labels now contains the predicted label for each image
print(f"Prediction for image 1: {id2label[predicted_labels[0]]}")
print(f"Prediction for image 2: {id2label[predicted_labels[1]]}")

## Optimization strategies

In [None]:
def benchmark(processor, model, inputs, device):
    import time
    import psutil

    # Warm up
    print("WARMING UP")
    for img_path in inputs[:2]:
        with torch.no_grad():
            img = load_image(img_path.as_posix())
            input_ = processor(images=img, return_tensors="pt").to(device)
            _ = model(**input_)  # Perform inference

    process = psutil.Process()

    start_time = time.time()
    start_memory = process.memory_info().rss
    start_cpu_percent = process.cpu_percent(interval=None)  # Non-blocking

    track_cpu = []
    track_memory = []
    # Your code to benchmark (e.g., model inference)
    print("RUNNING BENCH")
    for img_path in inputs[2:]:
        with torch.no_grad():
            img = load_image(img_path.as_posix())
            input_ = processor(images=img, return_tensors="pt").to(device)
            _ = model(**input_)  # Perform inference
            track_cpu.append(process.cpu_percent(interval=None))
            track_memory.append((process.memory_info().rss - start_memory) / (1024 * 1024))

    end_time = time.time()
    end_memory = process.memory_info().rss
    end_cpu_percent = process.cpu_percent(interval=None)  # Non-blocking

    elapsed_time = end_time - start_time
    memory_usage = end_memory - start_memory
    cpu_usage = end_cpu_percent

    print(f"Elapsed time: {elapsed_time:.2f} seconds ({elapsed_time / len(inputs[2:]):.2f} seconds per image)")
    print(f"CPU Usage: {cpu_usage:.2f}%")
    print(
        f"Memory Usage: {memory_usage / (1024 * 1024):.2f} MB (Min: {min(track_memory):.2f} | Max: {max(track_memory):.2f} | Average: {np.mean(track_memory):.2f})"
    )
    del model


In [None]:
device = torch.accelerator.current_accelerator()

In [None]:
def reload_model(model_path=output_dir, device=torch.accelerator.current_accelerator()):
    image_processor = AutoImageProcessor.from_pretrained(model_path)
    model = AutoModelForImageClassification.from_pretrained(model_path).to(device)

    return image_processor, model

In [None]:
def model_dtypes(model):
    dtypes = []
    for _, param in model.named_parameters():
        dtypes.append(param.dtype)

    return set(dtypes)

In [None]:
image_processor, model = reload_model()
model_dtypes(model)

In [None]:
model.eval()  # Set to evaluation mode

# Simply half the model float16
model = model.half()  # Convert to float16

model_dtypes(model)

In [None]:
benchmark(image_processor, model, image_paths, device)

In [None]:
dtype = torch.float16
example_input = image_processor(image1, return_tensors="pt")
input_tensor = example_input["pixel_values"].to(device).to(dtype)

# Perform inference
with torch.no_grad():
    output = model(input_tensor)
    predicted_class = torch.argmax(output.logits, dim=-1)
    print(f"Predicted class: {predicted_class}")

output

In [None]:
image_processor, model = reload_model()
model.eval()  # Set to evaluation mode
model_dtypes(model)

In [None]:
benchmark(image_processor, model, image_paths, device)

In [None]:
dtype = torch.float16
example_input = image_processor(image1, return_tensors="pt")
input_tensor = example_input["pixel_values"].to(device)  # .to(dtype)

# Perform inference
with torch.no_grad():
    output = model(input_tensor)
    predicted_class = torch.argmax(output.logits, dim=-1)
    print(f"Predicted class: {predicted_class}")

In [None]:
output

In [None]:
input_tensor.shape

### Quantization

 Convert your model's weights from float32 to int8. This drastically reduces model size and can speed up inference on CPUs that support int8 operations. PyTorch has built-in quantization tools (e.g., `torch.quantization`).

In [None]:
model.eval()  # Important to set to eval mode

# Static Quantization (requires calibration)
model.qconfig = torch.quantization.get_default_qconfig()  # Choose appropriate qconfig for your CPU
torch.quantization.prepare(model, inplace=True)

# Calibration (run some data through the model to collect statistics)
# Replace 'calibration_dataset' with a representative subset of your training data
with torch.no_grad():
    for img_path in image_paths[:10]:
        img = Image.open(img_path).convert("RGB")
        input_ = image_processor(img)
        model(**input_)

torch.quantization.convert(model, inplace=True)