In [None]:
# imports
import pandas as pd
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torch.utils.data import Dataset
from PIL import Image
import torch

In [None]:
# choose device, not recommended to train with 'cpu'

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

print(device)

In [None]:
# load the model from Hugging Face
processor = AutoImageProcessor.from_pretrained("microsoft/beit-large-patch16-224")
model = AutoModelForImageClassification.from_pretrained("microsoft/beit-large-patch16-224", num_labels=200, ignore_mismatched_sizes=True)

In [None]:
class BirdDataset(Dataset):
    def __init__(self, df, processor, is_test=False):
        self.df = df
        self.processor = processor
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        encoding = self.processor(image, return_tensors="pt")
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        if not self.is_test:
            encoding["labels"] = torch.tensor(int(row["label"]))
        else:
            # ID encoded only for test set
            encoding["id"] = torch.tensor(int(row["id"]))

        return encoding

In [None]:
def load_data():
    train_df = pd.read_csv("/Users/tejaswimadduri/Downloads/aml-2025-feathers-in-focus/train_images.csv")
    test_df = pd.read_csv("/Users/tejaswimadduri/Downloads/aml-2025-feathers-in-focus/test_images_path.csv")

    # transform labels from 0 to 199
    train_df["label"] = train_df["label"] - 1

    # rewrite full image_path to have the correct folder
    train_df["image_path"] = "/Users/tejaswimadduri/Downloads/aml-2025-feathers-in-focus/train_images/train_images/" + train_df["image_path"].str.split("/").str[-1]
    test_df["image_path"] = "/Users/tejaswimadduri/Downloads/aml-2025-feathers-in-focus/test_images/test_images/" + test_df["image_path"].str.split("/").str[-1]

    # print sizes
    print(f"Train: {len(train_df)} | Test: {len(test_df)}")
    return train_df, test_df


In [None]:
train_df, test_df = load_data()

train_ds = BirdDataset(train_df, processor)
test_ds = BirdDataset(test_df, processor, is_test=True)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)


In [None]:
# freeze backbone
for param in model.beit.parameters():
    param.requires_grad = False


# unfreeze classification head
for param in model.classifier.parameters():
    param.requires_grad = True


In [None]:
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-4)
epochs = 20

train_losses = []
train_accuracies = []

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]"):
        # batch_id = batch.pop("id") 
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss
        logits = outputs.logits

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # accumulate loss and accuracy
        epoch_loss += loss.item() * logits.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == batch["labels"]).sum().item()
        total += logits.size(0)

    epoch_loss /= total
    epoch_acc = correct / total

    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)

    print(f"Epoch {epoch+1} | Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}")


In [None]:
model.save_pretrained("saved_model")
processor.save_pretrained("saved_model")

In [None]:
model = AutoModelForImageClassification.from_pretrained("saved_model")
processor = AutoImageProcessor.from_pretrained("saved_model")

model.to(device)

In [None]:
model.to(device)
model.eval()

predictions = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        # Sacar ids y no enviarlos al modelo
        batch_ids = batch.pop("id")
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        logits = outputs.logits

        # Predicciones (1-200)
        preds = torch.argmax(logits, dim=1) + 1

        # Guardar id y label
        for i in range(len(preds)):
            predictions.append({
                "id": int(batch_ids[i].item()),  # tensor -> int
                "label": int(preds[i].item())
            })

# Crear DataFrame y guardar CSV
pred_df = pd.DataFrame(predictions)
pred_df.to_csv("submission_beit.csv", index=False)

("Predictions saved to submission1.csv")
