## Prepare Dataset with HuggingFace 🤗


In [None]:
from datasets import IterableDataset, load_dataset
from transformers import ViTImageProcessor


def get_datasets(type: str) -> IterableDataset:
    # Load the dataset
    dataset = load_dataset("nobodyPerfecZ/recaptchav2-dataset")

    # Load the feature extractor
    processor = ViTImageProcessor.from_pretrained(
        pretrained_model_name_or_path="google/vit-base-patch16-224",
    )

    # Preprocess the dataset with the feature extractor
    dataset = (
        dataset.map(
            lambda example: processor(images=example["image"]),
            batched=True,
        )
        .rename_columns({"pixel_values": "inputs"})
        .with_format(type, columns=["inputs", "labels"])
    )

    return dataset

In [None]:
# dataset = get_datasets(type="jax")
dataset = get_datasets(type="pt")
# dataset = get_datasets(type="tf")

## Fine-Tuning of Pre-Trained Model with HuggingFace 🤗


### With Flax


In [None]:
from transformers import FlaxViTForImageClassification

# Fine-tune ViT model on a custom dataset
model = FlaxViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224",
    num_labels=5,
    id2label={
        0: "bicycle",
        1: "bus",
        2: "car",
        3: "crosswalk",
        4: "hydrant",
    },
    label2id={
        "bicycle": 0,
        "bus": 1,
        "car": 2,
        "crosswalk": 3,
        "hydrant": 4,
    },
    ignore_mismatched_sizes=True,
)

In [None]:
from typing import Dict, Tuple

import chex
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from tqdm import tqdm


def create_train_state(
    model: FlaxViTForImageClassification,
    max_grad_norm: float,
    learning_rate: float,
    b1: float,
    b2: float,
    eps: float,
):
    return TrainState.create(
        apply_fn=model.__call__,
        params=model.params,
        tx=optax.chain(
            optax.clip_by_global_norm(max_norm=max_grad_norm),
            optax.adamw(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps),
        ),
    )


@jax.jit
def train_step(state: TrainState, batch: Dict[str, chex.Array]):
    def loss_fn(params):
        logits = state.apply_fn(batch["inputs"], params=params, train=True).logits
        loss = optax.sigmoid_binary_cross_entropy(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state


@jax.jit
def eval_step(state: TrainState, batch: Dict[str, chex.Array]) -> Tuple[float, float]:
    def accuracy(logits: chex.Array, labels: chex.Array):
        probabilities = jax.nn.sigmoid(logits)
        predictions = (probabilities >= 0.5).astype(jnp.int32)
        return jnp.all(predictions == labels, axis=-1).mean()

    def hamming_accuracy(logits: chex.Array, labels: chex.Array):
        probabilities = jax.nn.sigmoid(logits)
        predictions = (probabilities >= 0.5).astype(jnp.int32)
        return 1 - jnp.logical_xor(predictions, labels).mean()

    logits = state.apply_fn(batch["inputs"], params=state.params, train=True).logits
    return accuracy(logits, batch["labels"]), hamming_accuracy(logits, batch["labels"])


def train(
    state: TrainState,
    train_dataset: IterableDataset,
    eval_dataset: IterableDataset,
    num_train_epochs: int,
    batch_size: int,
):
    metrics = {
        "accuracy": jnp.zeros((num_train_epochs,)),
        "hamming_accuracy": jnp.zeros((num_train_epochs,)),
    }

    for epoch in range(num_train_epochs):
        num_train_batches = len(train_dataset) // batch_size
        train_total = num_train_batches * batch_size
        with tqdm(total=train_total, desc=f"Training Epoch {epoch+1}") as pbar:
            for batch in train_dataset.iter(
                batch_size=batch_size, drop_last_batch=True
            ):
                state = train_step(state, batch)
                pbar.update(batch_size)

        num_eval_batches = len(eval_dataset) // batch_size
        eval_total = num_eval_batches * batch_size
        with tqdm(total=eval_total, desc=f"Evaluating Epoch {epoch+1}") as pbar:
            for batch in eval_dataset.iter(batch_size=batch_size, drop_last_batch=True):
                accuracy, hamming_accuracy = eval_step(state, batch)
                metrics["accuracy"] = (
                    metrics["accuracy"]
                    .at[epoch]
                    .set(metrics["accuracy"][epoch] + accuracy)
                )
                metrics["hamming_accuracy"] = (
                    metrics["hamming_accuracy"]
                    .at[epoch]
                    .set(metrics["hamming_accuracy"][epoch] + hamming_accuracy)
                )
                pbar.update(batch_size)

        metrics["accuracy"] = (
            metrics["accuracy"]
            .at[epoch]
            .set(metrics["accuracy"][epoch] / num_eval_batches)
        )
        metrics["hamming_accuracy"] = (
            metrics["hamming_accuracy"]
            .at[epoch]
            .set(metrics["hamming_accuracy"][epoch] / num_eval_batches)
        )

    return state, metrics

In [None]:
# Hyperparameters
max_grad_norm = 1.0
learning_rate = 5e-5
b1 = 0.9
b2 = 0.999
eps = 1e-8

# Training Hyperparmeters
num_train_epochs = 2
batch_size = 32

train_state = create_train_state(
    model=model,
    max_grad_norm=max_grad_norm,
    learning_rate=learning_rate,
    b1=b1,
    b2=b2,
    eps=eps,
)
final_state, metrics = train(
    state=train_state,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    num_train_epochs=num_train_epochs,
    batch_size=batch_size,
)
model.params = final_state.params
model.save_pretrained("./vit-finetuned-patch16-224-recaptchav2")

### With PyTorch


In [None]:
from transformers import ViTForImageClassification

# Fine-tune ViT model on a custom dataset
model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224",
    num_labels=5,
    id2label={
        0: "bicycle",
        1: "bus",
        2: "car",
        3: "crosswalk",
        4: "hydrant",
    },
    label2id={
        "bicycle": 0,
        "bus": 1,
        "car": 2,
        "crosswalk": 3,
        "hydrant": 4,
    },
    ignore_mismatched_sizes=True,
).to(device="cuda")

