In [None]:
# imports
import pandas as pd
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torch.utils.data import Dataset
from PIL import Image
import torch

In [None]:
# choose device, not recommended to train with 'cpu'

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

print(device)

cuda


In [None]:
# load model from Hugging Face
# ViT model based on BERT encoder, trained on ImageNet
processor = AutoImageProcessor.from_pretrained("microsoft/beit-large-patch16-224")
model = AutoModelForImageClassification.from_pretrained("microsoft/beit-large-patch16-224", num_labels=200, ignore_mismatched_sizes=True)

Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/beit-large-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([200, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([200]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
class BirdDataset(Dataset):
    def __init__(self, df, processor, is_test=False):
        self.df = df
        self.processor = processor
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        encoding = self.processor(image, return_tensors="pt")
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        if not self.is_test:
            encoding["labels"] = torch.tensor(int(row["label"]))
        else:
            # only save the id for test set
            encoding["id"] = torch.tensor(int(row["id"]))

        return encoding

In [None]:
def load_data():
    train_df = pd.read_csv("/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/train_images.csv")
    test_df = pd.read_csv("/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/test_images_path.csv")

    # adjusts labels
    train_df["label"] = train_df["label"] - 1

    # correct folders
    train_df["image_path"] = "/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/train_images/train_images/" + \
                             train_df["image_path"].str.split("/").str[-1]

    test_df["image_path"] = "/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/test_images/test_images/" + \
                            test_df["image_path"].str.split("/").str[-1]

    val_df = (
        train_df.groupby("label", group_keys=False)
        .apply(lambda x: x.sample(1, random_state=42))  # 1 image per class
    )

    # remove validation rows from train
    train_df2 = train_df.drop(val_df.index)

    print(f"Train: {len(train_df2)} | Val: {len(val_df)} | Test: {len(test_df)}")
    return train_df2, val_df, test_df

In [None]:
# create datasets and dataloaders
train_df, val_df, test_df = load_data()

train_ds = BirdDataset(train_df, processor)
val_ds = BirdDataset(val_df, processor)
test_ds = BirdDataset(test_df, processor, is_test=True)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

Train: 3726 | Val: 200 | Test: 4000


  .apply(lambda x: x.sample(1, random_state=42))  # 1 image per class


In [None]:
for param in model.beit.parameters():
    param.requires_grad = False

# unfreeze classification head
for param in model.classifier.parameters():
    param.requires_grad = True

# unfreeze last N encoder blocks
"""N = 1
for block in model.beit.encoder.layer[-N:]:
    for param in block.parameters():
        param.requires_grad = True
"""

'N = 1\nfor block in model.beit.encoder.layer[-N:]:\n    for param in block.parameters():\n        param.requires_grad = True\n'

The number of layers frozen for the fine-tuning has been part of the experiments; the classifier is always unfrozen, but unfrozen 1 or 2 encoder blocks have been tested, leading to worse performance than letting them freeze.

In [None]:
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-4)
epochs = 20 # higher and lower number of epochs have also been tested

# save history
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# fine-tuning loop
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]"):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss
        logits = outputs.logits

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * logits.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == batch["labels"]).sum().item()
        total += logits.size(0)

    epoch_loss /= total
    epoch_acc = correct / total

    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)

    # validation for monitoring
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [VAL]"):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            logits = outputs.logits

            val_loss += loss.item() * logits.size(0)
            preds = torch.argmax(logits, dim=1)
            val_correct += (preds == batch["labels"]).sum().item()
            val_total += logits.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    # monitor progress
    print(
        f"Epoch {epoch+1} | "
        f"Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f} | "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )


In [None]:
model.save_pretrained("weights/BEiT_finetuned")
processor.save_pretrained("weights/BEiT_finetuned")

The weights of this model are not included in the repository, as their weigh is >1GB.

In [None]:
# check point
model = AutoModelForImageClassification.from_pretrained("weights/BEiT_finetuned")
processor = AutoImageProcessor.from_pretrained("weights/BEiT_finetuned")

model.to(device)

BeitForImageClassification(
  (beit): BeitModel(
    (embeddings): BeitEmbeddings(
      (patch_embeddings): BeitPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BeitEncoder(
      (layer): ModuleList(
        (0): BeitLayer(
          (attention): BeitAttention(
            (attention): BeitSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (relative_position_bias): BeitRelativePositionBias()
            )
            (output): BeitSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
 

In [None]:
model.to(device)
model.eval()

predictions = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        # don't send ids to model
        batch_ids = batch.pop("id")
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        logits = outputs.logits

        # retransform predictions
        preds = torch.argmax(logits, dim=1) + 1

        # save id and label
        for i in range(len(preds)):
            predictions.append({
                "id": int(batch_ids[i].item()),
                "label": int(preds[i].item())
            })

pred_df = pd.DataFrame(predictions)
pred_df.to_csv("submissions/submission_beit.csv", index=False)
