In [21]:
!pip install wandb -qU
import wandb
wandb.login()

In [22]:
from google.colab import drive
drive.mount('/content/drive')

# # Unzip data
!unzip /content/drive/MyDrive/generated_images_10Kids_cropped.zip -d my_data

In [23]:

!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu121


In [24]:
!pip install tqdm torchsummary

Defaulting to user installation because normal site-packages is not writeable


In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
from tqdm import tqdm
import random
from torchsummary import summary

In [26]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, padding=0)
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=0)
        self.bn3 = nn.BatchNorm2d(32)

        # Fully connected layers
        self.fc1 = nn.Linear(32 * 12 * 12, 41)  # Updated to 32 * 12 * 12
        self.fc1komma5 = nn.Linear(41,32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 1)
        self.dropout = nn.Dropout(0.5)

    def forward_one(self, x):
        x = F.relu(self.conv1(x)) # 8 * 112 * 112
        x = F.max_pool2d(x, 2)  # output size: (8, 56, 56)
        x = F.relu(self.conv2(x)) # 16* 52 * 52
        x = F.max_pool2d(x, 2)  # output size: (16, 26, 26)
        x = F.relu(self.conv3(x)) # 32 * 24 * 24
        x = F.max_pool2d(x, 2)  # output size: (32, 12, 12)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc1komma5(x))
        x = F.relu(self.fc2(x))
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        distance = torch.abs(output1 - output2)
        output = self.fc3(distance)
        return output

In [27]:
class FaceDataset(Dataset):
    def __init__(self, image_folder, people_dirs, transform=None):
        self.image_folder = image_folder
        self.people_dirs = people_dirs
        self.transform = transform
        self.image_pairs = []
        self.labels = []
        self._prepare_data()

    def _prepare_data(self):
        for person_dir in self.people_dirs:
            person_path = os.path.join(self.image_folder, person_dir)
            images = os.listdir(person_path)
            for i in range(len(images)):
                for j in range(i + 1, len(images)):
                    self.image_pairs.append((os.path.join(person_path, images[i]), os.path.join(person_path, images[j])))
                    self.labels.append(1)

                    # Add negative samples
                    neg_person = person_dir
                    while neg_person == person_dir:
                        neg_person = random.choice(self.people_dirs)

                    neg_images = os.listdir(os.path.join(self.image_folder, neg_person))
                    random_image_index = random.randrange(start=0, stop=len(neg_images))
                    self.image_pairs.append((os.path.join(person_path, images[i]), os.path.join(self.image_folder, neg_person, neg_images[random_image_index])))
                    self.labels.append(0)

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        img1_path, img2_path = self.image_pairs[idx]
        label = self.labels[idx]
        img1 = Image.open(img1_path).convert('L')
        img2 = Image.open(img2_path).convert('L')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, torch.tensor(label, dtype=torch.float32)

# Function to split dataset
def split_dataset(image_folder, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    people_dirs = os.listdir(image_folder)
    random.shuffle(people_dirs)

    train_end = int(train_ratio * len(people_dirs))
    val_end = train_end + int(val_ratio * len(people_dirs))

    train_dirs = people_dirs[:train_end]
    val_dirs = people_dirs[train_end:val_end]
    test_dirs = people_dirs[val_end:]

    return train_dirs, val_dirs, test_dirs

# Initialize wandb
wandb.init(project='face-recognition-steffen')

# Hyperparameters and setup
batch_size = 512
learning_rate = 0.02
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wandb.config.update({
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "epochs": epochs,
    "device": str(device)
})

print(f'Batch size: {batch_size}')
print(f'LR: {learning_rate}')
print(f'Epochs: {epochs}')
print(f'Device: {device}')

# Data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor()
])

# Load dataset
image_folder = 'my_data/generated_images_10Kids_cropped'  # Update with the path to your dataset
train_dirs, val_dirs, test_dirs = split_dataset(image_folder)

train_dataset = FaceDataset(image_folder, train_dirs, transform=transform)
val_dataset = FaceDataset(image_folder, val_dirs, transform=transform)
test_dataset = FaceDataset(image_folder, test_dirs, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Model, loss, and optimizer
model = SiameseNetwork().to(device)
summary(model, [(1, 112, 112), (1, 112, 112)])
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
scaler = torch.cuda.amp.GradScaler()

# Training script with validation
def train(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    accumulation_steps = 4
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
            for i,(img1, img2, label) in enumerate(train_loader):
                img1, img2, label = img1.to(device), img2.to(device), label.to(device)
                with torch.cuda.amp.autocast():
                    outputs = model(img1, img2).squeeze()
                    loss = criterion(outputs, label)
                    loss = loss / accumulation_steps
                scaler.scale(loss).backward()
                if (i+1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                running_loss += loss.item() * accumulation_steps
                pbar.set_postfix(loss=running_loss / (pbar.n + 1))
                pbar.update(1)
        scheduler.step()
        val_loss, val_accuracy = evaluate(model, val_loader, criterion)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {running_loss/len(train_loader)}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": running_loss / len(train_loader),
            "val_loss": val_loss,
            "val_accuracy": val_accuracy
        })

        # Save the model
        torch.save(model.state_dict(), f'networks/network_epoch{epoch}.pth')

# Evaluation function
def evaluate(model, data_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for img1, img2, label in data_loader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            with torch.cuda.amp.autocast():
                outputs = model(img1, img2).squeeze()
                loss = criterion(outputs, label)
            running_loss += loss.item()
            predicted = (outputs > 0).float()
            correct += (predicted == label).sum().item()
            total += label.size(0)
    accuracy = correct / total
    return running_loss / len(data_loader), accuracy

# Train the model
train(model, train_loader, val_loader, criterion, optimizer, epochs=epochs)

# Evaluate on test set
test_loss, test_accuracy = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss}, Test Accuracy: {test_accuracy}')

# Log final test metrics to wandb
wandb.log({
    "test_loss": test_loss,
    "test_accuracy": test_accuracy
})

# Finish wandb run
wandb.finish()

Batch size: 128
LR: 0.01
Epochs: 5
Device: cpu




----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 112, 112]              80
            Conv2d-2           [-1, 16, 52, 52]           3,216
            Conv2d-3           [-1, 32, 24, 24]           4,640
            Linear-4                   [-1, 41]         188,969
            Linear-5                   [-1, 32]           1,344
            Linear-6                   [-1, 16]             528
            Conv2d-7          [-1, 8, 112, 112]              80
            Conv2d-8           [-1, 16, 52, 52]           3,216
            Conv2d-9           [-1, 32, 24, 24]           4,640
           Linear-10                   [-1, 41]         188,969
           Linear-11                   [-1, 32]           1,344
           Linear-12                   [-1, 16]             528
           Linear-13                    [-1, 1]              17
Total params: 397,571
Trainable params:

Epoch 1/5:   1%|▏         | 285/20781 [01:10<1:25:02,  4.02batch/s, loss=0.687]


KeyboardInterrupt: 