In [1]:
import pandas as pd
import numpy as np
import os
import json
import shutil

caption_dir = '/kaggle/input/tagged-anime-illustrations/danbooru-metadata/danbooru-metadata'

id_to_tags = {}

for filename in os.listdir(caption_dir):
    f_path = os.path.join(caption_dir, filename)
    if os.path.isfile(f_path):
        with open(f_path, 'r') as f:
            for line in f:
                data = json.loads(line.rstrip())
                tags = [x["name"] for x in data["tags"]]
                caption = " ".join(tags)
                id_to_tags[data["id"]] = caption 

In [2]:
from collections import Counter

tag_counter = Counter()
for tags in id_to_tags.values():
    for tag in tags.split(" "):
        tag = tag.strip()
        if tag:  # avoid empty strings
            tag_counter[tag] += 1
            
top_tags = [tag for tag, _ in tag_counter.most_common(500)]

top_tags_set = set(top_tags)

top_2000_id_to_tags = {
    img_id: " ".join([t for t in tags.split(" ") if t in top_tags_set])
    for img_id, tags in id_to_tags.items()
}

In [3]:
from collections import Counter

filtered_tag_counter_2000 = Counter()

for tags in top_2000_id_to_tags.values():
    for tag in tags.split(" "):
        tag = tag.strip()
        if tag:  
            filtered_tag_counter_2000[tag] += 1

print("Total unique tags:", len(filtered_tag_counter_2000))

Total unique tags: 500


In [4]:
print(top_2000_id_to_tags["1017000"])

1girl bow brown_hair detached_sleeves frills hair_bow hair_ribbon hair_tubes hakurei_reimu highres midriff navel red_eyes ribbon skirt skirt_set solo standing touhou


In [5]:
unique_tags = list(filtered_tag_counter_2000.keys())

In [6]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class DanbooruMultiLabelDataset(Dataset):
    def __init__(self, root_dir, label_dict, unique_tags, transform=None):
        """
        root_dir: folder with all subfolders of images
        label_dict: dict mapping 'image_id' -> list of tags
        unique_tags: list of all unique tags (defines the multi-label space)
        """
        self.root_dir = root_dir
        self.label_dict = label_dict
        self.tag_to_idx = {tag: i for i, tag in enumerate(unique_tags)}
        self.transform = transform

        # Collect image paths
        self.image_paths = []
        for subdir, _, files in os.walk(root_dir):
            for f in files:
                if f.lower().endswith((".jpg", ".jpeg", ".png")):
                    img_id = os.path.splitext(f)[0]
                    if img_id in label_dict:
                        self.image_paths.append(os.path.join(subdir, f))

    def __len__(self):
        return len(self.image_paths)

    def encode_tags(self, tags):
        vec = torch.zeros(len(self.tag_to_idx), dtype=torch.float32)
        for tag in tags:
            if tag in self.tag_to_idx:
                vec[self.tag_to_idx[tag]] = 1.0
        return vec

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        img = Image.open(path).convert("RGB")
    
        img_id = os.path.splitext(os.path.basename(path))[0]
        tags_str = self.label_dict[img_id]
    
        # Split tags if they are stored as a space-separated string
        if isinstance(tags_str, str):
            tags = tags_str.split()  # split by whitespace
        else:
            tags = tags_str  # already a list
    
        label_vec = self.encode_tags(tags)
    
        if self.transform:
            img = self.transform(img)
    
        return img, label_vec

In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

dataset = DanbooruMultiLabelDataset(
    root_dir="/kaggle/input/tagged-anime-illustrations/danbooru-images/danbooru-images",
    label_dict=top_2000_id_to_tags,
    unique_tags=unique_tags,
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

for images, label_vecs in dataloader:
    print(images.shape)      # (B, 3, 224, 224)
    print(label_vecs.shape)  # (B, num_tags)
    break

torch.Size([64, 3, 224, 224])
torch.Size([64, 500])


In [8]:
from torch.utils.data import random_split
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size  

train_dataset, val_dataset, test_dataset = random_split(
    dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # reproducible splits
)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

Train: 269626 | Val: 33703 | Test: 33704


In [9]:
# Get one batch from the dataloader
imgs, labels = next(iter(train_loader))

# Print 5 label vectors and how many active tags they have
for i in range(5):
    label_vec = labels[i]
    print(f"\nSample {i} — Active tags: {int(label_vec.sum().item())}")
    print(label_vec.nonzero(as_tuple=True)[0].tolist())


Sample 0 — Active tags: 12
[14, 29, 60, 63, 113, 118, 141, 220, 335, 397, 444, 476]

Sample 1 — Active tags: 17
[2, 3, 13, 37, 41, 46, 73, 87, 104, 110, 113, 141, 145, 148, 446, 458, 464]

Sample 2 — Active tags: 15
[0, 3, 5, 38, 61, 96, 131, 135, 148, 175, 226, 339, 354, 416, 425]

Sample 3 — Active tags: 21
[2, 3, 14, 40, 41, 55, 79, 93, 95, 96, 98, 104, 147, 155, 185, 201, 311, 314, 322, 325, 344]

Sample 4 — Active tags: 8
[3, 5, 17, 90, 226, 269, 335, 360]


In [10]:
import torch
# Load the model
model = torch.hub.load('RF5/danbooru-pretrained', 'resnet50')


Downloading: "https://github.com/RF5/danbooru-pretrained/zipball/master" to /root/.cache/torch/hub/master.zip
Downloading: "https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth" to /root/.cache/torch/hub/checkpoints/resnet50-13306192.pth
100%|██████████| 110M/110M [00:00<00:00, 230MB/s] 


In [11]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader
# from torchvision import datasets, transforms, models
# from torch.optim.lr_scheduler import StepLR
# import numpy as np
# import time
# import copy
# import os

# # Load the model
# num_classes = 500
# model = models.resnet50(pretrained=True)

# model.fc = nn.Linear(model.fc.in_features, num_classes)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)

In [12]:
print(model)

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256

In [13]:
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-4)
# scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

