In [1]:
from contrastive_loss import ContrastiveLoss
from resnet_siamese import ResNet34
from make_dataset import SignaturePairsDataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import torch
from torch.optim import Adam
import torch.nn.functional as F
from tqdm.notebook import tqdm
import wandb
from torch.optim.lr_scheduler import StepLR
import numpy as np
from PIL import ImageOps

In [2]:
run = wandb.init(
    # Set the project where this run will be logged
    project="signet-reimplement",
    )

[34m[1mwandb[0m: Currently logged in as: [33mdieplstks[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.BILINEAR),
    ImageOps.invert,
    transforms.ToTensor(),
    # transforms.Lambda(lambda x: 1.0 - x),  # Invert the colors
    # transforms.Normalize((0,), (.078,)),  # Calculated
])

x_train_range = range(1, 51)
x_val_range = range(51, 56)
train_dataset = SignaturePairsDataset(
    originals_dir='signatures/signatures/full_org',
    forgeries_dir='signatures/signatures/full_forg',
    x_train_range=x_train_range,
    x_val_range=[],  # Empty range for validation data
    transform=transform
)

val_dataset = SignaturePairsDataset(
    originals_dir='signatures/signatures/full_org',
    forgeries_dir='signatures/signatures/full_forg',
    x_train_range=[],  # Empty range for training data
    x_val_range=x_val_range,
    transform=transform
)

# Create the DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)


In [5]:
net = ResNet34(in_channels=1, emb_size=128).to(device)
loss = ContrastiveLoss()
optimizer = Adam(net.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=4, gamma=0.5)

In [6]:
epochs = 100

for e in range(epochs):
    train_loader_with_tqdm = tqdm(train_loader, desc=f"Epoch {e+1}/{epochs} (Training)")
    for img1, img2, ys in train_loader_with_tqdm:
        img1 = img1.to(device)
        img2 = img2.to(device)
        ys = ys.to(device)
        o1, o2 = net(img1), net(img2)
        
        optimizer.zero_grad()
        l = loss(o1, o2, ys)
        l.backward()
        optimizer.step()
        
        wandb.log({"batch_loss": l.item()})

    
    with torch.no_grad():
        val_loader_with_tqdm = tqdm(val_loader, desc=f"Epoch {e+1}/{epochs} (Validation)")
        distances = []
        labels = []
    
        for img1, img2, ys in val_loader_with_tqdm:
            img1 = img1.to(device)
            img2 = img2.to(device)
            ys = ys.to(device)
            o1, o2 = net(img1), net(img2)
            distance_batch = F.pairwise_distance(o1, o2).cpu().numpy()
            distances.append(distance_batch)
            labels.append(ys.cpu().numpy())
    
        distances = np.concatenate(distances)
        labels = np.concatenate(labels)
    
        # Scan for the best threshold
        thresholds = np.linspace(start=0, stop=1, num=100)  # Example range, adjust as needed
        best_accuracy = 0
        best_threshold = 0

    for t in thresholds:
        predictions = (distances < t).astype(int)
        accuracy = (predictions == labels).mean()
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = t

    if accuracy < 0.51:
        break
    wandb.log({"best_validation_accuracy": best_accuracy, "best_threshold": best_threshold})
    print(f'Epoch {e+1}/{epochs}, Best Validation Accuracy: {best_accuracy:.4f} at Threshold: {best_threshold}')

    

Epoch 1/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 1/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 1/100, Best Validation Accuracy: 0.7522 at Threshold: 0.05050505050505051


Epoch 2/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 2/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 2/100, Best Validation Accuracy: 0.8880 at Threshold: 0.13131313131313133


Epoch 3/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 3/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 3/100, Best Validation Accuracy: 0.9156 at Threshold: 0.08080808080808081


Epoch 4/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 4/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 4/100, Best Validation Accuracy: 0.9210 at Threshold: 0.07070707070707072


Epoch 5/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 5/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 5/100, Best Validation Accuracy: 0.9279 at Threshold: 0.04040404040404041


Epoch 6/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 6/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 6/100, Best Validation Accuracy: 0.9431 at Threshold: 0.030303030303030304


Epoch 7/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 7/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 7/100, Best Validation Accuracy: 0.9337 at Threshold: 0.030303030303030304


Epoch 8/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 8/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 8/100, Best Validation Accuracy: 0.9391 at Threshold: 0.030303030303030304


Epoch 9/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 9/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 9/100, Best Validation Accuracy: 0.9326 at Threshold: 0.030303030303030304


Epoch 10/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 10/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 10/100, Best Validation Accuracy: 0.9417 at Threshold: 0.030303030303030304


Epoch 11/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 11/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 11/100, Best Validation Accuracy: 0.9413 at Threshold: 0.020202020202020204


Epoch 12/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 12/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 12/100, Best Validation Accuracy: 0.9341 at Threshold: 0.030303030303030304


Epoch 13/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 13/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 13/100, Best Validation Accuracy: 0.9395 at Threshold: 0.030303030303030304


Epoch 14/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 14/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 14/100, Best Validation Accuracy: 0.9366 at Threshold: 0.05050505050505051


Epoch 15/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 15/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 15/100, Best Validation Accuracy: 0.9420 at Threshold: 0.030303030303030304


Epoch 16/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 16/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 16/100, Best Validation Accuracy: 0.9558 at Threshold: 0.030303030303030304


Epoch 17/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 17/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 17/100, Best Validation Accuracy: 0.9486 at Threshold: 0.020202020202020204


Epoch 18/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 18/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 18/100, Best Validation Accuracy: 0.9399 at Threshold: 0.04040404040404041


Epoch 19/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 19/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 19/100, Best Validation Accuracy: 0.9409 at Threshold: 0.020202020202020204


Epoch 20/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 20/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 20/100, Best Validation Accuracy: 0.9902 at Threshold: 0.09090909090909091


Epoch 21/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 21/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 21/100, Best Validation Accuracy: 0.9732 at Threshold: 0.020202020202020204


Epoch 22/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 22/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 22/100, Best Validation Accuracy: 0.9717 at Threshold: 0.04040404040404041


Epoch 23/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 23/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 23/100, Best Validation Accuracy: 0.9667 at Threshold: 0.020202020202020204


Epoch 24/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 24/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 24/100, Best Validation Accuracy: 0.9667 at Threshold: 0.030303030303030304


Epoch 25/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 25/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 25/100, Best Validation Accuracy: 0.9688 at Threshold: 0.020202020202020204


Epoch 26/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 26/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 26/100, Best Validation Accuracy: 0.9623 at Threshold: 0.020202020202020204


Epoch 27/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 27/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 27/100, Best Validation Accuracy: 0.9645 at Threshold: 0.020202020202020204


Epoch 28/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 28/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 28/100, Best Validation Accuracy: 0.5681 at Threshold: 0.494949494949495


Epoch 29/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 29/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 29/100, Best Validation Accuracy: 0.5913 at Threshold: 0.04040404040404041


Epoch 30/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 30/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 30/100, Best Validation Accuracy: 0.6178 at Threshold: 0.030303030303030304


Epoch 31/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 31/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 31/100, Best Validation Accuracy: 0.6033 at Threshold: 0.04040404040404041


Epoch 32/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 32/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 32/100, Best Validation Accuracy: 0.6228 at Threshold: 0.020202020202020204


Epoch 33/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 33/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 33/100, Best Validation Accuracy: 0.6025 at Threshold: 0.13131313131313133


Epoch 34/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 34/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 34/100, Best Validation Accuracy: 0.6149 at Threshold: 0.26262626262626265


Epoch 35/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 35/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 35/100, Best Validation Accuracy: 0.6159 at Threshold: 0.21212121212121213


Epoch 36/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

Epoch 36/100 (Validation):   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 36/100, Best Validation Accuracy: 0.6123 at Threshold: 0.31313131313131315


Epoch 37/100 (Training):   0%|          | 0/216 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [7]:
wandb.finish()

VBox(children=(Label(value='0.004 MB of 0.020 MB uploaded\r'), FloatProgress(value=0.22026816945200156, max=1.…

0,1
batch_loss,▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁
best_threshold,▁▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁█▁▁▁▁▃▅▄▅
best_validation_accuracy,▄▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████▁▁▂▂▂▂▂▂▂

0,1
batch_loss,0.00068
best_threshold,0.31313
best_validation_accuracy,0.61232


In [8]:
print(accuracy)

0.5224637681159421
