In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(device)

cuda


In [None]:
class MultimodalBirdDataset(Dataset):
    """dataset for multimodal learning: images + attribute embeddings"""
    def __init__(self, df, img_dir, transform=None, is_test=False):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = f"{self.img_dir}/{row['image_path']}"
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # get embedded attribute vector per image
        attr_vec = torch.tensor(row['attr_vec'], dtype=torch.float32)

        if self.is_test:
            sample_id = int(row["id"])
            return image, attr_vec, sample_id
        
        label = int(row["label"])
        return image, attr_vec, label

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = self.relu(out)
        return out


class MultimodalResNet18Skinny(nn.Module):
    """ResNet18 Skinny adapted for multimodal input (image + attributes)"""
    def __init__(self, num_classes=200, attr_dim=312, dropout_rate=0.6):
        super().__init__()
        
        # image feature extraction
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.layer1 = self._make_layer(32, 32, 2, stride=1)
        self.layer2 = self._make_layer(32, 64, 2, stride=2)
        self.layer3 = self._make_layer(64, 128, 2, stride=2)
        self.layer4 = self._make_layer(128, 256, 2, stride=2)
        
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        
        # image branch from ResNet
        self.image_fc = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(256, 128)
        )
        
        # attribute branch processed
        self.attr_fc = nn.Sequential(
            nn.Linear(attr_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate)
        )
        
        # combining both branches
        self.fusion = nn.Sequential(
            nn.Linear(128 + 128, 256),  # 128 from image + 128 from attributes
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(256, num_classes)
        )
        
        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, in_channels, out_channels, blocks, stride):
        layers = [BasicBlock(in_channels, out_channels, stride)]
        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x_image, x_attr):
        x = self.stem(x_image)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)
        image_features = self.image_fc(x)
        
        attr_features = self.attr_fc(x_attr)
        
        combined = torch.cat([image_features, attr_features], dim=1)
        output = self.fusion(combined)
        
        return output

In [None]:
train_df = pd.read_csv("/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/train_images.csv")
test_df = pd.read_csv("/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/test_images_path.csv")

# transform labels
train_df["label"] = train_df["label"] - 1

# validation one sample per class
val_df = train_df.groupby('label').apply(lambda x: x.sample(1, random_state=42)).reset_index(drop=True)

train_df = train_df.drop(val_df.index).reset_index(drop=True)

In [None]:
# same as ResNet18_skinny.ipynb
train_tfms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

test_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
# load embedddings and align with dataframes
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

train_embeddings = np.load("/kaggle/input/embeddings-test-csv/train_embeddings.npy")
val_embeddings = np.load("/kaggle/input/embeddings-test-csv/val_embeddings.npy")
test_embeddings = np.load("/kaggle/input/embeddings-test-csv/test_embeddings.npy")

train_df["attr_vec"] = [train_embeddings[i] for i in range(len(train_df))]
val_df["attr_vec"] = [val_embeddings[i] for i in range(len(val_df))]
test_df["attr_vec"] = [test_embeddings[i] for i in range(len(test_df))]

In [None]:
# load datasets and dataloaders
train_ds = MultimodalBirdDataset(
    train_df,
    "/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/train_images",
    train_tfms,
    is_test=False
)

val_ds = MultimodalBirdDataset(
    val_df,
    "/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/train_images",
    test_tfms,
    is_test=False
)

test_ds = MultimodalBirdDataset(
    test_df,
    "/kaggle/input/dataset-aml-feathers/aml-2025-feathers-in-focus/test_images",
    test_tfms,
    is_test=True
)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=0)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
print(f"Test loader: {len(test_loader)} batches")

In [None]:
# model, loss, optimizer
model = MultimodalResNet18Skinny(num_classes=200, attr_dim=312, dropout_rate=0.6).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) # higher lr has been tested

In [None]:
# same asa in ResNet18_skinny.ipynb
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for images, attrs, labels in tqdm(loader, desc="Training"):
        images = images.to(device)
        attrs = attrs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, attrs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, attrs, labels in tqdm(loader, desc="Evaluating"):
            images = images.to(device)
            attrs = attrs.to(device)
            labels = labels.to(device)
            
            outputs = model(images, attrs)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

In [None]:
train_losses, val_losses = [], []
train_accs, val_accs = [], []

epochs = 300
patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    train_acc = evaluate(model, train_loader, device)

    # validation loss
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, attrs, labels in val_loader:
            images = images.to(device)
            attrs = attrs.to(device)
            labels = labels.to(device)
            outputs = model(images, attrs)
            val_loss += criterion(outputs, labels).item()
    val_loss /= len(val_loader)

    val_acc = evaluate(model, val_loader, device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    # monitor progress
    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f} | "
          f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

    # early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "../weights/multimodal_resnet18_skinny_best.pth")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        break

torch.save(model.state_dict(), "../weights/multimodal_resnet18_skinny.pth")

In [None]:
model.to(device)
model.eval()
predictions = []

with torch.no_grad():
    for images, attrs, ids in tqdm(test_loader, desc="Predicting"):
        images = images.to(device)
        attrs = attrs.to(device)
        logits = model(images, attrs)
        preds = torch.argmax(logits, dim=1) + 1  # retranform to 1-200

        for i in range(len(preds)):
            predictions.append({
                "id": int(ids[i].item()),
                "label": int(preds[i].item())
            })

pred_df = pd.DataFrame(predictions)
pred_df = pred_df.sort_values('id').reset_index(drop=True)
pred_df.to_csv("../submissions/submission_multimodal.csv", index=False)

Predicting: 100%|██████████| 125/125 [00:56<00:00,  2.22it/s]


In [None]:
plt.figure(figsize=(7, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("loss_curve.png")
plt.show()

plt.figure(figsize=(7, 5))
plt.plot(train_accs, label="Train Accuracy")
plt.plot(val_accs, label="Val Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy Curve")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("accuracy_curve.png")
plt.show()