In [14]:
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_tags = len(unique_tags)
# Replace final layer
model[1][8] = nn.Linear(in_features=512, out_features=num_tags)
print(f"✅ Replaced final layer with Linear(512, {num_tags})")
model.to(device)

✅ Replaced final layer with Linear(512, 500)


Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256

In [15]:
# # Freeze feature extractor
# for param in model[0].parameters():
#     param.requires_grad = False

In [16]:
# for name, param in model.named_parameters():
#     print(name, param.requires_grad)

In [17]:
import torch.nn as nn
import torch.optim as optim
criterion = nn.BCEWithLogitsLoss()  # multi-label classification
optimizer = optim.Adam(model[1].parameters(), lr=1e-3)

In [18]:
output_path = "tag_to_index_mapping.txt"
with open(output_path, "w", encoding="utf-8") as f:
    for tag, idx in sorted(dataset.tag_to_idx.items(), key=lambda x: x[1]):
        f.write(f"{idx}\t{tag}\n")

print(f"Saved tag-to-index mapping to {output_path}")

Saved tag-to-index mapping to tag_to_index_mapping.txt


In [19]:
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score, precision_recall_curve
import numpy as np
import torch


# ======================
def find_best_thresholds(y_true, y_pred_proba):
    thresholds = []
    for i in range(y_true.shape[1]):
        p, r, t = precision_recall_curve(y_true[:, i], y_pred_proba[:, i])
        f1 = 2 * p * r / (p + r + 1e-8)
        if len(f1) > 0 and not np.all(np.isnan(f1)):
            thresholds.append(t[np.nanargmax(f1)])
        else:
            thresholds.append(0.2)  # Default fallback threshold
    return np.clip(np.array(thresholds, dtype=np.float32), 0.05, 0.95)


# ======================
history_fine_tuned = {
    "train_loss": [],
    "val_loss": [],
    "f1": [],
    "precision": [],
    "recall": [],
}


# ======================
num_epochs_fine_tune = 10
print(f" Starting fine-tuning for {num_epochs_fine_tune} epochs...\n")

for epoch in range(num_epochs_fine_tune):
    model.train()
    train_loss = 0.0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs_fine_tune} [Train]"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * imgs.size(0)

    train_loss /= len(train_loader.dataset)

    # ======================
    # Validation phase
    # ======================
    model.eval()
    all_probs, all_labels = [], []
    val_loss = 0.0

    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs_fine_tune} [Val]"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels.float())
            val_loss += loss.item() * imgs.size(0)

            probs = torch.sigmoid(outputs).cpu()
            all_probs.append(probs)
            all_labels.append(labels.cpu())

    all_probs = torch.cat(all_probs).numpy()
    all_labels = torch.cat(all_labels).numpy()
    val_loss /= len(val_loader.dataset)

    # Use static threshold (0.2) for validation evaluation
    preds = (all_probs > 0.2).astype(int)

    f1 = f1_score(all_labels, preds, average="samples", zero_division=0)
    precision = precision_score(all_labels, preds, average="samples", zero_division=0)
    recall = recall_score(all_labels, preds, average="samples", zero_division=0)

    print(f"Epoch [{epoch+1}/{num_epochs_fine_tune}] | "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"P: {precision:.4f} | R: {recall:.4f} | F1: {f1:.4f}")

    history_fine_tuned["train_loss"].append(train_loss)
    history_fine_tuned["val_loss"].append(val_loss)
    history_fine_tuned["precision"].append(precision)
    history_fine_tuned["recall"].append(recall)
    history_fine_tuned["f1"].append(f1)


print("\n Tuning thresholds on validation set...")
best_thresholds = find_best_thresholds(all_labels, all_probs)
np.save("best_thresholds.npy", best_thresholds)
print(" Best thresholds saved to best_thresholds.npy")


# ======================
model.eval()
all_probs, all_labels = [], []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Testing"):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = torch.sigmoid(outputs).cpu()
        all_probs.append(probs)
        all_labels.append(labels.cpu())

