In [50]:
import os
import re
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from facenet_pytorch import InceptionResnetV1
from PIL import Image


In [51]:
class FaceDataset(Dataset):
    def __init__(self, root_dir, transform=None, max_classes=10):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        label_names = []

        for fname in sorted(os.listdir(root_dir)):
            if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                match = re.search(r'@([^@]+)@', fname)
                if match:
                    label = match.group(1)
                    if label not in label_names:
                        label_names.append(label)

        label_names = label_names[:max_classes]
        self.class_to_idx = {name: idx for idx, name in enumerate(label_names)}
        self.idx_to_class = {idx: name for name, idx in self.class_to_idx.items()}

        for fname in sorted(os.listdir(root_dir)):
            if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                match = re.search(r'@([^@]+)@', fname)
                if match:
                    label = match.group(1)
                    if label in self.class_to_idx:
                        path = os.path.join(root_dir, fname)
                        self.samples.append((path, self.class_to_idx[label]))
        self.labels = label_names

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [52]:
transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
])

# limit do 5 klas - długo zajmuje trening :c
dataset = FaceDataset(root_dir='data/', transform=transform, max_classes=5)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [53]:
model = InceptionResnetV1(pretrained='vggface2', classify=True, num_classes=len(dataset.class_to_idx))

for param in model.parameters():
    param.requires_grad = False

for param in model.logits.parameters():
    param.requires_grad = True
for param in model.last_linear.parameters():
    param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [54]:
dataset.labels

['N08_identity_4',
 'N00_identity_14',
 'N00_identity_11',
 'N00_identity_0',
 'N04_identity_5']

In [55]:
len(dataset.samples)

9096

In [56]:
for epoch in range(5):
    total_loss = 0
    correct = 0
    model.train()
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()

    accuracy = correct / len(dataset)
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}")

Epoch 1, Loss: 703.5246, Accuracy: 0.6266
Epoch 2, Loss: 413.7717, Accuracy: 0.8674
Epoch 3, Loss: 294.0270, Accuracy: 0.9101
Epoch 4, Loss: 246.8806, Accuracy: 0.9140
Epoch 5, Loss: 219.9734, Accuracy: 0.9171


In [57]:
torch.save(model.state_dict(), 'facenet_classifier.pth')


In [59]:
model.load_state_dict(torch.load('facenet_classifier.pth'))
model.eval()

def predict(image_path, model, transform, class_to_idx):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image)
        pred = output.argmax(1).item()
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    return idx_to_class[pred]

predicted_label = predict('data/12078789@N00_identity_0@261065930_1.jpg', model, transform, dataset.class_to_idx)
print(f'Predicted label: {predicted_label}')


Predicted label: N00_identity_0
