In [1]:
# !pip install -U transformers datasets tqdm

In [1]:
import torch, torch.nn as nn
from transformers import AutoImageProcessor, AutoModel
from PIL import Image

In [2]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

model_repo = "facebook/dinov3-vits16-pretrain-lvd1689m"
processor = AutoImageProcessor.from_pretrained(model_repo)
backbone = AutoModel.from_pretrained(model_repo)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
from utils import load_nsfwdataset, test_accuracy

In [4]:
class DinoV3Linear(nn.Module):
    def __init__(self, backbone: AutoModel, hidden_size: int, num_classes: int, freeze_backbone: bool = True):
        super().__init__()
        self.backbone = backbone
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False
            self.backbone.eval()

        self.head = nn.Linear(hidden_size, num_classes)

    def forward(self, pixel_values):
        outputs = self.backbone(pixel_values=pixel_values)
        last_hidden = outputs.last_hidden_state
        cls = last_hidden[:, 0]
        logits = self.head(cls)
        return logits

    def count_params(self):
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total_params, trainable_params

    @property
    def device(self):
        return next(self.parameters()).device
        
# Setup Model
hidden_size = getattr(backbone.config, "hidden_size", None)
model = DinoV3Linear(backbone, hidden_size, num_classes=2, freeze_backbone=True).to(device) 
total_params, trainable_params = model.count_params()

# Setup Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0004)

# Load Dataset
ds_train, ds_test = load_nsfwdataset(processor, batch_size_train = 128)

print(f"classifier model has {total_params/1e6:.2f}M parameters ({trainable_params} trainable)")

classifier model has 21.60M parameters (770 trainable)


In [5]:
# Training Loop
step = 0

for epoch in range(5):
    for images, labels in ds_train:
        images, labels = images.to(device), torch.Tensor(labels).to(device).long()
        logits = model(images)
        loss = nn.functional.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        if step % 10 == 0:
            print(f"step {step} (epoch {epoch}), loss {loss.item():.2f}")
        if step % 100 == 0:
            model.eval()
            acc = test_accuracy(ds_test, model)
            print(f"step {step} (epoch {epoch}), accuracy {acc:.2f}%")
            model.train()
    
        step += 1
        # break

step 0 (epoch 0), loss 0.81


Test accuray: 100%|██████████| 14/14 [00:01<00:00,  8.33it/s]


step 0 (epoch 0), accuracy 31.93%
step 10 (epoch 0), loss 0.63
step 20 (epoch 0), loss 0.54
step 30 (epoch 0), loss 0.43
step 40 (epoch 0), loss 0.39
step 50 (epoch 0), loss 0.33
step 60 (epoch 0), loss 0.28
step 70 (epoch 1), loss 0.25
step 80 (epoch 1), loss 0.23
step 90 (epoch 1), loss 0.21
step 100 (epoch 1), loss 0.19


Test accuray: 100%|██████████| 14/14 [00:01<00:00,  8.33it/s]


step 100 (epoch 1), accuracy 98.48%
step 110 (epoch 1), loss 0.20
step 120 (epoch 1), loss 0.18
step 130 (epoch 2), loss 0.17
step 140 (epoch 2), loss 0.16
step 150 (epoch 2), loss 0.12
step 160 (epoch 2), loss 0.12
step 170 (epoch 2), loss 0.12
step 180 (epoch 2), loss 0.11
step 190 (epoch 3), loss 0.09
step 200 (epoch 3), loss 0.11


Test accuray: 100%|██████████| 14/14 [00:01<00:00,  8.29it/s]


step 200 (epoch 3), accuracy 99.88%
step 210 (epoch 3), loss 0.09
step 220 (epoch 3), loss 0.08
step 230 (epoch 3), loss 0.08
step 240 (epoch 3), loss 0.08
step 250 (epoch 4), loss 0.07
step 260 (epoch 4), loss 0.08
step 270 (epoch 4), loss 0.07
step 280 (epoch 4), loss 0.07
step 290 (epoch 4), loss 0.07
step 300 (epoch 4), loss 0.06


Test accuray: 100%|██████████| 14/14 [00:01<00:00,  8.37it/s]


step 300 (epoch 4), accuracy 100.00%