In [None]:
from dataclasses import dataclass
from typing import Callable, Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from flax.training.train_state import TrainState
from torch.optim import AdamW, Optimizer
from tqdm import tqdm


@dataclass
class TrainState:
    apply_fn: Callable
    params: Dict[str, torch.Tensor]
    tx: Optimizer
    max_grad_norm: float


def create_train_state(
    model: ViTForImageClassification,
    max_grad_norm: float,
    learning_rate: float,
    b1: float,
    b2: float,
    eps: float,
):
    return TrainState(
        apply_fn=model.__call__,
        params=model.parameters(),
        tx=AdamW(model.parameters(), lr=learning_rate, betas=(b1, b2), eps=eps),
        max_grad_norm=max_grad_norm,
    )


def train_step(state: TrainState, batch: Dict[str, torch.Tensor]):
    def loss_fn():
        logits = state.apply_fn(batch["inputs"]).logits
        loss = F.binary_cross_entropy_with_logits(
            logits, batch["labels"].to(torch.float32)
        ).mean()
        return loss

    state.tx.zero_grad()
    loss = loss_fn()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(state.params, state.max_grad_norm)
    state.tx.step()
    return state


def eval_step(state: TrainState, batch: Dict[str, torch.Tensor]) -> Tuple[float, float]:
    def accuracy(logits: torch.Tensor, labels: torch.Tensor):
        probabilities = F.sigmoid(logits)
        predictions = (probabilities >= 0.5).to(torch.int32)
        return (
            torch.all(predictions == labels, dim=-1)
            .to(torch.float32)
            .mean()
            .detach()
            .cpu()
            .numpy()
        )

    def hamming_accuracy(logits: torch.Tensor, labels: torch.Tensor):
        probabilities = F.sigmoid(logits)
        predictions = (probabilities >= 0.5).to(torch.int32)
        return (
            1
            - torch.logical_xor(predictions, labels)
            .to(torch.float32)
            .mean()
            .detach()
            .cpu()
            .numpy()
        )

    logits = state.apply_fn(batch["inputs"]).logits
    return accuracy(logits, batch["labels"]), hamming_accuracy(logits, batch["labels"])


def train(
    state: TrainState,
    train_dataset: IterableDataset,
    eval_dataset: IterableDataset,
    num_train_epochs: int,
    batch_size: int,
):
    metrics = {
        "accuracy": np.zeros((num_train_epochs,)),
        "hamming_accuracy": np.zeros((num_train_epochs,)),
    }

    for epoch in range(num_train_epochs):
        model.train()
        num_train_batches = len(train_dataset) // batch_size
        train_total = num_train_batches * batch_size
        with tqdm(total=train_total, desc=f"Training Epoch {epoch+1}") as pbar:
            for batch in train_dataset.iter(
                batch_size=batch_size, drop_last_batch=True
            ):
                batch["inputs"] = batch["inputs"].to(device="cuda")
                batch["labels"] = batch["labels"].to(device="cuda")
                state = train_step(state, batch)
                pbar.update(batch_size)

        model.eval()
        num_eval_batches = len(eval_dataset) // batch_size
        eval_total = num_eval_batches * batch_size
        with tqdm(total=eval_total, desc=f"Evaluating Epoch {epoch+1}") as pbar:
            for batch in eval_dataset.iter(batch_size=batch_size, drop_last_batch=True):
                batch["inputs"] = batch["inputs"].to(device="cuda")
                batch["labels"] = batch["labels"].to(device="cuda")
                accuracy, hamming_accuracy = eval_step(state, batch)
                metrics["accuracy"][epoch] += accuracy
                metrics["hamming_accuracy"][epoch] += hamming_accuracy
                pbar.update(batch_size)

        metrics["accuracy"][epoch] /= num_eval_batches
        metrics["hamming_accuracy"][epoch] /= num_eval_batches

    return state, metrics

In [None]:
# Hyperparameters
max_grad_norm = 1.0
learning_rate = 5e-5
b1 = 0.9
b2 = 0.999
eps = 1e-8

# Training Hyperparmeters
num_train_epochs = 2
batch_size = 32

train_state = create_train_state(
    model=model,
    max_grad_norm=max_grad_norm,
    learning_rate=learning_rate,
    b1=b1,
    b2=b2,
    eps=eps,
)
final_state, metrics = train(
    state=train_state,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    num_train_epochs=num_train_epochs,
    batch_size=batch_size,
)
model.save_pretrained("./vit-finetuned-patch16-224-recaptchav2")
metrics

## Safe Model from PyTorch to Flax with HuggingFace 🤗


In [None]:
from transformers import FlaxViTForImageClassification

model = FlaxViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="./vit-finetuned-patch16-224-recaptchav2-v1",
    from_pt=True,
)
model.save_pretrained("./vit-finetuned-patch16-224-recaptchav2-v1")

## Safe Model from PyTorch to TensorFlow with HuggingFace 🤗


In [None]:
from transformers import TFViTForImageClassification

model = TFViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="./vit-finetuned-patch16-224-recaptchav2-v1",
    from_pt=True,
)
model.save_pretrained("./vit-finetuned-patch16-224-recaptchav2-v1")