all_probs = torch.cat(all_probs).numpy()
all_labels = torch.cat(all_labels).numpy()

# Use tuned thresholds for test predictions
preds = (all_probs > best_thresholds).astype(int)

precision = precision_score(all_labels, preds, average="samples", zero_division=0)
recall = recall_score(all_labels, preds, average="samples", zero_division=0)
f1 = f1_score(all_labels, preds, average="samples", zero_division=0)

print(f"\n Final Test Results (with tuned thresholds)")
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


# ======================
torch.save(model.state_dict(), "model_danboruu_resnet50_finetuned.pth")
print("\n Model saved after fine-tuning: model_danboruu_resnet50_finetuned.pth")


 Starting fine-tuning for 10 epochs...



Epoch 1/10 [Train]: 100%|██████████| 4213/4213 [51:58<00:00,  1.35it/s]
Epoch 1/10 [Val]: 100%|██████████| 527/527 [02:16<00:00,  3.87it/s]


Epoch [1/10] | Train Loss: 0.0964 | Val Loss: 0.0894 | P: 0.5001 | R: 0.5430 | F1: 0.5000


Epoch 2/10 [Train]: 100%|██████████| 4213/4213 [52:11<00:00,  1.35it/s]
Epoch 2/10 [Val]: 100%|██████████| 527/527 [02:10<00:00,  4.05it/s]


Epoch [2/10] | Train Loss: 0.0893 | Val Loss: 0.0886 | P: 0.5062 | R: 0.5421 | F1: 0.5026


Epoch 3/10 [Train]: 100%|██████████| 4213/4213 [52:13<00:00,  1.34it/s]
Epoch 3/10 [Val]: 100%|██████████| 527/527 [02:15<00:00,  3.90it/s]


Epoch [3/10] | Train Loss: 0.0887 | Val Loss: 0.0866 | P: 0.5156 | R: 0.5384 | F1: 0.5060


Epoch 4/10 [Train]: 100%|██████████| 4213/4213 [52:12<00:00,  1.35it/s]
Epoch 4/10 [Val]: 100%|██████████| 527/527 [02:11<00:00,  4.00it/s]


Epoch [4/10] | Train Loss: 0.0883 | Val Loss: 0.0844 | P: 0.5028 | R: 0.5516 | F1: 0.5060


Epoch 5/10 [Train]: 100%|██████████| 4213/4213 [52:14<00:00,  1.34it/s]
Epoch 5/10 [Val]: 100%|██████████| 527/527 [02:14<00:00,  3.91it/s]


Epoch [5/10] | Train Loss: 0.0880 | Val Loss: 0.0859 | P: 0.5038 | R: 0.5524 | F1: 0.5071


Epoch 6/10 [Train]: 100%|██████████| 4213/4213 [52:20<00:00,  1.34it/s]
Epoch 6/10 [Val]: 100%|██████████| 527/527 [02:12<00:00,  3.97it/s]


Epoch [6/10] | Train Loss: 0.0877 | Val Loss: 0.0857 | P: 0.5092 | R: 0.5516 | F1: 0.5091


Epoch 7/10 [Train]: 100%|██████████| 4213/4213 [52:17<00:00,  1.34it/s]
Epoch 7/10 [Val]: 100%|██████████| 527/527 [02:19<00:00,  3.79it/s]


Epoch [7/10] | Train Loss: 0.0875 | Val Loss: 0.0874 | P: 0.5059 | R: 0.5523 | F1: 0.5075


Epoch 8/10 [Train]: 100%|██████████| 4213/4213 [52:15<00:00,  1.34it/s]
Epoch 8/10 [Val]: 100%|██████████| 527/527 [02:11<00:00,  4.00it/s]


Epoch [8/10] | Train Loss: 0.0873 | Val Loss: 0.0845 | P: 0.5074 | R: 0.5552 | F1: 0.5100


Epoch 9/10 [Train]: 100%|██████████| 4213/4213 [52:14<00:00,  1.34it/s]
Epoch 9/10 [Val]: 100%|██████████| 527/527 [02:12<00:00,  3.96it/s]


Epoch [9/10] | Train Loss: 0.0872 | Val Loss: 0.0844 | P: 0.5126 | R: 0.5508 | F1: 0.5104


Epoch 10/10 [Train]: 100%|██████████| 4213/4213 [52:19<00:00,  1.34it/s]
Epoch 10/10 [Val]: 100%|██████████| 527/527 [02:11<00:00,  3.99it/s]


Epoch [10/10] | Train Loss: 0.0870 | Val Loss: 0.0841 | P: 0.4916 | R: 0.5706 | F1: 0.5083

 Tuning thresholds on validation set...




 Best thresholds saved to best_thresholds.npy


Testing: 100%|██████████| 527/527 [02:26<00:00,  3.60it/s]



 Final Test Results (with tuned thresholds)
F1 Score: 0.4774
Precision: 0.4487
Recall: 0.5667

 Model saved after fine-tuning: model_danboruu_resnet50_finetuned.pth
